# [VICReg](https://arxiv.org/abs/2105.04906) on CIFAR-10

In [1]:
from os import makedirs

import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision.models import resnet18

from models import VICReg
#from optimizer import LARS
from trainer import SSL_Trainer
from utils import SSL_CIFAR10

# Define hyperparameters
data_root = './data/'
save_root = './results/vic_reg/'

dl_kwargs = {'batch_size': 512, 'shuffle': True, 'num_workers': 2}

# Define data
ssl_data = SSL_CIFAR10(data_root,'VICReg', dl_kwargs)

# general training params
train_params = {'save_root': save_root, 'num_epochs': 800, 'optimizer': SGD,
                'scheduler': CosineAnnealingLR, 'warmup_epochs': 10, 'iter_scheduler':True,
                'evaluate_at': [100,200,400,600], 'verbose':True}

# params of optimizer
## In Original Paper for Imagenet when using LARS Optimizer
#optim_params = {'lr':0.2 * dl_kwargs['batch_size']/256,
#                'weight_decay': 1.5e-6, 'exclude_bias_and_norm': True}

# from: https://github.com/IgorSusmelj/barlowtwins/blob/main/main.py
optim_params = {'lr': 1e-3, 'momentum': 0.9, 'weight_decay': 5e-4} 

# params of scheduler
scheduler_params = {'T_max': (train_params['num_epochs']-train_params['warmup_epochs'])*len(ssl_data.train_dl)}
                    # 'eta_min': 1e-3} in orginal implementation

# Set parameters for fitting linear protocoler
eval_params  = {'lr':1e-2, 'num_epochs': 25, 'milestones': [12,20]}

# Get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")
    
# Create folder if it does not exists
makedirs(save_root, exist_ok=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Program running on NVIDIA A100-PCIE-40GB


# VICReg

In [None]:
# Define Model
resnet = resnet18(zero_init_residual=True)
# Cifar specifics
resnet.conv1 = torch.nn.Conv2d(3, 64, 3, 1, 1, bias=False)
resnet.maxpool = torch.nn.Identity()

vicreg = VICReg(resnet, projector_hidden = (2048,2048,2048))

# Define Trainer
cifar10_trainer = SSL_Trainer(vicreg, ssl_data, device)

# Train
cifar10_trainer.train(**train_params, optim_params=optim_params,
                      scheduler_params=scheduler_params, eval_params=eval_params)

  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 0, Loss: 39.735168221070595, Time epoch: 204.89725589752197
Epoch: 1, Loss: 39.090835964556824, Time epoch: 205.3641095161438
Epoch: 2, Loss: 38.87787211310003, Time epoch: 203.807523727417
Epoch: 3, Loss: 38.71969789328034, Time epoch: 203.9682457447052
Epoch: 4, Loss: 38.579558736270236, Time epoch: 204.86553359031677
Epoch: 5, Loss: 38.505296451529276, Time epoch: 204.0614631175995
Epoch: 6, Loss: 38.40654510812661, Time epoch: 202.76366353034973
Epoch: 7, Loss: 38.33438889021726, Time epoch: 207.45458030700684
Epoch: 8, Loss: 38.282954186508334, Time epoch: 208.33787631988525
Epoch: 9, Loss: 38.23350371527918, Time epoch: 208.6044797897339
Epoch: 10, Loss: 38.18447337199732, Time epoch: 213.4902629852295
