# [BarlowTwins](https://arxiv.org/abs/2103.03230) on CIFAR-10

In [None]:
# For Colab
#!git clone https://github.com/FloCF/SSL_pytorch.git
#!pip install SSL_pytorch/

In [None]:
from os import makedirs

import torch
from torch.optim import SGD

from torchvision.models import resnet18

from torchselfsup.models import BarlowTwins
#from torchselfsup.optimizer import LARS
from torchselfsup.scheduler import CosineAnnealingWithWarmupLR
from torchselfsup.trainer import SSL_Trainer
from torchselfsup.utils import SSL_CIFAR10

# Define hyperparameters
data_root = '../data/'
save_root = '../results/barlow_twins/'

dl_kwargs = {'batch_size': 512, 'shuffle': True, 'num_workers': 2}

# Define data
ssl_data = SSL_CIFAR10(data_root,'BYOL', dl_kwargs)

# general training params
train_params = {'save_root': save_root, 'num_epochs': 800, 'optimizer': SGD,
                'scheduler': CosineAnnealingWithWarmupLR, 'iter_scheduler':True,
                'evaluate_at': [1,50,100,200,400,600], 'verbose':True}

# params of optimizer
## In Original Paper for Imagenet when using LARS Optimizer
#optim_params = {'lr':0.2 * dl_kwargs['batch_size']/256, 'weight_decay': 1.5e-6,
#                'exclude_bias_and_norm': True}

# from: https://docs.lightly.ai/getting_started/benchmarks.html#cifar10
optim_params = {'lr': 6e-2, 'momentum': 0.9, 'weight_decay': 5e-4} 

# params of scheduler
scheduler_params = {'num_epochs': train_params['num_epochs'], 'len_traindl': len(ssl_data.train_dl),
                    'warmup_epochs': 10, 'iter_scheduler': train_params['iter_scheduler'],
                    'min_lr': 1e-4} # 'min_lr': 1e-3 in orginal

# Set parameters for fitting linear protocoler
eval_params  = {'lr':1e-2, 'num_epochs': 25, 'milestones': [12,20]}

# Get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")
    
# Create folder if it does not exists
makedirs(save_root, exist_ok=True)

## Training

In [None]:
# Define Encoder Model
resnet = resnet18(zero_init_residual=True)
# Cifar specifics
resnet.conv1 = torch.nn.Conv2d(3,64, 3, 1, 1, bias=False)
resnet.maxpool = torch.nn.Identity()
# Deactivate fully connected layer
resnet.fc = torch.nn.Identity()

# In case you don't know the encoding representation:
with torch.no_grad():
    test_out = resnet(next(iter(ssl_data.train_dl))[0][0][:1])
    repre_dim = test_out.shape[1]

# Define SSL model
model = BarlowTwins(resnet, repre_dim, projector_hidden = (2048,2048,2048))

# Define Trainer
cifar10_trainer = SSL_Trainer(model, ssl_data, device)

# Train
cifar10_trainer.train(**train_params, optim_params=optim_params,
                      scheduler_params=scheduler_params, eval_params=eval_params)