In [2]:
import torchvision
from torchvision import transforms
import torch
import torch.nn.functional as F
import torchvision.transforms as T

import numpy as np
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import os

from view_transform import ViewTransform

## Hyperparameters

In [3]:
'''
TODO
- We use a learning rate warm-up period of 10 epochs, after which we reduce the learning rate by a factor of 1000 using a co- sine decay schedule 
- LARS Optimizer there are implementations on github etc. 
- TOP-5 Acc 
- encoder = torchvision.models.resnet18() # also try with pretrained=true

-->> It might be interesting investigate the efficency frontier between max_batch and num_positives 

'''

'\nTODO\n- We use a learning rate warm-up period of 10 epochs, after which we reduce the learning rate by a factor of 1000 using a co- sine decay schedule \n- LARS Optimizer there are implementations on github etc. \n- TOP-5 Acc \n- encoder = torchvision.models.resnet18() # also try with pretrained=true\n\n-->> It might be interesting investigate the efficency frontier between max_batch and num_positives \n\n'

In [None]:
torch.manual_seed(42)
random.seed(42)

epochs = 1 # Original set to 1000 
dim = 1000 # depends on specific encoder architecture ie. modifications of basic, as-is resnet18
num_positives = 2

max_batch = 64
batch_size = max_batch / num_positives 
num_workers = 4
device = 'cpu' # or 'cuda' for faster training

# VicREG
base_lr = 0.2
learning_rate = batch_size/256 * base_lr 
weight_decay = 1e-6

# BarlowTwins
# learning_rate = base_lr * batch_size / 256
# weight_decay = 1.5*1e-6

## Data 

In [16]:
linear_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]),
        ])  

In [17]:
num_classes = 10

#trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())  
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())  
trainset.transform = ViewTransform(num_positives)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

#linear_trainset = trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=linear_transform)  
linear_trainset = trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=linear_transform)  
linear_trainloader = torch.utils.data.DataLoader(linear_trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

#testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=linear_transform)  
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=linear_transform)  
testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

## Model

In [18]:
# encoder = torchvision.models.resnet18()

def projector():
    proj_layers = []

    for i in range(2):
        proj_layers.append(torch.nn.Linear(dim, dim))
        proj_layers.append(torch.nn.ReLU(dim))
        proj_layers.append(torch.nn.BatchNorm1d(dim))
    
    return torch.nn.Sequential(*proj_layers)

## VicREG

In [45]:
#VicReg Paper - with modifications
def VIC_Reg(Z):
    
    N = Z[0].shape[0]
    D = Z[0].shape[1]

    mu = 25 # as per VIC-Reg Paper -- Subject to change
    la = 25 # "   "
    nu = 1 #  "   "
    
    # invariance loss
    sim_loss = 0
    std_loss = 0
    cov_loss = 0

    for i in range(len(Z)): 
        for j in range(i+1, len(Z)): 
            sim_loss += F.mse_loss(Z[i], Z[j])
    
    for zi in Z: 
        std_zi = torch.sqrt(zi.var(dim=0) + 1e-04)
        std_loss += torch.mean(torch.relu(1 - std_zi)) 
    
    for zi in Z: 
        zi = zi - zi.mean(dim=0)
        cov_zi = (zi.T @ zi) / (N - 1)
        cov_zi = cov_zi[~torch.eye(cov_zi.shape[0], dtype=bool)]
        cov_loss += cov_zi.pow_(2).sum() / D


    # variance loss

    #    st_loss = torch.mean(torch.relu(1 - torch.stack([torch.sqrt(zi.var(dim=0) + 1e-04) for zi in Z])))
    #    cov_loss = torch.stack([(zi - zi.mean(dim=0)).T @ (zi - zi.mean(dim=0)) for zi in Z])
    #    cov_loss = torch.sum(cov_loss[~torch.eye(cov_loss.shape[0], dtype=bool)]) / (D * (N - 1))
    # 

    loss = la * sim_loss + mu * std_loss + nu * cov_loss
    
    return loss/len(Z)

## Barlow Twins

In [46]:
#Barlow Twins Paper - with modifications

def barlow_twins(Z):
    la = 0.005 # I am not entirely conviced la < 1  is sensible? 
    
    #input is [batch_size, 1000]
    #conv1d requires 3 dimensions, target CC is DxD i.e. 1000x1000

    N = Z[0].shape[0]
    D = Z[0].shape[1]
    
    loss = 0

    for i in range(len(Z)): 
        for j in range(len(Z)): 
            zi = Z[i] - Z[i].mean(dim=0)
            zj = Z[j] - Z[j].mean(dim=0)

            c = torch.matmul(zi.T, zj)
            c_diff = (c - torch.eye(D)).pow(2)
            
            off_diags = (torch.ones(c_diff.shape).fill_diagonal_(0))*la
            c_diff = c_diff*off_diags

            loss += c_diff.sum()
    
    return loss / len(Z)

## Train

In [47]:
def train(trainset, loss_mode = "VicReg"): 
    
    mode = loss_mode ==  "VicReg"

    encoder = torchvision.models.resnet18() # also try with pretrained=true
    encoder.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    rpoj = projector()
    model = torch.nn.Sequential(encoder, rpoj)

    model = model.to(device)
    model.train()
    optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # TODO LARS

    for i in range(epochs):
        losses = []
        for X, _ in tqdm(trainloader):
            Z = []
            for xi in X: 
                xi = xi.to(device)
                Z.append(model(xi))
                
            if mode: 
                loss = VIC_Reg(Z)
            else:
                loss = barlow_twins(Z)
            
            loss.backward()
            optimizer.step()
            losses.append(loss.detach().item())


        print(f"Epoch: {i}, loss: {np.mean(losses)}")
        #DL 1 Homework 1 
        os.makedirs('models', exist_ok=True)
        if mode:
            os.makedirs('models/VicReg', exist_ok=True)
            torch.save(encoder.state_dict(), f'models/VicReg/model_{batch_size}_epoch_{i}.pt')
        else:
            os.makedirs('models/BarlowTwins', exist_ok=True)
            torch.save(encoder.state_dict(), f'models/BarlowTwins/model_{batch_size}_epoch_{i}.pt')
    
    return encoder


In [41]:
encoder_vicreg = train(trainloader, loss_mode="VicReg")

  0%|          | 0/938 [00:00<?, ?it/s]

torch.Size([4, 1000, 1000])


  0%|          | 0/938 [00:05<?, ?it/s]


IndexError: The shape of the mask [4, 4] at index 1 does not match the shape of the indexed tensor [4, 1000, 1000] at index 1

In [48]:
encoder_barlow = train(trainloader)

  3%|▎         | 32/938 [01:06<31:18,  2.07s/it] 


KeyboardInterrupt: 

## Linear Head 

In [None]:
encoder_vicreg.eval()
encoder_barlow.eval()

In [None]:
# path duplicate - keep for now as convenience loader
def load_models(path="models/VicReg/model_16_epoch_0.pt"):
    encoder = torchvision.models.resnet18()
    saved = torch.load(path)
    encoder.load_state_dict(saved)
    return encoder

def linear_train(path="models/VicReg/model_16_epoch_0.pt"):
    # I dont understand how the guys from the vicreg paper combinded 
    # this LinearSVC and pytorch, to my understanding indicated
    # by the fact that they use optimiser (but I am no expert, by any means) ^^
    # clf = LinearSVC(random_state=0, tol=1e-5)
    
    encoder = load_models(path)
    encoder.eval()

    linear_classifier = torch.nn.Linear(dim, num_classes)
    linear_classifier.to(device)

    optimizer = torch.optim.SGD(linear_classifier.parameters(), lr=learning_rate, weight_decay=weight_decay)
    criterion = torch.nn.CrossEntropyLoss()

    for i in range(epochs):
        losses = []
        for x, y in tqdm(linear_trainloader):
            
            optimizer.zero_grad()
            
            latent_space = encoder(x)
            output = linear_classifier(latent_space)
            
            loss = criterion(output, y)
            loss.backward()
            optimizer.step()
            
            losses.append(loss.detach().item())

    
    print(f"Epoch: {i}, loss: {np.mean(losses)}")
    os.makedirs('models/LC', exist_ok=True)
    torch.save(encoder.state_dict(), f'models/LC/model_{batch_size}_epoch_{i}.pt')

    
    return linear_classifier


In [None]:
model = linear_train()

In [None]:
def test(head):
    total_samples = 0
    total_correct = 0

    for x,y in tqdm(testset_loader):
        head.eval()
        model = torch.nn.Sequential(load_models(), head)

        with torch.no_grad():
            outputs = model(x)
        
        predicted_labels = torch.argmax(outputs, dim=1)

        # Update evaluation metrics
        total_samples += y.size(0)
        total_correct += (predicted_labels == y).sum().item()

        return total_correct / total_samples


In [None]:
test(model)


Linear classification. We follow standard protocols Misra & Maaten (2020); Caron et al. (2020); Zbontar et al. (2021) and train linear models on top of the frozen representations. For VOC07 Everingham et al. (2010), we train a linear SVM with LIBLINEAR Fan et al. (2008). The images are center cropped and resized to 224 × 224, and the C values are computed with cross-validation. For Places205 Zhou et al. (2014) we use SGD with a learning rate of 0.003, a weight decay of 0.0001, a momentum of 0.9 and a batch size of 256, for 28 epochs. The learning rate is divided by 10 at epochs 4, 8 and 12. For Inaturalist2018 Horn et al. (2018), we use SGD with a learning rate of 0.005, a weight decay of 0.0001, a momentum of 0.9 and a batch size of 256, for 84 epochs. The learning rate is divided by 10 at epochs 24, 48 and 72.

In [None]:
linear_train("models/BarlowTwins/model_16_epoch_0.pt")