<a href="https://colab.research.google.com/github/KyuhyoJeon/BYOL/blob/master/BYOL.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [14]:
### Google drive mount ###
# from google.colab import drive 
# drive.mount('/content/gdrive/')
### ------------------------------------------ ###

In [15]:
### Arguments define ###
import easydict
import os
from datetime import datetime

args = easydict.EasyDict({
    'image_size':32, # original = 224
    'learning_rate':0.2, # original lr = 0.2, others = 0.3 or 3e-4
    'momentum':0, 
    'weight_decay':1.5e-6, 
    'batch_size':1024, 
    'num_epochs':1000, 
    'resnet_version':'resnet18', # original = resnet50
    'optim':'lars', # 'lars', 'adam', 'sgd' 
    'checkpoint_epochs':10, 
    # ********************MUST CHECK********************** #
    'dataset_dir':'/content/gdrive/MyDrive/Colab Notebooks/datasets', # dataset directory
    'ckpt_dir':'/content/gdrive/MyDrive/Colab Notebooks/byol/ckpt',   # Network checkpoint directory
    'num_workers':8, 
    'nodes':1, 
    'gpus':1, 
    'nr':0, 
    'device':'cuda', 
    'eval':True, 
    'eval_epochs':30, 
    # ********************MUST CHECK********************** #
    'dryrun':True, # check line 47~53
    'debug':True, # check line 56~62
    'current_epochs':0
})

# ********************MUST CHECK********************** #
# dryrun setting
if args.dryrun:
  args.image_size=32
  args.num_epochs = 100
  args.batch_size = 256
  args.num_workers = 4
  args.dryrun_subset_size = 100
  args.resnet_version = 'resnet18'

# ********************MUST CHECK********************** #
# debug setting
if args.debug:
  args.image_size=32
  args.num_epochs = 1
  args.batch_size = 2
  args.num_workers = 0
  args.debug_subset_size = 8
  args.resnet_version = 'resnet18'
  

# make check point directory ex: "ckpt_dir/resnet18/lars/021805"
tmp_dir = os.path.join(args.ckpt_dir, f"{args.resnet_version}", f"{args.optim}", f"{datetime.now().strftime('%m%d%H')}")
if not os.path.exists(tmp_dir):
  os.makedirs(tmp_dir)
### ------------------------------------------ ###

In [16]:
### Image augmentation define
import torch
import torchvision
from torchvision import datasets, transforms

class simclr_transform:
  # augmentations: 
  # random patch, 224 resize, random hrizontal flip, color distortion, 
  # random swquence brightness, contrast, saturation, hue adjustment, 
  # and optional gray scale conversion, Gaussian blur, solarization
  imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
  def __init__(self, size, mean_std=imagenet_mean_std, s=1.0):
    self.transform = transforms.Compose(
        [
        transforms.RandomSizedCrop(size=size), 
        transforms.RandomHorizontalFlip(), 
        transforms.RandomApply(
            [transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s)], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply(
            [transforms.GaussianBlur(kernel_size=size//20*2+1, sigma=(0.1, 2.0))], p=0.5), 
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
        ]
    )
  def __call__(self, x):
    x1 = self.transform(x)
    x2 = self.transform(x)
    return x1, x2

from PIL import Image

class Transform_single():
  imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
  def __init__(self, size, train, normalize=imagenet_mean_std):
    if train == True:
      self.transform = transforms.Compose(
          [
           transforms.RandomResizedCrop(size, scale=(0.08, 1.0), 
                                        ratio=(3.0/4.0,4.0/3.0), 
                                        interpolation=Image.BICUBIC
                                        ),
           transforms.RandomHorizontalFlip(),
           transforms.ToTensor(),
           transforms.Normalize(*normalize)
          ]
      )
    else:
      self.transform = transforms.Compose(
          [
           transforms.Resize(int(size*(8/7)), 
                             interpolation=Image.BICUBIC
                             ), # 224 -> 256 
           transforms.CenterCrop(size),
           transforms.ToTensor(),
           transforms.Normalize(*normalize)
          ]
      )

  def __call__(self, x):
    return self.transform(x)
### ------------------------------------------ ###

In [17]:
### dataset load ###
cifar_train = datasets.CIFAR10(
    root=args.dataset_dir, 
    train=True, 
    transform=simclr_transform(args.image_size), 
    download=True
)

cifar_memory = datasets.CIFAR10(
    root=args.dataset_dir, 
    train=True, 
    download=False, 
    transform=Transform_single(size=args.image_size, train=False), 
    )

cifar_test = datasets.CIFAR10(
    root=args.dataset_dir, 
    train=False, 
    download=False, 
    transform=Transform_single(size=args.image_size, train=False), 
)

if args.debug:
  cifar_train = torch.utils.data.Subset(cifar_train, range(0, args.debug_subset_size))
  cifar_train.classes = cifar_train.dataset.classes
  cifar_train.targets = cifar_train.dataset.targets
  cifar_memory = torch.utils.data.Subset(cifar_memory, range(0, args.debug_subset_size))
  cifar_memory.classes = cifar_memory.dataset.classes
  cifar_memory.targets = cifar_memory.dataset.targets
  cifar_test = torch.utils.data.Subset(cifar_test, range(0, args.debug_subset_size))
  cifar_test.classes = cifar_test.dataset.classes
  cifar_test.targets = cifar_test.dataset.targets
# elif args.dryrun:
#   cifar_train = torch.utils.data.Subset(cifar_train, range(0, args.dryrun_subset_size))
#   cifar_train.classes = cifar_train.dataset.classes
#   cifar_train.targets = cifar_train.dataset.targets

train_loader = torch.utils.data.DataLoader(
    cifar_train, 
    batch_size=args.batch_size, 
    shuffle=True, 
    num_workers=args.num_workers, 
    drop_last=True, 
    pin_memory=True
)

memory_loader = torch.utils.data.DataLoader(
    cifar_memory, 
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True,
    pin_memory=True,
    )

test_loader = torch.utils.data.DataLoader(
    cifar_test, 
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True,
    pin_memory=True,
)
### ------------------------------------------ ###

  "please use transforms.RandomResizedCrop instead.")


Files already downloaded and verified


In [18]:
### dataset load check ###
import matplotlib.pyplot as plt
import numpy as np

# def imshow(img):
#   img = img / 2 + 0.5     # unnormalize
#   npimg = img.numpy()
#   plt.imshow(np.transpose(npimg, (1, 2, 0)))
#   plt.show()

# dataiter = iter(train_loader)
# (images1, images2), labels = dataiter.next()

# imshow(torchvision.utils.make_grid(images1))
# imshow(torchvision.utils.make_grid(images2))
# print(' '.join('%5s' % train_loader.dataset.classes[labels[j]] for j in range(len(labels))))
### ------------------------------------------ ###

In [19]:
### ResNet18 for CIFAR10 define ###
import torch
import torch.nn as nn
import os
# https://raw.githubusercontent.com/huyvnphan/PyTorch_CIFAR10/master/cifar10_models/resnet.py
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d']

def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
  """3x3 convolution with padding"""
  return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                    padding=dilation, groups=groups, bias=False, dilation=dilation)


def conv1x1(in_planes, out_planes, stride=1):
  """1x1 convolution"""
  return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class BasicBlock(nn.Module):
  expansion = 1

  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None):
    super(BasicBlock, self).__init__()
    if norm_layer is None:
      norm_layer = nn.BatchNorm2d
    if groups != 1 or base_width != 64:
      raise ValueError('BasicBlock only supports groups=1 and base_width=64')
    if dilation > 1:
      raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
    # Both self.conv1 and self.downsample layers downsample the input when stride != 1
    self.conv1 = conv3x3(inplanes, planes, stride)
    self.bn1 = norm_layer(planes)
    self.relu = nn.ReLU(inplace=True)
    self.conv2 = conv3x3(planes, planes)
    self.bn2 = norm_layer(planes)
    self.downsample = downsample
    self.stride = stride

  def forward(self, x):
    identity = x

    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu(out)

    out = self.conv2(out)
    out = self.bn2(out)

    if self.downsample is not None:
        identity = self.downsample(x)

    out += identity
    out = self.relu(out)

    return out


class Bottleneck(nn.Module):
  expansion = 4

  def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
                base_width=64, dilation=1, norm_layer=None):
    super(Bottleneck, self).__init__()
    if norm_layer is None:
        norm_layer = nn.BatchNorm2d
    width = int(planes * (base_width / 64.)) * groups
    # Both self.conv2 and self.downsample layers downsample the input when stride != 1
    self.conv1 = conv1x1(inplanes, width)
    self.bn1 = norm_layer(width)
    self.conv2 = conv3x3(width, width, stride, groups, dilation)
    self.bn2 = norm_layer(width)
    self.conv3 = conv1x1(width, planes * self.expansion)
    self.bn3 = norm_layer(planes * self.expansion)
    self.relu = nn.ReLU(inplace=True)
    self.downsample = downsample
    self.stride = stride

    def forward(self, x):
      identity = x

      out = self.conv1(x)
      out = self.bn1(out)
      out = self.relu(out)

      out = self.conv2(out)
      out = self.bn2(out)
      out = self.relu(out)

      out = self.conv3(out)
      out = self.bn3(out)

      if self.downsample is not None:
          identity = self.downsample(x)

      out += identity
      out = self.relu(out)

      return out


class ResNet(nn.Module):

  def __init__(self, block, layers, num_classes=10, zero_init_residual=False,
                groups=1, width_per_group=64, replace_stride_with_dilation=None,
                norm_layer=None):
    super(ResNet, self).__init__()
    if norm_layer is None:
      norm_layer = nn.BatchNorm2d
    self._norm_layer = norm_layer

    self.inplanes = 64
    self.dilation = 1
    if replace_stride_with_dilation is None:
      # each element in the tuple indicates if we should replace
      # the 2x2 stride with a dilated convolution instead
      replace_stride_with_dilation = [False, False, False]
    if len(replace_stride_with_dilation) != 3:
      raise ValueError("replace_stride_with_dilation should be None "
                        "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
    self.groups = groups
    self.base_width = width_per_group
    
    ## CIFAR10: kernel_size 7 -> 3, stride 2 -> 1, padding 3->1
    self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
    ## END
    
    self.bn1 = norm_layer(self.inplanes)
    self.relu = nn.ReLU(inplace=True)
    # self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(block, 64, layers[0])
    self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
                                    dilate=replace_stride_with_dilation[0])
    self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
                                    dilate=replace_stride_with_dilation[1])
    self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
                                    dilate=replace_stride_with_dilation[2])
    self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(512 * block.expansion, num_classes)

    for m in self.modules():
      if isinstance(m, nn.Conv2d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
      elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

    # Zero-initialize the last BN in each residual branch,
    # so that the residual branch starts with zeros, and each residual block behaves like an identity.
    # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
    if zero_init_residual:
      for m in self.modules():
        if isinstance(m, Bottleneck):
          nn.init.constant_(m.bn3.weight, 0)
        elif isinstance(m, BasicBlock):
          nn.init.constant_(m.bn2.weight, 0)

  def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
    norm_layer = self._norm_layer
    downsample = None
    previous_dilation = self.dilation
    if dilate:
        self.dilation *= stride
        stride = 1
    if stride != 1 or self.inplanes != planes * block.expansion:
        downsample = nn.Sequential(
            conv1x1(self.inplanes, planes * block.expansion, stride),
            norm_layer(planes * block.expansion),
        )

    layers = []
    layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
                        self.base_width, previous_dilation, norm_layer))
    self.inplanes = planes * block.expansion
    for _ in range(1, blocks):
        layers.append(block(self.inplanes, planes, groups=self.groups,
                            base_width=self.base_width, dilation=self.dilation,
                            norm_layer=norm_layer))

    return nn.Sequential(*layers)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = self.relu(x)
    # x = self.maxpool(x)

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = self.avgpool(x)
    x = x.reshape(x.size(0), -1)
    x = self.fc(x)

    return x


def _resnet(arch, block, layers, pretrained, progress, device, **kwargs):
  model = ResNet(block, layers, **kwargs)
  if pretrained:
      script_dir = os.path.dirname(__file__)
      state_dict = torch.load(script_dir + '/state_dicts/'+arch+'.pt', map_location=device)
      model.load_state_dict(state_dict)
  return model


def resnet18(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNet-18 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress, device,
                  **kwargs)


def resnet34(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNet-34 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress, device,
                  **kwargs)


def resnet50(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNet-50 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress, device,
                  **kwargs)


def resnet101(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNet-101 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress, device,
                  **kwargs)


def resnet152(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNet-152 model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress, device,
                  **kwargs)


def resnext50_32x4d(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNeXt-50 32x4d model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  kwargs['groups'] = 32
  kwargs['width_per_group'] = 4
  return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
                  pretrained, progress, device, **kwargs)


def resnext101_32x8d(pretrained=False, progress=True, device='cpu', **kwargs):
  """Constructs a ResNeXt-101 32x8d model.
  Args:
      pretrained (bool): If True, returns a model pre-trained on ImageNet
      progress (bool): If True, displays a progress bar of the download to stderr
  """
  kwargs['groups'] = 32
  kwargs['width_per_group'] = 8
  return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
                  pretrained, progress, device, **kwargs)

### ------------------------------------------ ###

In [20]:
### resnet call ###
import torch.nn as nn
from torchvision import models

if args.resnet_version is not None:
  resnet = eval(f'{args.resnet_version}()')
  # resnet = eval(f'models.{args.resnet_version}()') # Original torchvision models
  resnet.output_dim = resnet.fc.in_features
  resnet.fc = nn.Identity()
else:
  raise NotImplementedError("Backbone is not implemented!")
### ------------------------------------------ ###

In [21]:
### byol network define ###
import copy
import math
from torch.nn import functional

hidden_size=4096
projection_size=256

class MLP(nn.Module):
  def __init__(self, input_dim):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(input_dim, hidden_size), 
        nn.BatchNorm1d(hidden_size, momentum=1-0.9, eps=1e-5), 
        nn.ReLU(inplace=True), 
        nn.Linear(hidden_size, projection_size)
    )
  def forward(self, x):
    return self.net(x)

class BYOL(nn.Module):
  def __init__(self, backbone):
    super().__init__()

    self.backbone=backbone
    self.projector = MLP(resnet.output_dim)
    self.online_encoder = nn.Sequential(
        self.backbone, 
        self.projector,
    )
    self.predictor = MLP(projection_size)
    self.target_encoder = copy.deepcopy(self.online_encoder)

  def target_ema(self, k, K, base_tau=0.996):
    return 1-(1-base_tau)*(math.cos(math.pi*k/K)+1)/2

  def reset_moving_average(self):
    del self.target_encoder
    self.target_encoder = copy.deepcopy(self.online_encoder)

  def update_moving_average(self, global_step, max_steps):
    tau = self.target_ema(global_step, max_steps)
    for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
      target.data = tau*target.data + (1-tau)*online.data
  
  def loss_function(self, p, z):
    p=functional.normalize(p, dim=-1, p=2)
    z=functional.normalize(z, dim=-1, p=2)
    return 2 - 2*(p*z).sum(dim=-1)

  def forward(self, x1, x2):
    z1_online, z2_online = self.online_encoder(x1), self.online_encoder(x2)
    p1_online, p2_online = self.predictor(z1_online), self.predictor(z2_online)
    with torch.no_grad():
      z1_target, z2_target = self.target_encoder(x1), self.target_encoder(x2)
    
    loss1, loss2 = self.loss_function(p1_online, z2_target.detach()), self.loss_function(p2_online, z1_target.detach())

    loss = loss1+loss2
    return loss.mean()
### ------------------------------------------ ###

In [22]:
### byol network call ###
byol = BYOL(resnet)
byol = byol.to(args.device)
byol = torch.nn.DataParallel(byol)
### ------------------------------------------ ###

In [23]:
### Lars optimizer define ###
from torch.optim import Adam, SGD
from torch.optim.optimizer import Optimizer

class LARS(Optimizer):
  def __init__(self, named_modules, lr, momentum=0.9, trust_coef=1e-3, weight_decay=1.5e-6, exclude_bias_from_adaption=True):
    defaults = dict(momentum=momentum, lr=lr, weight_decay=weight_decay, trust_coef=trust_coef)
    parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption)
    super(LARS, self).__init__(parameters, defaults)

  @torch.no_grad() 
  def step(self):
    for group in self.param_groups: # only 1 group in most cases 
      weight_decay = group['weight_decay']
      momentum = group['momentum']
      lr = group['lr']
      trust_coef = group['trust_coef']
      # print(group['name'])
      # eps = group['eps']
      for p in group['params']:
        # breakpoint()
        if p.grad is None:
          continue
        global_lr = lr
        velocity = self.state[p].get('velocity', 0)  
        # if name in self.exclude_from_layer_adaptation:
        if self._use_weight_decay(group):
          p.grad.data += weight_decay * p.data 

        trust_ratio = 1.0 
        if self._do_layer_adaptation(group):
          w_norm = torch.norm(p.data, p=2)
          g_norm = torch.norm(p.grad.data, p=2)
          trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 
        scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 
        next_v = momentum * velocity + scaled_lr * p.grad.data 
        update = next_v
        p.data = p.data - update 

  def _use_weight_decay(self, group):
    return False if group['name'] == 'exclude' else True
  def _do_layer_adaptation(self, group):
    return False if group['name'] == 'exclude' else True

  def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True):
    base = [] 
    exclude = []
    for name, module in named_modules:
      if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]:
        # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
        for name2, param in module.named_parameters():
          exclude.append(param)
      else:
        for name2, param in module.named_parameters():
          if name2 == 'bias':
            exclude.append(param)
          elif name2 == 'weight':
            base.append(param)
          else:
            pass # non leaf modules 
    return [{
        'name': 'base',
        'params': base
        },{
        'name': 'exclude',
        'params': exclude
    }] if exclude_bias_from_adaption == True else [{
        'name': 'base',
        'params': base+exclude 
    }]

LARS optimizer is from Github PatrickHua/SimSiam


> Link: https://github.com/PatrickHua/SimSiam/blob/main/optimizers/lars_simclr.py


### ------------------------------------------ ###

In [24]:
### optimizer call ###
predictor_prefix = ('module.predictor', 'predictor')
parameters = [{
    'name': 'base',
    'params': [param for name, param in byol.named_parameters() if not name.startswith(predictor_prefix)],
    'lr': args.learning_rate
},{
    'name': 'predictor',
    'params': [param for name, param in byol.named_parameters() if name.startswith(predictor_prefix)],
    'lr': args.learning_rate
}]
if args.optim == 'lars':
  optimizer = LARS(byol.named_modules(), lr=args.learning_rate*args.batch_size/256, weight_decay=args.weight_decay)
elif args.optim == 'adam':
  optimizer = Adam(parameters, lr=args.learning_rate*args.batch_size/256)
elif args.optim == 'sgd':
  optimizer = SGD(parameters, lr=args.learning_rate*args.batch_size/256, momentum=0.9)
else: # default is LARS
  optimizer = LARS(byol.named_modules(), lr=args.learning_rate*args.batch_size/256, weight_decay=args.weight_decay)
### ------------------------------------------ ###

In [25]:
### knn monitor define ###
from tqdm import tqdm
import torch.nn.functional as F 
# # code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N
# # test using a knn monitor
# def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, hide_progress=False):
#   net.eval()
#   classes = len(memory_data_loader.dataset.classes)
#   total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
#   with torch.no_grad():
#     # generate feature bank
#     for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
#       feature = net(data.cuda(non_blocking=True))
#       feature = F.normalize(feature, dim=1)
#       feature_bank.append(feature)
#     # [D, N]
#     feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
#     # [N]
#     feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
#     # loop test data to predict the label by weighted knn search
#     test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
#     for data, target in test_bar:
#       data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
#       feature = net(data)
#       feature = F.normalize(feature, dim=1)
      
#       pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)

#       total_num += data.size(0)
#       total_top1 += (pred_labels[:, 0] == target).float().sum().item()
#       test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100})
#   return total_top1 / total_num * 100

# # knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# # implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
# def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
#   # compute cos similarity between each feature vector and feature bank ---> [B, N]
#   sim_matrix = torch.mm(feature, feature_bank)
#   # [B, K]
#   sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
#   # [B, K]
#   sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
#   sim_weight = (sim_weight / knn_t).exp()

#   # counts for each class
#   one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
#   # [B*K, C]
#   one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
#   # weighted score ---> [B, C]
#   pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

#   pred_labels = pred_scores.argsort(dim=-1, descending=True)
#   return pred_labels
### ------------------------------------------ ###

In [26]:
### Training ###
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
from datetime import datetime
import os

writer = SummaryWriter()

global_step = 0
for epoch in tqdm(range(0, args.num_epochs), desc=f'Training'):
  metrics = defaultdict(list)
  
  for step, ((x1, x2), labels) in enumerate(tqdm(train_loader, desc=f'Epoch {epoch+1+args.current_epochs}/{args.num_epochs}')):
    x1, x2 = x1.cuda(non_blocking=True), x2.cuda(non_blocking=True)

    main_loss = byol(x1, x2)
    optimizer.zero_grad()
    main_loss.backward()
    optimizer.step()
    byol.module.update_moving_average(epoch+1+args.current_epochs, args.num_epochs)
    
    writer.add_scalar("Loss/train_step", main_loss, global_step)
    metrics["Loss/train"].append(main_loss.item())
    global_step += 1
  
  for k, v in metrics.items():
    writer.add_scalar(k, np.array(v).mean(), epoch+1+args.current_epochs)

  if epoch+1%args.checkpoint_epochs == 0:
    ckpt_path = os.path.join(tmp_dir, f"byol_{args.optim}_{epoch+1+args.current_epochs}.pt")
    print(f'Saving model at epoch {epoch+1+args.current_epochs}')
    torch.save({
        'epoch':epoch+1+args.current_epochs, 
        'state_dict':byol.module.state_dict()
    }, ckpt_path)

ckpt_path = os.path.join(tmp_dir, f"byol_{args.optim}_final.pt")
print(f'Saving final model at epoch {epoch+1+args.current_epochs}')
torch.save({
    'epoch':epoch+1+args.current_epochs, 
    'state_dict':byol.module.state_dict()
}, ckpt_path)
### ------------------------------------------ ###



Training:   0%|          | 0/1 [00:00<?, ?it/s][A[A


Epoch 1/1:   0%|          | 0/4 [00:00<?, ?it/s][A[A[A


Epoch 1/1:  25%|██▌       | 1/4 [00:00<00:02,  1.42it/s][A[A[A


Epoch 1/1:  50%|█████     | 2/4 [00:01<00:01,  1.48it/s][A[A[A


Epoch 1/1:  75%|███████▌  | 3/4 [00:01<00:00,  1.53it/s][A[A[A


Epoch 1/1: 100%|██████████| 4/4 [00:02<00:00,  1.58it/s]


Training: 100%|██████████| 1/1 [00:02<00:00,  2.54s/it]


Saving final model at epoch 1


In [27]:
### knn check ###
# main_accuracy = knn_monitor(byol.module.backbone, memory_loader, test_loader, k=min(200, len(memory_loader.dataset)), hide_progress=True)
# print('Accuracy:', main_accuracy)
### ------------------------------------------ ###

In [28]:
### learning rate scheduler define ###
import numpy as np

class LR_Scheduler(object):
  def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
    self.base_lr = base_lr
    self.constant_predictor_lr = constant_predictor_lr
    warmup_iter = iter_per_epoch * warmup_epochs
    warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
    decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
    cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
    
    self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    self.optimizer = optimizer
    self.iter = 0
    self.current_lr = 0
  def step(self):
    for param_group in self.optimizer.param_groups:

      if self.constant_predictor_lr and param_group['name'] == 'predictor':
        param_group['lr'] = self.base_lr
      else:
        lr = param_group['lr'] = self.lr_schedule[self.iter]
    
    self.iter += 1
    self.current_lr = lr
    return lr
  def get_lr(self):
    return self.current_lr
### ------------------------------------------ ###

In [29]:
### Linear Evaluation define ###
class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.log = []
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.log.append(self.avg)
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

def linear_eval(args, eval_from):
  eval_train_loader = torch.utils.data.DataLoader(
      torchvision.datasets.CIFAR10(
          root=args.dataset_dir, 
          train=True, 
          download=False, 
          transform=Transform_single(size=args.image_size, train=True), 
      ), 
      shuffle=True,
      batch_size=args.batch_size,
      num_workers=args.num_workers,
      drop_last=True,
      pin_memory=True,
  )

  eval_test_loader = torch.utils.data.DataLoader(
      torchvision.datasets.CIFAR10(
          root=args.dataset_dir, 
          train=False, 
          download=False, 
          transform=Transform_single(size=args.image_size, train=False), 
      ), 
      shuffle=False,
      batch_size=args.batch_size,
      num_workers=args.num_workers,
      drop_last=True,
      pin_memory=True,
  )

  eval_model = eval(f"{args.resnet_version}()")
  eval_model.output_dim = eval_model.fc.in_features
  eval_model.fc = torch.nn.Identity()
  eval_classifier = nn.Linear(in_features=eval_model.output_dim, out_features=10, bias=True).to(args.device)

  ###
  assert eval_from is not None
  eval_save_dict = torch.load(eval_from, map_location='cuda')
  eval_msg = eval_model.load_state_dict({k[9:]:v for k, v in eval_save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
  
  print(eval_msg)
  eval_model = eval_model.to(args.device)
  eval_model = torch.nn.DataParallel(eval_model)

  # if torch.cuda.device_count() > 1: eval_classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(eval_classifier)
  eval_classifier = torch.nn.DataParallel(eval_classifier)
  # define optimizer 'sgd', eval_classifier, lr=eval_base_lr=30, momentum=eval_optim_momentum-0.9, weight_decay=eval_optim_weight_decay=0
  predictor_prefix = ('module.predictor', 'predictor')
  parameters = [{
      'name': 'base',
      'params': [param for name, param in eval_classifier.named_parameters() if not name.startswith(predictor_prefix)],
      'lr': 30
  },{
      'name': 'predictor',
      'params': [param for name, param in eval_classifier.named_parameters() if name.startswith(predictor_prefix)],
      'lr': 30
  }]
  eval_optimizer = torch.optim.SGD(parameters, lr=30, momentum=0.9, weight_decay=0)

  # define lr scheduler
  eval_lr_scheduler = LR_Scheduler(
      eval_optimizer,
      0, 0*args.batch_size/256, 
      30, 30*args.batch_size/256, 0*args.batch_size/256, 
      len(eval_train_loader),
  )

  eval_loss_meter = AverageMeter(name='Loss')
  eval_acc_meter = AverageMeter(name='Accuracy')

  # Start training
  eval_global_progress = tqdm(range(0, args.eval_epochs), desc=f'Evaluating')
  for epoch in eval_global_progress:
    eval_loss_meter.reset()
    eval_model.eval()
    eval_classifier.train()
    eval_local_progress = tqdm(eval_train_loader, desc=f'Epoch {epoch}/{args.eval_epochs}', disable=True)
    
    for idx, (images, labels) in enumerate(eval_local_progress):

      eval_classifier.zero_grad()
      with torch.no_grad():
        eval_feature = eval_model(images.to(args.device))

      eval_preds = eval_classifier(eval_feature)

      eval_loss = F.cross_entropy(eval_preds, labels.to(args.device))

      eval_loss.backward()
      eval_optimizer.step()
      eval_loss_meter.update(eval_loss.item())
      eval_lr = eval_lr_scheduler.step()
      eval_local_progress.set_postfix({'lr':eval_lr, "loss":eval_loss_meter.val, 'loss_avg':eval_loss_meter.avg})

  eval_classifier.eval()
  eval_correct, eval_total = 0, 0
  eval_acc_meter.reset()
  for idx, (images, labels) in enumerate(eval_test_loader):
    with torch.no_grad():
      eval_feature = eval_model(images.to(args.device))
      eval_preds = eval_classifier(eval_feature).argmax(dim=1)
      eval_correct = (eval_preds == labels.to(args.device)).sum().item()
      eval_acc_meter.update(eval_correct/eval_preds.shape[0])
  print(f'Accuracy = {eval_acc_meter.avg*100:.2f}')
### ------------------------------------------ ###

In [None]:
### liner evaluation ###
if args.eval is not False:
  linear_eval(args, ckpt_path)
### ------------------------------------------ ###



Evaluating:   0%|          | 0/30 [00:00<?, ?it/s][A[A

<All keys matched successfully>




Evaluating:   3%|▎         | 1/30 [53:28<25:50:38, 3208.22s/it][A[A