# Trying BarlowTwins on CIFAR-10

In [1]:
from os import makedirs

import torch
from torchvision.models import resnet18

from models import BarlowTwins
from utils import SSL_CIFAR10
from trainer import cifar10_trainer

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")
    
# Define hyperparameters
data_root = '/home/space/datasets/'
save_root = './results/barlow_twins/'
makedirs(save_root, exist_ok=True)

dl_kwargs = {'batch_size': 512, 'shuffle': True, 'num_workers': 2}

optim_params = {'lr':0.2 * dl_kwargs['batch_size']/256,
                'weight_decay': 1.5e-6,
                'exclude_bias_and_norm':True}
train_params = {'num_epochs': 800, 'warmup_epchs': 10, 'eta_min':1e-3, 'start_epoch':0}
eval_params  = {'evaluate_at': [100,200,400,600], 'lr':1e-2, 'num_epochs': 50, 'milestones': [30,40]}

Program running on Tesla P100-PCIE-12GB


# BarlowTwins

In [None]:
# Define data
byol_cifar10 = SSL_CIFAR10(data_root,'BYOL', dl_kwargs)

# Define Model
resnet = resnet18(zero_init_residual=True)
repre_dim = resnet.fc.in_features

barlow_twins = BarlowTwins(resnet,
                           projector_hidden = (2048,2048,2048)).to(device)

# Train
cifar10_trainer(save_root, barlow_twins, byol_cifar10,
                optim_params, train_params, eval_params)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Pretrained model available, use it?[y/n]: n


  return torch.max_pool2d(input, kernel_size, stride, padding, dilation, ceil_mode)


Epoch: 0, Loss: 1336.685903018283, Time epoch: 68.75380325317383
Epoch: 1, Loss: 1111.0566494341979, Time epoch: 68.52011156082153
Epoch: 2, Loss: 1041.6986209830059, Time epoch: 69.0090320110321
Epoch: 3, Loss: 999.1219041961984, Time epoch: 68.926194190979
Epoch: 4, Loss: 952.2877625140948, Time epoch: 69.13982725143433
Epoch: 5, Loss: 920.8228929657297, Time epoch: 68.73075151443481
Epoch: 6, Loss: 891.1156270135309, Time epoch: 69.18457841873169
Epoch: 7, Loss: 858.2364137000644, Time epoch: 68.72125816345215
Epoch: 8, Loss: 831.5743250896021, Time epoch: 68.97410011291504
Epoch: 9, Loss: 811.6565520296392, Time epoch: 68.74092316627502
Epoch: 10, Loss: 791.2377243828528, Time epoch: 68.36446738243103
Epoch: 11, Loss: 772.7735142658667, Time epoch: 68.90442776679993
Epoch: 12, Loss: 755.084187615778, Time epoch: 68.60586786270142
Epoch: 13, Loss: 740.379050343307, Time epoch: 68.72429847717285
Epoch: 14, Loss: 723.5792582403753, Time epoch: 68.73918724060059
Epoch: 15, Loss: 715.03