## **Import Modules and Packages**

In [1]:
import warnings
warnings.filterwarnings("ignore")

import os
import random
import time
import json
import zipfile
import gc
import easydict
import glob
import multiprocessing
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# Transform을 위한 라이브러리
from torchvision import transforms, models
from torchvision.transforms import Normalize

import timm
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model을 위한 라이브러리
import segmentation_models_pytorch as smp
from segmentation_models_pytorch.encoders import get_preprocessing_fn

# Fold를 위한 라이브러리
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold

# loss, optimizer, scheduler 를 위한 라이브러리
from pytorch_toolbelt import losses
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from madgrad import MADGRAD

# Weight & bias
import wandb

# 이미지 시각화를 위한 라이브러리
from PIL import Image
import webcolors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns
sns.set()

%matplotlib inline

plt.rcParams["axes.grid"] = False

print("Pytorch version: {}".format(torch.__version__))
print("GPU: {}".format(torch.cuda.is_available()))

print("Device name: ", torch.cuda.get_device_name(0))
print("Device count: ", torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

Pytorch version: 1.10.0+cu102
GPU: True
Device name:  Tesla V100-PCIE-32GB
Device count:  1


## **Set Configs**

In [2]:
CFG = {}

CFG["seed"] = 21
CFG["data_root"] = '/opt/ml/Workspace/Art_classification/data'
args = easydict.EasyDict(CFG)

In [3]:
label = {}

label["dog"]=0
label["elephant"]=1
label["giraffe"]=2
label["guitar"]=3
label["horse"]=4
label["house"]=5
label["person"]=6
label["0"]=-1 # for test label

In [4]:
train_path = os.path.join(CFG["data_root"],"train")
test_path = os.path.join(CFG["data_root"],"test")

In [5]:
def data_frame(data_path, data_type):
    labels = []
    for path in os.walk(data_path):
        label_type = path[0].split('/')[-1]
        if label_type == "train" or label_type == 'test':
            continue
        path_root = path[0]
        images = path[-1]

        idx = label[label_type]
        for image in images:
            img_path = os.path.join(path_root, image)
            label_idx = {"img_path": img_path, "label":idx}
            labels.append(label_idx)
            
    data_frame = pd.DataFrame(labels)
    data_frame = data_frame.sort_values(["label", "img_path"])
    data_frame = data_frame.reset_index(drop=True)
    return data_frame

## **Utils**

In [6]:
# Fix Seed
def seed_everything(seed):
    random_seed = seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

seed_everything(CFG["seed"])

## **Dataset & Dataloader**

In [7]:
def get_augmentation(data_type):
    if data_type == "train":
        return A.Compose(
            [
                A.OneOf(
                    [
                        A.GridDistortion(p=1.0),
                        A.RandomGridShuffle(p=1.0),
                        A.HorizontalFlip(p=1.0),
                        A.GridDropout(),
                        A.ElasticTransform(p=1.0)
                    ],
                    p=1.0,
                ),
                A.Resize(224,224),
                A.Normalize(mean=[0.5556861, 0.50740065, 0.45690217],std=[0.22876642, 0.21754766, 0.22090458], max_pixel_value=255),
                ToTensorV2()
            ]
        )
    else:
        return A.Compose(
            [
                A.Resize(224,224),
                A.Normalize(mean=[0.5556861, 0.50740065, 0.45690217],std=[0.22876642, 0.21754766, 0.22090458], max_pixel_value=255),
                ToTensorV2()
            ], p=1.0)

In [8]:
class ArtDataset(Dataset):
    def __init__(self, df, mode="train", transform=None):
        super().__init__()
        self.mode = mode
        self.df = df.reset_index()
        self.image_id = self.df.img_path
        self.label = self.df.label
        self.transform = transform
        
    def __len__(self):
        return len(self.df)    
    
    def __getitem__(self, index):
        image_id = self.image_id[index]
        label = self.label[index]
        
        image = cv2.imread(image_id)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        if self.mode in ("train", "valid"):
            
            if self.transform:
                image = self.transform(image=np.array(image))["image"]

            return torch.tensor(image, dtype=torch.float), torch.tensor(label, dtype=torch.long)

        elif self.mode in ("test"):
            if self.transform:
                image = self.transform(image=np.array(image))["image"]
                image/=255.0

            return torch.tensor(image, dtype=torch.float), torch.tensor(0, dtype=torch.long)


In [9]:
# train_df = pd.read_csv("/opt/ml/Workspace/Art_classification/code/train_df.csv")
# test_df = pd.read_csv("/opt/ml/Workspace/Art_classification/code/test_df.csv")

train_df = data_frame(data_path=train_path, data_type="train")
test_df = data_frame(data_path=test_path, data_type="test")

## **Transform**

In [10]:
def collate_fn(batch):
    return tuple(zip(*batch))

## **Define Dataset**

In [11]:
def fold_df(data_frame, folds=5):
    skf = StratifiedKFold(n_splits=folds)
    
    X = data_frame.img_path.values
    y = data_frame.label.values
    
    split_df = []
    for _, (train_index, valid_index) in enumerate(skf.split(X,y)):
        train_df = data_frame.iloc[train_index].copy().reset_index(drop=True)
        valid_df = data_frame.iloc[valid_index].copy().reset_index(drop=True)
        
        split_df.append((train_df, valid_df))
    return split_df

split_df = fold_df(train_df)
split_df[4][0].groupby('label').count()

Unnamed: 0_level_0,img_path
label,Unnamed: 1_level_1
0,264
1,164
2,188
3,107
4,121
5,196
6,319


In [12]:
split_dfd = fold_df(train_df)
for i in range(len(split_dfd)):
    X,y = split_dfd[i]
    print(X)
    break


                                               img_path  label
0     /opt/ml/Workspace/Art_classification/data/trai...      0
1     /opt/ml/Workspace/Art_classification/data/trai...      0
2     /opt/ml/Workspace/Art_classification/data/trai...      0
3     /opt/ml/Workspace/Art_classification/data/trai...      0
4     /opt/ml/Workspace/Art_classification/data/trai...      0
...                                                 ...    ...
1353  /opt/ml/Workspace/Art_classification/data/trai...      6
1354  /opt/ml/Workspace/Art_classification/data/trai...      6
1355  /opt/ml/Workspace/Art_classification/data/trai...      6
1356  /opt/ml/Workspace/Art_classification/data/trai...      6
1357  /opt/ml/Workspace/Art_classification/data/trai...      6

[1358 rows x 2 columns]


## **Model**

In [13]:
model = timm.create_model(model_name='swin_base_patch4_window7_224', pretrained=True, num_classes = 7)

In [14]:
x = torch.randn([1, 3, 224, 224])
out = model(x).to(device)
print(f"input : {x.shape} | output : {out.size()}")

input : torch.Size([1, 3, 224, 224]) | output : torch.Size([1, 7])


## **Loss & Optimizer**

In [15]:
criterion = nn.CrossEntropyLoss()
optimizer = MADGRAD(params=model.parameters(), lr=1e-4, weight_decay=1e-6)
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=25,T_mult=1)

## **Get Dataloader**

In [16]:
def get_dataloader(train_df, valid_df, test_df):
    train_dataset = ArtDataset(train_df, mode="train", transform=get_augmentation(data_type="train"))
    valid_dataset = ArtDataset(valid_df, mode="valid", transform=get_augmentation(data_type="valid"))
    test_dataset = ArtDataset(test_df, mode="test", transform=get_augmentation(data_type="test"))
    
    train_loader = DataLoader(train_dataset, batch_size = 32, shuffle=True, num_workers=0, collate_fn=collate_fn)
    valid_loader = DataLoader(valid_dataset, batch_size = 16, shuffle=False, num_workers=0, collate_fn=collate_fn)
    test_loader = DataLoader(test_dataset, batch_size = 16, shuffle=False, num_workers=0, collate_fn=collate_fn)
    
    return train_loader, valid_loader, test_loader

## **Train/Valid One Epoch**

In [19]:
def train_one_epoch(epoch, model, data_loader, criterion, optimizer, scheduler, device):
    model.train()
    
    cnt = 0
    correct = 0
    scaler = GradScaler()

    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar:
        
        image = torch.stack(image).float()
        label = torch.stack(label).long()
        
        image = image.to(device)
        label = label.to(device)
        
        with autocast(enabled=True):
            model = model.to(device)
                
            output = model(image)
            loss = criterion(output, label)
        
        scaler.scale(loss).backward()  
        scaler.step(optimizer)
        scaler.update()
        
        optimizer.zero_grad()
        
        _, preds = torch.max(output, 1)
        correct += torch.sum(preds==label.data)
        cnt += 1
        
        description = f"# epoch : {epoch + 1} Loss : {(loss.item()):.4f}"
        pbar.set_description(description)
        
    acc = correct / cnt
    scheduler.step()


    return acc

In [20]:
def valid_one_epoch(model, data_loader, split_df, device):
    print(f"Start Validation")
    
    model.eval()
    correct = 0
    
    pbar_valid = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar_valid:

        image = torch.stack(image).float()
        label = torch.stack(label).long()        
        
        image = image.to(device)
        label = label.to(device)
        model = model.to(device)
        
        output = model(image)
        
        _, preds = torch.max(output, 1)
        correct += torch.sum(preds==label.data)

        description_valid = f" correct : {(correct.item()/len(split_df)):.4f}"
        pbar_valid.set_description(description_valid)
    acc = correct / len(split_df)
    print(f"Validation acc: {acc: .4f}")

    return acc, output

## **Pseudo_labeling**

In [21]:
def pseudo_labeling(model, train_dataset, test_loader):
    print(f"Start Pseudo")
    pseudo_dataset = copy.deepcopy(test_df)
    
    model.eval()
    pseudo = []
    for step, (image, label) in tqdm(enumerate(test_loader), total=len(test_loader)):
        image = torch.stack(image).float()
        label = torch.stack(label).long()    
        
        image = image.to(device)
        label = label.to(device)
        
        output = model(image).to(device)
        
        _, preds = torch.max(output, 1)
        pseudo.extend(list(preds.cpu().numpy()))
        
    pseudo_dataset.label = pseudo
    pseudo_dataset = pd.concat([train_dataset, pseudo_dataset]).reset_index(drop=True)
    
    return pseudo_dataset

In [22]:
train_data = split_df[4][0]
valid_data = split_df[4][1]

## **Run !!**

In [23]:
def run(epoch, model, train_df, test_df, optimizer, criterion, scheduler, device, Folds=5):
    torch.cuda.empty_cache()
    gc.collect()

    for fold in range(Folds):
        
        print(f"{fold+1} fold start")
        split_df = fold_df(train_df, folds=5)
        train_loader, valid_loader, test_loader = get_dataloader(split_df[fold][0] ,split_df[fold][1], test_df)
        
        early_stopping_cnt = 0
        patience = 5
        best_acc = 0
        
        num_epochs = epoch
        for epoch in range(num_epochs):
            train_acc = train_one_epoch(epoch, model, train_loader, criterion, optimizer, scheduler, device)
            with torch.no_grad():
                valid_acc, outputs = valid_one_epoch(model, valid_loader, split_df[fold][1], device)
            
            if valid_acc > best_acc:
                best_acc = valid_acc
                early_stopping_cnt = 0
                print(f"Best acc is {best_acc}")
                print(f"Early Stopping Counter is {early_stopping_cnt}")
            else:
                early_stopping_cnt += 1
                if early_stopping_cnt >= patience:
                    print(f"Early Stopping Counter: {early_stopping_cnt} out of {patience}")
                    break
                
            scheduler.step(best_acc)
            # train_df = pseudo_labeling(model, train_df, test_loader)
            # train_loader, _, _ = get_dataloader(train_df, test_df)

In [24]:
train_df.img_path[0]

'/opt/ml/Workspace/Art_classification/data/train/dog/pic_001.jpg'

In [25]:
run(epoch=20, model=model, train_df=train_df, test_df=test_df, optimizer=optimizer, criterion=criterion, scheduler=scheduler, device=device, Folds=5)

1 fold start


# epoch : 1 Loss : 0.1812: 100%|██████████| 43/43 [00:35<00:00,  1.20it/s]


Start Validation


 correct : 0.9529: 100%|██████████| 22/22 [00:03<00:00,  7.27it/s]


Validation acc:  0.9529
Best acc is 0.9529411792755127
Early Stopping Counter is 0


# epoch : 2 Loss : 0.2607: 100%|██████████| 43/43 [00:39<00:00,  1.08it/s]


Start Validation


 correct : 0.9412: 100%|██████████| 22/22 [00:03<00:00,  7.22it/s]


Validation acc:  0.9412


# epoch : 3 Loss : 0.4038: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


 correct : 0.9588: 100%|██████████| 22/22 [00:02<00:00,  7.76it/s]


Validation acc:  0.9588
Best acc is 0.958823561668396
Early Stopping Counter is 0


# epoch : 4 Loss : 0.1355: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 0.9500: 100%|██████████| 22/22 [00:02<00:00,  7.55it/s]


Validation acc:  0.9500


# epoch : 5 Loss : 0.3073: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9441: 100%|██████████| 22/22 [00:02<00:00,  7.70it/s]


Validation acc:  0.9441


# epoch : 6 Loss : 0.0319: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s]


Start Validation


 correct : 0.9794: 100%|██████████| 22/22 [00:02<00:00,  7.35it/s]


Validation acc:  0.9794
Best acc is 0.979411780834198
Early Stopping Counter is 0


# epoch : 7 Loss : 0.1744: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9235: 100%|██████████| 22/22 [00:02<00:00,  7.46it/s]


Validation acc:  0.9235


# epoch : 8 Loss : 0.0124: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9676: 100%|██████████| 22/22 [00:02<00:00,  7.61it/s]


Validation acc:  0.9676


# epoch : 9 Loss : 0.0539: 100%|██████████| 43/43 [00:35<00:00,  1.21it/s]


Start Validation


 correct : 0.9647: 100%|██████████| 22/22 [00:02<00:00,  8.10it/s]


Validation acc:  0.9647


# epoch : 10 Loss : 0.0239: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 0.9529: 100%|██████████| 22/22 [00:02<00:00,  8.12it/s]


Validation acc:  0.9529


# epoch : 11 Loss : 0.1377: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s]


Start Validation


 correct : 0.9618: 100%|██████████| 22/22 [00:02<00:00,  8.24it/s]


Validation acc:  0.9618
Early Stopping Counter: 5 out of 5
2 fold start


# epoch : 1 Loss : 0.0322: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9824: 100%|██████████| 22/22 [00:02<00:00,  7.73it/s]


Validation acc:  0.9824
Best acc is 0.9823529720306396
Early Stopping Counter is 0


# epoch : 2 Loss : 0.0058: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9735: 100%|██████████| 22/22 [00:02<00:00,  7.86it/s]


Validation acc:  0.9735


# epoch : 3 Loss : 0.0215: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9882: 100%|██████████| 22/22 [00:02<00:00,  7.83it/s]


Validation acc:  0.9882
Best acc is 0.9882352948188782
Early Stopping Counter is 0


# epoch : 4 Loss : 0.0398: 100%|██████████| 43/43 [00:37<00:00,  1.16it/s]


Start Validation


 correct : 0.9706: 100%|██████████| 22/22 [00:02<00:00,  7.66it/s]


Validation acc:  0.9706


# epoch : 5 Loss : 0.0162: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:03<00:00,  7.14it/s]


Validation acc:  0.9941
Best acc is 0.9941176772117615
Early Stopping Counter is 0


# epoch : 6 Loss : 0.0105: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:03<00:00,  7.12it/s]


Validation acc:  0.9941


# epoch : 7 Loss : 0.0589: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9912: 100%|██████████| 22/22 [00:02<00:00,  7.56it/s]


Validation acc:  0.9912


# epoch : 8 Loss : 0.3664: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.44it/s]


Validation acc:  0.9941


# epoch : 9 Loss : 0.0028: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9824: 100%|██████████| 22/22 [00:02<00:00,  7.82it/s]


Validation acc:  0.9824


# epoch : 10 Loss : 0.0691: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:03<00:00,  6.98it/s]


Validation acc:  0.9941
Early Stopping Counter: 5 out of 5
3 fold start


# epoch : 1 Loss : 0.0034: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.39it/s]


Validation acc:  1.0000
Best acc is 1.0
Early Stopping Counter is 0


# epoch : 2 Loss : 0.0506: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:03<00:00,  7.06it/s]


Validation acc:  0.9971


# epoch : 3 Loss : 0.3703: 100%|██████████| 43/43 [00:38<00:00,  1.10it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.49it/s]


Validation acc:  0.9941


# epoch : 4 Loss : 0.0033: 100%|██████████| 43/43 [00:38<00:00,  1.10it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  8.14it/s]


Validation acc:  0.9941


# epoch : 5 Loss : 0.0097: 100%|██████████| 43/43 [00:36<00:00,  1.16it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:03<00:00,  7.06it/s]


Validation acc:  0.9971


# epoch : 6 Loss : 0.0012: 100%|██████████| 43/43 [00:39<00:00,  1.09it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:03<00:00,  7.23it/s]


Validation acc:  0.9941
Early Stopping Counter: 5 out of 5
4 fold start


# epoch : 1 Loss : 0.0134: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.68it/s]


Validation acc:  0.9941
Best acc is 0.9941002726554871
Early Stopping Counter is 0


# epoch : 2 Loss : 0.0164: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.75it/s]


Validation acc:  0.9971
Best acc is 0.9970501661300659
Early Stopping Counter is 0


# epoch : 3 Loss : 0.0251: 100%|██████████| 43/43 [00:39<00:00,  1.09it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.86it/s]


Validation acc:  0.9971


# epoch : 4 Loss : 0.1361: 100%|██████████| 43/43 [00:36<00:00,  1.17it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.60it/s]


Validation acc:  1.0000
Best acc is 1.0
Early Stopping Counter is 0


# epoch : 5 Loss : 0.0105: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.99it/s]


Validation acc:  0.9941
5 fold start


# epoch : 1 Loss : 0.3527: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


 correct : 0.9971: 100%|██████████| 22/22 [00:03<00:00,  7.28it/s]


Validation acc:  0.9971
Best acc is 0.9970501661300659
Early Stopping Counter is 0


# epoch : 2 Loss : 0.0321: 100%|██████████| 43/43 [00:35<00:00,  1.22it/s]


Start Validation


 correct : 0.9912: 100%|██████████| 22/22 [00:02<00:00,  8.10it/s]


Validation acc:  0.9912


# epoch : 3 Loss : 0.0051: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s]


Start Validation


 correct : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  8.11it/s]


Validation acc:  1.0000
Best acc is 1.0
Early Stopping Counter is 0


# epoch : 4 Loss : 0.0144: 100%|██████████| 43/43 [00:39<00:00,  1.10it/s]


Start Validation


 correct : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.57it/s]

Validation acc:  0.9941



