## **Import Modules and Packages**

In [1]:
import warnings

warnings.filterwarnings("ignore")

import os
import random
import gc
import easydict
import glob
import multiprocessing
import copy

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import cv2

import numpy as np
import pandas as pd
from tqdm import tqdm

# Transform을 위한 라이브러리
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Model을 위한 라이브러리
import timm

# Fold를 위한 라이브러리
from sklearn.model_selection import GroupKFold, KFold, StratifiedKFold

# loss, optimizer, scheduler 를 위한 라이브러리
from pytorch_toolbelt import losses
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import autocast, GradScaler
from madgrad import MADGRAD

# Weight & bias
import wandb

# 이미지 시각화를 위한 라이브러리
from PIL import Image
import webcolors
import matplotlib.pyplot as plt
from matplotlib.patches import Patch
import seaborn as sns

sns.set()

%matplotlib inline

plt.rcParams["axes.grid"] = False

print("Pytorch version: {}".format(torch.__version__))
print("GPU: {}".format(torch.cuda.is_available()))

print("Device name: ", torch.cuda.get_device_name(0))
print("Device count: ", torch.cuda.device_count())

# GPU 사용 가능 여부에 따라 device 정보 저장
device = "cuda" if torch.cuda.is_available() else "cpu"

Pytorch version: 1.10.0+cu102
GPU: True
Device name:  Tesla V100-PCIE-32GB
Device count:  1


## **Set Configs**

In [2]:
CFG = {}

CFG['experiment_number'] = "Dec11"

CFG["seed"] = 21
CFG["data_root"] = "/opt/ml/Workspace/Dev_Matching_ArtPaint_Classification"
CFG["saved_dir"] = os.path.join(CFG["data_root"],"saved")
CFG["lr"] = 1e-4
CFG["weight_decay"] = 1e-6

CFG["mean"] = [0.5556861, 0.50740065, 0.45690217]
CFG["std"] = [0.22876642, 0.21754766, 0.22090458]

CFG["timm_model"] = "swin_base_patch4_window7_224"
# CFG["timm_model"] = "vit_base_patch16_224"

CFG["n_epoch"] = 20
CFG["n_Folds"] = 5
CFG["n_iter"] = 3
CFG["patience"] = 3
args = easydict.EasyDict(CFG)

In [3]:
train_path = os.path.join(CFG["data_root"], "data/train")
test_path = os.path.join(CFG["data_root"], "data/test")

In [4]:
label = {}

label["dog"] = 0
label["elephant"] = 1
label["giraffe"] = 2
label["guitar"] = 3
label["horse"] = 4
label["house"] = 5
label["person"] = 6
label["0"] = -1  # for test label

## **Utils**

In [5]:
def seed_everything(seed):
    random_seed = seed
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)


seed_everything(CFG["seed"])

In [6]:
def collate_fn(batch):
    return tuple(zip(*batch))

## **Dataset**

In [7]:
def data_frame(data_path):
    labels = []
    for path in os.walk(data_path):
        label_type = path[0].split("/")[-1]
        if label_type == "train" or label_type == "test":
            continue
        path_root = path[0]
        images = path[-1]

        idx = label[label_type]
        for image in images:
            img_path = os.path.join(path_root, image)
            label_idx = {"img_path": img_path, "label": idx}
            labels.append(label_idx)

    data_frame = pd.DataFrame(labels)
    data_frame = data_frame.sort_values(["label", "img_path"])
    data_frame = data_frame.reset_index(drop=True)
    return data_frame

In [8]:
train_df = data_frame(data_path=train_path)
test_df = data_frame(data_path=test_path)

In [9]:
class ArtDataset(Dataset):
    def __init__(self, df, transform=None):
        super().__init__()
        self.df = df.reset_index()
        self.image_id = self.df.img_path
        self.label = self.df.label
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_id = self.image_id[idx]
        label = self.label[idx]

        image = cv2.imread(image_id)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            image = self.transform(image=np.array(image))["image"]

        image = torch.tensor(image, dtype=torch.float)
        label = torch.tensor(label, dtype=torch.long)

        return image, label

## **Transform**

In [10]:
def get_augmentation(data_type):
    if data_type == "train":
        return A.Compose(
            [
                A.OneOf(
                    [
                        A.GridDistortion(p=1.0),
                        A.RandomGridShuffle(p=1.0),
                        A.ElasticTransform(p=1.0),
                        A.GridDropout(),
                    ],
                    p=1.0,
                ),
                A.Resize(224, 224),
                A.HorizontalFlip(p=0.5),
                A.Normalize(CFG["mean"], CFG["std"]),
                ToTensorV2(),
            ],
            p=1.0,
        )
    else:  # "Valid and Test"
        return A.Compose(
            [
                A.Resize(224, 224),
                A.Normalize(CFG["mean"], CFG["std"]),
                ToTensorV2(),
            ],
            p=1.0,
        )

## **Define Dataset**

In [11]:
def fold_df(data_frame, folds=CFG["n_Folds"]):
    skf = StratifiedKFold(n_splits=folds)

    X = data_frame.img_path.values
    y = data_frame.label.values

    split_df = []
    for _, (train_index, valid_index) in enumerate(skf.split(X, y)):
        train_df = data_frame.iloc[train_index].copy().reset_index(drop=True)
        valid_df = data_frame.iloc[valid_index].copy().reset_index(drop=True)

        split_df.append((train_df, valid_df))
    return split_df

## **Model**

In [12]:
model = timm.create_model(
    model_name=CFG["timm_model"], pretrained=True, num_classes=7
)

In [13]:
x = torch.randn([1, 3, 224, 224])
out = model(x).to(device)
print(f"input : {x.shape} | output : {out.size()}")

input : torch.Size([1, 3, 224, 224]) | output : torch.Size([1, 7])


## **Loss & Optimizer**

In [14]:
# criterion = nn.CrossEntropyLoss()
criterion = losses.SoftCrossEntropyLoss()
optimizer = MADGRAD(params=model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=CFG["n_epoch"], T_mult=1)

## **Get Dataloader**

In [15]:
def get_dataloader(train_df, valid_df, test_df):
    train_dataset = ArtDataset(
        train_df, transform=get_augmentation(data_type="train")
    )
    valid_dataset = ArtDataset(
        valid_df, transform=get_augmentation(data_type="valid")
    )
    test_dataset = ArtDataset(
        test_df, transform=get_augmentation(data_type="test")
    )

    train_loader = DataLoader(
        train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate_fn
    )
    valid_loader = DataLoader(
        valid_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn,
    )
    test_loader = DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers=0, collate_fn=collate_fn
    )

    return train_loader, valid_loader, test_loader

## **Train/Valid One Epoch**

In [16]:
def train_one_epoch(epoch, model, data_loader, criterion, optimizer, scheduler, device):
    model.train()

    cnt = 0
    correct = 0
    scaler = GradScaler()

    pbar = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)

        with autocast(enabled=True):
            model = model.to(device)

            output = model(image)
            loss = criterion(output, label)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        optimizer.zero_grad()

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)
        cnt += 1

        description = f"| # Epoch : {epoch + 1} Loss : {(loss.item()):.4f}"
        pbar.set_description(description)

    acc = correct / cnt
    scheduler.step()

In [17]:
def valid_one_epoch(model, data_loader, split_df, device):
    print(f"Start Validation")

    model.eval()
    correct = 0

    pbar_valid = tqdm(enumerate(data_loader), total=len(data_loader))
    for step, (image, label) in pbar_valid:

        image = torch.stack(image).float()
        label = torch.stack(label).long()

        image = image.to(device)
        label = label.to(device)
        model = model.to(device)

        output = model(image)

        _, preds = torch.max(output, 1)
        correct += torch.sum(preds == label.data)

        description_valid = f"| Acc : {(correct.item()/len(split_df)):.4f}"
        pbar_valid.set_description(description_valid)
    acc = correct / len(split_df)

    return acc, output

## **Pseudo_labeling**

In [18]:
def pseudo_labeling(model, train_dataset, test_loader):
    print(f"Start Pseudo")
    pseudo_dataset = copy.deepcopy(test_df)

    model.eval()
    pseudo = []
    for step, (image, label) in tqdm(enumerate(test_loader), total=len(test_loader)):
        image = torch.stack(image).float()
        label = torch.stack(label)

        image = image.to(device)
        label = label.to(device)

        output = model(image).to(device)

        _, preds = torch.max(output, 1)
        pseudo.extend(list(preds.cpu().numpy()))

    pseudo_dataset.label = pseudo
    pseudo_dataset = pd.concat([train_dataset, pseudo_dataset]).reset_index(drop=True)

    return pseudo_dataset

## **Make Save Directory**

In [19]:
def save_model(model, file_name):
    saved_dir = CFG["saved_dir"]

    if not os.path.isdir(saved_dir):
        os.mkdir(saved_dir)
        
    output_path = os.path.join(saved_dir, file_name)
    torch.save(model, output_path)
    print(f"model saved {file_name}")

## **Run !!**

In [20]:
def run(
    epochs, model, train_df, test_df, optimizer, criterion, scheduler, device, Folds=5
):
    torch.cuda.empty_cache()
    gc.collect()
    for _ in range(CFG["n_iter"]):
        for fold in range(CFG["n_Folds"]):

            print(f"{fold+1} fold start")
            split_df = fold_df(train_df, folds=CFG["n_Folds"])
            train_split = split_df[fold][0]
            valid_split = split_df[fold][1]
            train_loader, valid_loader, test_loader = get_dataloader(
                train_split, valid_split, test_df
            )

            early_stopping_cnt = 0
            patience = CFG["patience"]
            best_acc = 0.5

            for epoch in range(epochs):
                train_one_epoch(
                    epoch, model, train_loader, criterion, optimizer, scheduler, device
                )
                with torch.no_grad():
                    valid_acc, outputs = valid_one_epoch(
                        model, valid_loader, valid_split, device
                    )

                if valid_acc > best_acc:
                    best_acc = valid_acc
                    early_stopping_cnt = 0
                    save_file_name = f"{best_acc:.4f}_{CFG['timm_model']}_fold{fold+1}.pt"
                    save_model(model, save_file_name)
                else:
                    early_stopping_cnt += 1
                    print(f"Early Stopping Counter is {early_stopping_cnt}")
                    if early_stopping_cnt >= patience:
                        print(
                            f"Early Stopping Counter: {early_stopping_cnt} out of {patience}"
                        )
                        break

            scheduler.step(best_acc)
        train_df = pseudo_labeling(model, train_df, test_loader)
        train_loader, _, _ = get_dataloader(train_df, train_df, test_df)

In [21]:
run(
    epochs=CFG["n_epoch"],
    model=model,
    train_df=train_df,
    test_df=test_df,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    Folds=CFG["n_Folds"],
)

1 fold start


| # Epoch : 1 Loss : 0.2482: 100%|██████████| 43/43 [00:43<00:00,  1.01s/it]


Start Validation


| Acc : 0.9559: 100%|██████████| 22/22 [00:03<00:00,  6.07it/s]


model saved 0.9559_swin_base_patch4_window7_224_fold1.pt


| # Epoch : 2 Loss : 0.4427: 100%|██████████| 43/43 [00:44<00:00,  1.05s/it]


Start Validation


| Acc : 0.9441: 100%|██████████| 22/22 [00:03<00:00,  6.79it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0461: 100%|██████████| 43/43 [00:43<00:00,  1.01s/it]


Start Validation


| Acc : 0.9735: 100%|██████████| 22/22 [00:02<00:00,  7.64it/s]


model saved 0.9735_swin_base_patch4_window7_224_fold1.pt


| # Epoch : 4 Loss : 0.2032: 100%|██████████| 43/43 [00:42<00:00,  1.02it/s]


Start Validation


| Acc : 0.9647: 100%|██████████| 22/22 [00:03<00:00,  7.18it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.8954: 100%|██████████| 43/43 [00:42<00:00,  1.00it/s]


Start Validation


| Acc : 0.9559: 100%|██████████| 22/22 [00:02<00:00,  8.01it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.0384: 100%|██████████| 43/43 [00:43<00:00,  1.02s/it]


Start Validation


| Acc : 0.9588: 100%|██████████| 22/22 [00:03<00:00,  7.21it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 0.6402: 100%|██████████| 43/43 [00:44<00:00,  1.04s/it]


Start Validation


| Acc : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  8.07it/s]


model saved 0.9941_swin_base_patch4_window7_224_fold2.pt


| # Epoch : 2 Loss : 0.0095: 100%|██████████| 43/43 [00:41<00:00,  1.04it/s]


Start Validation


| Acc : 0.9853: 100%|██████████| 22/22 [00:02<00:00,  7.60it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.2007: 100%|██████████| 43/43 [00:43<00:00,  1.00s/it]


Start Validation


| Acc : 0.9882: 100%|██████████| 22/22 [00:03<00:00,  7.28it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.0045: 100%|██████████| 43/43 [00:41<00:00,  1.03it/s]


Start Validation


| Acc : 0.9735: 100%|██████████| 22/22 [00:03<00:00,  7.20it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.1645: 100%|██████████| 43/43 [00:45<00:00,  1.07s/it]


Start Validation


| Acc : 0.9912: 100%|██████████| 22/22 [00:03<00:00,  6.92it/s]


model saved 0.9912_swin_base_patch4_window7_224_fold3.pt


| # Epoch : 2 Loss : 0.4886: 100%|██████████| 43/43 [00:45<00:00,  1.05s/it]


Start Validation


| Acc : 0.9912: 100%|██████████| 22/22 [00:02<00:00,  7.98it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0595: 100%|██████████| 43/43 [00:42<00:00,  1.00it/s]


Start Validation


| Acc : 0.9882: 100%|██████████| 22/22 [00:03<00:00,  7.11it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.6418: 100%|██████████| 43/43 [00:42<00:00,  1.00it/s]


Start Validation


| Acc : 0.9882: 100%|██████████| 22/22 [00:03<00:00,  7.28it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.0547: 100%|██████████| 43/43 [00:42<00:00,  1.00it/s]


Start Validation


| Acc : 0.9941: 100%|██████████| 22/22 [00:02<00:00,  7.75it/s]


model saved 0.9941_swin_base_patch4_window7_224_fold4.pt


| # Epoch : 2 Loss : 0.0024: 100%|██████████| 43/43 [00:42<00:00,  1.01it/s]


Start Validation


| Acc : 0.9912: 100%|██████████| 22/22 [00:02<00:00,  8.12it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.1048: 100%|██████████| 43/43 [00:41<00:00,  1.03it/s]


Start Validation


| Acc : 0.9794: 100%|██████████| 22/22 [00:03<00:00,  6.64it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.0338: 100%|██████████| 43/43 [00:43<00:00,  1.02s/it]


Start Validation


| Acc : 0.9882: 100%|██████████| 22/22 [00:03<00:00,  7.02it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.0665: 100%|██████████| 43/43 [00:42<00:00,  1.01it/s]


Start Validation


| Acc : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.98it/s]


model saved 0.9971_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 2 Loss : 0.1136: 100%|██████████| 43/43 [00:43<00:00,  1.00s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.63it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 3 Loss : 0.0547: 100%|██████████| 43/43 [00:44<00:00,  1.03s/it]


Start Validation


| Acc : 0.9971: 100%|██████████| 22/22 [00:02<00:00,  7.43it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.0055: 100%|██████████| 43/43 [00:42<00:00,  1.00it/s]


Start Validation


| Acc : 1.0000: 100%|██████████| 22/22 [00:02<00:00,  7.73it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.0514: 100%|██████████| 43/43 [00:42<00:00,  1.02it/s]


Start Validation


| Acc : 0.9794: 100%|██████████| 22/22 [00:02<00:00,  7.75it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  7.38it/s]


1 fold start


| # Epoch : 1 Loss : 0.0516: 100%|██████████| 52/52 [00:53<00:00,  1.04s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 26/26 [00:03<00:00,  7.95it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold1.pt


| # Epoch : 2 Loss : 0.1770: 100%|██████████| 52/52 [00:52<00:00,  1.00s/it]


Start Validation


| Acc : 0.9976: 100%|██████████| 26/26 [00:03<00:00,  7.32it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.1401: 100%|██████████| 52/52 [00:50<00:00,  1.04it/s]


Start Validation


| Acc : 1.0000: 100%|██████████| 26/26 [00:03<00:00,  7.57it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.6362: 100%|██████████| 52/52 [00:54<00:00,  1.05s/it]


Start Validation


| Acc : 0.9829: 100%|██████████| 26/26 [00:04<00:00,  6.34it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 0.5222: 100%|██████████| 52/52 [00:54<00:00,  1.06s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  6.79it/s]


model saved 0.9927_swin_base_patch4_window7_224_fold2.pt


| # Epoch : 2 Loss : 0.0041: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it]


Start Validation


| Acc : 0.9780: 100%|██████████| 26/26 [00:03<00:00,  7.21it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.3560: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 26/26 [00:03<00:00,  7.18it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold2.pt


| # Epoch : 4 Loss : 0.0130: 100%|██████████| 52/52 [00:54<00:00,  1.04s/it]


Start Validation


| Acc : 0.9976: 100%|██████████| 26/26 [00:03<00:00,  7.36it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.0268: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.31it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.0052: 100%|██████████| 52/52 [00:52<00:00,  1.00s/it]


Start Validation


| Acc : 0.9951: 100%|██████████| 26/26 [00:03<00:00,  7.28it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.0644: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it]


Start Validation


| Acc : 0.9951: 100%|██████████| 26/26 [00:03<00:00,  6.70it/s]


model saved 0.9951_swin_base_patch4_window7_224_fold3.pt


| # Epoch : 2 Loss : 0.0005: 100%|██████████| 52/52 [00:51<00:00,  1.01it/s]


Start Validation


| Acc : 0.9902: 100%|██████████| 26/26 [00:03<00:00,  7.24it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0008: 100%|██████████| 52/52 [00:49<00:00,  1.05it/s]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.61it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.0605: 100%|██████████| 52/52 [00:54<00:00,  1.05s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.55it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.0029: 100%|██████████| 52/52 [00:53<00:00,  1.04s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.55it/s]


model saved 0.9927_swin_base_patch4_window7_224_fold4.pt


| # Epoch : 2 Loss : 0.0150: 100%|██████████| 52/52 [00:52<00:00,  1.01s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 26/26 [00:03<00:00,  7.51it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold4.pt


| # Epoch : 3 Loss : 0.0041: 100%|██████████| 52/52 [00:52<00:00,  1.02s/it]


Start Validation


| Acc : 0.9976: 100%|██████████| 26/26 [00:03<00:00,  8.05it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.0014: 100%|██████████| 52/52 [00:54<00:00,  1.04s/it]


Start Validation


| Acc : 0.9976: 100%|██████████| 26/26 [00:03<00:00,  7.74it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.0001: 100%|██████████| 52/52 [00:52<00:00,  1.00s/it]


Start Validation


| Acc : 0.9976: 100%|██████████| 26/26 [00:03<00:00,  7.30it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.8015: 100%|██████████| 52/52 [00:54<00:00,  1.04s/it]


Start Validation


| Acc : 0.9878: 100%|██████████| 26/26 [00:03<00:00,  7.17it/s]


model saved 0.9878_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 2 Loss : 0.0041: 100%|██████████| 52/52 [00:56<00:00,  1.08s/it]


Start Validation


| Acc : 0.9902: 100%|██████████| 26/26 [00:03<00:00,  7.81it/s]


model saved 0.9902_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 3 Loss : 0.0073: 100%|██████████| 52/52 [00:53<00:00,  1.02s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.69it/s]


model saved 0.9927_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 4 Loss : 0.2962: 100%|██████████| 52/52 [00:53<00:00,  1.03s/it]


Start Validation


| Acc : 0.9756: 100%|██████████| 26/26 [00:03<00:00,  6.92it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.0001: 100%|██████████| 52/52 [00:52<00:00,  1.00s/it]


Start Validation


| Acc : 0.9927: 100%|██████████| 26/26 [00:03<00:00,  7.19it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.0994: 100%|██████████| 52/52 [00:53<00:00,  1.02s/it]


Start Validation


| Acc : 0.9853: 100%|██████████| 26/26 [00:03<00:00,  7.24it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.62it/s]


1 fold start


| # Epoch : 1 Loss : 0.1111: 100%|██████████| 60/60 [01:02<00:00,  1.03s/it]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:04<00:00,  6.91it/s]


model saved 0.9979_swin_base_patch4_window7_224_fold1.pt


| # Epoch : 2 Loss : 0.0061: 100%|██████████| 60/60 [01:01<00:00,  1.03s/it]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:04<00:00,  6.79it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0089: 100%|██████████| 60/60 [01:02<00:00,  1.04s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  6.95it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold1.pt


| # Epoch : 4 Loss : 0.0069: 100%|██████████| 60/60 [01:04<00:00,  1.08s/it]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:04<00:00,  7.14it/s]


Early Stopping Counter is 1


| # Epoch : 5 Loss : 0.1964: 100%|██████████| 60/60 [01:01<00:00,  1.03s/it]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:04<00:00,  7.27it/s]


Early Stopping Counter is 2


| # Epoch : 6 Loss : 0.0444: 100%|██████████| 60/60 [01:04<00:00,  1.07s/it]


Start Validation


| Acc : 0.9979: 100%|██████████| 30/30 [00:03<00:00,  7.50it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
2 fold start


| # Epoch : 1 Loss : 0.1526: 100%|██████████| 60/60 [01:02<00:00,  1.04s/it]


Start Validation


| Acc : 0.9958: 100%|██████████| 30/30 [00:04<00:00,  7.16it/s]


model saved 0.9958_swin_base_patch4_window7_224_fold2.pt


| # Epoch : 2 Loss : 0.0545: 100%|██████████| 60/60 [01:03<00:00,  1.05s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  6.84it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold2.pt


| # Epoch : 3 Loss : 0.0169: 100%|██████████| 60/60 [00:59<00:00,  1.01it/s]


Start Validation


| Acc : 0.9938: 100%|██████████| 30/30 [00:04<00:00,  7.42it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.0372: 100%|██████████| 60/60 [01:02<00:00,  1.05s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  7.29it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.2766: 100%|██████████| 60/60 [01:02<00:00,  1.04s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  6.96it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
3 fold start


| # Epoch : 1 Loss : 0.0837: 100%|██████████| 60/60 [01:03<00:00,  1.06s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  7.30it/s]


model saved 1.0000_swin_base_patch4_window7_224_fold3.pt


| # Epoch : 2 Loss : 0.0873: 100%|██████████| 60/60 [01:03<00:00,  1.06s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:04<00:00,  6.74it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0930: 100%|██████████| 60/60 [01:02<00:00,  1.04s/it]


Start Validation


| Acc : 1.0000: 100%|██████████| 30/30 [00:03<00:00,  7.90it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.0605: 100%|██████████| 60/60 [01:03<00:00,  1.05s/it]


Start Validation


| Acc : 0.9854: 100%|██████████| 30/30 [00:03<00:00,  7.81it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
4 fold start


| # Epoch : 1 Loss : 0.0155: 100%|██████████| 60/60 [01:01<00:00,  1.03s/it]


Start Validation


| Acc : 0.9916: 100%|██████████| 30/30 [00:03<00:00,  7.58it/s]


model saved 0.9916_swin_base_patch4_window7_224_fold4.pt


| # Epoch : 2 Loss : 0.2101: 100%|██████████| 60/60 [01:01<00:00,  1.02s/it]


Start Validation


| Acc : 0.9937: 100%|██████████| 30/30 [00:04<00:00,  7.14it/s]


model saved 0.9937_swin_base_patch4_window7_224_fold4.pt


| # Epoch : 3 Loss : 0.0658: 100%|██████████| 60/60 [01:01<00:00,  1.02s/it]


Start Validation


| Acc : 0.9937: 100%|██████████| 30/30 [00:04<00:00,  6.82it/s]


Early Stopping Counter is 1


| # Epoch : 4 Loss : 0.0109: 100%|██████████| 60/60 [01:02<00:00,  1.04s/it]


Start Validation


| Acc : 0.9937: 100%|██████████| 30/30 [00:04<00:00,  7.47it/s]


Early Stopping Counter is 2


| # Epoch : 5 Loss : 0.0132: 100%|██████████| 60/60 [01:04<00:00,  1.07s/it]


Start Validation


| Acc : 0.9916: 100%|██████████| 30/30 [00:04<00:00,  7.08it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
5 fold start


| # Epoch : 1 Loss : 0.0364: 100%|██████████| 60/60 [01:00<00:00,  1.02s/it]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:04<00:00,  7.37it/s]


model saved 0.9875_swin_base_patch4_window7_224_fold5.pt


| # Epoch : 2 Loss : 0.0161: 100%|██████████| 60/60 [01:01<00:00,  1.03s/it]


Start Validation


| Acc : 0.9749: 100%|██████████| 30/30 [00:04<00:00,  6.98it/s]


Early Stopping Counter is 1


| # Epoch : 3 Loss : 0.0241: 100%|██████████| 60/60 [01:01<00:00,  1.02s/it]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:04<00:00,  6.89it/s]


Early Stopping Counter is 2


| # Epoch : 4 Loss : 0.0608: 100%|██████████| 60/60 [01:01<00:00,  1.03s/it]


Start Validation


| Acc : 0.9875: 100%|██████████| 30/30 [00:04<00:00,  7.39it/s]


Early Stopping Counter is 3
Early Stopping Counter: 3 out of 3
Start Pseudo


100%|██████████| 22/22 [00:02<00:00,  8.29it/s]


In [22]:
for model in os.walk(CFG["saved_dir"]):
    models=model[-1][-5:]
for model in models:
    torch.load(os.path.join(CFG["saved_dir"],model)).eval()