In [None]:
import os
import time
import datetime

import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torch.utils.data import DataLoader
import wandb

from config.config import load_config
from utils.utils import *
from datasets.transforms import build_unified_transforms

from datasets import get_dataset
from models import get_model

from utils.EarlyStopping import EarlyStopping
from utils.optimizer_factory import get_optimizer
from utils.scheduler_factory import get_scheduler
from trainer.train_loop import training_loop
from trainer.wandb_logger import WandbLogger

In [2]:
# 시드를 고정합니다.
SEED = 42
set_seed(SEED)

seed 고정 완료!


In [None]:
cfg = load_config("config/main_config.yaml")
train_transform, val_transform = build_unified_transforms(cfg["transforms"]["train"]), build_unified_transforms(cfg["transforms"]["val"])

DatasetClass = get_dataset(cfg['DATASET'])
ModelClass = get_model(cfg['MODEL'])
cfg_scheduler = cfg["scheduler"]
cfg_optimizer = cfg["optimizer"]

In [4]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = './data'

# output config
output_root = './output'

# training config
num_workers = os.cpu_count() // 2
num_classes = 17
meta_df = pd.read_csv(f"{data_path}/meta_kr.csv")
class_names = meta_df["class_name"].tolist()

In [5]:
date = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")
filename = f"{cfg['MODEL']}_{date}"

# wandb
logger = WandbLogger(
    project_name="document-type-classification",
    run_name=filename,
    config=cfg,
    save_path=f"{output_root}/checkpoint.pth"
)

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /data/ephemeral/home/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mfkjy132[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


✅ 적용된 폰트: 'NanumGothic'
✅ PDF/PS 폰트 타입 42로 설정 완료 (한글 깨짐 방지)


In [6]:
# Dataset 정의
train_dataset = DatasetClass(
    f"{data_path}/train_valid_set/train-label-fix-v1.csv",
    f"{data_path}/train/",
    transform=train_transform
)
val_dataset = DatasetClass(
    f"{data_path}/train_valid_set/val-v1.csv",
    f"{data_path}/train/",
    transform=val_transform
)
print(len(train_dataset), len(val_dataset))

1255 315


In [7]:
# DataLoader 정의
train_loader = DataLoader(
    train_dataset,
    batch_size=cfg["BATCH_SIZE"],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=cfg["BATCH_SIZE"],
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False,
    persistent_workers=True
)

In [None]:
# load model
model: nn.Module = ModelClass(num_classes=num_classes).to(device)

# 1. 모델의 모든 파라미터를 우선 동결(freeze)합니다.
for param in model.parameters():
    param.requires_grad = False

# 2. 특징 추출기(backbone)의 마지막 2개 블록의 동결을 해제(unfreeze)합니다.
# efficientnet_b3는 7개의 블록(0~6)을 가집니다.
num_blocks_to_unfreeze = 2
for i in range(num_blocks_to_unfreeze):
    for param in model.backbone.blocks[-(i+1)].parameters():
        param.requires_grad = True

# 3. 분류기(head)의 동결을 해제합니다.
for param in model.head.parameters():
    param.requires_grad = True
    
# 4. 학습시킬 파라미터만 필터링하여 옵티마이저에 전달합니다.
# requires_grad=True인 파라미터만 업데이트됩니다.
params_to_update = filter(lambda p: p.requires_grad, model.parameters())


early_stopping = EarlyStopping(patience=cfg["patience"], delta=cfg["delta"], verbose=True, save_path=f'{output_root}/checkpoint.pth')

# 손실 함수
criterion = nn.CrossEntropyLoss()

# 옵티마이저
optimizer = get_optimizer(cfg_optimizer["name"], model.parameters(), cfg_optimizer["params"])
# optimizer = get_optimizer(cfg_optimizer["name"], params_to_update, cfg_optimizer["params"])

# 스케쥴러
Scheduler = get_scheduler(cfg_scheduler["name"], optimizer, cfg_scheduler['params'])

model, valid_max_accuracy = training_loop(model, train_loader, val_loader, train_dataset, val_dataset, criterion, optimizer, device, cfg["EPOCHS"], early_stopping, logger, class_names, Scheduler)

Epoch [1/100], Train Loss: 2.4129: 100%|██████████| 79/79 [04:31<00:00,  3.44s/it]  
Epoch [1/100], Valid Loss: 1.5527: 100%|██████████| 20/20 [00:59<00:00,  2.99s/it]


  ✅ Validation loss improved. Saving model...
Epoch [1/100]
Train Loss: 2.2065, Train Accuracy: 0.4375, Train f1: 0.3735941857291823
Valid Loss: 1.7674, Valid Accuracy: 0.4889, Valid f1: 0.43185355122945435
Current LR: 0.00009758


Epoch [2/100], Train Loss: 1.4025: 100%|██████████| 79/79 [04:25<00:00,  3.36s/it]  
Epoch [2/100], Valid Loss: 1.8165: 100%|██████████| 20/20 [00:55<00:00,  2.76s/it]


  ✅ Validation loss improved. Saving model...
Epoch [2/100]
Train Loss: 1.1019, Train Accuracy: 0.6821, Train f1: 0.6151495665693703
Valid Loss: 1.4045, Valid Accuracy: 0.5619, Valid f1: 0.4849615010924865
Current LR: 0.00009055


Epoch [3/100], Train Loss: 0.8733: 100%|██████████| 79/79 [04:24<00:00,  3.35s/it]  
Epoch [3/100], Valid Loss: 0.6819: 100%|██████████| 20/20 [00:57<00:00,  2.86s/it]


  ✅ Validation loss improved. Saving model...
Epoch [3/100]
Train Loss: 0.6748, Train Accuracy: 0.7928, Train f1: 0.7493872047152701
Valid Loss: 1.0614, Valid Accuracy: 0.6730, Valid f1: 0.6121873785869021
Current LR: 0.00007960


Epoch [4/100], Train Loss: 0.8111: 100%|██████████| 79/79 [04:29<00:00,  3.41s/it]  
Epoch [4/100], Valid Loss: 1.0775: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


  ✅ Validation loss improved. Saving model...
Epoch [4/100]
Train Loss: 0.5047, Train Accuracy: 0.8327, Train f1: 0.7977290070643736
Valid Loss: 0.8886, Valid Accuracy: 0.6825, Valid f1: 0.6294141081096463
Current LR: 0.00006580


Epoch [5/100], Train Loss: 0.2577: 100%|██████████| 79/79 [04:23<00:00,  3.33s/it]  
Epoch [5/100], Valid Loss: 1.8878: 100%|██████████| 20/20 [00:57<00:00,  2.88s/it]


  ✅ Validation loss improved. Saving model...
Epoch [5/100]
Train Loss: 0.4295, Train Accuracy: 0.8422, Train f1: 0.8195096244784446
Valid Loss: 0.8621, Valid Accuracy: 0.7365, Valid f1: 0.6914490172850715
Current LR: 0.00005050


Epoch [6/100], Train Loss: 0.2721: 100%|██████████| 79/79 [04:25<00:00,  3.36s/it]  
Epoch [6/100], Valid Loss: 0.7130: 100%|██████████| 20/20 [00:57<00:00,  2.89s/it]


  ✅ Validation loss improved. Saving model...
Epoch [6/100]
Train Loss: 0.3723, Train Accuracy: 0.8645, Train f1: 0.8410866116820318
Valid Loss: 0.7961, Valid Accuracy: 0.7556, Valid f1: 0.7113769145592281
Current LR: 0.00003520


Epoch [7/100], Train Loss: 0.0996: 100%|██████████| 79/79 [04:18<00:00,  3.28s/it]  
Epoch [7/100], Valid Loss: 0.8385: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


  ↪️ No improvement. EarlyStopping counter: 1/10
Epoch [7/100]
Train Loss: 0.3348, Train Accuracy: 0.8932, Train f1: 0.872431325378357
Valid Loss: 0.8331, Valid Accuracy: 0.7238, Valid f1: 0.6823344326590107
Current LR: 0.00002140


Epoch [8/100], Train Loss: 0.6107: 100%|██████████| 79/79 [04:27<00:00,  3.38s/it]  
Epoch [8/100], Valid Loss: 1.5385: 100%|██████████| 20/20 [00:57<00:00,  2.87s/it]


  ↪️ No improvement. EarlyStopping counter: 2/10
Epoch [8/100]
Train Loss: 0.3056, Train Accuracy: 0.8932, Train f1: 0.8806902177514631
Valid Loss: 0.8463, Valid Accuracy: 0.7206, Valid f1: 0.6833077236978355
Current LR: 0.00001045


Epoch [9/100], Train Loss: 0.1541: 100%|██████████| 79/79 [04:25<00:00,  3.36s/it]  
Epoch [9/100], Valid Loss: 0.5305: 100%|██████████| 20/20 [00:54<00:00,  2.75s/it]


  ✅ Validation loss improved. Saving model...
Epoch [9/100]
Train Loss: 0.3171, Train Accuracy: 0.8980, Train f1: 0.8856337208105752
Valid Loss: 0.6777, Valid Accuracy: 0.7746, Valid f1: 0.727568585859798
Current LR: 0.00000342


Epoch [10/100], Train Loss: 1.0061: 100%|██████████| 79/79 [04:32<00:00,  3.45s/it]  
Epoch [10/100], Valid Loss: 1.1944: 100%|██████████| 20/20 [00:57<00:00,  2.88s/it]


  ↪️ No improvement. EarlyStopping counter: 1/10
Epoch [10/100]
Train Loss: 0.3317, Train Accuracy: 0.8884, Train f1: 0.8705404712908728
Valid Loss: 0.8318, Valid Accuracy: 0.7397, Valid f1: 0.6915810281409082
Current LR: 0.00010000


Epoch [11/100], Train Loss: 0.6819: 100%|██████████| 79/79 [04:28<00:00,  3.40s/it]  
Epoch [11/100], Valid Loss: 0.3487: 100%|██████████| 20/20 [00:59<00:00,  3.00s/it]


  ↪️ No improvement. EarlyStopping counter: 2/10
Epoch [11/100]
Train Loss: 0.3059, Train Accuracy: 0.8924, Train f1: 0.8825468660807672
Valid Loss: 0.6965, Valid Accuracy: 0.7619, Valid f1: 0.7250257092258924
Current LR: 0.00009939


Epoch [12/100], Train Loss: 0.1336: 100%|██████████| 79/79 [04:18<00:00,  3.27s/it]  
Epoch [12/100], Valid Loss: 0.4404: 100%|██████████| 20/20 [00:58<00:00,  2.91s/it]


  ↪️ No improvement. EarlyStopping counter: 3/10
Epoch [12/100]
Train Loss: 0.3005, Train Accuracy: 0.8940, Train f1: 0.8829126909628291
Valid Loss: 0.7236, Valid Accuracy: 0.7651, Valid f1: 0.7342819624259297
Current LR: 0.00009758


Epoch [13/100], Train Loss: 0.9527: 100%|██████████| 79/79 [04:21<00:00,  3.31s/it]  
Epoch [13/100], Valid Loss: 1.1181: 100%|██████████| 20/20 [00:56<00:00,  2.84s/it]


  ↪️ No improvement. EarlyStopping counter: 4/10
Epoch [13/100]
Train Loss: 0.3027, Train Accuracy: 0.8924, Train f1: 0.886746697060083
Valid Loss: 0.7327, Valid Accuracy: 0.7556, Valid f1: 0.7234456850370471
Current LR: 0.00009460


Epoch [14/100], Train Loss: 0.5508: 100%|██████████| 79/79 [04:21<00:00,  3.31s/it]  
Epoch [14/100], Valid Loss: 0.4983: 100%|██████████| 20/20 [00:54<00:00,  2.72s/it]


  ↪️ No improvement. EarlyStopping counter: 5/10
Epoch [14/100]
Train Loss: 0.2685, Train Accuracy: 0.9108, Train f1: 0.9027263691889961
Valid Loss: 0.6702, Valid Accuracy: 0.7778, Valid f1: 0.7482602777715213
Current LR: 0.00009055


Epoch [15/100], Train Loss: 0.4119: 100%|██████████| 79/79 [04:28<00:00,  3.40s/it]  
Epoch [15/100], Valid Loss: 0.6636: 100%|██████████| 20/20 [00:58<00:00,  2.95s/it]


  ✅ Validation loss improved. Saving model...
Epoch [15/100]
Train Loss: 0.2167, Train Accuracy: 0.9195, Train f1: 0.9152946917645293
Valid Loss: 0.5768, Valid Accuracy: 0.8032, Valid f1: 0.7724679353200735
Current LR: 0.00008550


Epoch [16/100], Train Loss: 0.2964: 100%|██████████| 79/79 [04:23<00:00,  3.34s/it]  
Epoch [16/100], Valid Loss: 0.6890: 100%|██████████| 20/20 [01:01<00:00,  3.06s/it]


  ↪️ No improvement. EarlyStopping counter: 1/10
Epoch [16/100]
Train Loss: 0.1999, Train Accuracy: 0.9291, Train f1: 0.9229340400191794
Valid Loss: 0.7037, Valid Accuracy: 0.7841, Valid f1: 0.7438687972926402
Current LR: 0.00007960


Epoch [17/100], Train Loss: 0.3382: 100%|██████████| 79/79 [04:26<00:00,  3.38s/it]  
Epoch [17/100], Valid Loss: 0.8014: 100%|██████████| 20/20 [00:57<00:00,  2.87s/it]


  ↪️ No improvement. EarlyStopping counter: 2/10
Epoch [17/100]
Train Loss: 0.1713, Train Accuracy: 0.9394, Train f1: 0.9353672454712026
Valid Loss: 0.6879, Valid Accuracy: 0.8000, Valid f1: 0.7713393518993449
Current LR: 0.00007297


Epoch [18/100], Train Loss: 0.0141: 100%|██████████| 79/79 [04:12<00:00,  3.19s/it] 
Epoch [18/100], Valid Loss: 0.8212: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


  ↪️ No improvement. EarlyStopping counter: 3/10
Epoch [18/100]
Train Loss: 0.1731, Train Accuracy: 0.9339, Train f1: 0.928834945524712
Valid Loss: 0.7449, Valid Accuracy: 0.7619, Valid f1: 0.7401948949487995
Current LR: 0.00006580


Epoch [19/100], Train Loss: 0.1647: 100%|██████████| 79/79 [04:23<00:00,  3.34s/it]  
Epoch [19/100], Valid Loss: 0.8294: 100%|██████████| 20/20 [00:57<00:00,  2.89s/it]


  ↪️ No improvement. EarlyStopping counter: 4/10
Epoch [19/100]
Train Loss: 0.1564, Train Accuracy: 0.9482, Train f1: 0.9445373127110787
Valid Loss: 0.6376, Valid Accuracy: 0.7873, Valid f1: 0.7607243854242469
Current LR: 0.00005824


Epoch [20/100], Train Loss: 0.1736: 100%|██████████| 79/79 [04:30<00:00,  3.43s/it]  
Epoch [20/100], Valid Loss: 1.2778: 100%|██████████| 20/20 [00:56<00:00,  2.82s/it]


  ↪️ No improvement. EarlyStopping counter: 5/10
Epoch [20/100]
Train Loss: 0.1488, Train Accuracy: 0.9474, Train f1: 0.9444919278001753
Valid Loss: 0.6153, Valid Accuracy: 0.8190, Valid f1: 0.8040889369721427
Current LR: 0.00005050


Epoch [21/100], Train Loss: 0.1712: 100%|██████████| 79/79 [04:23<00:00,  3.33s/it]  
Epoch [21/100], Valid Loss: 0.6177: 100%|██████████| 20/20 [00:56<00:00,  2.82s/it]


  ✅ Validation loss improved. Saving model...
Epoch [21/100]
Train Loss: 0.1475, Train Accuracy: 0.9466, Train f1: 0.9443971566202315
Valid Loss: 0.5532, Valid Accuracy: 0.8286, Valid f1: 0.8072891583872777
Current LR: 0.00004276


Epoch [22/100], Train Loss: 0.0531: 100%|██████████| 79/79 [04:25<00:00,  3.36s/it]  
Epoch [22/100], Valid Loss: 0.0904: 100%|██████████| 20/20 [00:54<00:00,  2.72s/it]


  ↪️ No improvement. EarlyStopping counter: 1/10
Epoch [22/100]
Train Loss: 0.1203, Train Accuracy: 0.9594, Train f1: 0.9571341526963981
Valid Loss: 0.6060, Valid Accuracy: 0.8159, Valid f1: 0.8010094282834002
Current LR: 0.00003520


Epoch [23/100], Train Loss: 0.2394: 100%|██████████| 79/79 [04:23<00:00,  3.34s/it]  
Epoch [23/100], Valid Loss: 0.0992: 100%|██████████| 20/20 [00:55<00:00,  2.79s/it]


  ↪️ No improvement. EarlyStopping counter: 2/10
Epoch [23/100]
Train Loss: 0.1086, Train Accuracy: 0.9649, Train f1: 0.9635050532138815
Valid Loss: 0.6643, Valid Accuracy: 0.7937, Valid f1: 0.7804984270306425
Current LR: 0.00002803


Epoch [24/100], Train Loss: 0.0087: 100%|██████████| 79/79 [04:20<00:00,  3.30s/it]  
Epoch [24/100], Valid Loss: 1.0649: 100%|██████████| 20/20 [00:54<00:00,  2.70s/it]


  ↪️ No improvement. EarlyStopping counter: 3/10
Epoch [24/100]
Train Loss: 0.1092, Train Accuracy: 0.9665, Train f1: 0.9649947618875162
Valid Loss: 0.7601, Valid Accuracy: 0.7937, Valid f1: 0.7536609968811661
Current LR: 0.00002140


Epoch [25/100], Train Loss: 0.0200: 100%|██████████| 79/79 [04:24<00:00,  3.35s/it]  
Epoch [25/100], Valid Loss: 0.1419: 100%|██████████| 20/20 [00:57<00:00,  2.88s/it]


  ↪️ No improvement. EarlyStopping counter: 4/10
Epoch [25/100]
Train Loss: 0.1010, Train Accuracy: 0.9689, Train f1: 0.967166795112225
Valid Loss: 0.6710, Valid Accuracy: 0.7905, Valid f1: 0.7587770057348325
Current LR: 0.00001550


Epoch [26/100], Train Loss: 0.0306: 100%|██████████| 79/79 [04:19<00:00,  3.28s/it]  
Epoch [26/100], Valid Loss: 0.1482: 100%|██████████| 20/20 [00:56<00:00,  2.83s/it]


  ↪️ No improvement. EarlyStopping counter: 5/10
Epoch [26/100]
Train Loss: 0.0966, Train Accuracy: 0.9745, Train f1: 0.9733583434991698
Valid Loss: 0.6462, Valid Accuracy: 0.8032, Valid f1: 0.79182177585485
Current LR: 0.00001045


Epoch [27/100], Train Loss: 0.4152: 100%|██████████| 79/79 [04:29<00:00,  3.41s/it]  
Epoch [27/100], Valid Loss: 1.5689: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


  ↪️ No improvement. EarlyStopping counter: 6/10
Epoch [27/100]
Train Loss: 0.1069, Train Accuracy: 0.9649, Train f1: 0.9626297632943988
Valid Loss: 0.6657, Valid Accuracy: 0.8000, Valid f1: 0.7755470395589347
Current LR: 0.00000640


Epoch [28/100], Train Loss: 0.0722: 100%|██████████| 79/79 [04:24<00:00,  3.35s/it]  
Epoch [28/100], Valid Loss: 0.2431: 100%|██████████| 20/20 [00:55<00:00,  2.80s/it]


  ↪️ No improvement. EarlyStopping counter: 7/10
Epoch [28/100]
Train Loss: 0.1040, Train Accuracy: 0.9689, Train f1: 0.9651045460935258
Valid Loss: 0.6196, Valid Accuracy: 0.8159, Valid f1: 0.8060987801463346
Current LR: 0.00000342


Epoch [29/100], Train Loss: 0.5965: 100%|██████████| 79/79 [04:20<00:00,  3.29s/it]  
Epoch [29/100], Valid Loss: 0.4320: 100%|██████████| 20/20 [00:59<00:00,  2.98s/it]


  ↪️ No improvement. EarlyStopping counter: 8/10
Epoch [29/100]
Train Loss: 0.0924, Train Accuracy: 0.9689, Train f1: 0.9668890905196079
Valid Loss: 0.5545, Valid Accuracy: 0.8286, Valid f1: 0.8136813728888397
Current LR: 0.00000161


Epoch [30/100], Train Loss: 0.4242: 100%|██████████| 79/79 [04:25<00:00,  3.35s/it]  
Epoch [30/100], Valid Loss: 0.0440: 100%|██████████| 20/20 [00:58<00:00,  2.94s/it]


  ↪️ No improvement. EarlyStopping counter: 9/10
Epoch [30/100]
Train Loss: 0.0960, Train Accuracy: 0.9665, Train f1: 0.9634987659362285
Valid Loss: 0.5600, Valid Accuracy: 0.8095, Valid f1: 0.7893058772517595
Current LR: 0.00010000


Epoch [31/100], Train Loss: 0.1204: 100%|██████████| 79/79 [04:22<00:00,  3.33s/it]  
Epoch [31/100], Valid Loss: 0.3688: 100%|██████████| 20/20 [00:54<00:00,  2.73s/it]


  ↪️ No improvement. EarlyStopping counter: 10/10
Epoch [31/100]
Train Loss: 0.1224, Train Accuracy: 0.9610, Train f1: 0.9594282446493352
Valid Loss: 0.6159, Valid Accuracy: 0.8159, Valid f1: 0.7906598194029503
Current LR: 0.00009985
🛑 Early stopping at epoch 31


0,1
lr,█▇▇▆▄▃▂▂▁████▇▇▇▆▆▅▄▄▃▃▂▂▂▁▁▁██
train/acc,▁▄▆▆▆▇▇▇▇▇▇▇▇▇▇▇█▇█████████████
train/f1,▁▄▅▆▆▆▇▇▇▇▇▇▇▇▇▇█▇█████████████
train/loss,█▄▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/acc,▁▃▅▅▆▆▆▆▇▆▇▇▆▇▇▇▇▇▇███▇▇▇▇▇████
val/f1,▁▂▄▅▆▆▆▆▆▆▆▇▆▇▇▇▇▇▇███▇▇▇█▇████
val/loss,█▆▄▃▃▂▃▃▂▃▂▂▂▂▁▂▂▂▁▁▁▁▂▂▂▂▂▁▁▁▁

0,1
lr,0.0001
train/acc,0.96096
train/f1,0.95943
train/loss,0.12236
val/acc,0.81587
val/f1,0.79066
val/loss,0.61593


# 6. Inference & Save File
* 테스트 이미지에 대한 추론을 진행하고, 결과 파일을 저장합니다.

In [9]:
import os

import pandas as pd
import torch
from tqdm import tqdm
from torch.utils.data import DataLoader

from config.config import load_config
from models import get_model
from datasets import get_dataset
from datasets.transforms import build_unified_transforms

In [10]:
cfg = load_config("config/main_config.yaml")
ModelClass = get_model(cfg['MODEL'])
DatasetClass = get_dataset(cfg['DATASET'])
num_classes = 17
num_workers = os.cpu_count() // 2
output_root = './output'
data_path = './data'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

test_transform = build_unified_transforms(cfg["transforms"]["test"])

tst_dataset = DatasetClass(
    f"{data_path}/sample_submission.csv",
    f"{data_path}/test/",
    transform=test_transform
)

tst_loader = DataLoader(
    tst_dataset,
    batch_size=cfg["BATCH_SIZE"],
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
    persistent_workers=True
)

In [11]:
preds_list = []
model = ModelClass(num_classes).to(device)
model.load_state_dict(torch.load(f"{output_root}/checkpoint.pth", map_location="cpu"))

model.eval()
for image, _, _ in tqdm(tst_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = model(image)
    preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

100%|██████████| 197/197 [00:22<00:00,  8.67it/s]


In [12]:
pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list

In [13]:
sample_submission_df = pd.read_csv(f"{data_path}/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()

In [14]:
pred_df.to_csv("pred.csv", index=False)

In [None]:
pred_df.head()

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,6
2,00396fbc1f6cc21d.jpg,5
3,00471f8038d9c4b6.jpg,6
4,00901f504008d884.jpg,2


: 