## Contents
- Import Library & Define Functions
- Hyper-parameters
- Load Data
- Train Model
- Inference & Save File


## Import Library & Define Functions
* 학습 및 추론에 필요한 라이브러리를 로드합니다.
* 학습 및 추론에 필요한 함수와 클래스를 정의합니다.

In [None]:
import os
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from copy import deepcopy

from datetime import datetime
import time
from zoneinfo import ZoneInfo
import wandb

In [None]:
train_time = datetime.fromtimestamp(time.time(), tz=ZoneInfo("Asia/Seoul")).strftime("%Y%m%d-%H%M%S")
train_time

wandb.init(project="document-classification", name=f"run-{train_time}")

In [None]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [None]:
# 데이터셋 클래스를 정의합니다.
class ImageDataset(Dataset):
    def __init__(self, csv_file, img_dir, transform=None, augraphy_pipeline=None):
        self.data = pd.read_csv(csv_file)
        self.img_dir = img_dir
        self.transform = transform
        self.augraphy_pipeline = augraphy_pipeline

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name = os.path.join(self.img_dir, self.data.iloc[idx, 0])
        image = np.array(Image.open(img_name).convert('RGB'))
        label = self.data.iloc[idx, 1]

        if self.augraphy_pipeline:
            image = self.augraphy_pipeline(image)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

In [None]:
def validate(val_loader, model, loss_fn, device):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = loss_fn(outputs, labels)
            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return val_loss / len(val_loader), correct / total

## Hyper-parameters
* 학습 및 추론에 필요한 하이퍼파라미터들을 정의합니다.

In [None]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = 'data/'

# model config
model_name = 'efficientnet_b0'

# training config
img_size = 224
LR = 1e-3
EPOCHS = 1
BATCH_SIZE = 32
num_workers = 16
PATIENCE = 10
FOLD = 2
CLASS = 17

# 설정 로깅
wandb.config.update({
    "model": model_name,
    "img_size": img_size,
    "learning_rate": LR,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "num_workers": num_workers,
    'patience': PATIENCE,
    'fold': FOLD
})

## Load Data
* 학습, 테스트 데이터셋과 로더를 정의합니다.

In [None]:
import augraphy
from augraphy import *
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image

# augmentation을 위한 transform 코드
def get_augraphy_pipeline():
    return AugraphyPipeline([
        BleedThrough(p=0.5),
        DirtyRollers(p=0.5),
        InkBleed(p=0.5),
        Faxify(p=0.3),
        NoiseTexturize(p=0.5),
        Letterpress(p=0.5),
        LowInkPeriodicLines(p=0.5),
        LowInkRandomLines(p=0.5),
        Folding(p=0.5),
        Markup(p=0.3),  # PencilScribbles 대신
        Stains(p=0.3),  # Watermark 대신
        ])

def get_train_transforms(height, width):
    return A.Compose([
        A.RandomResizedCrop(height=height, width=width, scale=(0.8, 1.0), ratio=(0.75, 1.3333333333333333)),
        A.OneOf([
            A.RandomRotate90(p=0.5),
            A.Rotate(limit=180, p=0.5),
        ], p=0.7),
        A.Flip(p=0.5),
        A.RandomBrightnessContrast(brightness_limit=(-0.2, 0.2), contrast_limit=(-0.2, 0.2), p=0.7),
        A.GaussNoise(var_limit=(10.0, 150.0), p=0.5),
        A.GaussianBlur(blur_limit=(3, 15), p=0.5),
        A.OneOf([
            A.OpticalDistortion(distort_limit=0.1, shift_limit=0.1, p=1.0),
            A.GridDistortion(num_steps=5, distort_limit=0.1, p=1.0),
        ], p=0.5),
        A.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

def get_pred_transforms(height, width):
    return A.Compose([
        A.Resize(height=height, width=width),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

In [None]:
# Dataset 정의
full_dataset = ImageDataset(
        'data/train.csv',
        'data/train/',
        transform=get_train_transforms(img_size, img_size),
        augraphy_pipeline=get_augraphy_pipeline()
    )

pred_dataset = ImageDataset(
    "data/sample_submission.csv",
    "data/test/",
    transform=get_pred_transforms(img_size, img_size)
)

## Train Model
* 모델을 로드하고, 학습을 진행합니다.

In [None]:
# load model
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=CLASS
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)

In [None]:
from sklearn.model_selection import StratifiedKFold
from torch.utils.data import Subset, DataLoader

def create_kfold_loaders(dataset, batch_size, num_workers, n_splits=5):
    labels = [dataset[i][1] for i in range(len(dataset))]
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    
    train_loaders = []
    val_loaders = []
    
    for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(labels)), labels)):
        train_subset = Subset(dataset, train_idx)
        val_subset = Subset(dataset, val_idx)
        
        train_loader = DataLoader(
            train_subset,
            batch_size=batch_size,
            shuffle=True,
            num_workers=num_workers,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_subset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=num_workers,
            pin_memory=True
        )
        
        train_loaders.append(train_loader)
        val_loaders.append(val_loader)
        
        print(f"Fold {fold + 1}:")
        print(f"  Training samples: {len(train_subset)}")
        print(f"  Validation samples: {len(val_subset)}")
    
    return train_loaders, val_loaders

train_loaders, val_loaders = create_kfold_loaders(full_dataset, BATCH_SIZE, num_workers, n_splits=FOLD)

In [None]:
def calculate_metrics(y_true, y_pred):
    accuracy = (y_true == y_pred).mean()
    macro_f1 = f1_score(y_true, y_pred, average='macro')
    return accuracy, macro_f1

def train_one_fold(fold, train_loader, val_loader, model, criterion, optimizer, device, num_epochs, patience):
    best_val_f1 = 0
    early_stopping_counter = 0
    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []
    train_f1_scores, val_f1_scores = [], []

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_preds, train_labels = [], []

        for inputs, labels in tqdm(train_loader, desc=f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} - Training"):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            train_preds.extend(predicted.cpu().numpy())
            train_labels.extend(labels.cpu().numpy())

        train_loss /= len(train_loader.dataset)
        train_accuracy, train_f1 = calculate_metrics(np.array(train_labels), np.array(train_preds))

        model.eval()
        val_loss = 0.0
        val_preds, val_labels = [], []

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc=f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs} - Validation"):
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * inputs.size(0)
                _, predicted = torch.max(outputs, 1)
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(labels.cpu().numpy())

        val_loss /= len(val_loader.dataset)
        val_accuracy, val_f1 = calculate_metrics(np.array(val_labels), np.array(val_preds))

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accuracies.append(train_accuracy)
        val_accuracies.append(val_accuracy)
        train_f1_scores.append(train_f1)
        val_f1_scores.append(val_f1)

        print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}, Train F1: {train_f1:.4f}")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}, Val F1: {val_f1:.4f}")

        wandb.log({
            f"fold_{fold+1}/train_loss": train_loss,
            f"fold_{fold+1}/train_accuracy": train_accuracy,
            f"fold_{fold+1}/train_f1": train_f1,
            f"fold_{fold+1}/val_loss": val_loss,
            f"fold_{fold+1}/val_accuracy": val_accuracy,
            f"fold_{fold+1}/val_f1": val_f1,
            "epoch": epoch
        })

        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            torch.save(model.state_dict(), f'model/{train_time}_fold{fold+1}_best.pt')
            early_stopping_counter = 0
        else:
            early_stopping_counter += 1

        if early_stopping_counter >= patience:
            print(f"Early stopping triggered after epoch {epoch+1}")
            break

    return train_losses, val_losses, train_accuracies, val_accuracies, train_f1_scores, val_f1_scores

def train_kfold(train_loaders, val_loaders, model_name, num_classes, device, num_epochs, patience):
    all_train_losses, all_val_losses = [], []
    all_train_accuracies, all_val_accuracies = [], []
    all_train_f1_scores, all_val_f1_scores = [], []

    for fold, (train_loader, val_loader) in enumerate(zip(train_loaders, val_loaders)):
        print(f"\nTraining Fold {fold+1}")

        model = timm.create_model(model_name, pretrained=True, num_classes=num_classes).to(device)
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=LR)

        train_losses, val_losses, train_accuracies, val_accuracies, train_f1_scores, val_f1_scores = train_one_fold(
            fold, train_loader, val_loader, model, criterion, optimizer, device, num_epochs, patience
        )

        all_train_losses.append(train_losses)
        all_val_losses.append(val_losses)
        all_train_accuracies.append(train_accuracies)
        all_val_accuracies.append(val_accuracies)
        all_train_f1_scores.append(train_f1_scores)
        all_val_f1_scores.append(val_f1_scores)

    return all_train_losses, all_val_losses, all_train_accuracies, all_val_accuracies, all_train_f1_scores, all_val_f1_scores

results = train_kfold(train_loaders, val_loaders, model_name, CLASS, device, EPOCHS, PATIENCE)

## 평가

In [None]:
# 필요한 추가 라이브러리 import
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# 결과 분석 및 시각화 함수들

def plot_learning_curves(all_train_losses, all_val_losses, all_train_f1_scores, all_val_f1_scores):
    num_folds = len(all_train_losses)
    num_epochs = len(all_train_losses[0])
    
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 16))
    
    # Loss curves
    for fold in range(num_folds):
        ax1.plot(range(1, num_epochs+1), all_train_losses[fold], label=f'Fold {fold+1} Train')
        ax1.plot(range(1, num_epochs+1), all_val_losses[fold], label=f'Fold {fold+1} Val')
    
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss Curves')
    ax1.legend()
    ax1.grid(True)
    
    # F1 score curves
    for fold in range(num_folds):
        ax2.plot(range(1, num_epochs+1), all_train_f1_scores[fold], label=f'Fold {fold+1} Train')
        ax2.plot(range(1, num_epochs+1), all_val_f1_scores[fold], label=f'Fold {fold+1} Val')
    
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Macro F1 Score')
    ax2.set_title('Training and Validation Macro F1 Score Curves')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    fig.savefig('learning_curves.png')
    wandb.log({"learning_curves": wandb.Image('learning_curves.png')})

def plot_fold_comparison(all_val_accuracies, all_val_f1_scores):
    num_folds = len(all_val_accuracies)
    final_accuracies = [acc[-1] for acc in all_val_accuracies]
    final_f1_scores = [f1[-1] for f1 in all_val_f1_scores]
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Accuracy comparison
    ax1.bar(range(1, num_folds+1), final_accuracies)
    ax1.set_xlabel('Fold')
    ax1.set_ylabel('Validation Accuracy')
    ax1.set_title('Final Validation Accuracy by Fold')
    ax1.set_ylim(0, 1)
    
    # F1 score comparison
    ax2.bar(range(1, num_folds+1), final_f1_scores)
    ax2.set_xlabel('Fold')
    ax2.set_ylabel('Validation Macro F1 Score')
    ax2.set_title('Final Validation Macro F1 Score by Fold')
    ax2.set_ylim(0, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    fig.savefig('fold_comparison.png')
    wandb.log({"fold_comparison": wandb.Image('fold_comparison.png')})

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.tight_layout()
    plt.show()
    
    # Save the figure
    plt.savefig('confusion_matrix.png')
    wandb.log({"confusion_matrix": wandb.Image('confusion_matrix.png')})

In [None]:

def interpret_results(all_val_accuracies, all_val_f1_scores, final_accuracy, final_macro_f1, class_f1_scores):
    mean_val_accuracy = np.mean([acc[-1] for acc in all_val_accuracies])
    std_val_accuracy = np.std([acc[-1] for acc in all_val_accuracies])
    mean_val_f1 = np.mean([f1[-1] for f1 in all_val_f1_scores])
    std_val_f1 = np.std([f1[-1] for f1 in all_val_f1_scores])
    
    best_fold = np.argmax([f1[-1] for f1 in all_val_f1_scores]) + 1
    worst_fold = np.argmin([f1[-1] for f1 in all_val_f1_scores]) + 1
    
    best_class = np.argmax(class_f1_scores)
    worst_class = np.argmin(class_f1_scores)
    
    interpretation = f"""
    결과 해석:
    
    1. 모델 성능:
       - 평균 검증 정확도: {mean_val_accuracy:.4f} ± {std_val_accuracy:.4f}
       - 평균 검증 Macro F1 점수: {mean_val_f1:.4f} ± {std_val_f1:.4f}
       - 최종 테스트 정확도: {final_accuracy:.4f}
       - 최종 테스트 Macro F1 점수: {final_macro_f1:.4f}
       
       해석: {'모델이 양호한 성능을 보입니다.' if final_macro_f1 > 0.7 else '모델 성능에 개선의 여지가 있습니다.'}
       {'검증 세트와 테스트 세트의 성능 차이가 크지 않아 일반화가 잘 되었습니다.' if abs(mean_val_f1 - final_macro_f1) < 0.05 else '검증 세트와 테스트 세트의 성능 차이가 있어 추가적인 일반화 작업이 필요할 수 있습니다.'}
    
    2. 모델 일관성:
       - 폴드 간 Macro F1 점수 표준편차: {std_val_f1:.4f}
       - 최고 성능 폴드: {best_fold}, 최저 성능 폴드: {worst_fold}
       
       해석: {'폴드 간 성능 차이가 작아 모델이 안정적입니다.' if std_val_f1 < 0.05 else '폴드 간 성능 차이가 있어 모델의 안정성을 개선할 필요가 있습니다.'}
    
    3. 클래스별 성능:
       - 최고 성능 클래스 (F1 점수): 클래스 {best_class} ({class_f1_scores[best_class]:.4f})
       - 최저 성능 클래스 (F1 점수): 클래스 {worst_class} ({class_f1_scores[worst_class]:.4f})
       
       해석: {'클래스 간 성능 차이가 크지 않습니다.' if (class_f1_scores.max() - class_f1_scores.min()) < 0.2 else '클래스 간 성능 차이가 큽니다. 클래스 불균형 문제를 해결할 필요가 있습니다.'}
    
    4. 과적합 여부:
       {'검증 세트와 테스트 세트의 성능이 비슷하여 과적합 문제는 크지 않아 보입니다.' if abs(mean_val_f1 - final_macro_f1) < 0.05 else '검증 세트와 테스트 세트의 성능 차이가 있어 과적합의 가능성이 있습니다.'}
    
    5. 개선 방안:
    """
    
    if final_macro_f1 < 0.7:
        interpretation += "   - 모델의 복잡도를 높이거나 학습률을 조정해 보세요.\n"
        interpretation += "   - 더 많은 에폭 동안 학습을 진행해 보세요.\n"
    
    if std_val_f1 > 0.05:
        interpretation += "   - 더 강력한 정규화 기법을 적용해 보세요 (예: L2 정규화, 드롭아웃 증가).\n"
        interpretation += "   - 데이터 증강 기법을 다양화하거나 강화해 보세요.\n"
    
    if abs(mean_val_f1 - final_macro_f1) > 0.05:
        interpretation += "   - 교차 검증 과정에서 조기 종료(early stopping)를 적용해 보세요.\n"
        interpretation += "   - 검증 세트의 크기를 늘려 보세요.\n"
    
    if (class_f1_scores.max() - class_f1_scores.min()) > 0.2:
        interpretation += "   - 클래스 가중치를 조정하거나 언더/오버 샘플링 기법을 적용해 보세요.\n"
        interpretation += "   - 성능이 낮은 클래스에 대해 추가 데이터를 수집하거나 데이터 증강을 강화해 보세요.\n"
    
    interpretation += "   - 앙상블 기법을 적용하여 여러 모델의 예측을 결합해 보세요.\n"
    
    return interpretation

In [None]:
all_train_losses, all_val_losses, all_train_accuracies, all_val_accuracies, all_train_f1_scores, all_val_f1_scores = results

# 학습 곡선 플롯
plot_learning_curves(all_train_losses, all_val_losses, all_train_f1_scores, all_val_f1_scores)

# 폴드 간 성능 비교 플롯
plot_fold_comparison(all_val_accuracies, all_val_f1_scores)

# 전체 데이터에 대한 예측 수행 (마지막 폴드의 모델 사용)
model.load_state_dict(torch.load(f'model/{train_time}_fold{FOLD}_best.pt'))
model.eval()

all_preds = []
all_labels = []

pred_loader = DataLoader(pred_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=num_workers, pin_memory=True)

with torch.no_grad():
    for inputs, labels in tqdm(pred_loader, desc="Generating predictions"):
        inputs = inputs.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

# 혼동 행렬 플롯
class_names = [str(i) for i in range(CLASS)]  # 클래스 이름이 있다면 이를 사용
plot_confusion_matrix(all_labels, all_preds, class_names)

# 클래스별 F1 점수 계산 및 출력
class_f1_scores = f1_score(all_labels, all_preds, average=None)
for i, f1 in enumerate(class_f1_scores):
    print(f"Class {i} F1 Score: {f1:.4f}")
    wandb.log({f"class_{i}_f1_score": f1})

# 전체 테스트 세트에 대한 최종 성능 출력
final_accuracy = accuracy_score(all_labels, all_preds)
final_macro_f1 = f1_score(all_labels, all_preds, average='macro')

print(f"Final Test Accuracy: {final_accuracy:.4f}")
print(f"Final Test Macro F1 Score: {final_macro_f1:.4f}")

wandb.log({
    "final_test_accuracy": final_accuracy,
    "final_test_macro_f1": final_macro_f1
})

# 결과 해석
interpretation = interpret_results(all_val_accuracies, all_val_f1_scores, final_accuracy, final_macro_f1, class_f1_scores)
print(interpretation)
wandb.log({"results_interpretation": interpretation})

# Inference & Save File
* 테스트 이미지에 대한 추론을 진행하고, 결과 파일을 저장합니다.

In [None]:
def k_fold_inference(model_name, num_folds, pred_loader, device):
    print("Generating predictions for submission using k-fold models...")
    all_preds = []
    
    for fold in range(num_folds):
        print(f"Inferencing with fold {fold + 1} model")
        
        # 각 폴드의 모델 로드
        model = timm.create_model(model_name, pretrained=False, num_classes=CLASS).to(device)
        model.load_state_dict(torch.load(f'model/{train_time}_fold{fold+1}_best.pt'))
        model.eval()
        
        preds_list = []
        
        for batch in tqdm(pred_loaders[fold], desc=f"Fold {fold+1} - Inference"):
            # pred_loader가 단일 값 반환 또는 두 개의 값 반환에 대응
            if isinstance(batch, (list, tuple)):
                image = batch[0]
            else:
                image = batch
            
            image = image.to(device)
            
            with torch.no_grad():
                preds = model(image)
            preds_list.extend(preds.softmax(dim=1).detach().cpu().numpy())
        
        all_preds.append(preds_list)
    
    # 모든 폴드의 예측을 평균하여 최종 예측 생성
    final_preds = np.mean(all_preds, axis=0)
    return final_preds.argmax(axis=1)

pred_loaders = [
    DataLoader(
        pred_dataset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=0,
        pin_memory=True
    ) for _ in range(FOLD)  # 테스트 데이터는 모든 폴드에 대해 동일
]

preds_list = k_fold_inference(model_name, FOLD, pred_loaders, device)

In [None]:
pred_df = pd.DataFrame(pred_dataset.data, columns=['ID', 'target'])
pred_df['target'] = preds_list 

In [None]:
# 결과 검증
sample_submission_df = pd.read_csv("data/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()

In [None]:
submission_file_path = os.path.join('output', f'{train_time}.csv')
pred_df.to_csv(submission_file_path, index=False)

In [None]:
pred_df.head()

In [None]:
wandb.finish()