# Trying Bootstrap Your Own Latent (BYOL) on CIFAR-10

In [10]:
import copy
import time

from os import listdir, makedirs, path
from typing import Union

import torch
import torch.nn as nn
import torch.optim as opt

from PIL import Image

from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Print Device Type
if torch.cuda.is_available():
    print(f"Program running on {torch.cuda.get_device_name(device)}")
else:
    print("Program running on CPU")

# Define hyperparameters
data_root = '/home/fcfschulz/Documents/workspace/data/Vision/torchvision_ds/'
save_root = '/home/fcfschulz/Documents/workspace/data/saved_models/byol_cifar10/'

### For COLAB ############################################################
#from google.colab import drive
#drive.mount('/content/drive')
#
#data_root = './'
#save_root ='./drive/MyDrive/Colab_Notebooks/PhD/data/saved_models/byol/'
##########################################################################

# Make save root
makedirs(save_root, exist_ok=True)

dl_kwargs = {'batch_size': 256, 'shuffle': True, 'num_workers': 2}

train_params = {'base_lr':0.2, 'num_epochs': 250, 'warmup_epchs': 10,
                'weight_decay': 1.5e-6, 'eta_min':1e-6}

Program running on CPU


# BYOL
### Define Augmentation and Datasets

In [11]:
from augmentation import BYOL_transform

train_ds = CIFAR10(root=data_root, train = True, download = True,
                   transform = BYOL_transform(image_size=32, normalize=(0.5,0.5)))

train_dl = DataLoader(train_ds, drop_last=True, **dl_kwargs)

Files already downloaded and verified


### Define model

In [12]:
from models import BYOL

from torchvision.models import resnet18

backbone_net = resnet18()
repre_dim = backbone_net.fc.in_features
backbone_net.fc = nn.Flatten()

byol = BYOL(backbone_net, repre_dim, 32, 1024).to(device)

### Define Optimizer & Scheduler

In [None]:
optimizer = opt.Adam(byol.parameters(),
                     lr = train_params['base_lr'],
                     weight_decay = train_params['weight_decay'])

# Define scheduler for warmup
scheduler = opt.lr_scheduler.LambdaLR(optimizer, lambda it : (it+1)/(train_params['warmup_epchs']*len(train_dl)))

### Check for trained model

In [None]:
epoch_start = 0
loss_hist = []
lr_hist = []
tau_hist = []

if path.exists(path.join(save_root, f'epoch_{5:03}.tar')):
    user_answer = "Users_answer"
    while user_answer not in ["y","n"]:
        user_answer = input("Pretrained model available, use it?[y/n]: ").lower()[0]
    if user_answer=="y":
        epoch_start = max([int(file[-7:-4]) for file in listdir(save_root)])
        # Load data
        saved_data = torch.load(path.join(save_root, f'epoch_{epoch_start:03}.tar'), map_location=device)
        # Extract data
        byol.load_state_dict(saved_data['model'])
        optimizer.load_state_dict(saved_data['optim'])
        if epoch_start >= train_params['warmup_epchs']:
            iters_left = (train_params['num_epochs']-train_params['warmup_epchs'])*len(train_dl)
            scheduler = opt.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       iters_left,
                                                       eta_min=train_params['eta_min'])
        scheduler.load_state_dict(saved_data['sched'])
        loss_hist = saved_data['loss_hist']
        lr_hist = saved_data['lr_hist']
        tau_hist = saved_data['tau_hist']

### Training

In [None]:
total_iters = train_params['num_epochs'] * len(train_dl)

for epoch in range(epoch_start, train_params['num_epochs']):
    i = 0
    epoch_loss = 0
    start_time = time.time()
    for (x1,x2), _ in train_dl:
        x1,x2 = x1.to(device), x2.to(device)
        
        # Forward pass
        loss = byol(x1,x2)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update momentum encoder
        τ = byol.get_tau(1+i+len(train_dl)*epoch,total_iters)
        byol.update_moving_average(τ)
        tau_hist.append(τ)
        i += 1
        
        # Scheduler every iteration for cosine deday
        scheduler.step()
        
        # Save loss and LR
        epoch_loss += loss.item()
        lr_hist.extend(scheduler.get_last_lr())
    
    # Switch to Cosine Decay after warmup period
    if epoch+1==train_params['warmup_epchs']:
        iters_left = (train_params['num_epochs']-train_params['warmup_epchs'])*len(train_dl)
        scheduler = opt.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       iters_left,
                                                       eta_min=train_params['eta_min'])
    
    # Log
    loss_hist.append(epoch_loss/len(train_dl))
    print(f'Epoch: {epoch}, Loss: {loss_hist[-1]}, Time epoch: {time.time() - start_time}')
    
    # Save stats
    if (epoch+1)%5==0:
        torch.save({'model':byol.state_dict(),
                    'optim': optimizer.state_dict(),
                    'sched': scheduler.state_dict(),
                    'loss_hist': loss_hist,
                    'lr_hist': lr_hist,
                    'tau_hist': tau_hist}, 
                   path.join(save_root, f'epoch_{epoch+1:03}.tar'))

### Visualize loss, learning rate and $\tau$

In [None]:
import matplotlib.pyplot as plt

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(16,4))
ax1.plot(loss_hist)
ax1.set_title("Loss")

ax2.plot(lr_hist)
ax2.set_title("Learning rate")

ax3.plot(tau_hist)
ax3.set_title("Tau")

plt.show()

### Colab transfer saved models

In [None]:
#from google.colab import drive
#drive.mount('/content/drive')

#!cp ./saved/* 

# Linear Evaluation Protocol

### Get data

In [None]:
train_transf = transforms.Compose([
    transforms.RandomResizedCrop(32, (0.6, 1.),interpolation=transforms.InterpolationMode('bicubic')),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(), 
    transforms.Normalize(0.5,0.5)])
test_transf = transforms.Compose([transforms.ToTensor(), transforms.Normalize(0.5,0.5)])

le_train_ds  = CIFAR10(root=data_root, train = True, transform = train_transf, download = True)
test_ds  = CIFAR10(root=data_root, train = False, transform = test_transf, download = True)

le_train_dl = DataLoader(le_train_ds, drop_last=True, **dl_kwargs)
test_dl  = DataLoader(test_ds, drop_last=False, **dl_kwargs)

### Train

In [None]:
def extract_from_byol(byol_net):
    # Define encoder
    classifier = copy.deepcopy(byol_net.encoder_online)
    
    for p in classifier.parameters():
        p.requires_grad = False
    classifier[1] = nn.Sequential(nn.Linear(512,10))
    
    return classifier


def train_model(dataloader, model, optim, sched, num_epochs, criterion, device):
    for epoch in range(num_epochs):
        for x,y in dataloader:
            x,y = x.to(device), y.to(device)
            # forward
            loss = criterion(model(x), y)
            # backward
            optim.zero_grad()
            loss.backward()
            optim.step()
        
        # scheduler step
        sched.step()
            

def get_accuracy(datalaoder, model, device = 'cpu'):
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for x, y in datalaoder:
            x, y = x.to(device), y.to(device)
            # calculate outputs by running images through the network
            outputs = model(x)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()

    return 100 * correct / total

In [None]:
eval_params = {'base_lr':1e-2, 'num_epochs': 20, 'schedule_step': 8}

ce_loss = nn.CrossEntropyLoss()

# random init baseline
byol = BYOL(backbone_net, repre_dim, 32, 1024).to(device)
# Define encoder
classifier = extract_from_byol(byol).to(device)
# Define optimizer
le_optimizer = opt.Adam(classifier.parameters(), eval_params['base_lr'])
# Define scheduler
le_scheduler = opt.lr_scheduler.StepLR(le_optimizer, eval_params['schedule_step'])
# Train model
train_model(le_train_dl, classifier, le_optimizer, le_scheduler,
            eval_params['num_epochs'], ce_loss, device)
    
# Check accuracy
print(f'Accuracy of random init: {get_accuracy(test_dl, classifier, device)}')

# Check other models
models_avail = [f for f in listdir(save_root) if path.isfile(path.join(save_root, f))]
models_avail.sort()

for model_name in models_avail:
    # Load model
    saved_data = torch.load(path.join(save_root, model_name), map_location=device)
    # extract weights
    byol.load_state_dict(saved_data['model'])
    
    # Define encoder
    classifier = extract_from_byol(byol).to(device)
    # Define optimizer
    le_optimizer = opt.Adam(classifier.parameters(), eval_params['base_lr'])
    # Define scheduler
    le_scheduler = opt.lr_scheduler.StepLR(le_optimizer, eval_params['schedule_step'])
    
    # Train model
    train_model(le_train_dl, classifier, le_optimizer, le_scheduler,
                eval_params['num_epochs'], ce_loss, device)
    
    # Check accuracy
    print(f'Accuracy of model {model_name}: {get_accuracy(test_dl, classifier, device)}')