In [None]:
import random
import pandas as pd
import numpy as np
import cv2
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
# import torchvision.transforms as transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision.models as models
import timm
import os


def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

seed_everything()


model = timm.create_model("convnext_xlarge.fb_in22k", pretrained = True, num_classes = 25)
torch.cuda.empty_cache()

In [None]:
# random_tensor = torch.rand([64, 3, 224, 224])
# x = model(random_tensor)
# x.size()
# # x = x.view(x.size(0), -1)
# # x.size()

In [None]:
CFG = {
    "LEARNING_RATE": 4e-5,
    "EPOCHS": 20,
    "BATCH_SIZE": 32,
    "DEVICE": torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
}

In [None]:
class ImageSet(Dataset):
    def __init__(self, img, transform = None, class_name = None, label = None):
        self.img = img
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        images = self.img[idx]
        label = self.label[idx]
        img = cv2.imread(images)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image = img)["image"]
        label = class_name[label]
        return image, label

In [None]:
class AugmentSet(Dataset):
    def __init__(self, img, transform = None, class_name = None, label = None):
        self.img = img
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self, idx):
        images = self.img[idx]
        label = self.label[idx]
        img = cv2.imread(images)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image = img)["image"]
        label = class_name[label]
        return image, label

In [None]:
transform = A.Compose([
    A.Resize(height=224, width=224),
    A.Normalize(mean = [0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
    ToTensorV2()
])
transform_augment = A.Compose([
    A.Resize(height = 224, width = 224),
    A.OneOf([A.HorizontalFlip(), A.VerticalFlip(), A.RandomRotate90()], p = 0.8),
    A.OneOf([A.GaussianBlur(blur_limit = (1, 5)), A.MedianBlur(blur_limit = (1, 5))], p = 0.2),
    A.OneOf([A.RandomBrightnessContrast(brightness_limit = 0.1, contrast_limit = 0.1),\
             A.CLAHE()], p = 0.5),
    # A.OneOf([A.ElasticTransform(), A.GridDistortion()], p = 0.1),
    ToTensorV2()
])
    
    
    

data = pd.read_csv("train_.csv")
trainset, valset, _, _ = train_test_split(data, data["label"], test_size = 0.1, stratify = data["label"], random_state = 42)
# _, augmentset, _, _ = train_test_split(trainset, trainset["label"], test_size = 0.5, stratify = trainset["label"], random_state = 42)
trainset = trainset.reset_index()
trainset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
valset = valset.reset_index()
valset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
# augmentset = augmentset.reset_index()
# augmentset.drop(["index", "Unnamed: 0"], axis = 1, inplace = True)
augmentset = trainset

In [None]:
# Sampling For Augmentation
label = trainset["label"]
nums = np.unique(label, return_counts = True)
max_num = np.max(nums[1])
def balancing(df, nums):
    for i in range(len(nums[0])):
        label_name = nums[0][i]
        n = nums[1][i]
        N = max_num - int(n)
        if N == 0:
            continue
        sample = trainset[trainset["label"] == label_name].sample(N)
        if df is None:
            df = sample
        else:
            df = pd.concat([df, sample], axis = 0)
    return df
balance_aug = balancing(None, nums)
balance_aug.reset_index(inplace = True)
balance_aug.drop(["index"], axis = 1, inplace = True)
# augmentset = pd.concat([augmentset, balance_aug], axis = 0)
# augmentset.reset_index(inplace = True)
# augmentset.drop(["index"], axis = 1, inplace = True)
# augmentset = balance_aug

In [None]:
classes = np.unique(data["label"])
class_name = {name: i for i, name in enumerate(classes)}

In [None]:
trainset = ImageSet(img = trainset["upscale_img_path"], transform = transform, class_name = class_name, label = trainset["label"])
validset = ImageSet(img = valset["upscale_img_path"], transform = transform, class_name = class_name, label = valset["label"])
augmentset = AugmentSet(img = augmentset["upscale_img_path"], transform = transform_augment, class_name = class_name, label = augmentset["label"])

In [None]:
len(trainset), len(augmentset), len(validset)

In [None]:
def visualization(flag: bool = False):
    if flag:
        image1, image2 = augmentset[103][0], augmentset[2000][0]
        print(image1.size())
        image1_np = image1.numpy().transpose((1, 2, 0))
        image2_np = image2.numpy().transpose((1, 2, 0))
        image1_np = (image1_np - image1_np.min()) / (image1_np.max() - image1_np.min())
        image2_np = (image2_np - image2_np.min()) / (image2_np.max() - image2_np.min())
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.imshow(image1_np)
        plt.title('Image 1')

        plt.subplot(1, 2, 2)
        plt.imshow(image2_np)
        plt.title('Image 2')

        plt.show()
visualization(True)

In [None]:
# trainset = trainset + augmentset
trainloader = DataLoader(trainset + augmentset, batch_size = CFG["BATCH_SIZE"], shuffle = True, num_workers = 0)
validloader = DataLoader(validset, batch_size = CFG["BATCH_SIZE"], shuffle = False, num_workers = 0)

In [None]:
optimizer = optim.AdamW(params = model.parameters(), lr = CFG["LEARNING_RATE"])
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2,\
    threshold_mode='abs', min_lr=1e-8, verbose=True)

def train(model, optimizer, train_loader, val_loader, scheduler, device):
    model.to(device)
    criterion = nn.CrossEntropyLoss().to(device)
    val_score = []
    best_score = 0
    best_model = None
    
    for epoch in range(1, CFG['EPOCHS']+1):
        model.train()
        train_loss = []
        for imgs, labels in tqdm(iter(train_loader)):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            optimizer.zero_grad()
            
            output = model(imgs)
            loss = criterion(output, labels)
            
            loss.backward()
            optimizer.step()
            
            train_loss.append(loss.item())
                    
        _val_loss, _val_score = validation(model, criterion, val_loader, device)
        _train_loss = np.mean(train_loss)
        print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}] Val Loss : [{_val_loss:.5f}] Val F1 Score : [{_val_score:.5f}]')
        val_score.append(np.round(_val_score, 5))
       
        if scheduler is not None:
            # validation score를 기준으로 scheduler를 조정한다
            scheduler.step(_val_score)
            
        if best_score < _val_score:
            best_score = _val_score
            best_model = model
            torch.save(best_model, "best_xlarge_model_v2.pt")
            print(f'Epoch [{epoch}], Train Loss : [{_train_loss:.5f}], Best Val F1 Score : [{_val_score:.5f}]')
        
        plt.figure(figsize=(20, 10))
        plt.plot(val_score)
        for i, value in enumerate(val_score):
            plt.text(i, value, str(value), fontsize=12, ha='center')
        plt.savefig('xlarge_backbone.png', dpi=300, format='png')

    
    return best_model

In [None]:
def validation(model, criterion, val_loader, device):
    model.eval()
    val_loss = []
    preds, true_labels = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(iter(val_loader)):
            imgs = imgs.float().to(device)
            labels = labels.to(device)
            
            pred = model(imgs)
            
            loss = criterion(pred, labels)
            
            preds += pred.argmax(1).detach().cpu().numpy().tolist()
            true_labels += labels.detach().cpu().numpy().tolist()
            
            val_loss.append(loss.item())
        
        _val_loss = np.mean(val_loss)
        _val_score = f1_score(true_labels, preds, average='macro')
    
    return _val_loss, _val_score

In [None]:
infer_model = train(model, optimizer, trainloader, validloader,\
    scheduler, device = CFG["DEVICE"])
# torch.save(infer_model, "best_model_CONVNEXT_30epochs_.pt")

In [None]:
# # model = torch.load("best_model_CONVNEXT_30epochs_super.pt")
# # model.to(CFG["DEVICE"])
# class TestSet(Dataset):
#     def __init__(self, img, transform = None):
#         self.img = img
#         self.transform = transform
        
#     def __len__(self):
#         return len(self.img)
    
#     def __getitem__(self, idx):
#         image = self.img[idx]
#         image = cv2.imread(image)
#         image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#         if self.transform:
#             image = self.transform(image)
#         return image
# transform_ = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
#     transforms.Resize([224, 224], interpolation = InterpolationMode.BICUBIC)                    
# ])
# test = pd.read_csv("test_.csv")
# test_set = TestSet(img = test["img_path"], transform = transform_)
# test_loader = DataLoader(test_set, batch_size = 1, shuffle = False)

In [None]:
# def inference(model, test_loader, device):
#     model.to(device)
#     model.eval()
#     preds = []
#     with torch.no_grad():
#         for imgs in tqdm(iter(test_loader)):
#             imgs = imgs.float().to(CFG["DEVICE"])
#             pred = model(imgs)
#             preds += pred.argmax(1).detach().cpu().numpy().tolist()
    
#     return preds


In [None]:
# preds = inference(model, test_loader, device = CFG["DEVICE"])
# classes = list(class_name.keys())
# final = []
# for pred in preds:
#     final.append(classes[pred])
# submit = pd.read_csv("./sample_submission.csv")
# submit["label"] = final
# submit.to_csv("./submit_30epochs_xlarge_.csv", index = False)