In [None]:
import random
import pandas as pd
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
# import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.models as models
from torchvision.transforms import InterpolationMode
import timm

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything()

# Self-Distillation 효과를 보고자 한다.
# 자체적으로 처리하기 보다는 2-Stage를 통해서 진행해본다.
# Teacher의 경우는 모두 Freeze를 시킨다.
# Student의 경우는 최대한 Low-Resolution image를 통해 최적화를 시킨다.
teacher = torch.load("best_xlarge_model.pt", map_location = "cpu")
student = torch.load("best_xlarge_model.pt", map_location = "cpu")
# student = timm.create_model("convnext_large.fb_in22k", pretrained = True, num_classes = 25)

In [None]:
extraction_teacher = torch.nn.Sequential(*list(teacher.children())[:-1])
extraction_student = torch.nn.Sequential(*list(student.children())[:-1])
head_teacher = teacher.head
head_student = student.head

class ConvNext(nn.Module):
    def __init__(self, extraction, head):
        super(ConvNext, self).__init__()
        self.extraction = extraction
        self.head = head
    def forward(self, x):
        x1 = self.extraction(x)
        x2 = self.head(x1)
        return x1, x2

def freeze(model):
    for i, (name, param) in enumerate(model.named_parameters()):
        param.requires_grad = False
freeze(extraction_teacher)
freeze(head_teacher)
teacher = ConvNext(extraction = extraction_teacher, head = head_teacher)
student = ConvNext(extraction = extraction_student, head = head_student)

In [None]:
random_tensor = torch.rand([64, 3, 224, 224])
x = student(random_tensor)
x[0].size(), x[1].size()
# x = x.view(x.size(0), -1)
# x.size()

In [None]:
for name, param in student.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

In [None]:
# best learning rate: 5e-5
CFG = {
    "LEARNING_RATE": 5e-5,
    "EPOCHS": 30,
    "BATCH_SIZE": 16,
    "DEVICE": torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
}

In [None]:
class ImageSet(Dataset):
    def __init__(self, img_low, img_high, transform = None, class_name = None, label = None):
        self.img_low = img_low
        self.img_high = img_high
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        images_low = self.img_low[idx]
        images_high = self.img_high[idx]
        label = self.label[idx]
        imgs_low = cv2.imread(images_low)
        imgs_low = cv2.cvtColor(imgs_low, cv2.COLOR_BGR2RGB)
        imgs_high = cv2.imread(images_high)
        imgs_high = cv2.cvtColor(imgs_high, cv2.COLOR_BGR2RGB)
        if self.transform:
            image_low = self.transform(image = imgs_low)["image"]
            image_high = self.transform(image = imgs_high)["image"]
        label = class_name[label]
        return image_low, image_high, label
    

class AugmentSet(Dataset):
    def __init__(self, img_low, img_high, transform = None, transform_augment = None, class_name = None, label = None):
        self.img_low = img_low
        self.img_high = img_high
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        images_low = self.img_low[idx]
        images_high = self.img_high[idx]
        label = self.label[idx]
        imgs_low = cv2.imread(images_low)
        imgs_low = cv2.cvtColor(imgs_low, cv2.COLOR_BGR2RGB)
        imgs_high = cv2.imread(images_high)
        imgs_high = cv2.cvtColor(imgs_high, cv2.COLOR_BGR2RGB)
        if self.transform_augment:
            image_low = self.transform_augment(image = imgs_low)["image"]
        if self.transform:    
            image_high = self.transform(image = imgs_high)["image"]
        label = class_name[label]
        return image_low, image_high, label

In [None]:
transform = A.Compose([
    A.Resize(height=224, width=224),
    A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
    ToTensorV2()
])
transform_augment = A.Compose([
    A.Resize(height = 224, width = 224),
    A.OneOf([A.HorizontalFlip(), A.VerticalFlip(), A.RandomRotate90()], p = 1),
    # A.OneOf([A.GaussianBlur(blur_limit = (1, 5)), A.MedianBlur(blur_limit = (1, 5))], p = 0.2),
    A.OneOf([A.RandomBrightnessContrast(brightness_limit = 0.1, contrast_limit = 0.1),\
             A.CLAHE()], p = 0.5),
    # A.OneOf([A.ElasticTransform(), A.GridDistortion()], p = 0.1),
    ToTensorV2()
])


data = pd.read_csv("train_.csv")
trainset, valset, _, _ = train_test_split(data, data["label"], test_size = 0.1, stratify = data["label"], random_state = 42)
_, augmentset, _, _ = train_test_split(trainset, trainset["label"], test_size = 0.5, stratify = trainset["label"], random_state = 42)
trainset = trainset.reset_index()
trainset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
valset = valset.reset_index()
valset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
augmentset = augmentset.reset_index()
augmentset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
# augmentset = trainset

In [None]:
np.unique(trainset["label"], return_counts = True)

In [None]:
classes = np.unique(data["label"])
class_name = {name: i for i, name in enumerate(classes)}

In [None]:
trainset = ImageSet(img_low = trainset["img_path"], img_high = trainset["upscale_img_path"], transform = transform, class_name = class_name, label = trainset["label"])
validset = ImageSet(img_low = valset["img_path"], img_high = valset["upscale_img_path"], transform = transform, class_name = class_name, label = valset["label"])
augmentset = ImageSet(img_low = augmentset["img_path"], img_high = augmentset["upscale_img_path"], transform = transform_augment, class_name = class_name, label = augmentset["label"])
# augmentset2 = ImageSet(img_low = augmentset["img_path"], img_high = augmentset["upscale_img_path"], transform = transform_augment, class_name = class_name, label = augmentset["label"])

In [None]:
def visualization(flag: bool = False):
    if flag:
        image1, image2 = trainset[0][0], augmentset[0][0]
        print(image1.size())
        image1_np = image1.numpy().transpose((1, 2, 0))
        image2_np = image2.numpy().transpose((1, 2, 0))
        image1_np = (image1_np - image1_np.min()) / (image1_np.max() - image1_np.min())
        image2_np = (image2_np - image2_np.min()) / (image2_np.max() - image2_np.min())
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(image1_np)
        plt.title('Image 1')

        plt.subplot(1, 2, 2)
        plt.imshow(image2_np)
        plt.title('Image 2')

        plt.show()
visualization(True)

In [None]:
trainloader = DataLoader(trainset + augmentset, batch_size = CFG["BATCH_SIZE"], shuffle = True, num_workers = 0)
validloader = DataLoader(validset, batch_size = CFG["BATCH_SIZE"], shuffle = False, num_workers = 0)

In [None]:
len(trainset), len(validset)

In [None]:
for i, (name, param) in enumerate(teacher.named_parameters()):
    param.requires_grad = False

In [None]:
# 현재 public 기준 좋은 alpha: 0.2, T: 3이 베스트 score
# alpha가 낮을 수록 학생 모델 학습에 집중
# temperature가 높을 수록 Hard Task를 학습하는 데 집중
# 초반부에는 Temperature가 낮기 때문에 alpha값을 낮은 값부터 시작해서 쉬운 것에 대해서는 우선 모델 자체적으로 학습 가능하게
# 후반부에는 Temperature가 높기 때문에 alpha값을 상대적으로 높게해서 모델의 output을 따르도록 해본다/
def distillation_loss(logits, labels, teacher_logits, student_rprs, teacher_rprs, mse_loss, temperature):
    # base alpha = temperature / 10
    alpha = 0.1
    T = temperature
    student_loss = F.cross_entropy(input = logits, target = labels)
    KL_div = nn.KLDivLoss(reduction = "batchmean")(F.log_softmax(logits/T, dim = 1), F.softmax(teacher_logits/T, dim = 1)) * (T*T)
    mse = mse_loss(teacher_rprs, student_rprs)
    total_loss = (1-alpha)*student_loss + alpha*KL_div + mse
    return total_loss

In [None]:
optimizer = optim.AdamW(params = student.parameters(), lr = CFG["LEARNING_RATE"], weight_decay = 1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5,\
    threshold_mode='abs', min_lr=1e-8, verbose=True)

def train(teacher, student, optimizer, train_loader, val_loader, scheduler, device):
    teacher.to(device)
    student.to(device)
    teacher.eval()
    # criterion = nn.CrossEntropyLoss().to(device)
    temperature = 3
    best_score = 0
    best_model = None
    val_score = []
    iteration_cnt = 0
    mse_loss = nn.MSELoss()
    for epoch in range(1, CFG['EPOCHS']+1):
        student.train()
        train_loss = []
        iteration_cnt += 1
        for i_low, i_high, labels in tqdm(iter(train_loader)):
            i_low = i_low.float().to(device)
            i_high = i_high.float().to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            student_rprs, student_output = student(i_low)
            teacher_rprs, teacher_output = teacher(i_high)
            loss = distillation_loss(student_output, labels, teacher_output, student_rprs, teacher_rprs, mse_loss, temperature)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
        
        _val_loss, _val_score = validation(teacher, student, val_loader, device, temperature)
        
        if iteration_cnt == 10:
            if temperature <= 5:
                iteration_cnt = 0
                print("Temperature upscaled: {}".format(temperature))
                    
        _train_loss = np.mean(train_loss)
        print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}] Val Loss : [{_val_loss:.5f}] Val F1 Score : [{_val_score:.5f}]')
        
        val_score.append(np.round(_val_score, 5))
        
        if scheduler is not None:
            # validation score를 기준으로 scheduler를 조정한다
            scheduler.step(_val_score)
            
        if best_score < _val_score:
            best_score = _val_score
            best_model = student
            torch.save(best_model, "xlarge_distill_best.pt")
            print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}], Best Val F1 Score : [{_val_score:.5f}]')
        
        plt.figure(figsize=(20, 10))
        plt.plot(val_score)
        for i, value in enumerate(val_score):
            plt.text(i, value, str(value), fontsize=12, ha='center')
        plt.savefig('xlarge_best.png', dpi=300, format='png')

    
    return best_model, val_score

In [None]:
def validation(teacher, student, val_loader, device, temperature):
    student.eval()
    val_loss = []
    preds, true_labels = [], []
    mse_loss = nn.MSELoss()
    with torch.no_grad():
        for i_low, i_high, labels in tqdm(iter(val_loader)):
            i_low = i_low.float().to(device)
            i_high = i_high.float().to(device)
            labels = labels.to(device)
            
            student_rprs, student_pred = student(i_low)
            teacher_rprs, teacher_pred = teacher(i_high)
            loss = distillation_loss(student_pred, labels, teacher_pred, student_rprs, teacher_rprs, mse_loss, temperature)
            
            
            preds += student_pred.argmax(1).detach().cpu().numpy().tolist()
            true_labels += labels.detach().cpu().numpy().tolist()
            
            val_loss.append(loss.item())
        
        _val_loss = np.mean(val_loss)
        _val_score = f1_score(true_labels, preds, average='macro')
    
    return _val_loss, _val_score

In [None]:
infer_model = train(teacher, student, optimizer, trainloader, validloader,\
    scheduler = scheduler, device = CFG["DEVICE"])
best_model = infer_model[0]
scores = infer_model[1]
# torch.save(best_model, "best_model_distillation_fix.pt")