In [1]:
import torchvision
from torchvision import transforms
import torch
import torch.nn.functional as F
import torchvision.transforms as T

from sklearn.svm import LinearSVC

import numpy as np
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import os

from view_transform import ViewTransform

## Hyperparameters

In [2]:
'''
TODO
- We use a learning rate warm-up period of 10 epochs, after which we reduce the learning rate by a factor of 1000 using a co- sine decay schedule 
- LARS Optimizer there are implementations on github etc. 
'''

'\nTODO\n- We use a learning rate warm-up period of 10 epochs, after which we reduce the learning rate by a factor of 1000 using a co- sine decay schedule \n- LARS Optimizer there are implementations on github etc. \n'

In [3]:
torch.manual_seed(42)
random.seed(42)

epochs = 1 # Original set to 1000 
dim = 1000 # depends on specific encoder architecture ie. modifications of basic, as-is ResNet50

batch_size = 16 
num_workers = 4
device = 'cpu' # or 'cuda' for faster training

# VicREG
base_lr = 0.2
learning_rate = batch_size/256 * base_lr #  for barlow twins 
weight_decay = 1e-6

# BarlowTwins
# learning_rate = base_lr * batch_size / 256
# weight_decay = 1.5*1e-6

## Data 

In [4]:
linear_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        ])  

In [5]:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())  
trainset.transform = ViewTransform()
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

linear_trainset = trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=linear_transform)  
linear_trainloader = torch.utils.data.DataLoader(linear_trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=linear_transform)  
testnset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


## Model

In [6]:
# encoder = torchvision.models.resnet50()

def projector():
    proj_layers = []

    for i in range(4):
        proj_layers.append(torch.nn.Linear(dim, dim))
        proj_layers.append(torch.nn.ReLU(dim))
        proj_layers.append(torch.nn.BatchNorm1d(dim))
    
    return torch.nn.Sequential(*proj_layers)

## VicREG

In [7]:
#VicReg Paper - with modifications
def VIC_Reg(z1, z2):
    
    N = z1.shape[0]
    D = z1.shape[1]

    mu = 25 # as per VIC-Reg Paper -- Subject to change
    la = 25 # "   "
    nu = 1 #  "   "
    
    # invariance loss
    sim_loss = F.mse_loss(z1, z2)
    
    # variance loss
    std_z_a = torch.sqrt(z1.var(dim=0) + 1e-04)
    std_z_b = torch.sqrt(z2.var(dim=0) + 1e-04)
    std_loss = torch.mean(torch.relu(1 - std_z_a)) + torch.mean(torch.relu(1 - std_z_b))
    
    # covariance loss
    z1 = z1 - z1.mean(dim=0)
    z2 = z2 - z2.mean(dim=0)
    
    cov_z_a = (z1.T @ z1) / (N - 1)
    cov_z_b = (z2.T @ z2) / (N - 1)

    cov_z_a = cov_z_a[~torch.eye(cov_z_a.shape[0], dtype=bool)] # Off diags
    cov_z_b = cov_z_b[~torch.eye(cov_z_b.shape[0], dtype=bool)] # " "

    cov_loss = cov_z_a.pow_(2).sum() / D + cov_z_b.pow_(2).sum() / D
    loss = la * sim_loss + mu * std_loss + nu * cov_loss
    
    return loss

## Barlow Twins

In [36]:
#Barlow Twins Paper - with modifications

def barlow_twins(z1, z2):
    N = z1.shape[0]
    D = z1.shape[1]
    
    z1_norm = (z1 - z1.mean(0)) / z1.std(0) # NxD
    z2_norm = (z2 - z2.mean(0)) / z2.std(0) # NxD

    z1_norm = z1_norm[:, None , :]  
    z2_norm = z2_norm[:, None , :] 

    c = F.conv1d(z1_norm.T, z2_norm.T)

    
    # loss
    c_diff = (c - torch.eye(D)).pow(2) # DxD #multiplyoff-diagonalelemsofc_diffbylambda off_diagonal(c_diff).mul_(lambda)
    loss = c_diff.sum()
    
    return loss

In [35]:
barlow_twins(torch.rand((16,1000)), torch.rand((16,1000)))

torch.Size([16, 1, 1000])
torch.Size([1000, 1000, 1])


tensor(1.5011e+10)

## Train

In [37]:
def train(trainset, loss_mode = "VicReg"): 
    
    mode = loss_mode ==  "VicReg"

    encoder = torchvision.models.resnet50()
    rpoj = projector()
    model = torch.nn.Sequential(encoder, rpoj)

    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # TODO LARS

    for i in range(epochs):
        losses = []
        for (x1, x2), _ in tqdm(trainloader):

            optimizer.zero_grad()

            x1, x2 = x1.to(device), x2.to(device)
            z1, z2 = model(x1), model(x2)

            if mode: 
                loss = VIC_Reg(z1, z2)
            else:
                loss = barlow_twins(z1, z2)
            
            loss.backward()
            optimizer.step()
            losses.append(loss.detach().item())


        print(f"Epoch: {i}, loss: {np.mean(losses)}")
        #DL 1 Homework 1 
        os.makedirs('models', exist_ok=True)
        if mode:
            os.makedirs('VicReg', exist_ok=True)
            torch.save(encoder.state_dict(), f'models/VicReg/model_{batch_size}_epoch_{i}.pt')
        else:
            os.makedirs('BarlowTwins', exist_ok=True)
            torch.save(encoder.state_dict(), f'models/BarlowTwins/model_{batch_size}_epoch_{i}.pt')

    
    return encoder


In [19]:
encoder_vicreg = train(trainloader, loss_mode="VicReg")

100%|██████████| 3125/3125 [44:56<00:00,  1.16it/s]


Epoch: 0, loss: 39.53907440063477


In [20]:
encoder_barlow = train(trainloader)

100%|██████████| 3125/3125 [45:47<00:00,  1.14it/s]  


Epoch: 0, loss: 39.946143161621094


## Linear Head 

In [21]:
encoder_vicreg.eval()
encoder_barlow.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [55]:
# path duplicate - keep for now as convenience loader
def load_models(path="models/VicReg/model_16_epoch_0.pt"):
    encoder = torchvision.models.resnet50()
    saved = torch.load(path)
    encoder.load_state_dict(saved)
    return encoder

def linear_train(path="models/VicReg/model_16_epoch_0.pt"):
    
    encoder = load_models(path)
    encoder.eval()

    # I dont understand how the guys from the vicreg paper combinded 
    # this LinearSVC and pytorch, to my understanding indicated
    # by the fact that they use optimiser (but I am no expert, by any means) ^^

    clf = LinearSVC(random_state=0, tol=1e-5)

    model = torch.nn.Sequential(encoder, clf)
    model = model.to(device)

    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    for i in range(epochs):
        losses = []
        for x, y in tqdm(linear_trainloader):
            
            optimizer.zero_grad()
            output = model(x)
            
            loss = torch.nn.MSELoss(output, y)

            loss.backward()
            optimizer.step()
            losses.append(loss.detach().item())

    
    print(f"Epoch: {i}, loss: {np.mean(losses)}")


In [56]:
linear_train()

TypeError: sklearn.svm._classes.LinearSVC is not a Module subclass


Linear classification. We follow standard protocols Misra & Maaten (2020); Caron et al. (2020); Zbontar et al. (2021) and train linear models on top of the frozen representations. For VOC07 Everingham et al. (2010), we train a linear SVM with LIBLINEAR Fan et al. (2008). The images are center cropped and resized to 224 × 224, and the C values are computed with cross-validation. For Places205 Zhou et al. (2014) we use SGD with a learning rate of 0.003, a weight decay of 0.0001, a momentum of 0.9 and a batch size of 256, for 28 epochs. The learning rate is divided by 10 at epochs 4, 8 and 12. For Inaturalist2018 Horn et al. (2018), we use SGD with a learning rate of 0.005, a weight decay of 0.0001, a momentum of 0.9 and a batch size of 256, for 84 epochs. The learning rate is divided by 10 at epochs 24, 48 and 72.

In [None]:
linear_train("models/BarlowTwins/model_16_epoch_0.pt")