In [None]:
import os
import sys
import numpy as np
from math import pi, cos 


import torch
import torchvision
import torch.nn as nn
from logger import Logger
from torch import allclose
from datetime import datetime
import torch.nn.functional as tf 

from tqdm.notebook import tqdm
import torchvision.transforms as T
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from torch.testing import assert_allclose
from torchvision import datasets, transforms

import kornia
from kornia import augmentation as K
import kornia.augmentation.functional as F
import kornia.augmentation.random_generator as rg
from torchvision.transforms import functional as tvF

In [None]:
uid = 'byol'
dataset_name = 'stl10'
data_dir = 'dataset'
ckpt_dir = "./ckpt/"+str(datetime.now().strftime('%m%d%H%M%S'))
log_dir = "runs/"+str(datetime.now().strftime('%m%d%H%M%S'))

if not os.path.exists(data_dir):
    os.makedirs(data_dir)
    
if not os.path.exists(ckpt_dir):
    os.makedirs(ckpt_dir)

if not os.path.exists(log_dir):
    os.makedirs(log_dir)

In [None]:
# transformations

_MEAN =  [0.5, 0.5, 0.5]
_STD  =  [0.2, 0.2, 0.2]


# _MEAN_ =  torch.FloatTensor([CIFAR_MEAN])
# CIFAR_STD_  =  torch.FloatTensor([CIFAR_STD])

class InitalTransformation():
    def __init__(self):
        self.transform = T.Compose([
            T.ToTensor(),
            transforms.Normalize(_MEAN,_STD),
        ])

    def __call__(self, x):
        x = self.transform(x)
        return  x


def gpu_transformer(image_size,s=.2):
        
    train_transform = nn.Sequential(
                
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
                kornia.augmentation.RandomGrayscale(p=0.05),
            )

    test_transform = nn.Sequential(  
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                kornia.augmentation.ColorJitter(0.8*s,0.8*s,0.8*s,0.2*s,p=0.3),
                kornia.augmentation.RandomGrayscale(p=0.05),
        )

    return train_transform , test_transform
                
def get_clf_train_test_transform(image_size,s=.2):
        
    train_transform = nn.Sequential(
                
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
#                 kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_),
            )

    test_transform = nn.Sequential(  
                kornia.augmentation.RandomResizedCrop(image_size,scale=(0.5,1.0)),
                kornia.augmentation.RandomHorizontalFlip(p=0.5),
                # kornia.augmentation.RandomGrayscale(p=0.05),
                # kornia.augmentation.Normalize(CIFAR_MEAN_,CIFAR_STD_)
        )

    return train_transform , test_transform

#### The BYOL model is trained using the LARS optimizer.

#### After training, only the encoder of the online network is kept and a classifier can be trained over that encoder.

In [None]:
def get_train_test_dataloaders(dataset = "stl10", data_dir="./dataset", batch_size = 16,num_workers = 4, download=True): 
    
    train_loader = torch.utils.data.DataLoader(
        dataset = datasets.STL10(root=data_dir, split='unlabeled',transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
    )
    

    test_loader = torch.utils.data.DataLoader(
        dataset = datasets.STL10(root=data_dir, split='test',transform=InitalTransformation(), download=download),
        shuffle=True,
        batch_size= batch_size,
        num_workers = num_workers
        )
    return train_loader, test_loader

In [None]:
class BYOL(nn.Module):
    def __init__(self, backbone=None,base_target_ema=0.996,**kwargs):
        super().__init__()
        self.base_ema = base_target_ema
        
        if backbone is None:
            backbone = models.resnet50(pretrained=False)
            backbone.output_dim = backbone.fc.in_features
            backbone.fc = torch.nn.Identity()

#         encoder = torch.nn.Sequential(*list(backbone.children())[:-1])
        projector = MLPHead(in_dim=backbone.output_dim)
        
        self.online_encoder = nn.Sequential(
            backbone,
            projector)
        
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.online_predictor = MLPHead(in_dim=256,hidden_size=1024, projection_size=256)
        
            

    @torch.no_grad()
    def update_moving_average(self, global_step, max_steps):
        
        tau = 1- ((1 - self.base_ema)* (cos(pi*global_step/max_steps)+1)/2) 
        
        for online, target in zip(self.online_encoder.parameters(), self.target_encoder.parameters()):
            target.data = tau * target.data + (1 - tau) * online.data     
    
    def forward(self,x1,x2):
        
        z1 = self.online_encoder(x1)
        z2 = self.online_encoder(x2)
        
        q1 = self.online_predictor(z1)
        q2 = self.online_predictor(z2)
        
        with torch.no_grad():
            z1_t = self.target_encoder(x1)
            z2_t = self.target_encoder(x2)
       
        loss = loss_fn(q1, q2, z1_t, z2_t)
        
        return loss

In [None]:
import copy
from torch import nn
import torchvision.models as models

def loss_fn(q1,q2, z1t,z2t):
    
    l1 = - tf.cosine_similarity(q1, z1t.detach(), dim=-1).mean()
    l2 = - tf.cosine_similarity(q2, z2t.detach(), dim=-1).mean()
    
    return (l1+l2)/2


class MLPHead(nn.Module):
    def __init__(self, in_dim, hidden_size=4096, projection_size=256):
        super(MLPHead, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_size, projection_size)
        )
    def forward(self, x):
        return self.net(x)
    



In [None]:
if torch.cuda.is_available():
    dtype = torch.cuda.FloatTensor
    device = torch.device("cuda")
    # torch.cuda.set_device(device_id)
else:
    dtype = torch.FloatTensor
    device = torch.device("cpu")
    
print(device)

cuda


In [None]:
weight_decay = 1.5e-6
warmup_epochs =  10
warmup_lr = 0
momentum = 0.9
lr =  0.002
final_lr =  0
epochs = 50
stop_at_epoch = 100
batch_size = 256
knn_monitor = False
knn_interval = 5
knn_k = 200
image_size = (92,92)

In [None]:
train_loader, test_loader = get_train_test_dataloaders(batch_size=batch_size)
train_transform,test_transform = gpu_transformer(image_size)

Downloading http://ai.stanford.edu/~acoates/stl10/stl10_binary.tar.gz to ./dataset/stl10_binary.tar.gz


0it [00:00, ?it/s]

Extracting ./dataset/stl10_binary.tar.gz to ./dataset
Files already downloaded and verified


In [None]:

from lr_scheduler import LR_Scheduler
from lars import LARS

loss_ls = []
acc_ls = []

model = BYOL().to(device)


optimizer = LARS(model.named_modules(), lr=lr, momentum=momentum, weight_decay=weight_decay)

        
scheduler = LR_Scheduler(
        optimizer, warmup_epochs, warmup_lr*batch_size/8,

        epochs, lr*batch_size/8, final_lr*batch_size/8, 
        len(train_loader),
        constant_predictor_lr=True 
        )

In [None]:
min_loss = np.inf 
accuracy = 0


# start training 
logger = Logger(log_dir=log_dir, tensorboard=True, matplotlib=True)
global_progress = tqdm(range(0, epochs), desc=f'Training')
data_dict = {"loss": 100}

for epoch in global_progress:
    model.train()   
    local_progress = tqdm(train_loader, desc=f'Epoch {epoch}/{epochs}')
    
    for idx, (image, label) in enumerate(local_progress):
        image = image.to(device)
        aug_image = train_transform(image)
 
        model.zero_grad()
        loss = model.forward(image.to(device, non_blocking=True), aug_image.to(device, non_blocking=True))

        loss_scaler = loss.item()
        data_dict['loss'] = loss_scaler
        loss_ls.append(loss_scaler)
        loss.backward()
        
        optimizer.step()
        model.update_moving_average(epoch, epochs)
        
        scheduler.step()
        
        data_dict.update({'lr': scheduler.get_last_lr()})
        local_progress.set_postfix(data_dict)
        logger.update_scalers(data_dict)
    
    current_loss = data_dict['loss']
    
    global_progress.set_postfix(data_dict)
    logger.update_scalers(data_dict)
    
    model_path = os.path.join(ckpt_dir, f"{uid}_{datetime.now().strftime('%m%d%H%M%S')}.pth")

    if min_loss > current_loss:
        min_loss = current_loss
        
        torch.save({
        'epoch':epoch+1,
        'online_network': model.online_encoder.state_dict(),
        'target_network': model.target_encoder.state_dict()}, model_path)
        print(f'Model saved at: {model_path}')

Training:   0%|          | 0/50 [00:00<?, ?it/s]

Epoch 0/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828173557.pth


Epoch 1/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828174250.pth


Epoch 2/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828174939.pth


Epoch 3/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828175636.pth


Epoch 4/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828180334.pth


Epoch 5/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828181024.pth


Epoch 6/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828181718.pth


Epoch 7/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828182410.pth


Epoch 8/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828183105.pth


Epoch 9/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828183803.pth


Epoch 10/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 11/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828185149.pth


Epoch 12/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828185841.pth


Epoch 13/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828190535.pth


Epoch 14/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 15/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828191928.pth


Epoch 16/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 17/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 18/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 19/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828194726.pth


Epoch 20/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 21/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 22/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828200810.pth


Epoch 23/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 24/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828202155.pth


Epoch 25/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828202850.pth


Epoch 26/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 27/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 28/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 29/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 30/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 31/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 32/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 33/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 34/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 35/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 36/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 37/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 38/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 39/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 40/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 41/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 42/50:   0%|          | 0/391 [00:00<?, ?it/s]

Model saved at: ./ckpt/0828172719/byol_0828222734.pth


Epoch 43/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 44/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 45/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 46/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 47/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 48/50:   0%|          | 0/391 [00:00<?, ?it/s]

Epoch 49/50:   0%|          | 0/391 [00:00<?, ?it/s]

In [None]:
from matplotlib import pyplot as plt
plt.figure(111)
plt.plot(loss_ls)
plt.ylabel('loss')
plt.show()