# Trying Bootstrap Your Own Latent (BYOL) on CIFAR-10

In [8]:
import copy
import time

from os import listdir, makedirs, path
from typing import Union

import torch
import torch.nn as nn
import torch.optim as opt

from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18

from augmentation import BYOL_transform
from models import BYOL
from optimizer import LARS
from utils import Linear_Protocoler

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")

# Define hyperparameters
data_root = '/home/fcfschulz/Documents/workspace/data/Vision/torchvision_ds/'
save_root = '/home/fcfschulz/Documents/workspace/data/saved_models/byol_cifar10/'

### For COLAB ############################################################
#from google.colab import drive
#drive.mount('/content/drive')
#
#data_root = './'
#save_root ='./drive/MyDrive/Colab_Notebooks/PhD/data/saved_models/byol/'
##########################################################################

# Make save root
makedirs(save_root, exist_ok=True)

dl_kwargs = {'batch_size': 256, 'shuffle': True, 'num_workers': 2}

optim_params = {'lr':0.2, 'weight_decay': 1.5e-6,'exclude_bias_and_norm':True}
train_params = {'num_epochs': 250, 'warmup_epchs': 10}

eval_params = {'lr':1e-2, 'num_epochs': 5}

Program running on CPU


# BYOL
### Define Augmentation and Datasets

In [9]:
# Define Augmentations
train_transf = BYOL_transform(image_size=32, normalize=(0.5,0.5))

train_eval_transf = transforms.Compose([
    transforms.RandomResizedCrop(32, (0.8, 1.),interpolation=transforms.InterpolationMode('bicubic')),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    transforms.Normalize(0.5,0.5)])

test_transf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5,0.5)])

# Define Datasets
train_ds = CIFAR10(root=data_root, train = True, download = True, transform = train_transf)
train_eval_ds = CIFAR10(root=data_root, train = True, transform = train_eval_transf, download = True)
test_ds  = CIFAR10(root=data_root, train = False, transform = test_transf, download = True)

# Define Dataloaders
train_dl = DataLoader(train_ds, drop_last=True, **dl_kwargs)#
train_eval_dl = DataLoader(train_eval_ds, drop_last=False, **dl_kwargs)
test_dl  = DataLoader(test_ds, drop_last=False, **dl_kwargs)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


### Define model

In [10]:
backbone_net = resnet18()
repre_dim = backbone_net.fc.in_features
backbone_net.fc = nn.Flatten()

byol = BYOL(backbone_net, repre_dim, 32, 1024).to(device)

### Define Optimizer & Scheduler

In [20]:
optimizer = LARS(byol.parameters(), **optim_params)

# Define scheduler for warmup
scheduler = opt.lr_scheduler.LambdaLR(optimizer, lambda it : (it+1)/(train_params['warmup_epchs']*len(train_dl)))

### Check for trained model

In [None]:
# Init
epoch_start = 0
lp_acc = []
loss_hist = []
lr_hist = []
tau_hist = []

if path.exists(path.join(save_root, f'epoch_{5:03}.tar')):
    user_answer = "Users_answer"
    while user_answer not in ["y","n"]:
        user_answer = input("Pretrained model available, use it?[y/n]: ").lower()[0]
    if user_answer=="y":
        epoch_start = max([int(file[-7:-4]) for file in listdir(save_root)])
        # Load data
        saved_data = torch.load(path.join(save_root, f'epoch_{epoch_start:03}.tar'), map_location=device)
        # Extract data
        byol.load_state_dict(saved_data['model'])
        optimizer.load_state_dict(saved_data['optim'])
        if epoch_start >= train_params['warmup_epchs']:
            iters_left = (train_params['num_epochs']-train_params['warmup_epchs'])*len(train_dl)
            scheduler = opt.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       iters_left,
                                                       eta_min=train_params['eta_min'])
        scheduler.load_state_dict(saved_data['sched'])
        lp_acc = saved_data['lp_acc']
        loss_hist = saved_data['loss_hist']
        lr_hist = saved_data['lr_hist']
        tau_hist = saved_data['tau_hist']

### Training

In [None]:
# Run linear protocol for random init model
if len(lp_acc)==0:
    linear_proto = Linear_Protocoler(byol.backbone_net, out_dim=repre_dim)
    linear_proto.train(train_eval_dl, eval_params['num_epochs'], eval_params['base_lr'])
    lp_acc.append(linear_pro.get_accuracy(test_dl))
    
# get total number of iterations
total_iters = train_params['num_epochs'] * len(train_dl)

# Run Training
for epoch in range(epoch_start, train_params['num_epochs']):
    epoch_loss = 0
    start_time = time.time()
    for i, ((x1,x2), _) in enumerate(train_dl):
        x1,x2 = x1.to(device), x2.to(device)
        
        # Forward pass
        loss = byol(x1,x2)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update momentum encoder
        # get τ
        τ = byol.get_tau(1+i+len(train_dl)*epoch,total_iters)
        byol.update_moving_average(τ)
        tau_hist.append(τ)
        
        # Scheduler every iteration for cosine deday
        scheduler.step()
        
        # Save loss and LR
        epoch_loss += loss.item()
        lr_hist.extend(scheduler.get_last_lr())
    
    # Switch to Cosine Decay after warmup period
    if epoch+1==train_params['warmup_epchs']:
        iters_left = (train_params['num_epochs']-train_params['warmup_epchs'])*len(train_dl)
        scheduler = opt.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       iters_left,
                                                       eta_min=train_params['eta_min'])
    
    # Log
    loss_hist.append(epoch_loss/len(train_dl))
    print(f'Epoch: {epoch}, Loss: {loss_hist[-1]}, Time epoch: {time.time() - start_time}')
    
    # Run linear protocol and save stats
    if (epoch+1)%5==0:
        # Linear protocol
        linear_proto = Linear_Protocoler(byol.backbone_net, out_dim=repre_dim)
        linear_proto.train(train_eval_dl, eval_params['num_epochs'], eval_params['base_lr'])
        lp_acc.append(linear_pro.get_accuracy(test_dl))
        
        torch.save({'model':byol.state_dict(),
                    'optim': optimizer.state_dict(),
                    'sched': scheduler.state_dict(),
                    'lp_acc': lp_acc,
                    'loss_hist': loss_hist,
                    'lr_hist': lr_hist,
                    'tau_hist': tau_hist}, 
                   path.join(save_root, f'epoch_{epoch+1:03}.tar'))

### Visualize loss, learning rate and $\tau$

In [None]:
import matplotlib.pyplot as plt

fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16,4))
ax1.plot(loss_hist)
ax1.set_title("Loss")

ax2.plot(lp_acc)
ax2.set_title("Linear Evaluation Protocol")

ax3.plot(lr_hist)
ax3.set_title("Learning rate")

ax4.plot(tau_hist)
ax4.set_title("Tau")

plt.show()

### Colab transfer saved models

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

#!cp ./saved/* 