In [4]:
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)

implementaiton of https://arxiv.org/pdf/2108.11458.pdf

In [3]:
import torch
import torch.optim as optim
import torch.nn as nn

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader

import math

## SimSiam Encoder

### Model Definition

In [5]:
class BackBoneEncoder(nn.Module) : 
  '''
  SimSiam - Simple Siamese https://arxiv.org/pdf/2011.10566.pdf  
  '''

  def __init__(self, base_encoder, dim, pred_dim):

    super(BackBoneEncoder, self).__init__()

    self.encoder = base_encoder(num_classes = dim, zero_init_residual=True) 

    last_dim = self.encoder.fc.in_features

    self.encoder.fc = nn.Sequential(
        nn.Linear(last_dim, last_dim, bias = False), 
        nn.BatchNorm1d(last_dim),
        nn.Mish(inplace = True),
        nn.Linear(last_dim, last_dim, bias = False), 
        nn.BatchNorm1d(last_dim),
        nn.Mish(inplace=True),
        self.encoder.fc,
        nn.BatchNorm1d(dim, affine=False)
    )
    self.encoder.fc[6].bias.requires_grad = False

    self.predictor = nn.Sequential(
        nn.Linear(dim, pred_dim, bias = False), 
        nn.BatchNorm1d(pred_dim),
        nn.Mish(inplace = True), 
        nn.Linear(pred_dim, dim) 
    )

  def forward(self, x1, x2) : 

    z1 = self.encoder(x1)
    z2 = self.encoder(x2)

    p1 = self.predictor(z1)
    p2 = self.predictor(z2)

    # Detach as the stop gradient operation
    return p1, p2, z1.detach(), z2.detach()

### Data

Below we define the Augmentation strategies that will generate the non-constrative sampling.

In [6]:
augmentation = [
    transforms.RandomResizedCrop(28, scale=(.2, 1)),
    transforms.RandomApply([transforms.ColorJitter(.4, .4, .4, .1)], p = .6),
    # Now Omitted to speed up a bit 
    # TODO: Add more tranform
    transforms.RandomHorizontalFlip(),
    transforms.RandomGrayscale(p = 2),
    transforms.ToTensor()

]

class TwoCropTransform : 

  def __init__(self, base_transform) : 
    self.base_transform = base_transform

  def __call__(self, image) : 
    q = self.base_transform(image)
    k = self.base_transform(image)

    return [q, k]
  

We will use the CIFAR10 Dataset. 

In [7]:
train_data = datasets.CIFAR10(root = "data", train = True, download = True, transform = TwoCropTransform(transforms.Compose(augmentation)))
train_data_size = len(train_data)

Files already downloaded and verified


In [8]:
loader = DataLoader(train_data, batch_size = 100, shuffle=True, num_workers = 2)



In [9]:
model = BackBoneEncoder(models.__dict__["resnet18"], 2048, 512)

In [10]:
criterion = nn.CosineSimilarity(dim = 1).to("cuda")

optimizer = optim.SGD(model.parameters(), lr=0.05, momentum = .9, weight_decay=.0001)

In [11]:
NUM_EPOCH = 70

# Cosine Decay Schedule for lr decay.
def adjust_learning_rate(optimizer, init_lr, epoch):
    """Decay the learning rate based on schedule"""
    cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / NUM_EPOCH))
    for param_group in optimizer.param_groups:
        if 'fix_lr' in param_group and param_group['fix_lr']:
            param_group['lr'] = init_lr
        else:
            param_group['lr'] = cur_lr

In [18]:
losses = AverageMeter("Loss", ":.4f")
device = "cuda" if torch.cuda.is_available() else "cpu"
model.train()

model.to(device)

for epoch in range(NUM_EPOCH) : 

  for i, (images, _) in enumerate(loader):

    images[0] = images[0].to(device, non_blocking = True)
    images[1] = images[1].to(device, non_blocking = True)
    
    p1, p2, z1, z2 = model(images[0], images[1])

    loss = -(criterion(p1, z2).mean() + criterion(p2, z1).mean()) * 0.5

    losses.update(loss.item(), 100)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  print(f"Epoch : {epoch} , {losses}")
  adjust_learning_rate(optimizer, .05, epoch )


torch.save(model.state_dict(), "resnet18_70_V0.pt")



Epoch : 0 , Loss -0.7421 (-0.7689)


KeyboardInterrupt: ignored