In [None]:
import random
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import timm

In [None]:
data = pd.read_csv("train.csv")
classes = np.unique(data["label"])
class_name = {name: i for i, name in enumerate(classes)}

In [None]:
class Augmentset(Dataset):
    def __init__(self, img, transform = None, transform_ = None):
        self.img = img
        self.transform = transform
        self.transform_ = transform_
    
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self, idx):
        image = self.img.loc[idx, :]
        image1_pth = image[0]
        image2_pth = image[0]
        image1 = cv2.imread(image1_pth)
        image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
        image2 = cv2.imread(image2_pth)
        image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
        if self.transform:
            image1 = self.transform(image1)
        if self.transform_:
            image2 = self.transform_(image2)
        return image1, image2

transform_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.RandomResizedCrop(224),
    # transforms.RandomApply(
    #     [# brightness, contrast, saturation, hue
    #     transforms.ColorJitter(0.5, 0.5, 0.5, 0.2)
    #     ], p = 0.5
    # ),
    transforms.RandomApply(
        [# brightness, contrast, saturation, hue
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.GaussianBlur(kernel_size = 3, sigma = (1.0, 2.0))
        ], p = 1
    ),
    ]                    
)


transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.Resize(224),
    transforms.GaussianBlur(kernel_size = 3, sigma = (1.0, 2.0))
]                   
)

images = data
data = Augmentset(img = images, transform = transform, transform_ = transform_)

In [None]:
def visualization(flag: bool = False):
    if flag:
        image1, image2 = data[24]
        print(image1.size())
        image1_np = image1.numpy().transpose((1, 2, 0))
        image2_np = image2.numpy().transpose((1, 2, 0))
        image1_np = (image1_np - image1_np.min()) / (image1_np.max() - image1_np.min())
        image2_np = (image2_np - image2_np.min()) / (image2_np.max() - image2_np.min())
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(image1_np)
        plt.title('Image 1')

        plt.subplot(1, 2, 2)
        plt.imshow(image2_np)
        plt.title('Image 2')

        plt.show()
visualization(True)

In [None]:
# Loss Function
def nt_xent(x1, x2, t):
    # calcularte cosinesimilarity
    # x1[None, :, :] -> change the dimension of matrix
    # cos_sim -> 32 x 32 matrix -> batch 단위로 cosine similarity 계산
    N = x1.size(0)
    similarity = F.cosine_similarity(x1[None, :, :], x2[:, None, :], dim = -1)
    # to remove the calcuation of itself.
    mask = torch.eye(N, dtype = torch.bool).to(device)
    similarity /= t
    neg_sample = similarity.clone()
    neg_sample[mask.bool()] = float("-inf")
    similarity, neg_sample = torch.exp(similarity), torch.exp(neg_sample)
    probs = similarity / torch.sum(neg_sample, dim = 1)
    loss = -torch.log(probs.diagonal())
    return loss.mean()

In [None]:
trainloader = DataLoader(data, batch_size = 32, shuffle = True)
# for i, j in trainloader:
#     print(i.size())
#     print(j.size())
#     break

In [None]:
convnext = timm.create_model("convnext_base.fb_in22k", pretrained = True, num_classes = 0)
# random tensor
convnext = torch.nn.Sequential(*list(convnext.children())[:-1])
# srcnn = torch.load("best_psnr.pt")
# random_tensor = torch.ones([256, 3, 224, 224])
# effnet(random_tensor).size()
# convnext


In [None]:
# random_tensor = torch.ones([256, 3, 224, 224])
# convnext(random_tensor).size()

In [None]:
class simCLR(nn.Module):
    def __init__(self, embedding_size, model):
        super(simCLR, self).__init__()
        self.backbone = model
        self.fc1 = nn.Linear(1024 * 7 * 7, 2048)
        self.fc2 = nn.Linear(2048, embedding_size)
        torch.nn.init.kaiming_normal_(self.fc1.weight)
        torch.nn.init.kaiming_normal_(self.fc2.weight)
    def forward(self, x):
        x = self.backbone(x)
        x = x.reshape(-1, 1024 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
model = simCLR(embedding_size = 128, model = convnext)
# model = torch.load("resolution_ssl_11_.pt")

In [None]:
model

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
optimizer = optim.AdamW(params = model.parameters(), lr = 1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = 400)

In [None]:
model.to(device)
best_loss = 9999999
for epoch in range(1, 10 + 1):
    # 모델을 훈련 모드로 설정하는 method
    model.train()
    train_loss = []
    temperature = 0.5
    for img1, img2 in tqdm(trainloader):
        img1 = img1.to(device)
        img2 = img2.to(device)
        output1 = model(img1)
        output2 = model(img2)
        optimizer.zero_grad()
        loss = nt_xent(output1, output2, temperature)
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    _train_loss = np.mean(train_loss)
    scheduler.step()
    if best_loss > _train_loss:
        best_loss = _train_loss
        best_model = model
        torch.save(best_model, "res_ssl_{}_.pt".format(epoch))
        print("Best Loss: {}".format(best_loss))