# This is a notebook based on the swin-moe.py

In [5]:
from tutel import system

import os
import time
import json
import random
import argparse
import datetime
import numpy as np
from functools import partial
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist

from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.utils import accuracy, AverageMeter

from config import get_config
from models import build_model
from data import build_loader
from lr_scheduler import build_scheduler
from optimizer import build_optimizer
from logger import create_logger
from utils import NativeScalerWithGradNormCount, reduce_tensor
from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad



In [6]:
try:
    from tutel import moe as tutel_moe
except:
    tutel_moe = None
    print("Tutel has not been installed. To use Swin-MoE, please install Tutel; otherwise, just ignore this.")
    assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"

In [14]:
def parse_option():
    parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
    parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
    parser.add_argument(
        "--opts",
        help="Modify config options by adding 'KEY VALUE' pairs. ",
        default=None,
        nargs='+',
    )

    # easy config modification
    parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
    parser.add_argument('--data-path', type=str, help='path to dataset')
    parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
    parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
                        help='no: no cache, '
                             'full: cache all data, '
                             'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
    parser.add_argument('--pretrained',
                        help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
    parser.add_argument('--resume', help='resume from checkpoint')
    parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
    parser.add_argument('--use-checkpoint', action='store_true',
                        help="whether to use gradient checkpointing to save memory")
    parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
    parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
                        help='mixed precision opt level, if O0, no amp is used (deprecated!)')
    parser.add_argument('--output', default='output', type=str, metavar='PATH',
                        help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
    parser.add_argument('--tag', help='tag of experiment')
    parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
    parser.add_argument('--throughput', action='store_true', help='Test throughput only')

    # distributed training
    parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')

    args, unparsed = parser.parse_known_args([])

    config = get_config(args)

    return args, config

## Gshard-loss

In [13]:
import torch
from torch.distributions.normal import Normal

def _one_hot_with_dtype(data, num_classes, dtype, hot_value=1):
    result = torch.zeros([data.size(0), num_classes], device=data.device, dtype=dtype)
    result.scatter_(1, data.unsqueeze(-1), hot_value)
    return result

def gshard_loss(scores_w_noise, top_ids):
    num_samples, num_global_experts = int(scores_w_noise.size(0)), int(scores_w_noise.size(1))
    mask = _one_hot_with_dtype(top_ids[:, 0], num_global_experts, dtype=scores_w_noise.dtype,
        hot_value=num_global_experts / num_samples)
    me = torch.sum(scores_w_noise, dim=0)
    ce = torch.sum(mask, dim=0)
    l_aux = torch.sum(me * ce) / num_samples
    return l_aux

def load_importance_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise):
    def load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise):
        assert gate_noise > 0, "`gate_noise` must be > 0 for normalization in load_importance_loss()."
        normal = Normal(
            torch.tensor([0.0], device=scores_wo_noise.device),
            torch.tensor([gate_noise / num_global_experts], device=scores_wo_noise.device),
        )
        threshold = topk_logits[:, -1].view(-1, 1).float()
        diff = scores_wo_noise.float() - threshold.float()
        prob = normal.cdf(diff)
        Load = prob.sum(0)
        l_load = Load.float().var() / (Load.float().mean() ** 2 + 1e-10)
        return l_load

    def importance_loss(scores_wo_noise):
        Impi = scores_wo_noise.float().sum(0)
        l_imp = Impi.float().var() / (Impi.float().mean() ** 2 + 1e-10)

        return l_imp

    l_imp = importance_loss(scores_wo_noise)
    l_load = load_loss(scores_wo_noise, topk_logits, num_global_experts, gate_noise)
    return (l_imp + l_load) / 2.0

In [None]:
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '8888'
world_size = 8
rank = 0

dist.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)

In [8]:
from lr_scheduler import build_scheduler

In [15]:
args, config = parse_option()

usage: Swin Transformer training and evaluation script --cfg FILE
                                                       [--opts OPTS [OPTS ...]]
                                                       [--batch-size BATCH_SIZE]
                                                       [--data-path DATA_PATH]
                                                       [--zip]
                                                       [--cache-mode {no,full,part}]
                                                       [--pretrained PRETRAINED]
                                                       [--resume RESUME]
                                                       [--accumulation-steps ACCUMULATION_STEPS]
                                                       [--use-checkpoint]
                                                       [--disable_amp]
                                                       [--amp-opt-level {O0,O1,O2}]
                                                       [--output P

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [16]:
config

Available objects for config:
     AliasManager
     DisplayFormatter
     HistoryManager
     IPCompleter
     IPKernelApp
     LoggingMagics
     MagicsManager
     OSMagics
     PrefilterManager
     ScriptMagics
     StoreMagics
     ZMQInteractiveShell


In [17]:
??build_scheduler

[0;31mSignature:[0m [0mbuild_scheduler[0m[0;34m([0m[0mconfig[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0mn_iter_per_epoch[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m <no docstring>
[0;31mSource:[0m   
[0;32mdef[0m [0mbuild_scheduler[0m[0;34m([0m[0mconfig[0m[0;34m,[0m [0moptimizer[0m[0;34m,[0m [0mn_iter_per_epoch[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mnum_steps[0m [0;34m=[0m [0mint[0m[0;34m([0m[0mconfig[0m[0;34m.[0m[0mTRAIN[0m[0;34m.[0m[0mEPOCHS[0m [0;34m*[0m [0mn_iter_per_epoch[0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0mwarmup_steps[0m [0;34m=[0m [0mint[0m[0;34m([0m[0mconfig[0m[0;34m.[0m[0mTRAIN[0m[0;34m.[0m[0mWARMUP_EPOCHS[0m [0;34m*[0m [0mn_iter_per_epoch[0m[0;34m)[0m[0;34m[0m
[0;34m[0m    [0mdecay_steps[0m [0;34m=[0m [0mint[0m[0;34m([0m[0mconfig[0m[0;34m.[0m[0mTRAIN[0m[0;34m.[0m[0mLR_SCHEDULER[0m[0;34m.[0m[0mDECAY_EPOCHS[0m [0;34m*[0m [0mn_i