In [None]:
import torchvision
from torchvision import transforms
import torch
import torch.nn.functional as F
import torchvision.transforms as T
from torch.optim.lr_scheduler import CosineAnnealingLR

import numpy as np
import random
from tqdm import tqdm
from matplotlib import pyplot as plt
import os
import math
import glob 
import logging
import time
import psutil
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

from view_transform import ViewTransform
from LARS import LARS

## Hyperparameters

In [None]:
'''
TODO
-->> It might be interesting investigate the efficiency frontier between max_batch and views 
'''

In [None]:
torch.manual_seed(42)
random.seed(42)

warmup_steps = 10 
start_epoch = 1
epochs = 100 # Original set to 1000 
output_enc = 512
dim = 8192
num_views = 4
offset = 1 

num_workers = 2
device = 'cuda' # or 'cuda' for faster training

batch_size = 2048

# VicREG
base_lr_ = 0.2
lr_head = 0.3
learning_rate = batch_size/256 * base_lr_ 
weight_decay = 1e-6

logging.basicConfig(filename=f'logs/b:{batch_size}.log', filemode='w', level=logging.INFO)


## Data 

In [None]:
num_classes = 10

#trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True)  
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)  
trainset.transform = ViewTransform(2)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

## Model

In [None]:
def projector():
    proj_layers = []
    proj_layers.append(torch.nn.Flatten())

    proj_layers.append(torch.nn.Linear(output_enc, dim))
    proj_layers.append(torch.nn.ReLU(dim))
    proj_layers.append(torch.nn.BatchNorm1d(dim))

    proj_layers.append(torch.nn.Linear(dim, dim))
    proj_layers.append(torch.nn.ReLU(dim))
    proj_layers.append(torch.nn.BatchNorm1d(dim))
    
    proj_layers.append(torch.nn.Linear(dim, dim, bias=False))
    
    return torch.nn.Sequential(*proj_layers)

## VicREG

In [None]:
#VicReg Paper - with modifications
def VIC_Reg(Z):
    
    N = Z[0].shape[0]
    D = Z[0].shape[1]

    mu = 25
    la = 25
    nu = 1 
    
    sim_loss = 0
    std_loss = 0
    cov_loss = 0

    for i in range(len(Z)): 
        for j in range(i+1, len(Z)): 
            sim_loss += F.mse_loss(Z[i], Z[j])

    for zi in Z: 
        std_zi = torch.sqrt(zi.var(dim=0) + 1e-04)
        std_loss += torch.mean(torch.relu(1 - std_zi)) 
    
    for zi in Z: 
        zi = zi - zi.mean(dim=0)
        cov_zi = (zi.T @ zi) / (N - 1)
        cov_zi = cov_zi[~torch.eye(cov_zi.shape[0], dtype=bool,device=device)]
        cov_loss += cov_zi.pow_(2).sum() / D

    sim_loss /= (len(Z) * (len(Z)-1)) / 2
    std_loss /= len(Z)
    cov_loss /= len(Z)
    
    logging.info('IL: %.3f, STDL: %.3f, CVL: %.3f',la * sim_loss, mu * std_loss, nu * cov_loss)

    loss = la * sim_loss + mu * std_loss + nu * cov_loss
    
    return loss

In [None]:
# Copied from https://github.com/facebookresearch/vicreg/blob/main/main_vicreg.py

def adjust_learning_rate(optimizer, loader, step):
    max_steps = epochs * len(loader)
    warmup_steps = 10 * len(loader)
    base_lr = base_lr_ * batch_size / 256
    if step < warmup_steps:
        lr = base_lr * step / warmup_steps
    else:
        step -= warmup_steps
        max_steps -= warmup_steps
        q = 0.5 * (1 + math.cos(math.pi * step / max_steps))
        end_lr = base_lr * 0.001
        lr = base_lr * q + end_lr * (1 - q)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
    return lr

## Train

In [None]:
def train(trainset, offset, batch_size=2048, load_from_checkpoint=""): 

    encoder = torchvision.models.resnet18() # also try with pretrained=true
    encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    encoder = torch.nn.Sequential(*list(encoder.children())[:-1])

    
    rpoj = projector()
    model = torch.nn.Sequential(encoder, rpoj)
    model.train()
    optimizer = LARS(model.parameters(),lr=0,weight_decay=weight_decay)
    
    if load_from_checkpoint != "": 
        checkpoint = torch.load(load_from_checkpoint)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        #start_epoch = checkpoint['epoch']
        loss = checkpoint['loss'] 
        model.train()
        
    model = model.to(device)
    last_time = time.time()

    for i in range(start_epoch, epochs+1):
        losses = []
        for step, (X, _) in tqdm(enumerate(trainloader), total=epochs):
            
            if step % offset == 0: 
                lr = adjust_learning_rate(optimizer, trainloader, step)  
                optimizer.zero_grad()
            
            Z = []
            for xi in X: 
                xi = xi.to(device)
                Z.append(model(xi))
                
            loss = VIC_Reg(Z)
            logging.info('%s ,Epoch: %d, Step: %d, Loss: %.3f, Elapsed: %d, View: %d', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())), i, step, loss, time.time()-last_time, len(Z))
            
            if step % offset == 0: 
                loss.backward()
                optimizer.step()
                losses.append(loss.detach().item())

            last_time = time.time()


        print(f"Epoch: {i}, loss: {np.mean(losses)}")
        #DL 1 Homework 1 

        if i%10==0:
            torch.save({
                'epoch': i,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': np.mean(losses),
                }, f'model_{num_views}_{batch_size}_epoch_{i}.pt')
    
    return encoder

### Fixed Batch - Increasing Views

In [None]:
for i in [2,4,6,8]:
    print(f"View: {i} / {num_views}")
    view = i 
    at_limit = False

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)  
    trainset.transform = ViewTransform(view)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    
    os.makedirs('logs', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    #
    #logging.basicConfig(filename=f'logs/loss_details_b:{batch_size}.log', filemode='w', level=logging.DEBUG)

    while True:
        try:
            #if(at_limit):
            #    batch_size //= 2
            #    offset = offset * 2
            #    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
                
            encoder_vicreg = train(trainloader, offset)
            break  #  successful break out of the loop
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("Out of memory error occurred. Reducing batch size and retrying...")
                # Reduce batch size & Conversly increase step size
                batch_size //= 2
                offset = offset * 2
                at_limit = True

                trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
            else:
                raise e

### Fixed Total Gradients - Increase Views & Reduce Batch Size

In [None]:
for i in [4,6,8]:
    print(f"View: {i} / {num_views}")
    view = i 
    at_limit = False
    
    batch_size = 2 * batch_size / i 
    
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)  
    trainset.transform = ViewTransform(view)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    
    os.makedirs('logs', exist_ok=True)
    os.makedirs('models', exist_ok=True)
    #
    #logging.basicConfig(filename=f'logs/loss_details_b:{batch_size}.log', filemode='w', level=logging.DEBUG)

    while True:
        try:
            #if(at_limit):
            #    batch_size //= 2
            #    offset = offset * 2
            #    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
                
            encoder_vicreg = train(trainloader, offset, batch_size)
            break  #  successful break out of the loop
        except RuntimeError as e:
            if "out of memory" in str(e):
                print("Out of memory error occurred. Reducing batch size and retrying...")
                # Reduce batch size & Conversly increase step size
                batch_size //= 2
                offset = offset * 2
                at_limit = True

                trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)
            else:
                raise e

## Model Evaluations - Beyond Loss (kNN, LinearHead)

In [None]:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose(
    [
        transforms.RandomResizedCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ]))

train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose(
    [
        transforms.ToTensor(),
        normalize,
    ]))

test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers)


In [None]:
def load_all_models(path="./models/"):
    return glob.glob(path+"*.pt")

def load_final_models(path="./models/"):
    return glob.glob(path+"*100*.pt")
# ChatGPT &/ PyTorch topK 
def top_k_accuracy(output, target, k=1):
    _, pred = output.topk(k, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    correct_k = correct[:k].reshape(-1).float().sum()
    accuracy = correct_k.mul_(100.0 / target.size(0))
    return accuracy.item()


In [None]:
for model_path in tqdm(load_all_models(), total=len(load_all_models())): 

    encoder = torchvision.models.resnet18()
    encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    encoder = torch.nn.Sequential(*list(encoder.children())[:-1])

    checkpoint = torch.load(model_path)
    encoder.load_state_dict(checkpoint['model_state_dict'])
    encoder.eval()
    encoder.to(device)
    
    X_train_embedding = []
    X_test_embedding = []
    y_train = []
    y_test = []

    for x,y in train_loader:
        x = x.to(device)
        X_train_embedding.append(encoder(x))
        y_train.append(y_train)

    for x,y in test_loader:
        x = x.to(device)
        X_test_embedding.append(encoder(x))
        y_test.append(y_train)

    knn = KNeighborsClassifier()
    knn.fit(X_train_embedding, y_train)
    X_test_predicted = knn.predict(X_test_embedding)
    accuracy = accuracy_score(y_test, X_test_predicted)
    logging.info("Model: %s, kNN-Accuracy: %.3f", model_path, accuracy)

In [None]:
for model_path in tqdm(load_final_models(), total=epochs*len(load_final_models())):
    
    encoder = torchvision.models.resnet18()
    encoder.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
    checkpoint = torch.load(model_path)
    encoder.load_state_dict(checkpoint['model_state_dict'])
    encoder.requires_grad_(False)
    encoder.eval()
    encoder.to(device)

    head = torch.nn.Linear(output_enc, num_classes)
    head.weight.data.normal_(mean=0.0, std=0.01)
    head.bias.data.zero_()
    model = torch.nn.Sequential(encoder, head)

    param_groups = [dict(params=head.parameters(), lr=lr_head)]

    criterion = torch.nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(param_groups, 0, momentum=0.9, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)

    for i in range(epochs):
        for step, (X, Y) in tqdm(enumerate(trainloader), total=epochs):
            X = X.to(device)
            output = model(X)
            loss = criterion(output, Y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                accuracy = []
                for images, targets in test_loader: 
                    output = model(images)
                    accuracy.append(top_k_accuracy(output, targets))
                
                accuracy = accuracy.mean()
            
            logging.info("Model: %s, Accuracy: %.3f, Epoch: %d", model_path, accuracy, i)

## Visualization

In [None]:
df = pd.read_csv("logs/b_2048_v_2.log", delimiter=',', header=None)
columns = ["Info", "Epoch", "Step", "Loss", "Elapsed", "Views"]
df.columns = columns
df["Epoch"] = df["Epoch"].str.extract(r"(\d+)", expand=False).astype(int)
df["Step"] = df["Step"].str.extract(r"(\d+)", expand=False).astype(int)
df["Loss"] = df["Loss"].str.extract(r"(\d+\.\d+)", expand=False).astype(float)
df["Elapsed"] = df["Elapsed"].str.extract(r"(\d+)", expand=False).astype(float)
df["Views"] = df["Views"].str.extract(r"(\d+)", expand=False).astype(int)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(10, 8))

for i, ax in enumerate(axes.flatten()):
    grouped_data = df[df["Views"] == i+2]
    #grouped_data = grouped_data.groupby("Epoch").agg({"Loss": "mean"}).reset_index()
    ax.plot(grouped_data["Loss"])
    ax.set_xlabel("Steps")
    ax.set_ylabel("Loss")
    ax.set_title(f"View {i+2}")

for i in range(2,num_views+1): 
    grouped_data = df[df["Views"] == i]
    #grouped_data.groupby("Epoch").agg({"Loss": "mean"}).reset_index()
    ax.plot(grouped_data["Epoch"], grouped_data["Loss"],label=f"{i} Views")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Average Loss")


# Get the handles and labels from the plots
handles, labels = ax.get_legend_handles_labels()

# Place the legend outside the plot
plt.figlegend(handles, labels, loc='center', bbox_to_anchor=(1.05, 1))

plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.show()

In [None]:
for view in range(2,num_views+1):
    last_loss = df[df["Views"] == view]["Loss"].tail(1).values[0]
    print(f"Last loss for View {view}: {last_loss}")

In [None]:
df = pd.read_csv("logs/b_128_v_2.log", delimiter=',', header=None)
columns = ["Info", "Epoch", "Step", "Loss", "Elapsed", "Views"]
df.columns = columns

In [None]:
df = pd.read_csv("logs/b_128_v_2.log", delimiter=',', header=None)
columns = ["Info", "Epoch", "Step", "Loss", "Elapsed", "Views"]
df.columns = columns
df["Epoch"] = df["Epoch"].str.extract(r"(\d+)", expand=False).astype(int)
df["Step"] = df["Step"].str.extract(r"(\d+)", expand=False).astype(int)
df["Loss"] = df["Loss"].str.extract(r"(\d+\.\d+)", expand=False).astype(float)
df["Elapsed"] = df["Elapsed"].str.extract(r"(\d+)", expand=False).astype(float)
df["Views"] = df["Views"].str.extract(r"(\d+)", expand=False).astype(int)

In [None]:
fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(10, 8))

for i, ax in enumerate(axes.flatten()):
    grouped_data = df[df["Views"] == i+2]
    #grouped_data = grouped_data.groupby("Epoch").agg({"Loss": "mean"}).reset_index()
    ax.plot(grouped_data["Loss"])
    ax.set_xlabel("Steps")
    ax.set_ylabel("Loss")
    ax.set_title(f"View {i+2}")

for i in range(2,3+1): 
    grouped_data = df[df["Views"] == i]
    #grouped_data.groupby("Epoch").agg({"Loss": "mean"}).reset_index()
    ax.plot(grouped_data["Epoch"], grouped_data["Loss"],label=f"{i} Views")
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Average Loss")


# Get the handles and labels from the plots
handles, labels = ax.get_legend_handles_labels()

# Place the legend outside the plot
plt.figlegend(handles, labels, loc='center', bbox_to_anchor=(1.05, 1))

plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.show()
