In [None]:
import random
import pandas as pd
import numpy as np
import cv2
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import timm

def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

class ConvNext(nn.Module):
    def __init__(self, extraction, head):
        super(ConvNext, self).__init__()
        self.extraction = extraction
        self.head = head
    def forward(self, x):
        x1 = self.extraction(x)
        x2 = self.head(x1)
        return x1, x2

    
seed_everything()
model = torch.load("xlarge_distill_best_v2.pt", map_location = "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
class TestSet(Dataset):
    def __init__(self, img, transform = None):
        self.img = img
        self.transform = transform
        
    def __len__(self):
        return len(self.img)
    
    def __getitem__(self, idx):
        image = self.img[idx]
        image = cv2.imread(image)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(image)
        return image

class ImageSet(Dataset):
    def __init__(self, img, transform = None, class_name = None, label = None):
        self.img = img
        self.label = label
        self.transform = transform
        self.class_name = class_name
        
    def __len__(self):
        return len(self.label)
    
    def __getitem__(self, idx):
        images = self.img[idx]
        label = self.label[idx]
        img = cv2.imread(images)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.transform:
            image = self.transform(img)
        label = class_name[label]
        return image, label
    

transform_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    transforms.Resize([224, 224])                    
])

data = pd.read_csv("train.csv")
trainset, valset, _, _ = train_test_split(data, data["label"], test_size = 0.2, stratify = data["label"], random_state = 0)
valset = valset.reset_index()
valset.drop(["index"], axis = 1, inplace = True)
classes = np.unique(data["label"])
class_name = {name: i for i, name in enumerate(classes)}
validset = ImageSet(img = valset["img_path"], transform = transform_, class_name = class_name, label = valset["label"])
valloader = DataLoader(validset, batch_size = 1, shuffle = False)
test = pd.read_csv("test_.csv")
label = data["label"]
test_set = TestSet(img = test["img_path"], transform = transform_)
test_loader = DataLoader(test_set, batch_size = 1, shuffle = False)

In [None]:
def inference(model, test_loader, device):
    model.to(device)
    model.eval()
    preds = []
    with torch.no_grad():
        for imgs in tqdm(iter(test_loader)):
            imgs = imgs.float().to(device)
            _, pred = model(imgs)
            preds += pred.argmax(1).detach().cpu().numpy().tolist()
    
    return preds

In [None]:
classes = np.unique(label)
class_name = {name: i for i, name in enumerate(classes)}
preds = inference(model, test_loader, device = device)
classes = list(class_name.keys())
final = []
for pred in preds:
    final.append(classes[pred])
submit = pd.read_csv("./sample_submission.csv")
submit["label"] = final
submit.to_csv("./distill_best.csv", index = False)

In [None]:
# wrong_list = []
# for i, name in enumerate(final):
#     if valset["label"][i] != final[i]:
#         wrong_list.append([valset["img_path"][i], valset["label"][i]])
# wrong = pd.DataFrame(wrong_list)
# wrong.columns = ["img_path", "label"]
# wrong.to_csv("wrong.csv")

In [None]:
# wrong