# **📄 Document type classification baseline code**
> 문서 타입 분류 대회에 오신 여러분 환영합니다! 🎉     
> 아래 baseline에서는 ResNet 모델을 로드하여, 모델을 학습 및 예측 파일 생성하는 프로세스에 대해 알아보겠습니다.

## Contents
- Prepare Environments
- Import Library & Define Functions
- Hyper-parameters
- Load Data
- Train Model
- Inference & Save File


## 1. Prepare Environments

* 데이터 로드를 위한 구글 드라이브를 마운트합니다.
* 필요한 라이브러리를 설치합니다.

In [None]:
# 구글 드라이브 마운트, Colab을 이용하지 않는다면 패스해도 됩니다.
# from google.colab import drive
# drive.mount('/gdrive', force_remount=True)
# drive.mount('/content/drive')

In [None]:
# 구글 드라이브에 업로드된 대회 데이터를 압축 해제하고 로컬에 저장합니다.
# !tar -xvf drive/MyDrive/datasets_fin.tar > /dev/null

In [None]:
# 필요한 라이브러리를 설치합니다.
# !pip install timm
!pip install augraphy albumentations tqdm

In [None]:
!pip install opencv-python

In [None]:
!apt-get update -y

In [None]:
!apt-get install -y libgl1-mesa-glx
# !apt install libgl1-mesa-glx

In [None]:
# !apt-get update && apt-get install -y python3-opencv

## 2. Import Library & Define Functions
* 학습 및 추론에 필요한 라이브러리를 로드합니다.
* 학습 및 추론에 필요한 함수와 클래스를 정의합니다.

In [1]:
import os
import time
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

from sklearn.model_selection import train_test_split # train-validation-test set 나누는 라이브러리


In [2]:
# 시드를 고정합니다.
SEED = 100
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [3]:
# 데이터셋 클래스를 정의합니다.
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image=img)['image']
        return img, target

In [4]:
# one epoch 학습을 위한 함수입니다.
def train_one_epoch(loader, model, optimizer, loss_fn, device):
    model.train()
    train_loss = 0
    preds_list = []
    targets_list = []

    pbar = tqdm(loader)
    for image, targets in pbar:
        image = image.to(device)
        targets = targets.to(device)

        model.zero_grad(set_to_none=True)

        preds = model(image)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())

        pbar.set_description(f"Train Loss: {loss.item():.4f}")

    train_loss /= len(loader)
    train_acc = accuracy_score(targets_list, preds_list)
    train_f1 = f1_score(targets_list, preds_list, average='macro')

    ret = {
        "train_loss": train_loss,
        "train_acc": train_acc,
        "train_f1": train_f1,
    }

    return ret

In [5]:
# one epoch 검증을 위한 함수입니다.
def val_one_epoch(loader, model, optimizer, loss_fn, device):
    model.eval()  # 모델을 평가 모드로 설정
    valid_loss = 0
    preds_list = []
    targets_list = []

    pbar = tqdm(loader)
    for image, targets in pbar:
        image = image.to(device)
        targets = targets.to(device)

        preds = model(image)
        loss = loss_fn(preds, targets)


        valid_loss += loss.item()
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())

        pbar.set_description(f"Val Loss: {loss.item():.4f}")

    valid_loss /= len(loader)
    val_acc = accuracy_score(targets_list, preds_list)
    val_f1 = f1_score(targets_list, preds_list, average='macro')

    ret = {
        "valid_loss": valid_loss,
        "val_acc": val_acc,
        "val_f1": val_f1,
    }

    return ret

## 3. Hyper-parameters
* 학습 및 추론에 필요한 하이퍼파라미터들을 정의합니다.

In [6]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = 'data/'

# model config
model_name = 'tf_efficientnet_b4' # 'resnet34', 'resnet50' 'efficientnet-b0', ...

# training config
img_size = 224
LR = 1e-3
EPOCHS = 50 #1
BATCH_SIZE = 32
num_workers = 0


patience = 10
weight_decay=2e-4

In [7]:
# all_pretrained_models_available = timm.list_models('tf_efficientnet*', pretrained=True)
# all_pretrained_models_available

## 4. Load Data
* 학습, 테스트 데이터셋과 로더를 정의합니다.

In [7]:
# augmentation을 위한 transform 코드
trn_transform = A.Compose([
    # 이미지 크기 조정
    A.Resize(height=img_size, width=img_size),
    # images normalization
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # numpy 이미지나 PIL 이미지를 PyTorch 텐서로 변환
    ToTensorV2(),
])

# test image 변환을 위한 transform 코드
tst_transform = A.Compose([
    A.Resize(height=img_size, width=img_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])

In [9]:
# csv 파일 합치기
train_df = pd.read_csv("data/train.csv")
aug_df = pd.read_csv("data/aug_img.csv")
aug2_df = pd.read_csv("data/aug_img2.csv")
aug3_df = pd.read_csv("data/aug_img3.csv")

combine_df = pd.concat([train_df, aug_df, aug2_df, aug3_df], ignore_index=True)

# 파일 저장
combine_df.to_csv("data/multi_aug3_combine.csv", index=False)

In [10]:
len(combine_df)

64370

In [8]:
# 데이터 분할....aug_img2.csv, aug_img3.csv 파일내 이미지파일이름 수정후 실행함 2024.8.8
aug_data_csv = pd.read_csv("data/multi_aug3_combine.csv")

# train set과 validation set, test set을 각각 나눕니다. 8 : 1 : 1 의 비율로 나눕니다.
train_csv, val_csv = train_test_split(aug_data_csv, test_size = .2, random_state = 100)

# index 를 reset 해줍니다.
train_csv = train_csv.reset_index(drop=True)
val_csv = val_csv.reset_index(drop=True)

print("Train 개수: ", len(train_csv))
print("Validation 개수: ", len(val_csv))

# 파일 저장
train_csv.to_csv("data/aug_train.csv", index=False)
val_csv.to_csv("data/valid.csv", index=False)

Train 개수:  51496
Validation 개수:  12874


In [8]:
# Dataset 정의
trn_dataset = ImageDataset(
    "data/aug_train.csv",
    "data/aug_img/",
    transform=trn_transform
)

# val 추가
val_dataset = ImageDataset(
    "data/valid.csv",
    "data/aug_img/",
    transform=trn_transform
)

tst_dataset = ImageDataset(
    "data/sample_submission.csv",
    "data/test/",
    transform=tst_transform
)
print(len(trn_dataset), len(val_dataset), len(tst_dataset))

51496 12874 3140


In [9]:
# DataLoader 정의
trn_loader = DataLoader(
    trn_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False
)

# valid
val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

tst_loader = DataLoader(
    tst_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

## 5. Train Model
* 모델을 로드하고, 학습을 진행합니다.

In [11]:
# load model
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=17
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR, weight_decay=weight_decay)

In [12]:
best_valid_loss = float('inf')  # 가장 좋은 validation loss를 저장
early_stop_counter = 0  # 카운터
valid_max_accuracy = -1

for epoch in range(EPOCHS):
    ret = train_one_epoch(trn_loader, model, optimizer, loss_fn, device=device)
    # ret['epoch'] = epoch + 1
    
    retv = val_one_epoch(val_loader, model, optimizer, loss_fn, device=device)
    # retv['epoch'] = epoch + 1

    # log = ""
    # for k, v in ret.items():
    #   log += f"{k}: {v:.4f}\n"
    # print(log)
    
    ##################################################
    if retv['val_acc'] > valid_max_accuracy:
        valid_max_accuracy = retv['val_acc']

    # validation loss가 감소하면 모델 저장 및 카운터 리셋
    if retv['valid_loss'] < best_valid_loss:
        best_valid_loss = retv['valid_loss']
        torch.save(model.state_dict(), f"./model_{model_name}.pt")
        early_stop_counter = 0

    # validation loss가 증가하거나 같으면 카운터 증가
    else:
        early_stop_counter += 1

    print(f"Epoch [{epoch + 1}/{EPOCHS}]")
    print(f"Train Loss: {ret['train_loss']:.4f}, Train Accuracy: {ret['train_acc']:.4f} Train F1: {ret['train_f1']:.4f}")
    print(f"Valid Loss: {retv['valid_loss']:.4f}, Valid Accuracy: {retv['val_acc']:.4f} Valid F1: {retv['val_f1']:.4f}")
    print('-'*80)

    # 조기 종료 카운터가 설정한 patience를 초과하면 학습 종료
    if early_stop_counter >= patience:
        print("Early stopping")
        break

Train Loss: 0.1723: 100%|██████████| 1610/1610 [04:36<00:00,  5.82it/s]
Val Loss: 0.2736: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [1/50]
Train Loss: 0.3403, Train Accuracy: 0.8829 Train F1: 0.8761
Valid Loss: 0.1894, Valid Accuracy: 0.9289 Valid F1: 0.9274
--------------------------------------------------------------------------------


Train Loss: 0.4879: 100%|██████████| 1610/1610 [04:35<00:00,  5.84it/s]
Val Loss: 0.0961: 100%|██████████| 403/403 [00:37<00:00, 10.86it/s]


Epoch [2/50]
Train Loss: 0.1880, Train Accuracy: 0.9328 Train F1: 0.9294
Valid Loss: 0.1388, Valid Accuracy: 0.9468 Valid F1: 0.9455
--------------------------------------------------------------------------------


Train Loss: 0.0572: 100%|██████████| 1610/1610 [04:35<00:00,  5.85it/s]
Val Loss: 0.4256: 100%|██████████| 403/403 [00:37<00:00, 10.80it/s]


Epoch [3/50]
Train Loss: 0.1585, Train Accuracy: 0.9447 Train F1: 0.9424
Valid Loss: 0.1768, Valid Accuracy: 0.9348 Valid F1: 0.9294
--------------------------------------------------------------------------------


Train Loss: 0.1861: 100%|██████████| 1610/1610 [04:35<00:00,  5.85it/s]
Val Loss: 0.2370: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [4/50]
Train Loss: 0.1362, Train Accuracy: 0.9525 Train F1: 0.9504
Valid Loss: 0.1677, Valid Accuracy: 0.9415 Valid F1: 0.9382
--------------------------------------------------------------------------------


Train Loss: 0.1670: 100%|██████████| 1610/1610 [04:35<00:00,  5.84it/s]
Val Loss: 0.4419: 100%|██████████| 403/403 [00:37<00:00, 10.80it/s]


Epoch [5/50]
Train Loss: 0.1122, Train Accuracy: 0.9609 Train F1: 0.9593
Valid Loss: 0.1648, Valid Accuracy: 0.9439 Valid F1: 0.9410
--------------------------------------------------------------------------------


Train Loss: 0.0692: 100%|██████████| 1610/1610 [04:35<00:00,  5.85it/s]
Val Loss: 0.1547: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [6/50]
Train Loss: 0.0968, Train Accuracy: 0.9674 Train F1: 0.9662
Valid Loss: 0.0948, Valid Accuracy: 0.9672 Valid F1: 0.9643
--------------------------------------------------------------------------------


Train Loss: 1.4615: 100%|██████████| 1610/1610 [04:35<00:00,  5.85it/s]
Val Loss: 0.2994: 100%|██████████| 403/403 [00:37<00:00, 10.87it/s]


Epoch [7/50]
Train Loss: 0.0873, Train Accuracy: 0.9711 Train F1: 0.9699
Valid Loss: 0.1138, Valid Accuracy: 0.9619 Valid F1: 0.9596
--------------------------------------------------------------------------------


Train Loss: 0.0370: 100%|██████████| 1610/1610 [04:35<00:00,  5.85it/s]
Val Loss: 0.0442: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [8/50]
Train Loss: 0.0763, Train Accuracy: 0.9751 Train F1: 0.9744
Valid Loss: 0.0660, Valid Accuracy: 0.9762 Valid F1: 0.9756
--------------------------------------------------------------------------------


Train Loss: 0.1115: 100%|██████████| 1610/1610 [04:34<00:00,  5.86it/s]
Val Loss: 0.2891: 100%|██████████| 403/403 [00:37<00:00, 10.79it/s]


Epoch [9/50]
Train Loss: 0.0733, Train Accuracy: 0.9761 Train F1: 0.9754
Valid Loss: 0.0639, Valid Accuracy: 0.9796 Valid F1: 0.9764
--------------------------------------------------------------------------------


Train Loss: 0.2377: 100%|██████████| 1610/1610 [04:34<00:00,  5.86it/s]
Val Loss: 0.0213: 100%|██████████| 403/403 [00:37<00:00, 10.80it/s]


Epoch [10/50]
Train Loss: 0.0648, Train Accuracy: 0.9799 Train F1: 0.9791
Valid Loss: 0.0944, Valid Accuracy: 0.9702 Valid F1: 0.9679
--------------------------------------------------------------------------------


Train Loss: 0.1644: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0707: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [11/50]
Train Loss: 0.0643, Train Accuracy: 0.9798 Train F1: 0.9792
Valid Loss: 0.0567, Valid Accuracy: 0.9814 Valid F1: 0.9807
--------------------------------------------------------------------------------


Train Loss: 0.0170: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0165: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [12/50]
Train Loss: 0.0564, Train Accuracy: 0.9823 Train F1: 0.9815
Valid Loss: 0.0545, Valid Accuracy: 0.9818 Valid F1: 0.9811
--------------------------------------------------------------------------------


Train Loss: 0.0778: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0417: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [13/50]
Train Loss: 0.0572, Train Accuracy: 0.9824 Train F1: 0.9820
Valid Loss: 0.0420, Valid Accuracy: 0.9862 Valid F1: 0.9857
--------------------------------------------------------------------------------


Train Loss: 0.1749: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0789: 100%|██████████| 403/403 [00:37<00:00, 10.75it/s]


Epoch [14/50]
Train Loss: 0.0532, Train Accuracy: 0.9842 Train F1: 0.9836
Valid Loss: 0.0617, Valid Accuracy: 0.9789 Valid F1: 0.9775
--------------------------------------------------------------------------------


Train Loss: 0.5219: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.1208: 100%|██████████| 403/403 [00:37<00:00, 10.79it/s]


Epoch [15/50]
Train Loss: 0.0544, Train Accuracy: 0.9828 Train F1: 0.9822
Valid Loss: 0.1211, Valid Accuracy: 0.9637 Valid F1: 0.9601
--------------------------------------------------------------------------------


Train Loss: 0.2792: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.1634: 100%|██████████| 403/403 [00:37<00:00, 10.80it/s]


Epoch [16/50]
Train Loss: 0.0515, Train Accuracy: 0.9846 Train F1: 0.9839
Valid Loss: 0.0743, Valid Accuracy: 0.9756 Valid F1: 0.9753
--------------------------------------------------------------------------------


Train Loss: 0.0042: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0364: 100%|██████████| 403/403 [00:37<00:00, 10.79it/s]


Epoch [17/50]
Train Loss: 0.0524, Train Accuracy: 0.9833 Train F1: 0.9826
Valid Loss: 0.0451, Valid Accuracy: 0.9852 Valid F1: 0.9845
--------------------------------------------------------------------------------


Train Loss: 0.0078: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.1335: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [18/50]
Train Loss: 0.0458, Train Accuracy: 0.9856 Train F1: 0.9850
Valid Loss: 0.0371, Valid Accuracy: 0.9875 Valid F1: 0.9871
--------------------------------------------------------------------------------


Train Loss: 0.3399: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0908: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [19/50]
Train Loss: 0.0474, Train Accuracy: 0.9857 Train F1: 0.9852
Valid Loss: 0.0585, Valid Accuracy: 0.9808 Valid F1: 0.9798
--------------------------------------------------------------------------------


Train Loss: 0.0353: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0232: 100%|██████████| 403/403 [00:37<00:00, 10.88it/s]


Epoch [20/50]
Train Loss: 0.0469, Train Accuracy: 0.9851 Train F1: 0.9845
Valid Loss: 0.0417, Valid Accuracy: 0.9866 Valid F1: 0.9856
--------------------------------------------------------------------------------


Train Loss: 0.0285: 100%|██████████| 1610/1610 [04:34<00:00,  5.88it/s]
Val Loss: 0.0166: 100%|██████████| 403/403 [00:37<00:00, 10.79it/s]


Epoch [21/50]
Train Loss: 0.0456, Train Accuracy: 0.9852 Train F1: 0.9847
Valid Loss: 0.0490, Valid Accuracy: 0.9839 Valid F1: 0.9832
--------------------------------------------------------------------------------


Train Loss: 0.0098: 100%|██████████| 1610/1610 [04:34<00:00,  5.88it/s]
Val Loss: 0.0226: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [22/50]
Train Loss: 0.0430, Train Accuracy: 0.9870 Train F1: 0.9865
Valid Loss: 0.0613, Valid Accuracy: 0.9803 Valid F1: 0.9786
--------------------------------------------------------------------------------


Train Loss: 0.3247: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0179: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [23/50]
Train Loss: 0.0461, Train Accuracy: 0.9857 Train F1: 0.9853
Valid Loss: 0.0642, Valid Accuracy: 0.9791 Valid F1: 0.9780
--------------------------------------------------------------------------------


Train Loss: 0.1496: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0558: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [24/50]
Train Loss: 0.0410, Train Accuracy: 0.9879 Train F1: 0.9872
Valid Loss: 0.0680, Valid Accuracy: 0.9773 Valid F1: 0.9753
--------------------------------------------------------------------------------


Train Loss: 0.1626: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0374: 100%|██████████| 403/403 [00:37<00:00, 10.80it/s]


Epoch [25/50]
Train Loss: 0.0445, Train Accuracy: 0.9859 Train F1: 0.9855
Valid Loss: 0.0554, Valid Accuracy: 0.9820 Valid F1: 0.9816
--------------------------------------------------------------------------------


Train Loss: 0.0509: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0185: 100%|██████████| 403/403 [00:37<00:00, 10.81it/s]


Epoch [26/50]
Train Loss: 0.0405, Train Accuracy: 0.9876 Train F1: 0.9871
Valid Loss: 0.0677, Valid Accuracy: 0.9772 Valid F1: 0.9764
--------------------------------------------------------------------------------


Train Loss: 0.0400: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0337: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [27/50]
Train Loss: 0.0435, Train Accuracy: 0.9869 Train F1: 0.9864
Valid Loss: 0.0310, Valid Accuracy: 0.9901 Valid F1: 0.9893
--------------------------------------------------------------------------------


Train Loss: 0.0380: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0566: 100%|██████████| 403/403 [00:37<00:00, 10.86it/s]


Epoch [28/50]
Train Loss: 0.0408, Train Accuracy: 0.9879 Train F1: 0.9874
Valid Loss: 0.0418, Valid Accuracy: 0.9858 Valid F1: 0.9852
--------------------------------------------------------------------------------


Train Loss: 0.0134: 100%|██████████| 1610/1610 [04:33<00:00,  5.89it/s]
Val Loss: 0.0642: 100%|██████████| 403/403 [00:37<00:00, 10.84it/s]


Epoch [29/50]
Train Loss: 0.0401, Train Accuracy: 0.9875 Train F1: 0.9873
Valid Loss: 0.0396, Valid Accuracy: 0.9876 Valid F1: 0.9871
--------------------------------------------------------------------------------


Train Loss: 0.0264: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0092: 100%|██████████| 403/403 [00:37<00:00, 10.81it/s]


Epoch [30/50]
Train Loss: 0.0395, Train Accuracy: 0.9883 Train F1: 0.9878
Valid Loss: 0.0406, Valid Accuracy: 0.9869 Valid F1: 0.9859
--------------------------------------------------------------------------------


Train Loss: 0.0029: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.2164: 100%|██████████| 403/403 [00:37<00:00, 10.81it/s]


Epoch [31/50]
Train Loss: 0.0389, Train Accuracy: 0.9883 Train F1: 0.9879
Valid Loss: 0.0432, Valid Accuracy: 0.9864 Valid F1: 0.9860
--------------------------------------------------------------------------------


Train Loss: 0.0469: 100%|██████████| 1610/1610 [04:33<00:00,  5.89it/s]
Val Loss: 0.0436: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [32/50]
Train Loss: 0.0379, Train Accuracy: 0.9887 Train F1: 0.9885
Valid Loss: 0.0554, Valid Accuracy: 0.9817 Valid F1: 0.9810
--------------------------------------------------------------------------------


Train Loss: 0.0427: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0389: 100%|██████████| 403/403 [00:37<00:00, 10.82it/s]


Epoch [33/50]
Train Loss: 0.0416, Train Accuracy: 0.9873 Train F1: 0.9868
Valid Loss: 0.0610, Valid Accuracy: 0.9809 Valid F1: 0.9798
--------------------------------------------------------------------------------


Train Loss: 0.0141: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0084: 100%|██████████| 403/403 [00:37<00:00, 10.79it/s]


Epoch [34/50]
Train Loss: 0.0394, Train Accuracy: 0.9883 Train F1: 0.9878
Valid Loss: 0.0413, Valid Accuracy: 0.9869 Valid F1: 0.9865
--------------------------------------------------------------------------------


Train Loss: 0.1459: 100%|██████████| 1610/1610 [04:34<00:00,  5.87it/s]
Val Loss: 0.0223: 100%|██████████| 403/403 [00:37<00:00, 10.83it/s]


Epoch [35/50]
Train Loss: 0.0384, Train Accuracy: 0.9887 Train F1: 0.9885
Valid Loss: 0.0512, Valid Accuracy: 0.9817 Valid F1: 0.9815
--------------------------------------------------------------------------------


Train Loss: 0.0141: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0256: 100%|██████████| 403/403 [00:37<00:00, 10.86it/s]


Epoch [36/50]
Train Loss: 0.0391, Train Accuracy: 0.9883 Train F1: 0.9879
Valid Loss: 0.0426, Valid Accuracy: 0.9866 Valid F1: 0.9851
--------------------------------------------------------------------------------


Train Loss: 0.3382: 100%|██████████| 1610/1610 [04:33<00:00,  5.88it/s]
Val Loss: 0.0125: 100%|██████████| 403/403 [00:37<00:00, 10.84it/s]

Epoch [37/50]
Train Loss: 0.0361, Train Accuracy: 0.9892 Train F1: 0.9887
Valid Loss: 0.0483, Valid Accuracy: 0.9849 Valid F1: 0.9848
--------------------------------------------------------------------------------
Early stopping





# 6. Inference & Save File
* 테스트 이미지에 대한 추론을 진행하고, 결과 파일을 저장합니다.

In [13]:
# Model 불러오기
new_model_parameters = torch.load(f"./model_{model_name}.pt")
model.load_state_dict(new_model_parameters) # parameter 정보를 model에 적용

<All keys matched successfully>

In [14]:
preds_list = []

model.eval()
for image, _ in tqdm(tst_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = model(image)
    preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

100%|██████████| 99/99 [00:15<00:00,  6.36it/s]


In [15]:
pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list

In [16]:
sample_submission_df = pd.read_csv("data/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()

In [17]:
from datetime import datetime

now = datetime.now()
pred_df.to_csv(f"pred14_eff_b4_aug62800_dc0002_Stop10_{now.strftime('%Y-%m-%d-%H%M%S')}.csv", index=False)

In [18]:
pred_df.head(20)

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,12
2,00396fbc1f6cc21d.jpg,5
3,00471f8038d9c4b6.jpg,12
4,00901f504008d884.jpg,2
5,009b22decbc7220c.jpg,15
6,00b33e0ee6d59427.jpg,0
7,00bbdcfbbdb3e131.jpg,8
8,00c03047e0fbef40.jpg,15
9,00c0dabb63ca7a16.jpg,11


In [16]:
pred_df.head(20)

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,12
2,00396fbc1f6cc21d.jpg,5
3,00471f8038d9c4b6.jpg,12
4,00901f504008d884.jpg,2
5,009b22decbc7220c.jpg,15
6,00b33e0ee6d59427.jpg,0
7,00bbdcfbbdb3e131.jpg,8
8,00c03047e0fbef40.jpg,15
9,00c0dabb63ca7a16.jpg,11


In [None]:
# 저장