In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
os.environ["CUDA_VISIBLE_DEVICES"]='3'

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import timm
import numpy as np
import math
from lightly.data import LightlyDataset

In [4]:
import copy

import torch
import torchvision
from torch import nn

from my_dino_components_best import My_DINOLoss
from lightly.models.modules import DINOProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule
from iqa_distortions_best import Transform_Global_WD

from torch.nn.utils import clip_grad_value_

In [5]:
SEED=3407
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic=True
torch.backends.cudnn.benchmark = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
batch_size=32
num_workers=50
input_size = 224
out_dim = 256
detype = 'alt'
level = 5
margin = 1.0
tri_type = 'tri'

In [9]:
path_to_data = "//home//lt//my_complex_sector_240301//"


In [10]:
from lightly.data._image_loaders import pil_loader
if pil_loader(path_to_data+os.listdir(path_to_data)[0]).mode != 'RGB':
    raise ValueError("Wrong channel!")

In [11]:
class DINO(torch.nn.Module):
    def __init__(self, backbone, input_dim):
        super().__init__()
        self.student_backbone = backbone
        self.student_head = DINOProjectionHead(
            input_dim, 512, 64, out_dim, freeze_last_layer=1
        )
        self.teacher_backbone = copy.deepcopy(backbone)
        self.teacher_head = DINOProjectionHead(input_dim, 512, 64, out_dim)
        deactivate_requires_grad(self.teacher_backbone)
        deactivate_requires_grad(self.teacher_head)

    def forward(self, x):
        y = self.student_backbone(x).flatten(start_dim=1)
        z = self.student_head(y)
        return z

    def forward_teacher(self, x):
        y = self.teacher_backbone(x).flatten(start_dim=1)
        z = self.teacher_head(y)
        return z


In [12]:
backbone =  timm.create_model('resnet18', pretrained = False,in_chans = 1)
hidden_dim = backbone.fc.in_features
backbone.reset_classifier(0)
temp_model = DINO(backbone, hidden_dim)

device = "cuda" if torch.cuda.is_available() else "cpu"
temp_model.to(device)

DINO(
  (student_backbone): ResNet(
    (conv1): Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (act1): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (drop_block): Identity()
        (act1): ReLU(inplace=True)
        (aa): Identity()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act2): ReLU(inplace=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding

In [14]:
transform = Transform_Global_WD(level = level,detype = detype)
dataset_train = LightlyDataset(input_dir=path_to_data, transform=transform)
dataloader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True,
    num_workers=num_workers,
)

In [16]:
def my_clip_grad_value_(parameters, clip_value):
    if isinstance(parameters, torch.Tensor):
        parameters = [parameters]
    clip_value = float(clip_value)
    for p in filter(lambda p: p.grad is not None, parameters):
        p.grad.data.clamp_(min=-clip_value, max=clip_value)
        p.grad.data = p.grad.data.nan_to_num(nan=0.0)

In [17]:
level_range = 5
for level_i in range(level_range):
    level = level_range-level_i
    transform = Transform_Global_WD(level = level,detype = detype)
    dataset_train = LightlyDataset(input_dir=path_to_data, transform=transform)
    dataloader_train = torch.utils.data.DataLoader(
        dataset_train,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=num_workers,
    ) 
    epochs = 120
    criterion = My_DINOLoss(device = device,output_dim=out_dim,tri_type = tri_type,margin = margin)
    optim = torch.optim.Adam(temp_model.parameters(), lr=0.0001, weight_decay=0.05)
    lambda_w_warm = lambda epoch: (epoch / 20) if epoch < 20 else 0.5 * (math.cos((epoch - 20)/(100) * math.pi) + 1)
    scheduler = torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=lambda_w_warm)
    for epoch in range(epochs):
        total_loss = 0
        momentum_val = 0.999
        for batch in dataloader_train:
            views = batch[0]
            update_momentum(temp_model.student_backbone, temp_model.teacher_backbone, m=momentum_val)
            update_momentum(temp_model.student_head, temp_model.teacher_head, m=momentum_val)
            views = [view.to(device) for view in views]
            raw_views = views[0:1]
            nonde_views = views[1:2]
            de_views = views[2:3]
            teacher_out = [temp_model.forward_teacher(view) for view in raw_views] 
            student_nonde_out = [temp_model.forward(view) for view in nonde_views]
            student_de_out = [temp_model.forward(view) for view in de_views]
            # print(teacher_out[0].shape,student_out[0].shape)

            # teacher_out.to(device)
            # student_out.to(device)
            loss = criterion(teacher_out, student_nonde_out, student_de_out, epoch=epoch)
            total_loss += loss.detach()
            loss.backward()
            # We only cancel gradients of student head.
            temp_model.student_head.cancel_last_layer_gradients(current_epoch=epoch)
            my_clip_grad_value_(temp_model.parameters(),2.0)
            optim.step()
            optim.zero_grad()
        avg_loss = total_loss / len(dataloader_train)
        scheduler.step()

        print(f"level: {level:>01} epoch: {epoch:>02}, loss: {avg_loss:.5f}")



level: 5 epoch: 00, loss: 13.74884
level: 5 epoch: 01, loss: 12.38223
level: 5 epoch: 02, loss: 8.71642
level: 5 epoch: 03, loss: 6.27088
level: 5 epoch: 04, loss: 4.52726
level: 5 epoch: 05, loss: 3.05764
level: 5 epoch: 06, loss: 2.05166
level: 5 epoch: 07, loss: 1.33708
level: 5 epoch: 08, loss: 0.93956
level: 5 epoch: 09, loss: 0.72083
level: 5 epoch: 10, loss: 0.42041
level: 5 epoch: 11, loss: 0.28374
level: 5 epoch: 12, loss: 0.20895
level: 5 epoch: 13, loss: 0.22952
level: 5 epoch: 14, loss: 0.15959
level: 5 epoch: 15, loss: 0.15340
level: 5 epoch: 16, loss: 0.13845
level: 5 epoch: 17, loss: 0.12644
level: 5 epoch: 18, loss: 0.08482
level: 5 epoch: 19, loss: 0.10641
level: 5 epoch: 20, loss: 0.07852
level: 5 epoch: 21, loss: 0.15067
level: 5 epoch: 22, loss: 0.07923
level: 5 epoch: 23, loss: 0.07334
level: 5 epoch: 24, loss: 0.07524
level: 5 epoch: 25, loss: 0.10780
level: 5 epoch: 26, loss: 0.05676
level: 5 epoch: 27, loss: 0.05139
level: 5 epoch: 28, loss: 0.03888
level: 5 epo