# Trying BarlowTwins on CIFAR-10

In [None]:
from os import makedirs

import torch
from torchvision.models import resnet18

from models import BarlowTwins
from utils import SSL_CIFAR10
from trainer import cifar10_trainer

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")
    
# Define hyperparameters
data_root = '/home/space/datasets/'
save_root = './results/barlow_twins/'
makedirs(save_root, exist_ok=True)

dl_kwargs = {'batch_size': 512, 'shuffle': True, 'num_workers': 2}

optim_params = {'lr':0.2 * dl_kwargs['batch_size']/256,
                'weight_decay': 1.5e-6,
                'exclude_bias_and_norm':True}
train_params = {'num_epochs': 800, 'warmup_epchs': 10, 'eta_min':1e-3, 'start_epoch':0}
eval_params  = {'evaluate_at': [100,200,400,600], 'lr':1e-2, 'num_epochs': 50, 'milestones': [30,40]}

# BarlowTwins

In [None]:
# Define data
byol_cifar10 = SSL_CIFAR10(data_root,'BYOL', dl_kwargs)

# Define Model
resnet = resnet18(zero_init_residual=True)
repre_dim = resnet.fc.in_features

barlow_twins = BarlowTwins(resnet,
                           projector_hidden = (2048,2048,2048)).to(device)

# Train
cifar10_trainer(save_root, barlow_twins, byol_cifar10,
                optim_params, train_params, eval_params)