# Trying BarlowTwins on CIFAR-10

In [1]:
from os import makedirs

import torch
from torch.optim.lr_scheduler import CosineAnnealingLR

from torchvision.models import resnet18

from models import BarlowTwins
from optimizer import LARS
from trainer import SSL_Trainer
from utils import SSL_CIFAR10

# Define hyperparameters
data_root = '/home/space/datasets/'
#data_root = '/home/fcfschulz/Documents/workspace/data/Vision/'
save_root = './results/barlow_twins/'

dl_kwargs = {'batch_size': 512, 'shuffle': True, 'num_workers': 2}

# Define data
ssl_data = SSL_CIFAR10(data_root,'BYOL', dl_kwargs)

train_params = {'save_root': save_root, 'num_epochs': 800, 'optimizer': LARS,
                'scheduler': CosineAnnealingLR, 'warmup_epochs': 10, 'iter_scheduler':True,
                'evaluate_at': [1,100,200,400,600], 'verbose':True}

optim_params = {'lr':0.2 * dl_kwargs['batch_size']/256, 'weight_decay': 1.5e-6,
                'exclude_bias_and_norm': True}

scheduler_params = {'T_max': (train_params['num_epochs']-train_params['warmup_epochs'])*len(ssl_data.train_dl),
                    'eta_min': 1e-3}

eval_params  = {'lr':1e-2, 'num_epochs': 5, 'milestones': [30,40]}

# Get device
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")
    
# Create folder if it does not exists
makedirs(save_root, exist_ok=True)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Program running on Tesla P100-PCIE-12GB


# BarlowTwins

In [None]:
# Define Model
resnet = resnet18(zero_init_residual=True)

barlow_twins = BarlowTwins(resnet, projector_hidden = (2048,2048,2048)).to(device)

# Define Trainer
cifar10_trainer = SSL_Trainer(barlow_twins, ssl_data, device)

# Train
cifar10_trainer.train(**train_params, optim_params=optim_params,
                      scheduler_params=scheduler_params, eval_params=eval_params)

# Evaluate
cifar10_trainer.evaluate(eval_params)
# print
print(f'Accuracy after Training: KNN:{cifar10_trainer.eval_acc["knn"][-1]},'\
      'Linear: {cifar10_trainer.eval_acc["lin"][-1]}')

Pretrained model available, use it?[y/n]: y
Epoch: 17, Loss: 648.5309328688788, Time epoch: 70.1365442276001
Epoch: 18, Loss: 638.0418782971569, Time epoch: 69.61742854118347
Epoch: 19, Loss: 633.3371336632168, Time epoch: 69.57933115959167
Epoch: 20, Loss: 621.7218137131524, Time epoch: 70.32259368896484
Epoch: 21, Loss: 618.272109198816, Time epoch: 70.41536068916321
Epoch: 22, Loss: 612.5374743274807, Time epoch: 70.00139546394348
Epoch: 23, Loss: 611.9205171250806, Time epoch: 69.61311101913452
Epoch: 24, Loss: 605.7155528904236, Time epoch: 69.81289267539978
Epoch: 25, Loss: 602.1865284713273, Time epoch: 69.87919759750366
Epoch: 26, Loss: 592.8337830219073, Time epoch: 69.83821249008179
Epoch: 27, Loss: 591.414676626933, Time epoch: 69.79781031608582
Epoch: 28, Loss: 587.5868523981153, Time epoch: 69.73955416679382
Epoch: 29, Loss: 583.9519124768444, Time epoch: 69.47623229026794
Epoch: 30, Loss: 579.3783311352288, Time epoch: 69.75653409957886
Epoch: 31, Loss: 571.871519737637, 