# SK-Hynix Project Code - Pretrain

### SK-Hynix 프로젝트에서 진행한 연구의 실험 중, encoder 학습코드를 정리한 Jupyter notebook 파일입니다.    


####  1. 'Mixup 기법으로 얻은 representation을 contrastive task에 사용하는 것'이 본 연구의 핵심 아이디어입니다.

#### 2. 현재는 [MoCo](https://openaccess.thecvf.com/content_CVPR_2020/papers/He_Momentum_Contrast_for_Unsupervised_Visual_Representation_Learning_CVPR_2020_paper.pdf) 논문을 기반으로 아이디어를 구현했습니다.

#### 3. 아쉽게도 multi-gpu 실험은 현 jupyter notebook에서는 불가능합니다. 각 함수들의 기능만 봐주시면 될 것 같습니다.

In [1]:
# basic library setting
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.optim
import torch.multiprocessing as mp
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import Dataset

import os, math, random, time, shutil, builtins, argparse, warnings, json, glob
import numpy as np
from PIL import Image, ImageFilter

### [Section 1] 데이터 불러오기

#### CIFAR-10과 CIFAR-100 데이터의 경우, torchvision.datasets library에서 받아옵니다.
#### 반면 Tiny-ImageNet 실험의 경우, 현 directory 안에 data/tiny-imagenet-200 이 저장되어 있어야합니다. Tiny-ImageNet-200의 경우 [Tiny-ImageNet](https://tiny-imagenet.herokuapp.com/)에서 다운로드 가능합니다.

### <span style="color:red">Hynix 데이터에 적용하기 위해선 다음 두가지의 코드 구현이 필요해보입니다.</span>
- **pytorch library의 DataLoader에 데이터를 옮기는 코드**
- **데이터에 맞는 augmentation 코드**

In [2]:
# CIFAR-10, CIFAR-100 dataset
from torchvision.datasets import CIFAR10, CIFAR100


# Tiny-ImageNet dataset
EXTENSION = 'JPEG'
NUM_IMAGES_PER_CLASS = 500
CLASS_LIST_FILE = 'wnids.txt'
VAL_ANNOTATION_FILE = 'val_annotations.txt'

class TinyImageNet(Dataset):
    """Tiny ImageNet data set available from `http://cs231n.stanford.edu/tiny-imagenet-200.zip`.
    Parameters
    ----------
    root: string
        Root directory including `train`, `test` and `val` subdirectories.
    split: string
        Indicating which split to return as a data set.
        Valid option: [`train`, `test`, `val`]
    transform: torchvision.transforms
        A (series) of valid transformation(s).
    in_memory: bool
        Set to True if there is enough memory (about 5G) and want to minimize disk IO overhead.
    """
    def __init__(self, root, train=True, transform=None, target_transform=None, in_memory=False, download=False):
        self.root = os.path.expanduser(root)
        self.train = train
        self.split = 'train' if train else 'val'
        self.transform = transform
        self.target_transform = target_transform
        self.in_memory = in_memory
        self.split_dir = os.path.join(root, self.split)
        self.image_paths = sorted(glob.iglob(os.path.join(self.split_dir, '**', '*.%s' % EXTENSION), recursive=True))
        self.labels = {}  # fname - label number mapping
        self.images = []  # used for in-memory processing

        # build class label - number mapping
        with open(os.path.join(self.root, CLASS_LIST_FILE), 'r') as fp:
            self.label_texts = sorted([text.strip() for text in fp.readlines()])
        self.label_text_to_number = {text: i for i, text in enumerate(self.label_texts)}

        if self.split == 'train':
            for label_text, i in self.label_text_to_number.items():
                for cnt in range(NUM_IMAGES_PER_CLASS):
                    self.labels['%s_%d.%s' % (label_text, cnt, EXTENSION)] = i
        elif self.split == 'val':
            with open(os.path.join(self.split_dir, VAL_ANNOTATION_FILE), 'r') as fp:
                for line in fp.readlines():
                    terms = line.split('\t')
                    file_name, label_text = terms[0], terms[1]
                    self.labels[file_name] = self.label_text_to_number[label_text]

        # read all images into torch tensor in memory to minimize disk IO overhead
        if self.in_memory:
            self.images = [self.read_image(path) for path in self.image_paths]

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        file_path = self.image_paths[index]

        if self.in_memory:
            img = self.images[index]
        else:
            img = self.read_image(file_path)

        if self.split == 'test':
            return img
        else:
            # file_name = file_path.split('/')[-1]
            return img, self.labels[os.path.basename(file_path)]

    def __repr__(self):
        fmt_str = 'Dataset ' + self.__class__.__name__ + '\n'
        fmt_str += '    Number of datapoints: {}\n'.format(self.__len__())
        tmp = self.split
        fmt_str += '    Split: {}\n'.format(tmp)
        fmt_str += '    Root Location: {}\n'.format(self.root)
        tmp = '    Transforms (if any): '
        fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        tmp = '    Target Transforms (if any): '
        fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp)))
        return fmt_str

    def read_image(self, path):
        img = Image.open(path).convert('RGB')
        return self.transform(img) if self.transform else img

#### Augmentation 관련 코드입니다.
- 다른 augmentation이 적용된 query와 key를 뽑기 위한 **TwoCropsTransform** 코드
- MoCo 방법론에서 사용했던 **GaussianBlur**(SimCLR paper에서 사용한 것을 차용) augmentation 코드 

In [3]:
# Augmentation
class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x

#### Encoder 학습을 위해 선택한 데이터셋을 pytorch library의 DataLoader로 옮기는 코드입니다.

In [4]:
# Data loader
DATASETS = {'cifar10': CIFAR10, 'cifar100': CIFAR100, 'tiny-imagenet': TinyImageNet}
MEAN = {'cifar10': [0.4914, 0.4822, 0.4465], 'cifar100': [0.5071, 0.4867, 0.4408], 'tiny-imagenet': [0.485, 0.456, 0.406]}
STD = {'cifar10': [0.2023, 0.1994, 0.2010], 'cifar100':[0.2675, 0.2565, 0.2761], 'tiny-imagenet': [0.229, 0.224, 0.225]}

def data_loader(dataset, data_path, batch_size, num_workers, download=False, distributed=True, aug_plus=True):
    normalize = transforms.Normalize(MEAN[dataset], STD[dataset])

    if aug_plus:
        # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
        augmentation = [
            transforms.RandomApply([
                transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
            ], p=0.8),
            transforms.RandomGrayscale(p=0.2),
            transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]
    else:
        # MoCo v1's aug: the same as InstDisc https://arxiv.org/abs/1805.01978
        augmentation = [
            transforms.RandomGrayscale(p=0.2),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize
        ]

    augmentation.insert(0, transforms.RandomResizedCrop(224, scale=(0.2, 1.)))

    train_transform = TwoCropsTransform(transforms.Compose(augmentation))
    train_dataset = DATASETS[dataset](data_path, train=True, download=download, transform=train_transform)

    # for distributed learning
    train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if distributed else None

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=batch_size, shuffle=(train_sampler is None),
        num_workers=num_workers, pin_memory=True, sampler=train_sampler, drop_last=True)
    
    return train_loader, train_sampler

### [Section 2] 모델 불러오기

#### MoCo 방법론을 기반으로 코드를 구현했고, MoCo나 MixCo에 관한 argument 설정을 통해 두 방법론을 사용하실 수 있습니다.
#### Encoder의 기본 model로는 ResNet을 사용하였습니다.


### <span style="color:red">다양한 모델 적용을 위해선, pytorch에 구현된 모델 코드들을 가져와 사용하시면 될 것 같습니다.</span>

In [5]:
# ResNet code
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Block(nn.Module):
    __constants__ = ['downsample']
    
    def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 
                 base_width=64, dilation=1, norm_layer=None, num_splits=64, expansion=1, block_type='basic'):
        super(Block, self).__init__()
        if block_type not in ['basic', 'bottleneck']:
            raise ValueError('Block_Type only supports basic and bottleneck')
        self.block_type = block_type
        self.expansion = expansion
        
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
            self.split = False
        else:
            self.split = True
            
        if block_type == 'basic':
            if groups != 1 or base_width != 64:
                raise ValueError('BasicBlock only supports groups=1 and base_width=64')
            if dilation > 1:
                raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
        
        width = int(planes * (base_width / 64.)) * groups
        # Both conv3*3 with stride and self.downsample layers downsample the input when stride != 1
        if block_type == 'basic':
            self.conv1 = conv3x3(inplanes, width, stride)
            self.conv2 = conv3x3(width, width)
            
        if block_type == 'bottleneck':
            self.conv1 = conv1x1(inplanes, width)
            self.conv2 = conv3x3(width, width, stride, groups, dilation)
            self.conv3 = conv1x1(width, planes * self.expansion)
            self.bn3 = norm_layer(planes * self.expansion) if not self.split else norm_layer(planes * self.expansion, num_splits)
            
        self.bn1 = norm_layer(width) if not self.split else norm_layer(width, num_splits)
        self.bn2 = norm_layer(width) if not self.split else norm_layer(width, num_splits)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        
    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        if self.block_type == 'bottleneck':
            out = self.relu(out)
            
            out = self.conv3(out)
            out = self.bn3(out)
            
        if self.downsample is not None:
            identity = self.downsample(x)
            
        out += identity
        out = self.relu(out)
        
        return out
    
    
class ResNet(nn.Module):
    BasicBlock_arch = ['resnet10', 'resnet18', 'resnet34']
    Bottleneck_arch = ['resnet50', 'resnet101', 'resnet152', 'resnext50_32x4d', 'resnext101_32x8d', 
                      'wide_resnet50_2', 'wide_resnet101_2']

    def __init__(self, arch, repeats, num_classes=100, zero_init_residual=False,
                 groups=1, width_per_group=64, norm_layer=None, num_splits=None):
        super(ResNet, self).__init__()
        self.split = False if norm_layer is None else True
        self._norm_layer = nn.BatchNorm2d if norm_layer is None else norm_layer
        self.num_splits = num_splits
        self.inplanes = 64
        self.dilation = 1
        self.groups = groups
        self.base_width = width_per_group
        if arch in self.BasicBlock_arch:
            self.expansion = 1
            self.block_type = 'basic'
        elif arch in self.Bottleneck_arch:
            self.expansion = 4
            self.block_type = 'bottleneck'
        else:
            raise NotImplementedError('%s arch is not supported in ResNet' % arch)
        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
                                   bias=False)
        self.bn1 = self._norm_layer(self.inplanes) if not self.split else self._norm_layer(self.inplanes, self.num_splits)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)     
            
        planes = [64, 128, 256, 512]
        # self.planes attributes is needed to match with EP_module channels
        self.planes = [p * self.expansion for p in planes]
        strides = [1, 2, 2, 2]
        self.block_layers = self._make_layer(planes, repeats, strides)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(planes[-1] * self.expansion, num_classes)

        # weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        # Zero-initialize the last BN in each residual branch,
        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
        if zero_init_residual:
            for m in self.modules():
                if isinstance(m, Block):
                    if self.block_type == 'basic':
                        nn.init.constant_(m.bn2.weight, 0)
                    elif self.block_type == 'bottleneck':
                        nn.init.constant_(m.bn3.weight, 0)
                        
    def _make_layer(self, planes, repeats, strides):
        assert len(planes) == len(repeats) == len(strides) == 4, 'Number of Block should be 4'
        
        block_layers = []
        norm_layer = self._norm_layer
        for i in  range(4):
            plane = planes[i]
            repeat = repeats[i]
            stride = strides[i]
            
            downsample = None
            if stride != 1 or self.inplanes != plane * self.expansion:
                if not self.split:
                    downsample = nn.Sequential(
                        conv1x1(self.inplanes, plane * self.expansion, stride),
                        norm_layer(plane * self.expansion),
                    )
                else:
                    downsample = nn.Sequential(
                        conv1x1(self.inplanes, plane * self.expansion, stride),
                        norm_layer(plane * self.expansion, self.num_splits),
                    )

            layers = []
            layers.append(Block(self.inplanes, plane, stride, downsample, self.groups,
                                self.base_width, self.dilation, norm_layer, self.num_splits,
                                self.expansion, self.block_type))
            self.inplanes = plane * self.expansion
            for _ in range(1, repeat):
                layers.append(Block(self.inplanes, plane, groups=self.groups,
                                    base_width=self.base_width, dilation=self.dilation,
                                    norm_layer=norm_layer, num_splits=self.num_splits,
                                    expansion=self.expansion, block_type=self.block_type))
            block_layers.append(nn.Sequential(*layers))
        
        return nn.Sequential(*block_layers)
    
    def conv_stem(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        return x
    
    def pool_linear(self, x):
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        return x
        
    def forward(self, x):
        x = self.conv_stem(x)
        x = self.block_layers(x)
        x = self.pool_linear(x)

        return x
    
    
def _resnet(arch, repeats, **kwargs):
    model = ResNet(arch, repeats, **kwargs)
    return model


def resnet10(**kwargs):
    return _resnet('resnet10', [1, 1, 1, 1], **kwargs)    


def resnet18(**kwargs):
    return _resnet('resnet18', [2, 2, 2, 2], **kwargs)


def resnet34(**kwargs):
    return _resnet('resnet34', [3, 4, 6, 3], **kwargs)


def resnet50(**kwargs):
    return _resnet('resnet50', [3, 4, 6, 3], **kwargs)


def resnet101( **kwargs):
    return _resnet('resnet101', [3, 4, 23, 3], **kwargs)


def resnet152(**kwargs):
    return _resnet('resnet152', [3, 8, 36, 3], **kwargs)


def resnet50_32x4d(**kwargs):
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 4
    return _resnet('resnet50_32x4d', [3, 4, 6, 3], **kwargs)


def resnet101_32x8d(**kwargs):
    r"""ResNeXt-101 32x8d model from
    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
    """
    kwargs['groups'] = 32
    kwargs['width_per_group'] = 8
    return _resnet('resnet101_32x8d', [3, 4, 23, 3], **kwargs)


def wide_resnet50_2(**kwargs):
    r"""Wide ResNet-50-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet50_2', [3, 4, 6, 3], **kwargs)


def wide_resnet101_2(**kwargs):
    r"""Wide ResNet-101-2 model from
    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
    The model is the same as ResNet except for the bottleneck number of channels
    which is twice larger in every block. The number of channels in outer 1x1
    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
    """
    kwargs['width_per_group'] = 64 * 2
    return _resnet('wide_resnet101_2', [3, 4, 23, 3], **kwargs)

In [6]:
# if you have other architecture, type into ARCHITECTURE dict
ARCHITECTURE = {'resnet10': resnet10, 'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50}

#### MoCo 논문에서 multi-gpu 상황에서 성능 향상을 위해 Batch Shuffling이란 technique을 사용했습니다.
#### 하지만 Single-gpu에서는 shuffling 적용이 불가해, encoder에서 사용하는 BN을 다음의 SplitBatchNorm으로 바꿀 것을 권장하였습니다.

In [7]:
# SplitBatchNorm: Same effect with Batch Shuffling in MoCo
class SplitBatchNorm(nn.BatchNorm2d):
    def __init__(self, num_features, num_splits, **kw):
        super().__init__(num_features, **kw)
        self.num_splits = num_splits
        
    def forward(self, input):
        N, C, H, W = input.shape
        if self.training or not self.track_running_stats:
            running_mean_split = self.running_mean.repeat(self.num_splits)
            running_var_split = self.running_var.repeat(self.num_splits)
            outcome = nn.functional.batch_norm(
                input.view(-1, C * self.num_splits, H, W), running_mean_split, running_var_split, 
                self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
                True, self.momentum, self.eps).view(N, C, H, W)
            self.running_mean.data.copy_(running_mean_split.view(self.num_splits, C).mean(dim=0))
            self.running_var.data.copy_(running_var_split.view(self.num_splits, C).mean(dim=0))
            return outcome
        else:
            return nn.functional.batch_norm(
                input, self.running_mean, self.running_var, 
                self.weight, self.bias, False, self.momentum, self.eps)

#### Encoder class의 mixco argument를 통해, MoCo 방법론과 MixCo 방법론을 선택할 수 있습니다.

In [8]:
class Encoder(nn.Module):
    """
    Build a MoCo model with: a query encoder, a key encoder, and a queue
    https://arxiv.org/abs/1911.05722
    """
    def __init__(self, base_encoder, algo, dim=128, num_splits=64, K=65536, m=0.999, T=0.2, mix_T=0.05, mlp=False, single_gpu=False):
        """
        dim: feature dimension (default: 128)
        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)
        """
        super(Encoder, self).__init__()
        
        self.algo = algo
        self.single_gpu = single_gpu
        
        self.K = K
        self.m = m
        self.T = T
        self.mix_T = mix_T

        # create the encoders
        # num_classes is the output fc dimension
        
        norm_layer = SplitBatchNorm if single_gpu else None
        self.encoder_q = base_encoder(num_classes=dim, norm_layer=norm_layer, num_splits=num_splits)
        self.encoder_k = base_encoder(num_classes=dim, norm_layer=norm_layer, num_splits=num_splits)

        if mlp:  # hack: brute-force replacement
            dim_mlp = self.encoder_q.fc.weight.shape[1]
            self.encoder_q.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_q.fc)
            self.encoder_k.fc = nn.Sequential(nn.Linear(dim_mlp, dim_mlp), nn.ReLU(), self.encoder_k.fc)

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)  # initialize
            param_k.requires_grad = False  # not update by gradient

        # create the queue
        self.register_buffer("queue", torch.randn(dim, K))
        self.queue = nn.functional.normalize(self.queue, dim=0)

        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))

    @torch.no_grad()
    def concat_all_gather(self, tensor):
        """
        Performs all_gather operation on the provided tensors.
        *** Warning ***: torch.distributed.all_gather has no gradient.
        """
        tensors_gather = [torch.ones_like(tensor)
            for _ in range(torch.distributed.get_world_size())]
        torch.distributed.all_gather(tensors_gather, tensor, async_op=False)

        output = torch.cat(tensors_gather, dim=0)
        return output

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        """
        Momentum update of the key encoder
        """
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, keys):
        # gather keys before updating queue
        if not self.single_gpu:
            keys = self.concat_all_gather(keys)

        batch_size = keys.shape[0]

        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0  # for simplicity

        # replace the keys at ptr (dequeue and enqueue)
        self.queue[:, ptr:ptr + batch_size] = keys.T
        ptr = (ptr + batch_size) % self.K  # move pointer

        self.queue_ptr[0] = ptr
        
    @torch.no_grad()
    def _batch_shuffle_ddp(self, x):
        """
        Batch shuffle, for making use of BatchNorm.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = self.concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # random shuffle index
        idx_shuffle = torch.randperm(batch_size_all).cuda()

        # broadcast to all gpus
        torch.distributed.broadcast(idx_shuffle, src=0)

        # index for restoring
        idx_unshuffle = torch.argsort(idx_shuffle)

        # shuffled index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_shuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this], idx_unshuffle
    
    @torch.no_grad()
    def _batch_unshuffle_ddp(self, x, idx_unshuffle):
        """
        Undo batch shuffle.
        *** Only support DistributedDataParallel (DDP) model. ***
        """
        # gather from all gpus
        batch_size_this = x.shape[0]
        x_gather = self.concat_all_gather(x)
        batch_size_all = x_gather.shape[0]

        num_gpus = batch_size_all // batch_size_this

        # restored index for this gpu
        gpu_idx = torch.distributed.get_rank()
        idx_this = idx_unshuffle.view(num_gpus, -1)[gpu_idx]

        return x_gather[idx_this]
    
    @torch.no_grad()
    def img_mixer(self, im_q):
        B = im_q.size(0)
        assert B % 2 == 0
        sid = int(B/2)
        im_q1, im_q2 = im_q[:sid], im_q[sid:]
        
        # each image get different lambda
        lam = torch.from_numpy(np.random.uniform(0, 1, size=(sid,1,1,1))).float().to(im_q.device)
        imgs_mix = lam * im_q1 + (1-lam) * im_q2
        lbls_mix = torch.cat((torch.diag(lam.squeeze()), torch.diag((1-lam).squeeze())), dim=1)
        
        return imgs_mix, lbls_mix
    
    def forward(self, im_q, im_k):
        """
        Input:
            im_q: a batch of query images
            im_k: a batch of key images
        Output:
            logits, targets
        """

        if self.algo == 'moco':
            q = self.encoder_q(im_q) # queries: NxC
            q = nn.functional.normalize(q, dim=1)
            
        elif self.algo == 'mixco':
            imgs_mix, lbls_mix = self.img_mixer(im_q)
            # compute query features
            q = self.encoder_q(torch.cat((im_q, imgs_mix))) # queries: NxC
            q = nn.functional.normalize(q, dim=1)

            q_mix = q[im_q.size(0):]
            q = q[:im_q.size(0)]

        # compute key features
        with torch.no_grad():  # no gradient to keys
            self._momentum_update_key_encoder()  # update the key encoder

            # shuffle for making use of BN
            if not self.single_gpu:
                im_k, idx_unshuffle = self._batch_shuffle_ddp(im_k)

            k = self.encoder_k(im_k)  # keys: NxC
            k = nn.functional.normalize(k, dim=1)

            # undo shuffle
            if not self.single_gpu:
                k = self._batch_unshuffle_ddp(k, idx_unshuffle)
            

        # compute logits
        # Einstein sum is more intuitive
        # positive logits: Nx1
        l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
        # negative logits: NxK
        l_neg = torch.einsum('nc,ck->nk', [q, self.queue.clone().detach()])

        # logits: Nx(1+K)
        logits = torch.cat([l_pos, l_neg], dim=1)

        # labels: positive key indicators
        labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
        
        if self.algo == 'moco':
            # apply temperature
            logits /= self.T

            # dequeue and enqueue
            self._dequeue_and_enqueue(k)

            return logits, labels
            
        elif self.algo == 'mixco':
            # mixed logits: N/2 x N
            logits_mix_pos = torch.mm(q_mix, k.transpose(0, 1)) 
            # mixed negative logits: N/2 x K
            logits_mix_neg = torch.mm(q_mix, self.queue.clone().detach())
            logits_mix = torch.cat([logits_mix_pos, logits_mix_neg], dim=1) # N/2 x (N+K)
            lbls_mix = torch.cat([lbls_mix, torch.zeros_like(logits_mix_neg)], dim=1)

            # apply temperature
            logits /= self.T
            logits_mix /= self.mix_T

            # dequeue and enqueue
            self._dequeue_and_enqueue(k)

            return logits, labels, logits_mix, lbls_mix

### [Section 3] 손실함수 정의하기

#### Section 2의 Encoder가 반환해준 logit과 label에 대해서, contrastive 손실 함수를 통해 모델을 학습할 것입니다.
#### 기존의 MoCo에서 사용하는 contrastive loss는 다음과 같습니다.

\begin{equation*}
L_{𝑀𝑜𝐶𝑜}= -\sum_{i=1}^n log(\frac{exp(\frac{v_i \cdot v_{i}^{'}}{\tau})}{\sum_{j=0}^{r} exp(\frac{v_i \cdot v_{j}^{'}}{\tau})})
\end{equation*}

#### MixCo 방법론의 경우, 다음의 손실함수를 통해 학습합니다.

\begin{equation*}
L_{MixCo}= L_{MoCo} + \gamma * -(\sum_{i=1}^n \lambda_{i} * log(\frac{exp(\frac{v_{i}^{mix_{ij}} \cdot v_{i}^{'}}{\tau_{mix}})}{\sum_{k=0}^{r} exp(\frac{v_{i}^{mix_{ij}} \cdot v_{k}^{'}}{\tau_{mix}})}) + (1 - \lambda_{i}) * log(\frac{exp(\frac{v_{i}^{mix_{ij}} \cdot v_{j}^{'}}{\tau_{mix}})}{\sum_{k=0}^{r} exp(\frac{v_{i}^{mix_{ij}} \cdot v_{k}^{'}}{\tau_{mix}})}))
\end{equation*}

#### i번째 이미지와 j번째 이미지의 mixup으로 얻은 representation에 대한 contrastive 손실함수를 기존의 MoCo 손실함수에 더해주는 방법으로 학습합니다.

In [9]:
class SoftCrossEntropy(nn.Module):
    def __init__(self):
        super(SoftCrossEntropy, self).__init__()
        
    def forward(self, logits, target):
        probs = F.softmax(logits, 1) 
        nll_loss = (- target * torch.log(probs)).sum(1).mean()

        return nll_loss

    
class MixcoLoss(nn.Module):
    def __init__(self, gamma):
        super(MixcoLoss, self).__init__()
        self.loss_fn = nn.CrossEntropyLoss()
        self.soft_loss = SoftCrossEntropy()
        self.gamma = gamma

    def forward(self, outputs):
        if not self.gamma:
            logits, labels = outputs
            loss = self.loss_fn(logits, labels)
        else:
            logits, labels, logits_mix, lbls_mix = outputs
            loss = self.loss_fn(logits, labels)
            loss += self.gamma * self.soft_loss(logits_mix, lbls_mix)
        
        return loss  

### [Section 4] 학습 함수 구현하기

#### 먼저 학습 과정에서 사용될 seed 고정, learning_rate 조절, checkpoint 저장 등의 utils 함수들을 구현하였습니다. 
#### 그 다음 encoder의 한 epoch 단위의 학습 함수를 구현하였습니다.

In [10]:
#utils
def fix_seed(seed):
    # fix seed for reproducibility
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    cudnn.deterministic = True
    cudnn.benchmark = True
    warnings.warn('You have chosen to seed training. '
                  'This will turn on the CUDNN deterministic setting, '
                  'which can slow down your training considerably! '
                  'You may see unexpected behavior when restarting '
                  'from checkpoints.')


def save_checkpoint(state, is_best, filename='test'):
    filename = os.path.join('./results/pretrained', filename)
    torch.save(state, filename)
    
    
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def adjust_learning_rate(optimizer, epoch, lr, cos, num_epochs, schedule):
    """Decay the learning rate based on schedule"""
    lr = lr
    if cos:  # cosine lr schedule
        lr *= 0.5 * (1. + math.cos(math.pi * epoch / num_epochs))
    else:  # stepwise lr schedule
        for milestone in schedule:
            lr *= 0.1 if epoch >= milestone else 1.
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
        
        
def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res
    
    
def update_json(exp_name, part='pretrain', acc=[0,0], path='./results/results.json'):
    acc = [round(acc[0], 3), round(acc[1], 3)]
    if not os.path.exists(path):
        with open(path, 'w') as f:
            json.dump({}, f)

    with open(path, 'r') as f:
        result_dict = json.load(f)
    
        if exp_name not in result_dict.keys():
            result_dict[exp_name] = dict()

        result_dict[exp_name][part] = acc
    
    with open(path, 'w') as f:
        json.dump(result_dict, f)
        
    print('results updated to %s' % path)

In [11]:
def train(train_loader, model, optimizer, criterion, epoch, print_freq, gpu):
    batch_time = AverageMeter('Time', ':6.3f')
    data_time = AverageMeter('Data', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(train_loader),
        [batch_time, data_time, losses, top1, top5],
        prefix="Epoch: [{}]".format(epoch))

    # switch to train mode
    model.train()

    end = time.time()
    for i, (images, _) in enumerate(train_loader):
        # measure data loading time
        data_time.update(time.time() - end)

        if gpu is not None:
            images[0] = images[0].cuda(gpu, non_blocking=True)
            images[1] = images[1].cuda(gpu, non_blocking=True)

        # compute output
        outputs = model(im_q=images[0], im_k=images[1])
        loss = criterion(outputs)

        # acc1/acc5 are (K+1)-way contrast classifier accuracy
        # measure accuracy and record loss
        acc1, acc5 = accuracy(outputs[0], outputs[1], topk=(1, 5))
        losses.update(loss.item(), images[0].size(0))
        top1.update(acc1[0], images[0].size(0))
        top5.update(acc5[0], images[0].size(0))

        # compute gradient and do SGD step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()

        if i % print_freq == 0:
            progress.display(i)
            
    return top1.avg, top5.avg

### [Section 5] 분산학습 환경 설정하기

#### 이전 Section들에서 정의했던 함수와 class들을 이용해, 전체적인 main_worker 함수를 구현하였습니다. 
#### pytorch에서 제공하는 분산환경 구축 코드를 참조하여 'main' 함수를 구현였습니다. 'func' argument로는 'main_worker' function을 넣으면 됩니다. 
#### 참고로, 분산환경을 사용할 때는 데이터를 불러오는 과정에서 DistributedSampler를 사용해야합니다. (Section 1 참고)

In [12]:
def main_worker(gpu, ngpus_per_node, exp_name, distributed_kwargs, algo, arch, arch_kwargs, train_kwargs, data_kwargs):
    gpu = gpu

    # suppress printing if not master
    if distributed_kwargs['multiprocessing_distributed'] and gpu != 0:
        def print_pass(*args):
            pass
        builtins.print = print_pass

    if gpu is not None:
        print(gpu)
        print("Use GPU: {} for training".format(gpu))

    if distributed_kwargs['distributed']:
        if distributed_kwargs['dist_url'] == "env://" and distributed_kwargs['rank'] == -1:
            rank = int(os.environ["RANK"])
        if distributed_kwargs['multiprocessing_distributed']:
            # For multiprocessing distributed training, rank needs to be the
            # global rank among all the processes
            rank = rank * ngpus_per_node + gpu
        dist.init_process_group(backend=distributed_kwargs['dist_backend'], 
                                init_method=distributed_kwargs['dist_url'],
                                world_size=distributed_kwargs['world_size'],
                                rank=rank)
    # create model
    print("=> creating model '{}'".format(arch))
    
    model = Encoder(ARCHITECTURE[arch], algo, **arch_kwargs)
    print(model)

    if distributed_kwargs['distributed']:
        # For multiprocessing distributed, DistributedDataParallel constructor
        # should always set the single device scope, otherwise,
        # DistributedDataParallel will use all available devices.
        if gpu is not None:
            torch.cuda.set_device(gpu)
            model.cuda(gpu)
            # When using a single GPU per process and per
            # DistributedDataParallel, we need to divide the batch size
            # ourselves based on the total number of GPUs we have
            data_kwargs['batch_size'] = int(data_kwargs['batch_size'] / ngpus_per_node)
            data_kwargs['num_workers'] = int((data_kwargs['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)
            model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu])
        else:
            model.cuda()
            # DistributedDataParallel will divide and allocate batch_size to all
            # available GPUs if device_ids are not set
            model = torch.nn.parallel.DistributedDataParallel(model)
    elif gpu is not None:
        torch.cuda.set_device(gpu)
        model = model.cuda(gpu)
        
    # define loss function (criterion) and optimizer
    optimizer = torch.optim.SGD(model.parameters(), **train_kwargs['opt_kwargs'])
    criterion = MixcoLoss(train_kwargs['gamma']).cuda(gpu)

    # get train_loader
    train_loader, train_sampler = data_loader(**data_kwargs)
    
    for epoch in range(train_kwargs['num_epochs']):
        if distributed_kwargs['distributed']:
            train_sampler.set_epoch(epoch)
        adjust_learning_rate(optimizer, epoch, train_kwargs['opt_kwargs']['lr'], train_kwargs['cos'], train_kwargs['num_epochs'], train_kwargs['schedule'])

        # train for one epoch
        acc1, acc5 = train(train_loader, model, optimizer, criterion, epoch+1, train_kwargs['print_freq'], gpu)

    # always saves at the end of training    
    else:
        if not distributed_kwargs['multiprocessing_distributed'] \
        or (distributed_kwargs['multiprocessing_distributed'] and rank % ngpus_per_node == 0):
            save_checkpoint({
                'epoch': epoch+1,
                'arch': arch,
                'state_dict': model.state_dict(),
                'optimizer' : optimizer.state_dict(),
            }, is_best=False, filename='{}.pth.tar'.format(exp_name))
            
            update_json(exp_name, 'pretrain', [acc1.item(), acc5.item()])

In [13]:
def main(func, exp_name, distributed_kwargs, algo, arch, arch_kwargs, train_kwargs, data_kwargs):
    os.environ["CUDA_VISIBLE_DEVICES"] = ','.join([str(gpu) for gpu in distributed_kwargs['gpu']])
    if len(distributed_kwargs['gpu']) > 1:
        distributed_kwargs['gpu'] = None
    else:
        distributed_kwargs['gpu'] = [0]

    if distributed_kwargs['gpu'] is not None:
        warnings.warn('You have chosen a specific GPU. This will completely '
                      'disable data parallelism.')

    if distributed_kwargs['dist_url'] == "env://" and distributed_kwargs['world_size'] == -1:
        distributed_kwargs['world_size'] = int(os.environ["WORLD_SIZE"])

    distributed_kwargs['distributed'] = distributed_kwargs['world_size'] > 1 or distributed_kwargs['multiprocessing_distributed']
    distributed_kwargs['ngpus_per_node'] = torch.cuda.device_count()
    
    data_kwargs['distributed'] = distributed_kwargs['distributed']
    
    if distributed_kwargs['multiprocessing_distributed']:
        # Since we have ngpus_per_node processes per node, the total world_size
        # needs to be adjusted accordingly
        distributed_kwargs['world_size'] *= distributed_kwargs['ngpus_per_node']
        # Use torch.multiprocessing.spawn to launch distributed processes: the
        # main_worker process function
        mp.spawn(func, nprocs=distributed_kwargs['ngpus_per_node'], args=(distributed_kwargs['ngpus_per_node'], exp_name, distributed_kwargs, algo, arch, arch_kwargs, train_kwargs, data_kwargs))
    else:
        # Simply call main_worker function
        func(distributed_kwargs['gpu'][0], distributed_kwargs['ngpus_per_node'], exp_name, distributed_kwargs, algo, arch, arch_kwargs, train_kwargs, data_kwargs)

### [Section 6] Encoder 학습하기

#### ResNet encoder 모델을 라벨이 없는 Tiny-ImageNet 데이터를 이용해 학습해볼 것입니다.
### <span style="color:red">원하는 실험 셋팅에 필요한 argument를 정의하시면 됩니다.</span>
- 아래 세팅은 single-gpu 상황에서, mixco 방법론의 실험입니다.
- **moco** 방법론으로 실험하려면, **train_kwargs['gamma'] = 0.0** 으로 설정하시면 됩니다.
- **multiple-gpu** 상황에서의 실험을 원하시면, **distributed_kwargs['multiprocessing_distributed'] = True** 와 **distributed_kwargs['gpu'] = [gpu_number1, gpu_number2, ...]** 로 설정하시면 됩니다.
- <span style="color:red">하지만, jupyter notebook에서는 multi-gpu를 위한 torch.multiprocessing.spawn을 사용할 수 없어 여기서는 불가능합니다.<span style="color:red">

In [14]:
# make directories
!mkdir -p './results/pretrained'

# setting
seed = 0
exp_name = 'mixco_resnet10'

data_kwargs = {'dataset': 'tiny-imagenet',
               'data_path': './data/tiny-imagenet-200',
               'aug_plus': True,
               'batch_size': 128,
               'num_workers': 32,
               'download': False}

distributed_kwargs = {'multiprocessing_distributed': False,
                      'dist_url': 'tcp://localhost:13311',
                      'world_size': 1,
                      'rank': 0,
                      'dist_backend': 'nccl',
                      'gpu': [0]}

algo = 'mixco'
arch = 'resnet10'
arch_kwargs = {'dim': 128,
               'K': 65536,
               'm': 0.999,
               'T': 0.2,
               'mix_T': 0.05,
               'mlp': True}
arch_kwargs['single_gpu'] = False if len(distributed_kwargs['gpu']) > 1 else True
arch_kwargs['num_splits'] = int(data_kwargs['batch_size']/2) if arch_kwargs['single_gpu'] else None

train_kwargs = {'print_freq': 10,
                'gamma': 1.0,
                'num_epochs': 100,
                'schedule': [60, 80],
                'cos': True,
                'opt_kwargs': {'lr': 0.015, 'momentum': 0.9, 'weight_decay': 1e-4}}

In [15]:
fix_seed(seed)
main(main_worker, exp_name, distributed_kwargs, algo, arch, arch_kwargs, train_kwargs, data_kwargs)

0
Use GPU: 0 for training
=> creating model 'resnet10'


  if __name__ == '__main__':
  if __name__ == '__main__':


Encoder(
  (encoder_q): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (block_layers): Sequential(
      (0): Sequential(
        (0): Block(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (bn2): SplitBatchNorm(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
        )
      )
      (1): Sequential(
        (0): Block(
          (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False

KeyboardInterrupt: 