## **Import Modules and Packages**

In [28]:
import warnings

warnings.filterwarnings("ignore")

import os
import random
import gc
import easydict
import glob
import multiprocessing
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# Transform을 위한 라이브러리
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model을 위한 라이브러리
import timm

# Fold를 위한 라이브러리
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold

# loss, optimizer, scheduler 를 위한 라이브러리
from pytorch_toolbelt import losses
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from madgrad import MADGRAD

# Weight & bias
import wandb

# 이미지 시각화를 위한 라이브러리
from PIL import Image
import webcolors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

sns.set()

%matplotlib inline

plt.rcParams["axes.grid"] = False

print("Pytorch version: {}".format(torch.__version__))
print("GPU: {}".format(torch.cuda.is_available()))

print("Device name: ", torch.cuda.get_device_name(0))
print("Device count: ", torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

Pytorch version: 1.10.0+cu102
GPU: True
Device name:  Tesla V100-PCIE-32GB
Device count:  1


## **Set Configs**

In [29]:
CFG = {}

CFG["seed"] = 21
CFG["data_root"] = "/opt/ml/Workspace/Dev_Matching_ArtPaint_Classification/data"
args = easydict.EasyDict(CFG)

In [31]:
train_path = os.path.join(CFG["data_root"], "train")
test_path = os.path.join(CFG["data_root"], "test")

In [30]:
label = {}

label["dog"] = 0
label["elephant"] = 1
label["giraffe"] = 2
label["guitar"] = 3
label["horse"] = 4
label["house"] = 5
label["person"] = 6
label["0"] = -1  # for test label

In [32]:
def data_frame(data_path, data_type):
    labels = []
    for path in os.walk(data_path):
        label_type = path[0].split("/")[-1]
        if label_type == "train" or label_type == "test":
            continue
        path_root = path[0]
        images = path[-1]

        idx = label[label_type]
        for image in images:
            img_path = os.path.join(path_root, image)
            label_idx = {"img_path": img_path, "label": idx}
            labels.append(label_idx)

    data_frame = pd.DataFrame(labels)
    data_frame = data_frame.sort_values(["label", "img_path"])
    data_frame = data_frame.reset_index(drop=True)
    return data_frame

## **Utils**

In [33]:
def seed_everything(seed):
    random_seed = seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


seed_everything(CFG["seed"])

In [37]:
def collate_fn(batch):
    return tuple(zip(*batch))

## **Dataset & Dataloader**

In [35]:
class ArtDataset(Dataset):
    def __init__(self, df, mode="train", transform=None):
        super().__init__()
        self.mode = mode
        self.df = df.reset_index()
        self.image_id = self.df.img_path
        self.label = self.df.label
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        image_id = self.image_id[index]
        label = self.label[index]

        image = cv2.imread(image_id)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.mode in ("train", "valid"):
            if self.transform:
                image = self.transform(image=np.array(image))["image"]

            return torch.tensor(image, dtype=torch.float), torch.tensor(
                label, dtype=torch.long
            )

        elif self.mode in ("test"):
            if self.transform:
                image = self.transform(image=np.array(image))["image"]
                image /= 255.0

            return torch.tensor(image, dtype=torch.float), torch.tensor(
                0, dtype=torch.long
            )

## **Transform**

In [34]:
def get_augmentation(data_type):
    if data_type == "train":
        return A.Compose(
            [
                A.OneOf(
                    [
                        A.GridDistortion(p=1.0),
                        A.RandomGridShuffle(p=1.0),
                        A.ElasticTransform(p=1.0),
                        A.GridDropout(),
                    ],
                    p=1.0,
                ),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.Normalize(
                    mean=[0.5556861, 0.50740065, 0.45690217],
                    std=[0.22876642, 0.21754766, 0.22090458],
                    max_pixel_value=255,
                ),
                ToTensorV2(),
            ],
            p=1.0,
        )
    else:  # "Valid and Test"
        return A.Compose(
            [
                A.Resize(224, 224),
                A.Normalize(
                    mean=[0.5556861, 0.50740065, 0.45690217],
                    std=[0.22876642, 0.21754766, 0.22090458],
                    max_pixel_value=255,
                ),
                ToTensorV2(),
            ],
            p=1.0,
        )

## **Define Dataset**

In [38]:
def fold_df(data_frame, folds=5):
    skf = StratifiedKFold(n_splits=folds)

    X = data_frame.img_path.values
    y = data_frame.label.values

    split_df = []
    for _, (train_index, valid_index) in enumerate(skf.split(X, y)):
        train_df = data_frame.iloc[train_index].copy().reset_index(drop=True)
        valid_df = data_frame.iloc[valid_index].copy().reset_index(drop=True)

        split_df.append((train_df, valid_df))
    return split_df


split_df = fold_df(train_df)
split_df[4][0].groupby("label").count()

Unnamed: 0_level_0,img_path
label,Unnamed: 1_level_1
0,264
1,164
2,188
3,107
4,121
5,196
6,319


In [39]:
split_dfd = fold_df(train_df)
for i in range(len(split_dfd)):
    X, y = split_dfd[i]
    print(X)
    break

                                               img_path  label
0     /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      0
1     /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      0
2     /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      0
3     /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      0
4     /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      0
...                                                 ...    ...
1353  /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      6
1354  /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      6
1355  /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      6
1356  /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      6
1357  /opt/ml/Workspace/Dev_Matching_ArtPaint_Classi...      6

[1358 rows x 2 columns]


## **Model**

In [40]:
model = timm.create_model(
    model_name="swin_base_patch4_window7_224", pretrained=True, num_classes=7
)

In [41]:
x = torch.randn([1, 3, 224, 224])
out = model(x).to(device)
print(f"input : {x.shape} | output : {out.size()}")

input : torch.Size([1, 3, 224, 224]) | output : torch.Size([1, 7])


## **Loss & Optimizer**

In [42]:
# criterion = nn.CrossEntropyLoss()
criterion = losses.SoftCrossEntropyLoss()
optimizer = MADGRAD(params=model.parameters(), lr=1e-4, weight_decay=1e-6)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=25, T_mult=1)

## **Get Dataloader**

In [43]:
def get_dataloader(train_df, valid_df, test_df):
    train_dataset = ArtDataset(
        train_df, mode="train", transform=get_augmentation(data_type="train")
    )
    valid_dataset = ArtDataset(
        valid_df, mode="valid", transform=get_augmentation(data_type="valid")
    )
    test_dataset = ArtDataset(
        test_df, mode="test", transform=get_augmentation(data_type="test")
    )

    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=0,
        collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn
    )

    return train_loader, valid_loader, test_loader

## **Train/Valid One Epoch**

In [44]:
def train_one_epoch(epoch, model, data_loader, criterion, optimizer, scheduler, device):
    model.train()

    cnt = 0
    correct = 0
    scaler = GradScaler()

    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)

        with autocast(enabled=True):
            model = model.to(device)

            output = model(image)
            loss = criterion(output, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        optimizer.zero_grad()

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)
        cnt += 1

        description = f"# epoch : {epoch + 1} Loss : {(loss.item()):.4f}"
        pbar.set_description(description)

    acc = correct / cnt
    scheduler.step()

    return acc

In [45]:
def valid_one_epoch(model, data_loader, split_df, device):
    print(f"Start Validation")

    model.eval()
    correct = 0

    pbar_valid = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar_valid:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)
        model = model.to(device)

        output = model(image)

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)

        description_valid = f" correct : {(correct.item()/len(split_df)):.4f}"
        pbar_valid.set_description(description_valid)
    acc = correct / len(split_df)
    print(f"Validation acc: {acc: .4f}")

    return acc, output

## **Pseudo_labeling**

In [46]:
def pseudo_labeling(model, train_dataset, test_loader):
    print(f"Start Pseudo")
    pseudo_dataset = copy.deepcopy(test_df)

    model.eval()
    pseudo = []
    for step, (image, label) in tqdm(enumerate(test_loader), total=len(test_loader)):
        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)

        output = model(image).to(device)

        _, preds = torch.max(output, 1)
        pseudo.extend(list(preds.cpu().numpy()))

    pseudo_dataset.label = pseudo
    pseudo_dataset = pd.concat([train_dataset, pseudo_dataset]).reset_index(drop=True)

    return pseudo_dataset

In [47]:
train_data = split_df[4][0]
valid_data = split_df[4][1]

In [58]:
iter_n = 3

## **Run !!**

In [59]:
def run(
    epoch, model, train_df, test_df, optimizer, criterion, scheduler, device, Folds=5
):
    torch.cuda.empty_cache()
    gc.collect()
    for _ in range(iter_n):
        for fold in range(Folds):

            print(f"{fold+1} fold start")
            split_df = fold_df(train_df, folds=5)
            train_loader, valid_loader, test_loader = get_dataloader(
                split_df[fold][0], split_df[fold][1], test_df
            )

            early_stopping_cnt = 0
            patience = 5
            best_acc = 0

            num_epochs = epoch
            for epoch in range(num_epochs):
                train_acc = train_one_epoch(
                    epoch, model, train_loader, criterion, optimizer, scheduler, device
                )
                with torch.no_grad():
                    valid_acc, outputs = valid_one_epoch(
                        model, valid_loader, split_df[fold][1], device
                    )

                if valid_acc > best_acc:
                    best_acc = valid_acc
                    early_stopping_cnt = 0
                    print(f"Best acc is {best_acc}")
                else:
                    early_stopping_cnt += 1
                    print(f"Early Stopping Counter is {early_stopping_cnt}")
                    if early_stopping_cnt >= patience:
                        print(
                            f"Early Stopping Counter: {early_stopping_cnt} out of {patience}"
                        )
                        break

        scheduler.step(best_acc)
        train_df = pseudo_labeling(model, train_df, test_loader)
        train_loader, _, _ = get_dataloader(train_df, train_df, test_df)

In [54]:
train_df.img_path[0]

'/opt/ml/Workspace/Dev_Matching_ArtPaint_Classification/data/train/dog/pic_001.jpg'

In [60]:
run(
    epoch=20,
    model=model,
    train_df=train_df,
    test_df=test_df,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    Folds=5,
)

1 fold start


# epoch : 1 Loss : 0.0034: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.93it/s]


Validation acc:  0.9941
Best acc is 0.9941176772117615


# epoch : 2 Loss : 0.0036: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  8.12it/s]


Validation acc:  0.9971
Best acc is 0.9970588684082031


# epoch : 3 Loss : 0.0048: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.64it/s]


Validation acc:  0.9971
Early Stopping Counter is 1


# epoch : 4 Loss : 0.0464: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.69it/s]


Validation acc:  0.9971
Early Stopping Counter is 2


# epoch : 5 Loss : 0.2137: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 0.9912: 100%|██████████| 22/22 [00:02<00:00,  7.50it/s]


Validation acc:  0.9912
Early Stopping Counter is 3


# epoch : 6 Loss : 0.0402: 100%|██████████| 43/43 [00:36<00:00,  1.18it/s]


Start Validation


 correct : 0.9735: 100%|██████████| 22/22 [00:02<00:00,  8.06it/s]


Validation acc:  0.9735
Early Stopping Counter is 4


# epoch : 7 Loss : 0.0236: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  8.20it/s]


Validation acc:  0.9941
Early Stopping Counter is 5
Early Stopping Counter: 5 out of 5
2 fold start


# epoch : 1 Loss : 0.2373: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  8.12it/s]


Validation acc:  0.9941
Best acc is 0.9941176772117615


# epoch : 2 Loss : 0.0042: 100%|██████████| 43/43 [00:34<00:00,  1.23it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.66it/s]


Validation acc:  0.9941
Early Stopping Counter is 1


# epoch : 3 Loss : 0.0654: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.77it/s]


Validation acc:  1.0000
Best acc is 1.0


# epoch : 4 Loss : 0.0012: 100%|██████████| 43/43 [00:36<00:00,  1.18it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.14it/s]


Validation acc:  1.0000
Early Stopping Counter is 1


# epoch : 5 Loss : 0.0013: 100%|██████████| 43/43 [00:35<00:00,  1.20it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.77it/s]


Validation acc:  1.0000
Early Stopping Counter is 2


# epoch : 6 Loss : 0.0005: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.09it/s]


Validation acc:  1.0000
Early Stopping Counter is 3
3 fold start


# epoch : 1 Loss : 0.0020: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.92it/s]


Validation acc:  1.0000
Best acc is 1.0


# epoch : 2 Loss : 0.0002: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.81it/s]


Validation acc:  0.9971
Early Stopping Counter is 1


# epoch : 3 Loss : 0.0002: 100%|██████████| 43/43 [00:34<00:00,  1.26it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.99it/s]


Validation acc:  1.0000
Early Stopping Counter is 2


# epoch : 4 Loss : 0.0012: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.05it/s]


Validation acc:  1.0000
Early Stopping Counter is 3


# epoch : 5 Loss : 0.0002: 100%|██████████| 43/43 [00:36<00:00,  1.16it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.30it/s]


Validation acc:  1.0000
Early Stopping Counter is 4
4 fold start


# epoch : 1 Loss : 0.0002: 100%|██████████| 43/43 [00:36<00:00,  1.18it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.46it/s]


Validation acc:  1.0000
Best acc is 1.0


# epoch : 2 Loss : 0.0002: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.99it/s]


Validation acc:  1.0000
Early Stopping Counter is 1


# epoch : 3 Loss : 0.0001: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.02it/s]


Validation acc:  1.0000
Early Stopping Counter is 2


# epoch : 4 Loss : 0.0001: 100%|██████████| 43/43 [00:36<00:00,  1.18it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.23it/s]


Validation acc:  1.0000
Early Stopping Counter is 3
5 fold start


# epoch : 1 Loss : 0.0002: 100%|██████████| 43/43 [00:36<00:00,  1.17it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.93it/s]


Validation acc:  1.0000
Best acc is 1.0


# epoch : 2 Loss : 0.0003: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.21it/s]


Validation acc:  1.0000
Early Stopping Counter is 1


# epoch : 3 Loss : 0.0018: 100%|██████████| 43/43 [00:36<00:00,  1.19it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.21it/s]


Validation acc:  1.0000
Early Stopping Counter is 2
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.25it/s]


1 fold start


# epoch : 1 Loss : 0.3672: 100%|██████████| 52/52 [00:43<00:00,  1.20it/s]


Start Validation


 correct : 0.9829: 100%|██████████| 26/26 [00:03<00:00,  7.98it/s]


Validation acc:  0.9829
Best acc is 0.9829268455505371


# epoch : 2 Loss : 0.2379: 100%|██████████| 52/52 [00:44<00:00,  1.16it/s]


Start Validation


 correct : 0.9805: 100%|██████████| 26/26 [00:03<00:00,  8.29it/s]


Validation acc:  0.9805
Early Stopping Counter is 1
2 fold start


# epoch : 1 Loss : 0.0884: 100%|██████████| 52/52 [00:45<00:00,  1.15it/s]


Start Validation


 correct : 0.9073: 100%|██████████| 26/26 [00:03<00:00,  7.93it/s]


Validation acc:  0.9073
Best acc is 0.9073171019554138
3 fold start
4 fold start
5 fold start
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


1 fold start
2 fold start
3 fold start
4 fold start
5 fold start
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.57it/s]
