# BYOL - https://arxiv.org/pdf/2006.07733

BYOL był już przedstawiany na liście 3. Nie ma więc sensu go tutaj przedstawiać kolejny raz

# Import bibliotek

In [None]:
import copy
import numpy as np
import matplotlib.pyplot as plt
from torch import Tensor
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.dataset import *
from src.early_stopping import *
from src.self_supervised_modules import *
from src.utils import *

# Wczytanie danych do trenowania końcowego klasyfikatora

In [None]:
data_train = np.load('data_transform/subset0/data.npz')
data_train

In [None]:
X_train = data_train['images']
X_train.shape

In [None]:
y_train = data_train['labels']
y_train.shape

# Wczytanie danych do trenowania enkodera

In [None]:
images1 = np.load('data_transform/subset1/data.npz')['images']
images2 = np.load('data_transform/subset2/data.npz')['images']
images3 = np.load('data_transform/subset3/data.npz')['images']

all_images = np.concatenate((images1, images2, images3))
all_images.shape

# Wczytanie danych do walidacji klasyfikatora

In [None]:
data_val = images4 = np.load('data_transform/subset4/data.npz')
data_val

In [None]:
X_val = data_val['images']
X_val.shape

In [None]:
y_val = data_val['labels']
y_val.shape

# Definicja augmentacji
Ponieważ nasz zbiór danych to 10 najważniejszych obrazów to augmentacje, które zastosujemy muszą być identyczne dla każdego obrazu, żeby nie zatracić żadnej informacji.

Pomysły na augmentację:
1. Horizontal i Vertical flip obraza
2. Rotacja

Trudno wprowadzić więcej augmentacji gdyż ryzykujemy wówczas zbytnie zmienienie obrazów

In [None]:
def apply_transformations(images, chance_flip_horizontal=0.5, chance_flip_vertical=0.5, chance_rotate=0.5, rotate_angle_max=45, rotate_angle_min=-45):
    batch_size = images.shape[0]
    no_images = images.shape[2]

    for i in range(batch_size):
        flip_horizontal = random.random() < chance_flip_horizontal
        flip_vertical = random.random() < chance_flip_vertical
        rotate = random.random() < chance_rotate
        if rotate:
            rorate_angle = random.randint(rotate_angle_min, rotate_angle_max)
        
        for j in range(no_images):
            image_augment = images[i, 0, j]
            if flip_horizontal:
                image_augment = torchvision.transforms.functional.hflip(image_augment)
            if flip_vertical:
                image_augment = torchvision.transforms.functional.vflip(image_augment)
            if rotate:
                image_augment = image_augment.unsqueeze(0)
                image_augment = torchvision.transforms.functional.rotate(image_augment, rorate_angle, fill=-1000)
                image_augment = image_augment.squeeze()

            images[i, 0, j] = image_augment

    return images

# Definicja BYOL

In [None]:
class BYOL(nn.Module):
    def __init__(self, encoder, tau: float = 0.999):
        super(BYOL, self).__init__()
        
        # Initialize online network
        self.online_encoder = encoder
        self.online_projector = MLP(encoder.out_size, encoder.out_size, encoder.out_size, plain_last=False)
        self.online_predictor = MLP(encoder.out_size, encoder.out_size, encoder.out_size, plain_last=True)
        self.online_net = nn.Sequential(
            self.online_encoder, 
            self.online_projector, 
            self.online_predictor,
        )

        # Initialize target network with frozen weights
        self.target_encoder = self.copy_and_freeze_module(self.online_encoder)
        self.target_projector = self.copy_and_freeze_module(self.online_projector)
        self.target_net = nn.Sequential(self.target_encoder, self.target_projector)

        self.tau = tau

    def forward(self, x: Tensor) -> tuple[Tensor, Tensor]:
        x = x.cuda()
        
        t = apply_transformations(x)
        t_prim = apply_transformations(x)

        q = self.online_net(t)
        q_sym = self.online_net(t_prim)
        
        with torch.no_grad():
            z_prim = self.target_net(t_prim)
            z_prim_sym = self.target_net(t)

        q = torch.cat([q, q_sym], dim=0)
        z_prim = torch.cat([z_prim, z_prim_sym], dim=0)

        return q, z_prim

    @staticmethod
    def byol_loss(q: Tensor, z_prim: Tensor) -> Tensor:
        q = F.normalize(q, dim=-1)
        z_prim = F.normalize(z_prim, dim=-1)
        mult_sum = (q * z_prim).sum(dim = -1)
        return (2 - 2 * mult_sum).mean()

    @torch.no_grad()
    def update_target_network(self) -> None:
        for target_param, online_param in zip(self.target_net.parameters(), self.online_net.parameters()):
            target_param.data = self.tau * target_param.data + (1 - self.tau) * online_param.data


    @staticmethod
    def copy_and_freeze_module(model: nn.Module) -> nn.Module:
        model_copy = copy.deepcopy(model)
        for param in model_copy.parameters():
            param.requires_grad = False
    
        return model_copy

# Trenowanie enkodera używanego w BYOL

In [None]:
def train_byol(model, optimiser, dataloader, no_epochs=100, augment_params=None):
    if augment_params is None:
        augment_params = {}

    losses = []
    for epoch in range(no_epochs):
        sum_loss = 0
        for images in tqdm(dataloader):
            q, z_prim = model(images)
            
            loss = BYOL.byol_loss(q, z_prim)
    
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
    
            model.update_target_network()

            sum_loss += loss.item()
            
        print(f"Epoch {epoch}: loss = {sum_loss:.3f}")
        losses.append(sum_loss)
        
    plt.plot(losses)
    plt.title("Loss", fontsize=18)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()

In [None]:
seed_everything()

# 50 - default, 3 - for testing the code
no_epochs = 50
# no_epochs = 3

unlabelled_dataset = UnlabelledDataset(all_images)
unlabelled_dataloader = DataLoader(unlabelled_dataset, batch_size=256, shuffle=True, drop_last=True)

encoder = Encoder()
byol = BYOL(encoder).cuda()
optimiser = optim.SGD(byol.parameters(), lr=0.001)

train_byol(byol, optimiser, unlabelled_dataloader, no_epochs)

# Trenowanie klasyfikatora

In [None]:
seed_everything()

train_dataset = LabelledDataset(X_train, y_train)
val_dataset = LabelledDataset(X_val, y_val)

train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=64, shuffle=True)

model = SelfSupervisedClassifier(byol.online_encoder, 2).cuda()
optimiser = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

fit_classifier(
    model=model, optimiser=optimiser, loss_fn=loss_fn,
    train_dl=train_dl, val_dl=val_dl, epochs=50, early_stop=EarlyStopping(model_dir='model/byol', patience=5), print_metrics=True
)

# Liczenie metryk klasyfikacyjnych

In [None]:
val_dataset = LabelledDataset(X_val, y_val)
val_dl = DataLoader(val_dataset, batch_size=64, shuffle=True)
calculate_metrics(model, val_dl)