<a href="https://colab.research.google.com/github/KyuhyoJeon/BYOL/blob/master/BYOL_spijkervet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

https://github.com/Spijkervet/BYOL

In [None]:
import easydict
import os
from datetime import datetime

args = easydict.EasyDict({
    'image_size':224, 
    'learning_rate':3e-4, 
    'momentum':None, 
    'weight_decay':1.5e-6, 
    'batch_size':256, 
    'num_epochs':100, 
    'resnet_version':'resnet18', 
    'optim':'adam', # 'lars', 'adam', 'sgd' 
    'checkpoint_epochs':10, 
    'dataset_dir':'./datasets', 
    'ckpt_dir':'./ckpt', 
    'num_workers':8, 
    'nodes':1, 
    'gpus':1, 
    'nr':0, 
    'device':'cuda', 
    'eval':True, 
    'dryrun':True, 
    'debug':False
})

if args.dryrun:
  args.image_size=32
  args.num_epochs = 10
  args.batch_size = 256
  args.num_workers = 0
  args.dryrun_subset_size = 100
  args.resnet_version = 'resnet18'

if args.debug:
  args.image_size=32
  args.num_epochs = 1
  args.batch_size = 2
  args.num_workers = 0
  args.debug_subset_size = 8
  args.resnet_version = 'resnet18'

tmp_dir = os.path.join(args.ckpt_dir, f"{args.resnet_version}", f"{args.optim}", f"{datetime.now().strftime('%m%d%H')}")
if not os.path.exists(tmp_dir):
  os.makedirs(tmp_dir)

In [None]:
import torchvision

class TransformsSimCLR:
  """
  A stochastic data augmentation module that transforms any given data example randomly 
  resulting in two correlated views of the same example,
  denoted x ̃i and x ̃j, which we consider as a positive pair.
  """
  imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]

  def __init__(self, size, mean_std=imagenet_mean_std):
    s = 1
    color_jitter = torchvision.transforms.ColorJitter(
        0.8 * s, 0.8 * s, 0.8 * s, 0.2 * s
    )
    self.train_transform = torchvision.transforms.Compose(
        [
         torchvision.transforms.RandomResizedCrop(size=size),
         torchvision.transforms.RandomHorizontalFlip(),  # with 0.5 probability
         torchvision.transforms.RandomApply([color_jitter], p=0.8),
         torchvision.transforms.RandomGrayscale(p=0.2),
         torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize(*mean_std)
        ]
    )

    self.test_transform = torchvision.transforms.Compose(
        [
         torchvision.transforms.Resize(size=size),
         torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize(*mean_std)
        ]
    )

  def __call__(self, x):
    return self.train_transform(x), self.train_transform(x)  

In [None]:
from torchvision import datasets
import torch

# dataset
train_dataset = datasets.CIFAR10(
    args.dataset_dir,
    download=True,
    transform=TransformsSimCLR(size=args.image_size) # paper 224
)

if args.debug:
  train_dataset = torch.utils.data.Subset(train_dataset, range(0, args.debug_subset_size))
  train_dataset.classes = train_dataset.dataset.classes
  train_dataset.targets = train_dataset.dataset.targets

# if args.dryrun:
#   train_dataset = torch.utils.data.Subset(train_dataset, range(0, args.dryrun_subset_size))
#   train_dataset.classes = train_dataset.dataset.classes
#   train_dataset.targets = train_dataset.dataset.targets

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True, 
    drop_last=True,
    num_workers=args.num_workers,
    pin_memory=True
)

Files already downloaded and verified


In [None]:
import copy
import random
from functools import wraps

import torch
from torch import nn
import torch.nn.functional as F

# helper functions


def default(val, def_val):
  return def_val if val is None else val


def flatten(t):
  return t.reshape(t.shape[0], -1)


def singleton(cache_key):
  def inner_fn(fn):
    @wraps(fn)
    def wrapper(self, *args, **kwargs):
      instance = getattr(self, cache_key)
      if instance is not None:
          return instance

      instance = fn(self, *args, **kwargs)
      setattr(self, cache_key, instance)
      return instance

    return wrapper

  return inner_fn


# loss fn


def loss_fn(x, y):
  x = F.normalize(x, dim=-1, p=2)
  y = F.normalize(y, dim=-1, p=2)
  return 2 - 2 * (x * y).sum(dim=-1)


# augmentation utils


class RandomApply(nn.Module):
  def __init__(self, fn, p):
    super().__init__()
    self.fn = fn
    self.p = p

  def forward(self, x):
    if random.random() > self.p:
      return x
    return self.fn(x)


# exponential moving average


class EMA:
  def __init__(self, beta):
    super().__init__()
    self.beta = beta

  def update_average(self, old, new):
    if old is None:
      return new
    return old * self.beta + (1 - self.beta) * new


def update_moving_average(ema_updater, ma_model, current_model):
  for current_params, ma_params in zip(
      current_model.parameters(), ma_model.parameters()
  ):
    old_weight, up_weight = ma_params.data, current_params.data
    ma_params.data = ema_updater.update_average(old_weight, up_weight)


# MLP class for projector and predictor


class MLP(nn.Module):
  def __init__(self, dim, projection_size, hidden_size=4096):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )

  def forward(self, x):
    return self.net(x)


# a wrapper class for the base neural network
# will manage the interception of the hidden layer output
# and pipe it into the projecter and predictor nets


class NetWrapper(nn.Module):
  def __init__(self, net, projection_size, projection_hidden_size, layer=-2):
    super().__init__()
    self.net = net
    self.layer = layer

    self.projector = None
    self.projection_size = projection_size
    self.projection_hidden_size = projection_hidden_size

    self.hidden = None
    self.hook_registered = False

  def _find_layer(self):
    if type(self.layer) == str:
      modules = dict([*self.net.named_modules()])
      return modules.get(self.layer, None)
    elif type(self.layer) == int:
      children = [*self.net.children()]
      return children[self.layer]
    return None

  def _hook(self, _, __, output):
    self.hidden = flatten(output)

  def _register_hook(self):
    layer = self._find_layer()
    assert layer is not None, f"hidden layer ({self.layer}) not found"
    handle = layer.register_forward_hook(self._hook)
    self.hook_registered = True

  @singleton("projector")
  def _get_projector(self, hidden):
    _, dim = hidden.shape
    projector = MLP(dim, self.projection_size, self.projection_hidden_size)
    return projector.to(hidden)

  def get_representation(self, x):
    if not self.hook_registered:
      self._register_hook()

    if self.layer == -1:
      return self.net(x)

    _ = self.net(x)
    hidden = self.hidden
    self.hidden = None
    assert hidden is not None, f"hidden layer {self.layer} never emitted an output"
    return hidden

  def forward(self, x):
    representation = self.get_representation(x)
    projector = self._get_projector(representation)
    projection = projector(representation)
    return projection


# main class


class BYOL(nn.Module):
  def __init__(
      self,
      net,
      image_size,
      hidden_layer=-2,
      projection_size=256,
      projection_hidden_size=4096,
      augment_fn=None,
      moving_average_decay=0.99,
  ):
    super().__init__()

    self.online_encoder = NetWrapper(
        net, projection_size, projection_hidden_size, layer=hidden_layer
    )
    self.target_encoder = None
    self.target_ema_updater = EMA(moving_average_decay)

    self.online_predictor = MLP(
        projection_size, projection_size, projection_hidden_size
    )

    # send a mock image tensor to instantiate singleton parameters
    self.forward(torch.randn(2, 3, image_size, image_size), torch.randn(2, 3, image_size, image_size))

  @singleton("target_encoder")
  def _get_target_encoder(self):
    target_encoder = copy.deepcopy(self.online_encoder)
    return target_encoder

  def reset_moving_average(self):
    del self.target_encoder
    self.target_encoder = None

  def update_moving_average(self):
    assert (
        self.target_encoder is not None
    ), "target encoder has not been created yet"
    update_moving_average(
        self.target_ema_updater, self.target_encoder, self.online_encoder
    )

  def forward(self, image_one, image_two):
    online_proj_one = self.online_encoder(image_one)
    online_proj_two = self.online_encoder(image_two)

    online_pred_one = self.online_predictor(online_proj_one)
    online_pred_two = self.online_predictor(online_proj_two)

    with torch.no_grad():
      target_encoder = self._get_target_encoder()
      target_proj_one = target_encoder(image_one)
      target_proj_two = target_encoder(image_two)

    loss_one = loss_fn(online_pred_one, target_proj_two.detach())
    loss_two = loss_fn(online_pred_two, target_proj_one.detach())

    loss = loss_one + loss_two
    return loss.mean()

In [None]:
from torchvision import models

# model
if args.resnet_version == "resnet18":
  resnet = models.resnet18(pretrained=False)
elif args.resnet_version == "resnet50":
  resnet = models.resnet50(pretrained=False)
else:
  raise NotImplementedError("ResNet not implemented")

model = BYOL(resnet, image_size=args.image_size, hidden_layer="avgpool")
model = model.cuda()
model = torch.nn.DataParallel(model)

In [None]:
from torch.optim import Adam, SGD
from torch.optim.optimizer import Optimizer

class LARS(Optimizer):
  def __init__(self, named_modules, lr, momentum=0.9, trust_coef=1e-3, weight_decay=1.5e-6, exclude_bias_from_adaption=True):
    defaults = dict(momentum=momentum, lr=lr, weight_decay=weight_decay, trust_coef=trust_coef)
    parameters = self.exclude_from_model(named_modules, exclude_bias_from_adaption)
    super(LARS, self).__init__(parameters, defaults)

  @torch.no_grad() 
  def step(self):
    for group in self.param_groups: # only 1 group in most cases 
      weight_decay = group['weight_decay']
      momentum = group['momentum']
      lr = group['lr']
      trust_coef = group['trust_coef']
      # print(group['name'])
      # eps = group['eps']
      for p in group['params']:
        # breakpoint()
        if p.grad is None:
          continue
        global_lr = lr
        velocity = self.state[p].get('velocity', 0)  
        # if name in self.exclude_from_layer_adaptation:
        if self._use_weight_decay(group):
          p.grad.data += weight_decay * p.data 

        trust_ratio = 1.0 
        if self._do_layer_adaptation(group):
          w_norm = torch.norm(p.data, p=2)
          g_norm = torch.norm(p.grad.data, p=2)
          trust_ratio = trust_coef * w_norm / g_norm if w_norm > 0 and g_norm > 0 else 1.0 
        scaled_lr = global_lr * trust_ratio # trust_ratio is the local_lr 
        next_v = momentum * velocity + scaled_lr * p.grad.data 
        update = next_v
        p.data = p.data - update 

  def _use_weight_decay(self, group):
    return False if group['name'] == 'exclude' else True
  def _do_layer_adaptation(self, group):
    return False if group['name'] == 'exclude' else True

  def exclude_from_model(self, named_modules, exclude_bias_from_adaption=True):
    base = [] 
    exclude = []
    for name, module in named_modules:
      if type(module) in [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d]:
        # if isinstance(module, torch.nn.modules.batchnorm._BatchNorm)
        for name2, param in module.named_parameters():
          exclude.append(param)
      else:
        for name2, param in module.named_parameters():
          if name2 == 'bias':
            exclude.append(param)
          elif name2 == 'weight':
            base.append(param)
          else:
            pass # non leaf modules 
    return [{
        'name': 'base',
        'params': base
        },{
        'name': 'exclude',
        'params': exclude
    }] if exclude_bias_from_adaption == True else [{
        'name': 'base',
        'params': base+exclude 
    }]

In [None]:
# optimizer
if args.optim == 'lars':
  optimizer = LARS(model.named_modules(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optim == 'adam':
  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
elif args.optim == 'sgd':
  optimizer = torch.optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9)
else: # default = adam
  optimizer = LARS(model.named_modules(), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
from tqdm import tqdm
import numpy as np

writer = SummaryWriter()


global_step = 0
for epoch in tqdm(range(args.num_epochs), desc=f'Training'):
  metrics = defaultdict(list)
  for step, ((x_i, x_j), _) in enumerate(train_loader):
    x_i = x_i.cuda(non_blocking=True)
    x_j = x_j.cuda(non_blocking=True)

    loss = model(x_i, x_j)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    model.module.update_moving_average()  # update moving average of target encoder

    # if step % 1 == 0:
    #   print(f"Step [{step}/{len(train_loader)}]:\tLoss: {loss.item()}")

    writer.add_scalar("Loss/train_step", loss, global_step)
    metrics["Loss/train"].append(loss.item())
    global_step += 1

  # write metrics to TensorBoard
  for k, v in metrics.items():
    writer.add_scalar(k, np.array(v).mean(), epoch)

  if epoch % args.checkpoint_epochs == 0:
    ckpt_path = os.path.join(tmp_dir, f"byol1_{args.optim}_{epoch}.pt")
    print(f"Saving model at epoch {epoch}")
    torch.save(resnet.state_dict(), ckpt_path)

    # let other workers wait until model is finished
    # dist.barrier()

# save your improved network
ckpt_path = os.path.join(tmp_dir, f"byol1_{args.optim}_final.pt")
torch.save(resnet.state_dict(), ckpt_path)


Training:   0%|          | 0/10 [00:00<?, ?it/s][A
Training:  10%|█         | 1/10 [02:20<21:07, 140.85s/it][A

Saving model at epoch 0



Training:  20%|██        | 2/10 [04:41<18:45, 140.64s/it][A
Training:  30%|███       | 3/10 [07:02<16:26, 140.87s/it][A
Training:  40%|████      | 4/10 [09:25<14:08, 141.46s/it][A
Training:  50%|█████     | 5/10 [11:47<11:48, 141.63s/it][A
Training:  60%|██████    | 6/10 [14:09<09:26, 141.72s/it][A
Training:  70%|███████   | 7/10 [16:30<07:04, 141.54s/it][A
Training:  80%|████████  | 8/10 [18:51<04:42, 141.43s/it][A
Training:  90%|█████████ | 9/10 [21:12<02:21, 141.33s/it][A
Training: 100%|██████████| 10/10 [23:32<00:00, 141.24s/it]


In [None]:
from tqdm import tqdm
import torch.nn.functional as F 
# code copied from https://colab.research.google.com/github/facebookresearch/moco/blob/colab-notebook/colab/moco_cifar10_demo.ipynb#scrollTo=RI1Y8bSImD7N
# test using a knn monitor
def knn_monitor(net, memory_data_loader, test_data_loader, k=200, t=0.1, hide_progress=False):
  net.eval()
  classes = len(memory_data_loader.dataset.classes)
  total_top1, total_top5, total_num, feature_bank = 0.0, 0.0, 0, []
  with torch.no_grad():
    # generate feature bank
    for data, target in tqdm(memory_data_loader, desc='Feature extracting', leave=False, disable=hide_progress):
      feature = net(data.cuda(non_blocking=True))
      feature = F.normalize(feature, dim=1)
      feature_bank.append(feature)
    # [D, N]
    feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
    # [N]
    feature_labels = torch.tensor(memory_data_loader.dataset.targets, device=feature_bank.device)
    # loop test data to predict the label by weighted knn search
    test_bar = tqdm(test_data_loader, desc='kNN', disable=hide_progress)
    for data, target in test_bar:
      data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
      feature = net(data)
      feature = F.normalize(feature, dim=1)
      
      pred_labels = knn_predict(feature, feature_bank, feature_labels, classes, k, t)

      total_num += data.size(0)
      total_top1 += (pred_labels[:, 0] == target).float().sum().item()
      test_bar.set_postfix({'Accuracy':total_top1 / total_num * 100})
  return total_top1 / total_num * 100

# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and https://github.com/leftthomas/SimCLR
def knn_predict(feature, feature_bank, feature_labels, classes, knn_k, knn_t):
  # compute cos similarity between each feature vector and feature bank ---> [B, N]
  sim_matrix = torch.mm(feature, feature_bank)
  # [B, K]
  sim_weight, sim_indices = sim_matrix.topk(k=knn_k, dim=-1)
  # [B, K]
  sim_labels = torch.gather(feature_labels.expand(feature.size(0), -1), dim=-1, index=sim_indices)
  sim_weight = (sim_weight / knn_t).exp()

  # counts for each class
  one_hot_label = torch.zeros(feature.size(0) * knn_k, classes, device=sim_labels.device)
  # [B*K, C]
  one_hot_label = one_hot_label.scatter(dim=-1, index=sim_labels.view(-1, 1), value=1.0)
  # weighted score ---> [B, C]
  pred_scores = torch.sum(one_hot_label.view(feature.size(0), -1, classes) * sim_weight.unsqueeze(dim=-1), dim=1)

  pred_labels = pred_scores.argsort(dim=-1, descending=True)
  return pred_labels

In [None]:
from PIL import Image

class Transform_single():
  imagenet_mean_std = [[0.485, 0.456, 0.406],[0.229, 0.224, 0.225]]
  def __init__(self, size, train, normalize=imagenet_mean_std):
  #def __init__(self, size, train):
    if train == True:
      self.transform = torchvision.transforms.Compose(
          [
           torchvision.transforms.RandomResizedCrop(size, scale=(0.08, 1.0), 
                                        ratio=(3.0/4.0,4.0/3.0), 
                                        interpolation=Image.BICUBIC
                                        ),
           torchvision.transforms.RandomHorizontalFlip(),
           torchvision.transforms.ToTensor(),
           torchvision.transforms.Normalize(*normalize)
          ]
      )
    else:
      self.transform = torchvision.transforms.Compose(
          [
           torchvision.transforms.Resize(int(size*(8/7)), 
                             interpolation=Image.BICUBIC
                             ), # 224 -> 256 
           torchvision.transforms.CenterCrop(size),
           torchvision.transforms.ToTensor(),
           torchvision.transforms.Normalize(*normalize)
          ]
      )

  def __call__(self, x):
    return self.transform(x)

memory_dataset = datasets.CIFAR10(
    root=args.dataset_dir, 
    train=True, 
    download=False, 
    transform=Transform_single(size=args.image_size, train=False), 
    )
memory_loader = torch.utils.data.DataLoader(
    memory_dataset, 
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True,
    pin_memory=True,
    )

test_dataset = datasets.CIFAR10(
    root=args.dataset_dir, 
    train=False, 
    download=False, 
    transform=Transform_single(size=args.image_size, train=False), 
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, 
    shuffle=False,
    batch_size=args.batch_size,
    num_workers=args.num_workers,
    drop_last=True,
    pin_memory=True,
)

In [None]:
accuracy = knn_monitor(model.module.online_encoder.net, memory_loader, test_loader, k=min(200, len(memory_loader.dataset)), hide_progress=True)
print('Accuracy:', accuracy)


Feature extracting:   0%|          | 0/195 [00:00<?, ?it/s][A
Feature extracting:   1%|          | 1/195 [00:00<00:21,  9.07it/s][A
Feature extracting:   2%|▏         | 3/195 [00:00<00:19,  9.73it/s][A
Feature extracting:   3%|▎         | 5/195 [00:00<00:17, 10.59it/s][A
Feature extracting:   4%|▎         | 7/195 [00:00<00:17, 11.06it/s][A
Feature extracting:   5%|▍         | 9/195 [00:00<00:16, 11.54it/s][A
Feature extracting:   6%|▌         | 11/195 [00:00<00:15, 12.10it/s][A
Feature extracting:   7%|▋         | 13/195 [00:01<00:14, 12.45it/s][A
Feature extracting:   8%|▊         | 15/195 [00:01<00:14, 12.65it/s][A
Feature extracting:   9%|▊         | 17/195 [00:01<00:14, 12.62it/s][A
Feature extracting:  10%|▉         | 19/195 [00:01<00:13, 12.63it/s][A
Feature extracting:  11%|█         | 21/195 [00:01<00:13, 12.62it/s][A
Feature extracting:  12%|█▏        | 23/195 [00:01<00:13, 12.48it/s][A
Feature extracting:  13%|█▎        | 25/195 [00:02<00:14, 11.81it/s][A
Featu

Accuracy: 36.44831730769231





In [None]:
class AverageMeter():
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):

      
        self.name = name
        self.fmt = fmt
        self.log = []
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def reset(self):
        self.log.append(self.avg)
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

class LR_Scheduler(object):
  def __init__(self, optimizer, warmup_epochs, warmup_lr, num_epochs, base_lr, final_lr, iter_per_epoch, constant_predictor_lr=False):
    self.base_lr = base_lr
    self.constant_predictor_lr = constant_predictor_lr
    warmup_iter = iter_per_epoch * warmup_epochs
    warmup_lr_schedule = np.linspace(warmup_lr, base_lr, warmup_iter)
    decay_iter = iter_per_epoch * (num_epochs - warmup_epochs)
    cosine_lr_schedule = final_lr+0.5*(base_lr-final_lr)*(1+np.cos(np.pi*np.arange(decay_iter)/decay_iter))
    
    self.lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
    self.optimizer = optimizer
    self.iter = 0
    self.current_lr = 0
  def step(self):
    for param_group in self.optimizer.param_groups:

      if self.constant_predictor_lr and param_group['name'] == 'predictor':
        param_group['lr'] = self.base_lr
      else:
        lr = param_group['lr'] = self.lr_schedule[self.iter]
    
    self.iter += 1
    self.current_lr = lr
    return lr
  def get_lr(self):
    return self.current_lr

def linear_eval(args, eval_from):
  eval_train = torchvision.datasets.CIFAR10(
      root=args.dataset_dir, 
      train=True, 
      download=False, 
      transform=Transform_single(size=args.image_size, train=True), 
      )
  eval_train_loader = torch.utils.data.DataLoader(
      eval_train, 
      shuffle=True,
      batch_size=args.batch_size,
      num_workers=args.num_workers,
      drop_last=True,
      pin_memory=True,
      )

  eval_test = torchvision.datasets.CIFAR10(
      root=args.dataset_dir, 
      train=False, 
      download=False, 
      transform=Transform_single(size=args.image_size, train=False), 
      )
  eval_test_loader = torch.utils.data.DataLoader(
      eval_test, 
      shuffle=False,
      batch_size=args.batch_size,
      num_workers=args.num_workers,
      drop_last=True,
      pin_memory=True,
      )
    
  eval_model = eval(f"models.{args.resnet_version}()")
  eval_model.output_dim = eval_model.fc.in_features
  eval_model.fc = torch.nn.Identity()
  eval_classifier = nn.Linear(in_features=eval_model.output_dim, out_features=10, bias=True).to(args.device)

  ###
  assert eval_from is not None
  eval_save_dict = torch.load(eval_from, map_location='cuda')
  # eval_msg = eval_model.load_state_dict({k[9:]:v for k, v in eval_save_dict['state_dict'].items() if k.startswith('backbone.')}, strict=True)
  
  # print(eval_msg)
  eval_model = eval_model.to(args.device)
  eval_model = torch.nn.DataParallel(eval_model)
  # if torch.cuda.device_count() > 1: eval_classifier = torch.nn.SyncBatchNorm.convert_sync_batchnorm(eval_classifier)
  eval_classifier = torch.nn.DataParallel(eval_classifier)
  # define optimizer 'sgd', eval_classifier, lr=eval_base_lr=30, momentum=eval_optim_momentum-0.9, weight_decay=eval_optim_weight_decay=0
  predictor_prefix = ('module.predictor', 'predictor')
  parameters = [{
      'name': 'base',
      'params': [param for name, param in eval_classifier.named_parameters() if not name.startswith(predictor_prefix)],
      'lr': 30
  },{
      'name': 'predictor',
      'params': [param for name, param in eval_classifier.named_parameters() if name.startswith(predictor_prefix)],
      'lr': 30
  }]
  eval_optimizer = torch.optim.SGD(parameters, lr=30, momentum=0.9, weight_decay=0)

  # define lr scheduler
  eval_lr_scheduler = LR_Scheduler(
      eval_optimizer,
      0, 0*args.batch_size/256, 
      30, 30*args.batch_size/256, 0*args.batch_size/256, 
      len(eval_train_loader),
  )

  eval_loss_meter = AverageMeter(name='Loss')
  eval_acc_meter = AverageMeter(name='Accuracy')

  # Start training
  eval_global_progress = tqdm(range(0, 30), desc=f'Evaluating')
  for epoch in eval_global_progress:
    eval_loss_meter.reset()
    eval_model.eval()
    eval_classifier.train()
    eval_local_progress = tqdm(eval_train_loader, desc=f'Epoch {epoch}/{30}', disable=True)
    
    for idx, (images, labels) in enumerate(eval_local_progress):

      eval_classifier.zero_grad()
      with torch.no_grad():
        eval_feature = eval_model(images.to(args.device))

      eval_preds = eval_classifier(eval_feature)

      eval_loss = F.cross_entropy(eval_preds, labels.to(args.device))

      eval_loss.backward()
      eval_optimizer.step()
      eval_loss_meter.update(eval_loss.item())
      eval_lr = eval_lr_scheduler.step()
      eval_local_progress.set_postfix({'lr':eval_lr, "loss":eval_loss_meter.val, 'loss_avg':eval_loss_meter.avg})

  eval_classifier.eval()
  eval_correct, eval_total = 0, 0
  eval_acc_meter.reset()
  for idx, (images, labels) in enumerate(eval_test_loader):
    with torch.no_grad():
      eval_feature = eval_model(images.to(args.device))
      eval_preds = eval_classifier(eval_feature).argmax(dim=1)
      eval_correct = (eval_preds == labels.to(args.device)).sum().item()
      eval_acc_meter.update(eval_correct/eval_preds.shape[0])
  print(f'Accuracy = {eval_acc_meter.avg*100:.2f}')

In [None]:
if args.eval is not False:
  linear_eval(args, ckpt_path)


Evaluating:   0%|          | 0/30 [00:00<?, ?it/s][A
Evaluating:   3%|▎         | 1/30 [00:25<12:11, 25.24s/it][A
Evaluating:   7%|▋         | 2/30 [00:50<11:49, 25.33s/it][A
Evaluating:  10%|█         | 3/30 [01:16<11:25, 25.38s/it][A
Evaluating:  13%|█▎        | 4/30 [01:41<10:58, 25.33s/it][A
Evaluating:  17%|█▋        | 5/30 [02:06<10:29, 25.20s/it][A
Evaluating:  20%|██        | 6/30 [02:32<10:09, 25.40s/it][A
Evaluating:  23%|██▎       | 7/30 [02:58<09:47, 25.56s/it][A
Evaluating:  27%|██▋       | 8/30 [03:23<09:23, 25.62s/it][A
Evaluating:  30%|███       | 9/30 [03:50<09:01, 25.76s/it][A
Evaluating:  33%|███▎      | 10/30 [04:16<08:37, 25.87s/it][A
Evaluating:  37%|███▋      | 11/30 [04:42<08:13, 25.96s/it][A
Evaluating:  40%|████      | 12/30 [05:08<07:47, 25.99s/it][A
Evaluating:  43%|████▎     | 13/30 [05:34<07:23, 26.11s/it][A
Evaluating:  47%|████▋     | 14/30 [06:02<07:04, 26.53s/it][A
Evaluating:  50%|█████     | 15/30 [06:28<06:35, 26.38s/it][A
Evaluatin

Accuracy = 34.48


In [None]:
import torch.nn as nn
from torchvision import models

if args.resnet_version is not None:
  resnet2 = eval(f'models.{args.resnet_version}()')
  # resnet = eval(f"{backbone_name}()")
  resnet2.output_dim = resnet2.fc.in_features
  resnet2.fc = nn.Identity()
else:
  raise NotImplementedError("Backbone is not implemented!")

In [None]:
import copy
import math
from torch.nn import functional

class MLP2(nn.Module):
  def __init__(self, input_dim):
    super().__init__()

    self.net = nn.Sequential(
        nn.Linear(input_dim, 4096), 
        nn.BatchNorm1d(4096, momentum=1-0.9, eps=1e-5), 
        nn.ReLU(inplace=True), 
        nn.Linear(4096, 256)
    )
  def forward(self, x):
    return self.net(x)

class BYOL2(nn.Module):
  def __init__(self, backbone):
    super().__init__()

    self.backbone=backbone
    self.projector = MLP2(resnet2.output_dim)
    self.online_encoder = nn.Sequential(
        self.backbone, 
        self.projector,
    )
    self.predictor = MLP2(256)
    self.target_encoder = copy.deepcopy(self.online_encoder)

  def target_ema(self, k, K, base_tau=0.996):
    return 1-(1-base_tau)*(math.cos(math.pi*k/K)+1)/2

  def reset_moving_average(self):
    del self.target_encoder
    self.target_encoder = copy.deepcopy(self.online_encoder)

  def update_moving_average(self, global_step, max_steps):
    tau = self.target_ema(global_step, max_steps)
    for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
      target.data = tau*target.data + (1-tau)*online.data
  
  def loss_function(self, p, z):
    p=functional.normalize(p, dim=-1, p=2)
    z=functional.normalize(z, dim=-1, p=2)
    return 2 - 2*(p*z).sum(dim=-1)

  def forward(self, x1, x2):
    z1_online, z2_online = self.online_encoder(x1), self.online_encoder(x2)
    p1_online, p2_online = self.predictor(z1_online), self.predictor(z2_online)
    with torch.no_grad():
      z1_target, z2_target = self.target_encoder(x1), self.target_encoder(x2)
    
    loss1, loss2 = self.loss_function(p1_online, z2_target.detach()), self.loss_function(p2_online, z1_target.detach())

    loss = loss1+loss2
    return loss.mean()

In [None]:
byol = BYOL2(resnet2)
byol = byol.to(args.device)
byol = torch.nn.DataParallel(byol)

In [None]:
predictor_prefix2 = ('module.predictor', 'predictor')
parameters2 = [{
    'name': 'base',
    'params': [param for name, param in byol.named_parameters() if not name.startswith(predictor_prefix2)],
    'lr': args.learning_rate
},{
    'name': 'predictor',
    'params': [param for name, param in byol.named_parameters() if name.startswith(predictor_prefix2)],
    'lr': args.learning_rate
}]
if args.optim == 'lars':
  optimizer2 = LARS(byol.named_modules(), lr=args.learning_rate*args.batch_size/256, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optim == 'adam':
  optimizer2 = Adam(parameters2, lr=args.learning_rate*args.batch_size/256)
elif args.optim == 'sgd':
  optimizer2 = SGD(parameters2, lr=args.learning_rate*args.batch_size/256, momentum=0.9)
else: # default is LARS
  optimizer2 = LARS(byol.named_modules(), lr=args.learning_rate*args.batch_size/256, weight_decay=args.weight_decay)

In [None]:
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
from datetime import datetime
import os

# if not os.path.exists('./ckpt'):
#   os.makedirs('./ckpt')

# tmp_dir = os.path.join('./ckpt', f"{args.resnet_version}")
# if not os.path.exists(tmp_dir):
#   os.makedirs(tmp_dir)
# tmp_dir = os.path.join(tmp_dir, f"{args.optim}")
# if not os.path.exists(tmp_dir):
#   os.makedirs(tmp_dir)
# tmp_dir = os.path.join(tmp_dir, f"{datetime.now().strftime('%m%d%H')}")
# if not os.path.exists(tmp_dir):
#   os.makedirs(tmp_dir)

writer2 = SummaryWriter()

global_step = 0
for epoch in tqdm(range(0, args.num_epochs)):
  metrics = defaultdict(list)
  
  for step, ((x1, x2), labels) in enumerate(train_loader):
    x1, x2 = x1.cuda(non_blocking=True), x2.cuda(non_blocking=True)

    main_loss = byol(x1, x2)
    optimizer2.zero_grad()
    main_loss.backward()
    optimizer2.step()
    byol.module.update_moving_average(epoch, args.num_epochs)
    
    writer2.add_scalar("Loss/train_step", main_loss, global_step)
    metrics["Loss/train"].append(main_loss.item())
    global_step += 1
  
  for k, v in metrics.items():
    writer2.add_scalar(k, np.array(v).mean(), epoch)

  if epoch%args.checkpoint_epochs == 0:
    ckpt_path = os.path.join(tmp_dir, f"byol2_{args.optim}_{epoch}.pt")
    print(f'Saving model at epoch {epoch}')
    torch.save(resnet2.state_dict(), ckpt_path)

ckpt_path = os.path.join(tmp_dir, f"byol2_{args.optim}_final.pt")
print('Saving final model')
torch.save(resnet2.state_dict(), ckpt_path)


  0%|          | 0/10 [00:00<?, ?it/s][A
 10%|█         | 1/10 [02:23<21:27, 143.01s/it][A

Saving model at epoch 0



 20%|██        | 2/10 [04:44<19:00, 142.56s/it][A
 30%|███       | 3/10 [07:03<16:31, 141.61s/it][A
 40%|████      | 4/10 [09:24<14:08, 141.40s/it][A
 50%|█████     | 5/10 [11:45<11:46, 141.31s/it][A
 60%|██████    | 6/10 [14:06<09:24, 141.09s/it][A
 70%|███████   | 7/10 [16:26<07:02, 140.73s/it][A
 80%|████████  | 8/10 [18:48<04:42, 141.06s/it][A
 90%|█████████ | 9/10 [21:11<02:21, 141.79s/it][A
100%|██████████| 10/10 [23:33<00:00, 141.33s/it]

Saving final model





In [None]:
accuracy2 = knn_monitor(byol.module.backbone, memory_loader, test_loader, k=min(200, len(memory_loader.dataset)), hide_progress=True)
print('Accuracy:', accuracy2)


Feature extracting:   0%|          | 0/195 [00:00<?, ?it/s][A
Feature extracting:   1%|          | 2/195 [00:00<00:16, 11.91it/s][A
Feature extracting:   2%|▏         | 4/195 [00:00<00:15, 11.97it/s][A
Feature extracting:   3%|▎         | 6/195 [00:00<00:15, 12.15it/s][A
Feature extracting:   4%|▍         | 8/195 [00:00<00:15, 11.78it/s][A
Feature extracting:   5%|▌         | 10/195 [00:00<00:15, 11.61it/s][A
Feature extracting:   6%|▌         | 12/195 [00:01<00:16, 11.36it/s][A
Feature extracting:   7%|▋         | 13/195 [00:01<00:17, 10.64it/s][A
Feature extracting:   7%|▋         | 14/195 [00:01<00:17, 10.32it/s][A
Feature extracting:   8%|▊         | 16/195 [00:01<00:16, 10.55it/s][A
Feature extracting:   9%|▉         | 18/195 [00:01<00:16, 10.77it/s][A
Feature extracting:  10%|█         | 20/195 [00:01<00:15, 11.26it/s][A
Feature extracting:  11%|█▏        | 22/195 [00:01<00:14, 11.69it/s][A
Feature extracting:  12%|█▏        | 24/195 [00:02<00:14, 12.08it/s][A
Feat

Accuracy: 35.266426282051285





In [None]:
if args.eval is not False:
  linear_eval(args, ckpt_path)


Evaluating:   0%|          | 0/30 [00:00<?, ?it/s][A
Evaluating:   3%|▎         | 1/30 [00:25<12:25, 25.71s/it][A
Evaluating:   7%|▋         | 2/30 [00:50<11:55, 25.54s/it][A
Evaluating:  10%|█         | 3/30 [01:16<11:27, 25.47s/it][A
Evaluating:  13%|█▎        | 4/30 [01:41<11:02, 25.49s/it][A
Evaluating:  17%|█▋        | 5/30 [02:07<10:36, 25.45s/it][A
Evaluating:  20%|██        | 6/30 [02:32<10:12, 25.53s/it][A
Evaluating:  23%|██▎       | 7/30 [02:58<09:45, 25.45s/it][A
Evaluating:  27%|██▋       | 8/30 [03:23<09:22, 25.57s/it][A
Evaluating:  30%|███       | 9/30 [03:49<08:57, 25.60s/it][A
Evaluating:  33%|███▎      | 10/30 [04:14<08:30, 25.54s/it][A
Evaluating:  37%|███▋      | 11/30 [04:40<08:04, 25.50s/it][A
Evaluating:  40%|████      | 12/30 [05:06<07:40, 25.56s/it][A
Evaluating:  43%|████▎     | 13/30 [05:32<07:21, 25.94s/it][A
Evaluating:  47%|████▋     | 14/30 [05:59<06:56, 26.04s/it][A
Evaluating:  50%|█████     | 15/30 [06:25<06:31, 26.12s/it][A
Evaluatin

Accuracy = 33.63
