In [None]:
import copy
import pytorch_lightning as pl
import torch
import torchvision
from torch import nn
import math
import numpy as np
import torch.distributed as dist
from torch.nn import functional as F
from torchvision import transforms
from torchvision.datasets import STL10

In [None]:
class DINOTransform:
    def __init__(self):
        self.global_transforms = transforms.Compose([
            transforms.RandomResizedCrop(224, scale=(0.4, 1.0)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=23),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self.local_transforms = transforms.Compose([
            transforms.RandomResizedCrop(96, scale=(0.05, 0.4)),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1),
            transforms.RandomGrayscale(p=0.2),
            transforms.GaussianBlur(kernel_size=9),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def __call__(self, img):
        crops = [self.global_transforms(img) for _ in range(2)]
        crops += [self.local_transforms(img) for _ in range(6)]
        return crops

In [None]:
def cosine_schedule(current_epoch, max_epochs, base_value, final_value):  # to wzialem z dokumentacji ktorejs
    return final_value + 0.5 * (base_value - final_value) * (1 + math.cos(math.pi * current_epoch / max_epochs))

def update_momentum(student, teacher, m):  # to tez wzialem z dokumentacji ktorejs
    for param_student, param_teacher in zip(student.parameters(), teacher.parameters()):
        param_teacher.data = param_teacher.data * m + param_student.data * (1.0 - m)

def deactivate_requires_grad(model):
    for param in model.parameters():
        param.requires_grad = False
        
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):  # te dwie funkcje kurwa nie mam pojecia co robia
    return _no_grad_trunc_normal_(tensor, mean, std, a, b)

def _no_grad_trunc_normal_(tensor, mean, std, a, b):
    def norm_cdf(x):
        return (1. + math.erf(x / math.sqrt(2.))) / 2.

    with torch.no_grad():
        l = norm_cdf((a - mean) / std)
        u = norm_cdf((b - mean) / std)

        tensor.uniform_(2 * l - 1, 2 * u - 1)
        tensor.erfinv_()
        tensor.mul_(std * math.sqrt(2.))
        tensor.add_(mean)
        tensor.clamp_(min=a, max=b)
        return tensor

In [None]:
class DINOHead(nn.Module):
    def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True, nlayers=3, hidden_dim=2048, bottleneck_dim=256):
        super().__init__()
        nlayers = max(nlayers, 1)
        if nlayers == 1:
            self.mlp = nn.Linear(in_dim, bottleneck_dim)
        else:
            layers = [nn.Linear(in_dim, hidden_dim)]
            if use_bn:
                layers.append(nn.BatchNorm1d(hidden_dim))
            layers.append(nn.GELU())  # GELU tez pierwsze na oczy ale niektore implementacje dino maja wlasnie
            for _ in range(nlayers - 2):
                layers.append(nn.Linear(hidden_dim, hidden_dim))
                if use_bn:
                    layers.append(nn.BatchNorm1d(hidden_dim))
                layers.append(nn.GELU())
            layers.append(nn.Linear(hidden_dim, bottleneck_dim))
            self.mlp = nn.Sequential(*layers)
        self.apply(self._init_weights)
        self.last_layer = nn.utils.weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
        self.last_layer.weight_g.data.fill_(1)
        if norm_last_layer:
            self.last_layer.weight_g.requires_grad = False

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.mlp(x)
        x = nn.functional.normalize(x, dim=-1, p=2)
        x = self.last_layer(x)
        return x

In [None]:
class DINOLoss(nn.Module):  # tez z dokumentacji
    def __init__(self, out_dim, ncrops, warmup_teacher_temp, teacher_temp,
                 warmup_teacher_temp_epochs, nepochs, student_temp=0.1,
                 center_momentum=0.9):  # tu nie wiem dokladnie jak te wartosci interpretowac
        super().__init__()
        self.student_temp = student_temp
        self.center_momentum = center_momentum
        self.ncrops = ncrops
        self.register_buffer("center", torch.zeros(1, out_dim))  # wgl te centrum dziwne jest
        self.teacher_temp_schedule = np.concatenate((
            np.linspace(warmup_teacher_temp,
                        teacher_temp, warmup_teacher_temp_epochs),
            np.ones(nepochs - warmup_teacher_temp_epochs) * teacher_temp
        ))

    def forward(self, student_output, teacher_output, epoch):
        student_out = student_output / self.student_temp
        student_out = student_out.chunk(self.ncrops)

        temp = self.teacher_temp_schedule[epoch]
        teacher_out = F.softmax((teacher_output - self.center) / temp, dim=-1)
        teacher_out = teacher_out.detach().chunk(2)

        total_loss = 0
        n_loss_terms = 0
        for iq, q in enumerate(teacher_out):
            for v in range(len(student_out)):
                if v == iq:
                    continue
                loss = torch.sum(-q * F.log_softmax(student_out[v], dim=-1), dim=-1)
                total_loss += loss.mean()
                n_loss_terms += 1
        total_loss /= n_loss_terms
        self.update_center(teacher_output)
        return total_loss

    @torch.no_grad()
    def update_center(self, teacher_output):
        batch_center = torch.sum(teacher_output, dim=0, keepdim=True)
        dist.all_reduce(batch_center)
        batch_center = batch_center / (len(teacher_output) * dist.get_world_size())

        self.center = self.center * self.center_momentum + batch_center * (1 - self.center_momentum)

In [None]:
class DINO(pl.LightningModule):
    def __init__(self):
        super().__init__()
        resnet = torchvision.models.resnet18(pretrained=True)
        backbone = nn.Sequential(*list(resnet.children())[:-1])
        input_dim = 512

        self.student_backbone = backbone
        self.student_head = DINOHead(input_dim, 512, use_bn=True, norm_last_layer=True)
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOHead(input_dim, 512, use_bn=True, norm_last_layer=True)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

        self.criterion = DINOLoss(out_dim=512, ncrops=8, warmup_teacher_temp=0.04, teacher_temp=0.07,
                                  warmup_teacher_temp_epochs=5, nepochs=100)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z

    def training_step(self, batch, batch_idx):
        momentum = cosine_schedule(self.current_epoch, self.trainer.max_epochs, 0.996, 1.0)
        update_momentum(self.student_backbone, self.teacher_backbone, m=momentum)
        update_momentum(self.student_head, self.teacher_head, m=momentum)
        views = batch[0]
        views = [view.to(self.device) for view in views]
        global_views = views[:2]
        teacher_out = [self.forward_teacher(view) for view in global_views]
        student_out = [self.forward(view) for view in views]
        loss = self.criterion(student_out, teacher_out, epoch=self.current_epoch)
        return loss

    def on_after_backward(self):
        self.student_head.last_layer.weight_g.grad *= (self.current_epoch >= self.criterion.warmup_teacher_temp_epochs)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

In [None]:
model = DINO()

transform = DINOTransform()
dataset = STL10(
    root="data/unlabeled-dino",
    split='unlabeled',
    download=True,
    transform=transform
)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=64,
    shuffle=True
)

trainer = pl.Trainer(max_epochs=10, devices=1, accelerator="gpu" if torch.cuda.is_available() else "cpu")
trainer.fit(model=model, train_dataloaders=dataloader)