# DINO - self-DIstillation with NO labels - https://arxiv.org/pdf/2104.14294v2

# Import bibliotek

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from src.dataset import *
from src.early_stopping import *
from src.self_supervised_modules import *
from src.utils import *
import torch.nn.functional  as F

# Wczytanie danych do trenowania końcowego klasyfikatora

In [None]:
data_train = np.load('data_transform/subset0/data.npz')
data_train

In [None]:
X_train = data_train['images']
X_train.shape

In [None]:
y_train = data_train['labels']
y_train.shape

# Wczytanie danych do trenowania enkodera

In [None]:
images1 = np.load('data_transform/subset1/data.npz')['images']
images2 = np.load('data_transform/subset2/data.npz')['images']
images3 = np.load('data_transform/subset3/data.npz')['images']

all_images = np.concatenate((images1, images2, images3))
all_images.shape

# Wczytanie danych do walidacji klasyfikatora

In [None]:
data_val = images4 = np.load('data_transform/subset4/data.npz')
data_val

In [None]:
X_val = data_val['images']
X_val.shape

In [None]:
y_val = data_val['labels']
y_val.shape

# Definicja augmentacji
DINO wymaga cropowania obrazów odpowiednio obrazów dla nauczyciela i studenta. Zaimplementujmy ten typ augmentacji

In [None]:
class RandomResizedCrop3D(object):
    def __init__(self, output_size):
        assert isinstance(output_size, (int, tuple))
        if isinstance(output_size, int):
            self.output_size = (output_size, output_size, output_size)
        else:
            assert len(output_size) == 3
            self.output_size = output_size

    def __call__(self, volume):
        d, h, w = volume.shape  # Assuming volume is a 3D tensor (depth, height, width)
        new_d, new_h, new_w = self.output_size

        # Randomly crop
        if d == new_d and h == new_h and w == new_w:
            return volume  # No need to crop if dimensions match exactly

        if d < new_d or h < new_h or w < new_w:
            raise ValueError(f"Requested crop size ({self.output_size}) larger than input size ({volume.shape})")

        # Randomly choose top left corner for cropping
        top_d = random.randint(0, d - new_d)
        top_h = random.randint(0, h - new_h)
        top_w = random.randint(0, w - new_w)

        # Perform cropping
        cropped_volume = volume[top_d: top_d + new_d,
                                top_h: top_h + new_h,
                                top_w: top_w + new_w]

        return cropped_volume

def global_augment(images):
    transform = RandomResizedCrop3D(output_size=(8, 28, 28))
    tmp = []
    for img in images:
        tmp.append(transform(img.squeeze()))
    
    return torch.stack(tmp).unsqueeze(1)

def local_augment(images):
    transform = RandomResizedCrop3D(output_size=(4, 12, 12))
    tmp = []
    for img in images:
        tmp.append(transform(img.squeeze()))
    
    return torch.stack(tmp).unsqueeze(1)

# Definicja DINO

In [None]:
class DINO(nn.Module):
    def __init__(self, student_arch, teacher_arch, device: torch.device):
        """
        Args:
            student_arch (nn.Module): ViT Network for student_arch
            teacher_arch (nn.Module): ViT Network for teacher_arch
            device: torch.device ('cuda' or 'cpu')
        """
        super(DINO, self).__init__()
    
        self.student = student_arch(image_planes=8, images_width=28, images_height=28).to(device)
        self.teacher = teacher_arch(image_planes=4, images_width=12, images_height=12).to(device)
        self.teacher.load_state_dict(self.student.state_dict())

        self.register_buffer('center', torch.zeros(1, 128).to('cuda'))

        for param in self.teacher.parameters():
            param.requires_grad = False

    @staticmethod
    def distillation_loss(student_output, teacher_output, center, tau_s, tau_t):
        """
        Calculates distillation loss with centering and sharpening (function H in pseudocode).
        """
        # Detach teacher output to stop gradients.
        teacher_output = teacher_output.detach()

        # Center and sharpen teacher's outputs
        teacher_probs = F.softmax((teacher_output - center) / tau_t, dim=1)

        # Sharpen student's outputs
        student_probs = F.log_softmax(student_output / tau_s, dim=1)

        # Calculate cross-entropy loss between student's and teacher's probabilities.
        loss = - (teacher_probs * student_probs).sum(dim=1).mean()
        return loss

    def teacher_update(self, beta: float):
        for teacher_params, student_params in zip(self.teacher.parameters(), self.student.parameters()):
            teacher_params.data.mul_(beta).add_(student_params.data, alpha=(1 - beta))

# Definicja enkodera
Niestety, ale DINO ma problem z użyciem enkodera zawartego w self_supervised_modules. Jesteśmy zmuszeni zaimplementować oddzielny enkoder specjalnie na potrzeby DINO (który jest modyfikacją uniwersalnego enkodera)

Jednak mimo użycia innego enkodera końcowy model klasyfikacyjny będzie taki sam

In [None]:
class MEncoder(nn.Module):
    def __init__(self, channels=None, out_size=128, image_planes=10, images_width=32, images_height=32):
        super(MEncoder, self).__init__()
        if channels is None:
            channels = [1, 16, 32]
        
        self.out_size = out_size
        
        self.model = nn.Sequential()
        
        # Keeping track of the dimensions
        for i in range(len(channels) - 1):
            self.model.add_module(f'conv_{i}', nn.Conv3d(channels[i], channels[i+1], kernel_size=3, padding=1))
            self.model.add_module(f'relu_{i}', nn.ReLU())
            self.model.add_module(f'maxpool_{i}', nn.MaxPool3d(2))
        
        self.model.add_module("adaptive_pool", nn.AdaptiveAvgPool3d((1, 1, 1)))
        self.model.add_module("flatten", nn.Flatten())
        self.model.add_module("linear", nn.Linear(channels[-1], out_size))
    
    def forward(self, x):
        return self.model(x)

# Trenowanie enkodera używanego w DINO

In [None]:
def train_dino(dino: DINO,
               data_loader: DataLoader,
               optimizer: optim,
               device: torch.device,
               num_epochs,
               tps=0.99,
               tpt= 0.04,
               beta= 0.9,
               m= 0.9,
               ):
    """
    Args:
    dino: DINO Module
    data_loader (nn.Module): Dataloader for training
    optimizer (nn.optimizer): Optimizer for optimization (SGD etc.)
    defice (torch.device): 'cuda', 'cpu'
    num_epochs: Number of Epochs
    tps (float): tau for sharpening student logits
    tpt: for sharpening teacher logits
    beta (float): moving average decay 
    m (float): center moveing average decay
    """

    losses = []
    for epoch in range(num_epochs):
        dino.student.train()
        dino.teacher.eval()
    
        loss_tab = []
        for x in tqdm(data_loader):
            x1 = global_augment(x)
            x2 = local_augment(x)

            student_output1, student_output2 = dino.student(x1.to(device)), dino.student(x2.to(device))
            with torch.no_grad():
                teacher_output1, teacher_output2 = dino.teacher(x1.to(device)), dino.teacher(x2.to(device))

            # Compute distillation loss
            loss = (dino.distillation_loss(teacher_output1.to(device), student_output2.to(device), dino.center, tps, tpt) +
                    dino.distillation_loss(teacher_output2.to(device), student_output1.to(device), dino.center, tps, tpt)) / 2
            loss_tab.append(loss.clone().cpu())
            # Backpropagation
            optimizer.zero_grad()
            loss.requires_grad = True
            loss.backward()
            optimizer.step()

            dino.teacher_update(beta)
            
            with torch.no_grad():
                dino.center = m * dino.center + (1 - m) * torch.cat([teacher_output1, teacher_output2], dim=0).mean(dim=0)

        sum_loss = np.sum(loss_tab)
        print(f"Epoch: {epoch}, Loss: {sum_loss}")
        losses.append(sum_loss)

    plt.plot(losses)
    plt.title("Loss", fontsize=18)
    plt.xlabel("Iteration", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.show()

In [None]:
seed_everything()

# 50 - default, 3 - for testing the code
no_epochs = 50
# no_epochs = 3

unlabelled_dataset = UnlabelledDataset(all_images)
dino = DINO(MEncoder, MEncoder, 'cuda')
unlabelled_dataloader = DataLoader(unlabelled_dataset, batch_size=256, shuffle=True, drop_last=True)
optimizer = optim.Adam(dino.parameters(), lr=0.00001)

train_dino(dino, unlabelled_dataloader, optimizer,'cuda', no_epochs)

# Trenowanie klasyfikatora

In [None]:
seed_everything()

train_dataset = LabelledDataset(X_train, y_train)
val_dataset = LabelledDataset(X_val, y_val)

train_dl = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=64, shuffle=True)

model = SelfSupervisedClassifier(dino.student, 2).cuda()
optimiser = optim.Adam(model.parameters(), lr=0.001)
loss_fn = torch.nn.CrossEntropyLoss()

fit_classifier(
    model=model, optimiser=optimiser, loss_fn=loss_fn,
    train_dl=train_dl, val_dl=val_dl, epochs=50, early_stop=EarlyStopping(model_dir='model/dino', patience=5), print_metrics=True
)

# Liczenie metryk klasyfikacyjnych

In [None]:
val_dataset = LabelledDataset(X_val, y_val)
val_dl = DataLoader(val_dataset, batch_size=64, shuffle=True)
calculate_metrics(model, val_dl)