In [2]:
import torchvision
from torchvision import transforms
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR

import numpy as np
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
import math
import logging
import time
import psutil

from view_transform import ViewTransform
from LARS import LARS
import resnet

## Hyperparameters

In [3]:
'''
TODO
- try with pretrained weights for resnet18
-->> It might be interesting investigate the efficiency frontier between max_batch and views 
'''

'\nTODO\n- try with pretrained weights for resnet18\n-->> It might be interesting investigate the efficiency frontier between max_batch and views \n'

In [4]:
""" 
Experiment: Change number of Views from 2 to 16 and record loss after 100 epochs
"""

' \nExperiment: Change number of Views from 2 to 16 and record loss after 100 epochs\n'

In [5]:
torch.manual_seed(42)
random.seed(42)

warmup_steps = 10 
epochs = 100 # Original set to 1000 
output_enc = 1000
dim = 8192
num_views = 2

num_workers = 4
device = 'cpu' # or 'cuda' for faster training

batch_size = 2048

# VicREG
base_lr_ = 0.2
learning_rate = batch_size/256 * base_lr_ 
weight_decay = 1e-6

# BarlowTwins
# learning_rate = base_lr * batch_size / 256
# weight_decay = 1.5*1e-6

## Data 

In [6]:
num_classes = 10

#trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)  
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)  
trainset.transform = ViewTransform(num_views)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

#linear_trainset = trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=linear_transform)  
#linear_trainset = trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=linear_transform)  
#linear_trainloader = torch.utils.data.DataLoader(linear_trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

#testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=linear_transform)  
#testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=linear_transform)  
#testset_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

Files already downloaded and verified


## Model

In [7]:
# encoder = torchvision.models.resnet18()
# encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

def projector():
    proj_layers = []

    proj_layers.append(torch.nn.Linear(output_enc, dim))
    proj_layers.append(torch.nn.ReLU(dim))
    proj_layers.append(torch.nn.BatchNorm1d(dim))

    proj_layers.append(torch.nn.Linear(dim, dim))
    proj_layers.append(torch.nn.ReLU(dim))
    proj_layers.append(torch.nn.BatchNorm1d(dim))
    
    proj_layers.append(torch.nn.Linear(dim, dim, bias=False))
    
    return torch.nn.Sequential(*proj_layers)

## VicREG

In [8]:
#VicReg Paper - with modifications
def VIC_Reg(Z):
    
    N = Z[0].shape[0]
    D = Z[0].shape[1]

    mu = 25
    la = 25
    nu = 1 
    
    sim_loss = 0
    std_loss = 0
    cov_loss = 0

    for i in range(len(Z)): 
        for j in range(i+1, len(Z)): 
            sim_loss += F.mse_loss(Z[i], Z[j])

    for zi in Z: 
        std_zi = torch.sqrt(zi.var(dim=0) + 1e-04)
        std_loss += torch.mean(torch.relu(1 - std_zi)) 
    
    for zi in Z: 
        zi = zi - zi.mean(dim=0)
        cov_zi = (zi.T @ zi) / (N - 1)
        cov_zi = cov_zi[~torch.eye(cov_zi.shape[0], dtype=bool)]
        cov_loss += cov_zi.pow_(2).sum() / D

    sim_loss /= 2*len(Z)
    std_loss /= len(Z)
    cov_loss /= len(Z)

    loss = la * sim_loss + mu * std_loss + nu * cov_loss
    
    return loss

In [9]:
# Copied from https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py

def adjust_learning_rate(optimizer, loader, step):
    max_steps = epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = base_lr_ * batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr

## Train

In [10]:
def train(trainset, offset): 

    encoder = torchvision.models.resnet18() # also try with pretrained=true
    encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

    rpoj = projector()
    model = torch.nn.Sequential(encoder, rpoj)

    model = model.to(device)
    model.train()
    optimizer = LARS(model.parameters(),lr=learning_rate,weight_decay=weight_decay)
    #optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay=weight_decay) 

        
    os.makedirs('logs', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    logging.basicConfig(filename='logs/example.log', filemode='w', level=logging.INFO)
    starttime = time.time()

    for i in range(epochs):
        losses = []
        for step, (X, _) in tqdm(enumerate(trainloader), total=len(trainloader)):
            
            lr = adjust_learning_rate(optimizer, trainloader, step)  
            
            if step % offset == 0: 
                optimizer.zero_grad()
            
            Z = []
            for xi in X: 
                xi = xi.to(device)
                Z.append(model(xi))
                
            loss = VIC_Reg(Z)
            logging.info('%s ,Epoch: %d, Step: %d, Loss: %d, Elapsed: %d, lr: %d', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), i, step, loss, starttime-time.time(), lr)
            
            if step % offset == 0: 
                loss.backward()
                optimizer.step()
                losses.append(loss.detach().item())


        print(f"Epoch: {i}, loss: {np.mean(losses)}")
        #DL 1 Homework 1 

        torch.save(encoder.state_dict(), f'models/VicReg/model_{batch_size}_epoch_{i}.pt')
    
    return encoder

In [12]:
step = 1 

while True:
    try:
        encoder_vicreg = train(trainloader, step)
        break  #  successful break out of the loop
    except RuntimeError as e:
        if "out of memory" in str(e):
            print("Out of memory error occurred. Reducing batch size and retrying...")
            # Reduce batch size & Conversly increase step size
            batch_size //= 2
            step = step * 2

            trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
        else:
            raise e

  0%|          | 0/25 [00:00<?, ?it/s]