## **Import Modules and Packages**

In [1]:
import warnings

warnings.filterwarnings("ignore")

import os
import random
import gc
import easydict
import glob
import multiprocessing
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# Transform을 위한 라이브러리
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model을 위한 라이브러리
import timm

# Fold를 위한 라이브러리
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold

# loss, optimizer, scheduler 를 위한 라이브러리
from pytorch_toolbelt import losses
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from madgrad import MADGRAD

# Weight & bias
import wandb

# 이미지 시각화를 위한 라이브러리
from PIL import Image
import webcolors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

sns.set()

%matplotlib inline

plt.rcParams["axes.grid"] = False

print("Pytorch version: {}".format(torch.__version__))
print("GPU: {}".format(torch.cuda.is_available()))

print("Device name: ", torch.cuda.get_device_name(0))
print("Device count: ", torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

Pytorch version: 1.10.0+cu102
GPU: True
Device name:  Tesla V100-PCIE-32GB
Device count:  1


## **Set Configs**

In [2]:
CFG = {}

CFG['experiment_number'] = "Dec11"

CFG["seed"] = 21
CFG["data_root"] = "/opt/ml/Workspace/Dev_Matching_ArtPaint_Classification"
CFG["saved_dir"] = os.path.join(CFG["data_root"],"saved")
CFG["lr"] = 1e-4
CFG["weight_decay"] = 1e-6

CFG["mean"] = [0.5556861, 0.50740065, 0.45690217]
CFG["std"] = [0.22876642, 0.21754766, 0.22090458]

# CFG["timm_model"] = "swin_base_patch4_window7_224"
CFG["timm_model"] = "vit_base_patch16_224"

CFG["n_epoch"] = 20
CFG["n_Folds"] = 5
CFG["n_iter"] = 3
CFG["patience"] = 3
args = easydict.EasyDict(CFG)

In [3]:
train_path = os.path.join(CFG["data_root"], "data/train")
test_path = os.path.join(CFG["data_root"], "data/test")

In [4]:
label = {}

label["dog"] = 0
label["elephant"] = 1
label["giraffe"] = 2
label["guitar"] = 3
label["horse"] = 4
label["house"] = 5
label["person"] = 6
label["0"] = -1  # for test label

## **Utils**

In [5]:
def seed_everything(seed):
    random_seed = seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


seed_everything(CFG["seed"])

In [6]:
def collate_fn(batch):
    return tuple(zip(*batch))

## **Dataset**

In [7]:
def data_frame(data_path):
    labels = []
    for path in os.walk(data_path):
        label_type = path[0].split("/")[-1]
        if label_type == "train" or label_type == "test":
            continue
        path_root = path[0]
        images = path[-1]

        idx = label[label_type]
        for image in images:
            img_path = os.path.join(path_root, image)
            label_idx = {"img_path": img_path, "label": idx}
            labels.append(label_idx)

    data_frame = pd.DataFrame(labels)
    data_frame = data_frame.sort_values(["label", "img_path"])
    data_frame = data_frame.reset_index(drop=True)
    return data_frame

In [8]:
train_df = data_frame(data_path=train_path)
test_df = data_frame(data_path=test_path)

In [9]:
class ArtDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df.reset_index()
        self.image_id = self.df.img_path
        self.label = self.df.label
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        label = self.label[idx]

        image = cv2.imread(image_id)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image=np.array(image))["image"]

        image = torch.tensor(image, dtype=torch.float)
        label = torch.tensor(label, dtype=torch.long)

        return image, label

## **Transform**

In [10]:
def get_augmentation(data_type):
    if data_type == "train":
        return A.Compose(
            [
                A.OneOf(
                    [
                        A.GridDistortion(p=1.0),
                        A.RandomGridShuffle(p=1.0),
                        A.ElasticTransform(p=1.0),
                        A.GridDropout(),
                    ],
                    p=1.0,
                ),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.Normalize(CFG["mean"], CFG["std"]),
                ToTensorV2(),
            ],
            p=1.0,
        )
    else:  # "Valid and Test"
        return A.Compose(
            [
                A.Resize(224, 224),
                A.Normalize(CFG["mean"], CFG["std"]),
                ToTensorV2(),
            ],
            p=1.0,
        )

## **Define Dataset**

In [11]:
def fold_df(data_frame, folds=CFG["n_Folds"]):
    skf = StratifiedKFold(n_splits=folds)

    X = data_frame.img_path.values
    y = data_frame.label.values

    split_df = []
    for _, (train_index, valid_index) in enumerate(skf.split(X, y)):
        train_df = data_frame.iloc[train_index].copy().reset_index(drop=True)
        valid_df = data_frame.iloc[valid_index].copy().reset_index(drop=True)

        split_df.append((train_df, valid_df))
    return split_df

## **Model**

In [12]:
model = timm.create_model(
    model_name=CFG["timm_model"], pretrained=True, num_classes=7
)

In [13]:
x = torch.randn([1, 3, 224, 224])
out = model(x).to(device)
print(f"input : {x.shape} | output : {out.size()}")

input : torch.Size([1, 3, 224, 224]) | output : torch.Size([1, 7])


## **Loss & Optimizer**

In [14]:
# criterion = nn.CrossEntropyLoss()
criterion = losses.SoftCrossEntropyLoss()
optimizer = MADGRAD(params=model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG["n_epoch"], T_mult=1)

## **Get Dataloader**

In [15]:
def get_dataloader(train_df, valid_df, test_df):
    train_dataset = ArtDataset(
        train_df, transform=get_augmentation(data_type="train")
    )
    valid_dataset = ArtDataset(
        valid_df, transform=get_augmentation(data_type="valid")
    )
    test_dataset = ArtDataset(
        test_df, transform=get_augmentation(data_type="test")
    )

    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn
    )

    return train_loader, valid_loader, test_loader

## **Train/Valid One Epoch**

In [16]:
def train_one_epoch(epoch, model, data_loader, criterion, optimizer, scheduler, device):
    model.train()

    cnt = 0
    correct = 0
    scaler = GradScaler()

    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)

        with autocast(enabled=True):
            model = model.to(device)

            output = model(image)
            loss = criterion(output, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        optimizer.zero_grad()

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)
        cnt += 1

        description = f"| # Epoch : {epoch + 1} Loss : {(loss.item()):.4f}"
        pbar.set_description(description)

    acc = correct / cnt
    scheduler.step()

In [17]:
def valid_one_epoch(model, data_loader, split_df, device):
    print(f"Start Validation")

    model.eval()
    correct = 0

    pbar_valid = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar_valid:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)
        model = model.to(device)

        output = model(image)

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)

        description_valid = f"| Acc : {(correct.item()/len(split_df)):.4f}"
        pbar_valid.set_description(description_valid)
    acc = correct / len(split_df)

    return acc, output

## **Pseudo_labeling**

In [18]:
def pseudo_labeling(model, train_dataset, test_loader):
    print(f"Start Pseudo")
    pseudo_dataset = copy.deepcopy(test_df)

    model.eval()
    pseudo = []
    for step, (image, label) in tqdm(enumerate(test_loader), total=len(test_loader)):
        image = torch.stack(image).float()
        label = torch.stack(label)

        image = image.to(device)
        label = label.to(device)

        output = model(image).to(device)

        _, preds = torch.max(output, 1)
        pseudo.extend(list(preds.cpu().numpy()))

    pseudo_dataset.label = pseudo
    pseudo_dataset = pd.concat([train_dataset, pseudo_dataset]).reset_index(drop=True)

    return pseudo_dataset

## **Make Save Directory**

In [19]:
def save_model(model, file_name):
    saved_dir = CFG["saved_dir"]

    if not os.path.isdir(saved_dir):
        os.mkdir(saved_dir)
        
    output_path = os.path.join(saved_dir, file_name)
    torch.save(model, output_path)
    print(f"model saved {file_name}")

## **Run !!**

In [20]:
def run(
    epochs, model, train_df, test_df, optimizer, criterion, scheduler, device, Folds=5
):
    torch.cuda.empty_cache()
    gc.collect()
    for _ in range(CFG["n_iter"]):
        for fold in range(CFG["n_Folds"]):

            print(f"{fold+1} fold start")
            split_df = fold_df(train_df, folds=CFG["n_Folds"])
            train_split = split_df[fold][0]
            valid_split = split_df[fold][1]
            train_loader, valid_loader, test_loader = get_dataloader(
                train_split, valid_split, test_df
            )

            early_stopping_cnt = 0
            patience = CFG["patience"]
            best_acc = 0.5

            for epoch in range(epochs):
                train_one_epoch(
                    epoch, model, train_loader, criterion, optimizer, scheduler, device
                )
                with torch.no_grad():
                    valid_acc, outputs = valid_one_epoch(
                        model, valid_loader, valid_split, device
                    )

                if valid_acc > best_acc:
                    best_acc = valid_acc
                    early_stopping_cnt = 0
                    save_file_name = f"{best_acc:.4f}_{CFG['timm_model']}_fold{fold+1}.pt"
                    save_model(model, save_file_name)
                else:
                    early_stopping_cnt += 1
                    print(f"Early Stopping Counter is {early_stopping_cnt}")
                    if early_stopping_cnt >= patience:
                        print(
                            f"Early Stopping Counter: {early_stopping_cnt} out of {patience}"
                        )
                        break

            scheduler.step(best_acc)
        train_df = pseudo_labeling(model, train_df, test_loader)
        train_loader, _, _ = get_dataloader(train_df, train_df, test_df)

In [21]:
run(
    epochs=CFG["n_epoch"],
    model=model,
    train_df=train_df,
    test_df=test_df,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    Folds=CFG["n_Folds"],
)

1 fold start


| # Epoch : 1 Loss : 2.0842: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.6206: 100%|██████████| 22/22 [00:02<00:00,  8.04it/s]


model saved 0.6206_vit_base_patch16_224_fold1.pt


| # Epoch : 2 Loss : 1.0928: 100%|██████████| 43/43 [00:40<00:00,  1.05it/s]


Start Validation


| Acc : 0.6412: 100%|██████████| 22/22 [00:03<00:00,  7.30it/s]


model saved 0.6412_vit_base_patch16_224_fold1.pt


| # Epoch : 3 Loss : 0.3662: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.7206: 100%|██████████| 22/22 [00:02<00:00,  8.69it/s]


model saved 0.7206_vit_base_patch16_224_fold1.pt


| # Epoch : 4 Loss : 1.2928: 100%|██████████| 43/43 [00:37<00:00,  1.15it/s]


Start Validation


| Acc : 0.6794: 100%|██████████| 22/22 [00:02<00:00,  7.94it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 1.0297: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


| Acc : 0.8088: 100%|██████████| 22/22 [00:02<00:00,  8.66it/s]


model saved 0.8088_vit_base_patch16_224_fold1.pt


| # Epoch : 6 Loss : 0.8944: 100%|██████████| 43/43 [00:39<00:00,  1.08it/s]


Start Validation


| Acc : 0.8147: 100%|██████████| 22/22 [00:02<00:00,  8.70it/s]


model saved 0.8147_vit_base_patch16_224_fold1.pt


| # Epoch : 7 Loss : 0.3157: 100%|██████████| 43/43 [00:40<00:00,  1.06it/s]


Start Validation


| Acc : 0.7735: 100%|██████████| 22/22 [00:02<00:00,  8.87it/s]


Early Stopping Counter is 1


| # Epoch : 8 Loss : 0.0872: 100%|██████████| 43/43 [00:36<00:00,  1.18it/s]


Start Validation


| Acc : 0.8059: 100%|██████████| 22/22 [00:02<00:00,  8.34it/s]


Early Stopping Counter is 2


| # Epoch : 9 Loss : 0.4268: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.7912: 100%|██████████| 22/22 [00:02<00:00,  8.42it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 1.3865: 100%|██████████| 43/43 [00:37<00:00,  1.13it/s]


Start Validation


| Acc : 0.8235: 100%|██████████| 22/22 [00:02<00:00,  8.73it/s]


model saved 0.8235_vit_base_patch16_224_fold2.pt


| # Epoch : 2 Loss : 0.6694: 100%|██████████| 43/43 [00:42<00:00,  1.01it/s]


Start Validation


| Acc : 0.9000: 100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


model saved 0.9000_vit_base_patch16_224_fold2.pt


| # Epoch : 3 Loss : 0.5351: 100%|██████████| 43/43 [00:41<00:00,  1.04it/s]


Start Validation


| Acc : 0.8824: 100%|██████████| 22/22 [00:02<00:00,  8.13it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.6417: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.8794: 100%|██████████| 22/22 [00:02<00:00,  8.65it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.3873: 100%|██████████| 43/43 [00:39<00:00,  1.10it/s]


Start Validation


| Acc : 0.8912: 100%|██████████| 22/22 [00:02<00:00,  8.89it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.7972: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.9029: 100%|██████████| 22/22 [00:02<00:00,  8.22it/s]


model saved 0.9029_vit_base_patch16_224_fold3.pt


| # Epoch : 2 Loss : 0.7739: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.8794: 100%|██████████| 22/22 [00:02<00:00,  8.99it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0668: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


| Acc : 0.9147: 100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


model saved 0.9147_vit_base_patch16_224_fold3.pt


| # Epoch : 4 Loss : 1.1207: 100%|██████████| 43/43 [00:39<00:00,  1.08it/s]


Start Validation


| Acc : 0.7824: 100%|██████████| 22/22 [00:02<00:00,  8.58it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.3170: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.9088: 100%|██████████| 22/22 [00:02<00:00,  8.02it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.1459: 100%|██████████| 43/43 [00:38<00:00,  1.11it/s]


Start Validation


| Acc : 0.9059: 100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.6899: 100%|██████████| 43/43 [00:38<00:00,  1.12it/s]


Start Validation


| Acc : 0.9174: 100%|██████████| 22/22 [00:02<00:00,  8.22it/s]


model saved 0.9174_vit_base_patch16_224_fold4.pt


| # Epoch : 2 Loss : 0.1594: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


| Acc : 0.9381: 100%|██████████| 22/22 [00:02<00:00,  9.13it/s]


model saved 0.9381_vit_base_patch16_224_fold4.pt


| # Epoch : 3 Loss : 0.1732: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


| Acc : 0.9469: 100%|██████████| 22/22 [00:02<00:00,  8.13it/s]


model saved 0.9469_vit_base_patch16_224_fold4.pt


| # Epoch : 4 Loss : 0.2651: 100%|██████████| 43/43 [00:39<00:00,  1.09it/s]


Start Validation


| Acc : 0.9056: 100%|██████████| 22/22 [00:02<00:00,  8.68it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.1589: 100%|██████████| 43/43 [00:38<00:00,  1.13it/s]


Start Validation


| Acc : 0.8997: 100%|██████████| 22/22 [00:02<00:00,  8.06it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.2725: 100%|██████████| 43/43 [00:39<00:00,  1.09it/s]


Start Validation


| Acc : 0.9351: 100%|██████████| 22/22 [00:02<00:00,  8.01it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.3298: 100%|██████████| 43/43 [00:37<00:00,  1.14it/s]


Start Validation


| Acc : 0.9528: 100%|██████████| 22/22 [00:02<00:00,  8.99it/s]


model saved 0.9528_vit_base_patch16_224_fold5.pt


| # Epoch : 2 Loss : 0.0900: 100%|██████████| 43/43 [00:39<00:00,  1.09it/s]


Start Validation


| Acc : 0.9381: 100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.2146: 100%|██████████| 43/43 [00:40<00:00,  1.05it/s]


Start Validation


| Acc : 0.8938: 100%|██████████| 22/22 [00:02<00:00,  7.88it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.5816: 100%|██████████| 43/43 [00:39<00:00,  1.10it/s]


Start Validation


| Acc : 0.9499: 100%|██████████| 22/22 [00:02<00:00,  8.43it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.47it/s]


1 fold start


| # Epoch : 1 Loss : 0.5561: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s]


Start Validation


| Acc : 0.9610: 100%|██████████| 26/26 [00:03<00:00,  7.79it/s]


model saved 0.9610_vit_base_patch16_224_fold1.pt


| # Epoch : 2 Loss : 0.2354: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s]


Start Validation


| Acc : 0.9634: 100%|██████████| 26/26 [00:03<00:00,  8.36it/s]


model saved 0.9634_vit_base_patch16_224_fold1.pt


| # Epoch : 3 Loss : 0.1527: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s]


Start Validation


| Acc : 0.9317: 100%|██████████| 26/26 [00:03<00:00,  8.31it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.1045: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s]


Start Validation


| Acc : 0.9610: 100%|██████████| 26/26 [00:03<00:00,  8.66it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.1107: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s]


Start Validation


| Acc : 0.9366: 100%|██████████| 26/26 [00:03<00:00,  8.52it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 1.1391: 100%|██████████| 52/52 [00:45<00:00,  1.15it/s]


Start Validation


| Acc : 0.9585: 100%|██████████| 26/26 [00:03<00:00,  8.00it/s]


model saved 0.9585_vit_base_patch16_224_fold2.pt


| # Epoch : 2 Loss : 0.4629: 100%|██████████| 52/52 [00:45<00:00,  1.14it/s]


Start Validation


| Acc : 0.9415: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0728: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s]


Start Validation


| Acc : 0.9805: 100%|██████████| 26/26 [00:03<00:00,  8.14it/s]


model saved 0.9805_vit_base_patch16_224_fold2.pt


| # Epoch : 4 Loss : 0.9254: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s]


Start Validation


| Acc : 0.9488: 100%|██████████| 26/26 [00:03<00:00,  8.33it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.1401: 100%|██████████| 52/52 [00:45<00:00,  1.15it/s]


Start Validation


| Acc : 0.9366: 100%|██████████| 26/26 [00:03<00:00,  8.20it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.7825: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s]


Start Validation


| Acc : 0.9537: 100%|██████████| 26/26 [00:03<00:00,  8.57it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.0479: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s]


Start Validation


| Acc : 0.9341: 100%|██████████| 26/26 [00:03<00:00,  8.42it/s]


model saved 0.9341_vit_base_patch16_224_fold3.pt


| # Epoch : 2 Loss : 0.6193: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s]


Start Validation


| Acc : 0.9049: 100%|██████████| 26/26 [00:03<00:00,  8.10it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0951: 100%|██████████| 52/52 [00:48<00:00,  1.07it/s]


Start Validation


| Acc : 0.9561: 100%|██████████| 26/26 [00:03<00:00,  7.88it/s]


model saved 0.9561_vit_base_patch16_224_fold3.pt


| # Epoch : 4 Loss : 0.0545: 100%|██████████| 52/52 [00:48<00:00,  1.08it/s]


Start Validation


| Acc : 0.9659: 100%|██████████| 26/26 [00:02<00:00,  8.89it/s]


model saved 0.9659_vit_base_patch16_224_fold3.pt


| # Epoch : 5 Loss : 0.0021: 100%|██████████| 52/52 [00:46<00:00,  1.11it/s]


Start Validation


| Acc : 0.9390: 100%|██████████| 26/26 [00:03<00:00,  8.18it/s]


Early Stopping Counter is 1


| # Epoch : 6 Loss : 0.7162: 100%|██████████| 52/52 [00:47<00:00,  1.10it/s]


Start Validation


| Acc : 0.9585: 100%|██████████| 26/26 [00:02<00:00,  8.84it/s]


Early Stopping Counter is 2


| # Epoch : 7 Loss : 0.0208: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s]


Start Validation


| Acc : 0.9707: 100%|██████████| 26/26 [00:03<00:00,  8.60it/s]


model saved 0.9707_vit_base_patch16_224_fold3.pt


| # Epoch : 8 Loss : 0.0196: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s]


Start Validation


| Acc : 0.9561: 100%|██████████| 26/26 [00:03<00:00,  8.45it/s]


Early Stopping Counter is 1


| # Epoch : 9 Loss : 0.0068: 100%|██████████| 52/52 [00:47<00:00,  1.09it/s]


Start Validation


| Acc : 0.9683: 100%|██████████| 26/26 [00:03<00:00,  8.21it/s]


Early Stopping Counter is 2


| # Epoch : 10 Loss : 0.0601: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s]


Start Validation


| Acc : 0.9707: 100%|██████████| 26/26 [00:03<00:00,  8.20it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.1691: 100%|██████████| 52/52 [00:46<00:00,  1.13it/s]


Start Validation


| Acc : 0.9829: 100%|██████████| 26/26 [00:03<00:00,  8.61it/s]


model saved 0.9829_vit_base_patch16_224_fold4.pt


| # Epoch : 2 Loss : 0.1156: 100%|██████████| 52/52 [00:47<00:00,  1.08it/s]


Start Validation


| Acc : 0.9291: 100%|██████████| 26/26 [00:03<00:00,  8.43it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.3186: 100%|██████████| 52/52 [00:49<00:00,  1.06it/s]


Start Validation


| Acc : 0.9731: 100%|██████████| 26/26 [00:03<00:00,  8.52it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.1113: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s]


Start Validation


| Acc : 0.9658: 100%|██████████| 26/26 [00:03<00:00,  8.34it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.5348: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s]


Start Validation


| Acc : 0.9560: 100%|██████████| 26/26 [00:02<00:00,  8.73it/s]


model saved 0.9560_vit_base_patch16_224_fold5.pt


| # Epoch : 2 Loss : 0.0748: 100%|██████████| 52/52 [00:48<00:00,  1.06it/s]


Start Validation


| Acc : 0.9658: 100%|██████████| 26/26 [00:03<00:00,  8.38it/s]


model saved 0.9658_vit_base_patch16_224_fold5.pt


| # Epoch : 3 Loss : 0.4672: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s]


Start Validation


| Acc : 0.9193: 100%|██████████| 26/26 [00:02<00:00,  8.80it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.0105: 100%|██████████| 52/52 [00:45<00:00,  1.13it/s]


Start Validation


| Acc : 0.9609: 100%|██████████| 26/26 [00:03<00:00,  8.47it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.9517: 100%|██████████| 52/52 [00:46<00:00,  1.12it/s]


Start Validation


| Acc : 0.8851: 100%|██████████| 26/26 [00:03<00:00,  8.42it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  7.77it/s]


1 fold start


| # Epoch : 1 Loss : 0.7991: 100%|██████████| 60/60 [00:55<00:00,  1.07it/s]


Start Validation


| Acc : 0.9708: 100%|██████████| 30/30 [00:03<00:00,  8.29it/s]


model saved 0.9708_vit_base_patch16_224_fold1.pt


| # Epoch : 2 Loss : 0.2926: 100%|██████████| 60/60 [00:55<00:00,  1.09it/s]


Start Validation


| Acc : 0.9729: 100%|██████████| 30/30 [00:03<00:00,  8.15it/s]


model saved 0.9729_vit_base_patch16_224_fold1.pt


| # Epoch : 3 Loss : 0.3387: 100%|██████████| 60/60 [00:56<00:00,  1.06it/s]


Start Validation


| Acc : 0.9854: 100%|██████████| 30/30 [00:03<00:00,  8.23it/s]


model saved 0.9854_vit_base_patch16_224_fold1.pt


| # Epoch : 4 Loss : 0.1813: 100%|██████████| 60/60 [00:59<00:00,  1.01it/s]


Start Validation


| Acc : 0.9854: 100%|██████████| 30/30 [00:03<00:00,  7.68it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.5893: 100%|██████████| 60/60 [00:53<00:00,  1.12it/s]


Start Validation


| Acc : 0.9917: 100%|██████████| 30/30 [00:03<00:00,  7.63it/s]


model saved 0.9917_vit_base_patch16_224_fold1.pt


| # Epoch : 6 Loss : 0.2555: 100%|██████████| 60/60 [00:56<00:00,  1.06it/s]


Start Validation


| Acc : 0.9667: 100%|██████████| 30/30 [00:03<00:00,  7.66it/s]


Early Stopping Counter is 1


| # Epoch : 7 Loss : 0.2445: 100%|██████████| 60/60 [00:54<00:00,  1.09it/s]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:03<00:00,  8.04it/s]


model saved 0.9979_vit_base_patch16_224_fold1.pt


| # Epoch : 8 Loss : 0.0191: 100%|██████████| 60/60 [00:54<00:00,  1.09it/s]


Start Validation


| Acc : 0.9854: 100%|██████████| 30/30 [00:03<00:00,  7.76it/s]


Early Stopping Counter is 1


| # Epoch : 9 Loss : 0.1804: 100%|██████████| 60/60 [00:56<00:00,  1.05it/s]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:03<00:00,  7.90it/s]


Early Stopping Counter is 2


| # Epoch : 10 Loss : 0.1390: 100%|██████████| 60/60 [00:56<00:00,  1.07it/s]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:03<00:00,  8.09it/s]


model saved 1.0000_vit_base_patch16_224_fold1.pt


| # Epoch : 11 Loss : 0.0881: 100%|██████████| 60/60 [00:56<00:00,  1.07it/s]


Start Validation


| Acc : 0.9958: 100%|██████████| 30/30 [00:03<00:00,  7.67it/s]


Early Stopping Counter is 1


| # Epoch : 12 Loss : 0.0049: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:03<00:00,  8.29it/s]


Early Stopping Counter is 2


| # Epoch : 13 Loss : 0.0045: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:03<00:00,  8.65it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 0.1196: 100%|██████████| 60/60 [00:52<00:00,  1.13it/s]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:03<00:00,  8.22it/s]


model saved 0.9875_vit_base_patch16_224_fold2.pt


| # Epoch : 2 Loss : 0.4007: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:03<00:00,  8.15it/s]


model saved 0.9979_vit_base_patch16_224_fold2.pt


| # Epoch : 3 Loss : 0.4358: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9792: 100%|██████████| 30/30 [00:03<00:00,  8.32it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.2004: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9604: 100%|██████████| 30/30 [00:03<00:00,  7.79it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.3411: 100%|██████████| 60/60 [00:53<00:00,  1.12it/s]


Start Validation


| Acc : 0.9667: 100%|██████████| 30/30 [00:03<00:00,  8.54it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.0787: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9792: 100%|██████████| 30/30 [00:03<00:00,  8.01it/s]


model saved 0.9792_vit_base_patch16_224_fold3.pt


| # Epoch : 2 Loss : 0.1999: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9729: 100%|██████████| 30/30 [00:03<00:00,  8.60it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.1492: 100%|██████████| 60/60 [00:55<00:00,  1.07it/s]


Start Validation


| Acc : 0.9750: 100%|██████████| 30/30 [00:03<00:00,  7.91it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.4493: 100%|██████████| 60/60 [00:56<00:00,  1.07it/s]


Start Validation


| Acc : 0.9688: 100%|██████████| 30/30 [00:03<00:00,  8.41it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.3722: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s]


Start Validation


| Acc : 0.9478: 100%|██████████| 30/30 [00:03<00:00,  8.48it/s]


model saved 0.9478_vit_base_patch16_224_fold4.pt


| # Epoch : 2 Loss : 0.0799: 100%|██████████| 60/60 [00:54<00:00,  1.11it/s]


Start Validation


| Acc : 0.9374: 100%|██████████| 30/30 [00:03<00:00,  8.34it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.1012: 100%|██████████| 60/60 [00:55<00:00,  1.07it/s]


Start Validation


| Acc : 0.9499: 100%|██████████| 30/30 [00:03<00:00,  8.06it/s]


model saved 0.9499_vit_base_patch16_224_fold4.pt


| # Epoch : 4 Loss : 0.0827: 100%|██████████| 60/60 [00:57<00:00,  1.04it/s]


Start Validation


| Acc : 0.9541: 100%|██████████| 30/30 [00:03<00:00,  8.31it/s]


model saved 0.9541_vit_base_patch16_224_fold4.pt


| # Epoch : 5 Loss : 0.1030: 100%|██████████| 60/60 [00:53<00:00,  1.12it/s]


Start Validation


| Acc : 0.9541: 100%|██████████| 30/30 [00:03<00:00,  7.58it/s]


Early Stopping Counter is 1


| # Epoch : 6 Loss : 0.1928: 100%|██████████| 60/60 [00:54<00:00,  1.10it/s]


Start Validation


| Acc : 0.9478: 100%|██████████| 30/30 [00:03<00:00,  8.15it/s]


Early Stopping Counter is 2


| # Epoch : 7 Loss : 0.2420: 100%|██████████| 60/60 [00:51<00:00,  1.15it/s]


Start Validation


| Acc : 0.9624: 100%|██████████| 30/30 [00:03<00:00,  8.19it/s]


model saved 0.9624_vit_base_patch16_224_fold4.pt


| # Epoch : 8 Loss : 0.4195: 100%|██████████| 60/60 [00:56<00:00,  1.06it/s]


Start Validation


| Acc : 0.9562: 100%|██████████| 30/30 [00:04<00:00,  7.37it/s]


Early Stopping Counter is 1


| # Epoch : 9 Loss : 0.0815: 100%|██████████| 60/60 [00:55<00:00,  1.09it/s]


Start Validation


| Acc : 0.9562: 100%|██████████| 30/30 [00:03<00:00,  8.42it/s]


Early Stopping Counter is 2


| # Epoch : 10 Loss : 0.1061: 100%|██████████| 60/60 [00:56<00:00,  1.06it/s]


Start Validation


| Acc : 0.9415: 100%|██████████| 30/30 [00:03<00:00,  8.20it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.2803: 100%|██████████| 60/60 [00:57<00:00,  1.04it/s]


Start Validation


| Acc : 0.9269: 100%|██████████| 30/30 [00:03<00:00,  8.19it/s]


model saved 0.9269_vit_base_patch16_224_fold5.pt


| # Epoch : 2 Loss : 0.0801: 100%|██████████| 60/60 [00:57<00:00,  1.03it/s]


Start Validation


| Acc : 0.9228: 100%|██████████| 30/30 [00:03<00:00,  8.16it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.2871: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s]


Start Validation


| Acc : 0.9248: 100%|██████████| 30/30 [00:03<00:00,  8.35it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.2831: 100%|██████████| 60/60 [00:55<00:00,  1.08it/s]


Start Validation


| Acc : 0.9165: 100%|██████████| 30/30 [00:03<00:00,  8.08it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  7.61it/s]


In [22]:
for model in os.walk(CFG["saved_dir"]):
    models=model[-1][-5:]
for model in models:
    torch.load(os.path.join(CFG["saved_dir"],model)).eval()