# **📄 Document type classification baseline code**
> 문서 타입 분류 대회에 오신 여러분 환영합니다! 🎉     
> 아래 baseline에서는 ResNet 모델을 로드하여, 모델을 학습 및 예측 파일 생성하는 프로세스에 대해 알아보겠습니다.

## Contents
- Prepare Environments
- Import Library & Define Functions
- Hyper-parameters
- Load Data
- Train Model
- Inference & Save File


## 1. Prepare Environments

* 데이터 로드를 위한 구글 드라이브를 마운트합니다.
* 필요한 라이브러리를 설치합니다.

In [1]:
# 필요한 라이브러리를 설치합니다.
!pip install timm
!pip install wandb
!pip install python-dotenv

[0mCollecting python-dotenv
  Obtaining dependency information for python-dotenv from https://files.pythonhosted.org/packages/5f/ed/539768cf28c661b5b068d66d96a2f155c4971a5d55684a514c1a0e0dec2f/python_dotenv-1.1.1-py3-none-any.whl.metadata
  Downloading python_dotenv-1.1.1-py3-none-any.whl.metadata (24 kB)
Downloading python_dotenv-1.1.1-py3-none-any.whl (20 kB)
Installing collected packages: python-dotenv
Successfully installed python-dotenv-1.1.1
[0m

## 2. Import Library & Define Functions
* 학습 및 추론에 필요한 라이브러리를 로드합니다.
* 학습 및 추론에 필요한 함수와 클래스를 정의합니다.

In [29]:
import os
import time
import random

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

In [38]:
import os
from dotenv import load_dotenv

# .env 파일 불러오기
load_dotenv()

# 환경변수 사용
wandb_api_key = os.getenv("WANDB_API_KEY_")

# print("API Key from env:", os.getenv("WANDB_API_KEY_"))

import wandb
wandb.login(key=wandb_api_key)


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /data/ephemeral/home/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mjunegood[0m ([33mjunegood-[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
train_loss = 0.0
train_acc = 0.0
train_f1 = 0.0


wandb.init(
    project="doc-classification",   # 프로젝트 이름
    name="efficientnetv2-exp1",     # 실험 이름
    config={                        # 하이퍼파라미터 저장
        "learning_rate": 1e-4,
        "epochs": 30,
        "batch_size": 16,
        "model": "efficientnetv2_s",
        "input_size": 512,
        "train_loss": train_loss,
        "train_acc": train_acc,
        "train_f1": train_f1,


    }
)

def log_metrics(epoch, train_loss=None, val_loss=None, train_acc=None, val_acc=None, lr=None, train_f1=None, val_f1=None):
    """
    W&B 로그 기록 함수
    """
    log_dict = {"epoch": epoch}

    if train_loss is not None: log_dict["train_loss"] = train_loss
    if val_loss is not None: log_dict["val_loss"] = val_loss
    if train_acc is not None: log_dict["train_acc"] = train_acc
    if val_acc is not None: log_dict["val_acc"] = val_acc
    if train_f1 is not None: log_dict["train_f1"] = train_f1
    if val_f1 is not None: log_dict["val_f1"] = val_f1
    if lr is not None: log_dict["learning_rate"] = lr

    wandb.log(log_dict)

log_metrics(0, train_loss=train_loss, train_acc=train_acc, train_f1=train_f1)


In [45]:
# 시드를 고정합니다.
SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

In [46]:
# 데이터셋 클래스를 정의합니다.
class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image=img)['image']
        return img, target

In [47]:
# one epoch 학습을 위한 함수입니다.
def train_one_epoch(loader, model, optimizer, loss_fn, device):
    model.train()
    train_loss = 0
    preds_list = []
    targets_list = []
    
    pbar = tqdm(loader)
    for image, targets in pbar:
        image = image.to(device)
        targets = targets.to(device)

        model.zero_grad(set_to_none=True)

        preds = model(image)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())

        pbar.set_description(f"Loss: {loss.item():.4f}")

    train_loss /= len(loader)
    train_acc = accuracy_score(targets_list, preds_list)
    train_f1 = f1_score(targets_list, preds_list, average='macro')
    log_metrics(epoch=None, train_loss=train_loss, train_acc=train_acc, train_f1=train_f1) 

    ret = {
        "train_loss": train_loss,
        "train_acc": train_acc,
        "train_f1": train_f1,
    }

    return ret

## 3. Hyper-parameters
* 학습 및 추론에 필요한 하이퍼파라미터들을 정의합니다.

In [48]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data config
data_path = 'datasets_fin/'

# model config
model_name = 'efficientnet_b3' # 'resnet50' 'efficientnet-b0', resnet34 ...

# training config
img_size = 224
LR = 1e-4
EPOCHS = 30
BATCH_SIZE = 16
num_workers = 0

## 4. Load Data
* 학습, 테스트 데이터셋과 로더를 정의합니다.

In [49]:
# augmentation을 위한 transform 코드
'''
trn_transform = A.Compose([
    # 이미지 크기 조정
    A.Resize(height=img_size, width=img_size),
    # images normalization
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    # numpy 이미지나 PIL 이미지를 PyTorch 텐서로 변환
    ToTensorV2(),
])
'''

# test image 변환을 위한 transform 코드
'''
tst_transform = A.Compose([
    A.Resize(height=img_size, width=img_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])
'''

'\ntst_transform = A.Compose([\n    A.Resize(height=img_size, width=img_size),\n    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n    ToTensorV2(),\n])\n'

In [50]:
trn_transform = A.Compose([
    # 1. 랜덤 회전 (최대 15도까지 회전)
    A.Rotate(limit=15, p=0.5),
    
    # 2. 좌우 반전 (50% 확률로 적용)
    A.HorizontalFlip(p=0.5),
    
    # 3. 밝기와 대비 조절 (밝기 ±20%, 대비 ±20%)
    A.RandomBrightnessContrast(
        brightness_limit=0.2,
        contrast_limit=0.2,
        p=0.5
    ),
    
    # 4. 랜덤 크롭 후 리사이즈 (원본의 80-100% 크기로 자르기)
    A.RandomResizedCrop(
        height=img_size,
        width=img_size,
        scale=(0.8, 1.0),
        ratio=(0.9, 1.1),
        p=0.7
    ),
    
    # 5. 색상 지터링 (색조, 채도, 값 조절)
    A.HueSaturationValue(
        hue_shift_limit=10,
        sat_shift_limit=20,
        val_shift_limit=20,
        p=0.3
    ),
    
    # 6. 가우시안 노이즈 추가
    A.GaussNoise(var_limit=(10.0, 50.0), p=0.2),
    
    # 7. 모든 이미지를 동일한 크기로 리사이즈 (필수!)
    A.Resize(height=img_size, width=img_size),
    
    # 8. 정규화 (ImageNet 통계 사용)
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    
    # 9. 텐서 변환
    ToTensorV2(),
])


# 테스트용 변환 (증강 없이 기본 전처리만)
tst_transform = A.Compose([
    A.Resize(height=img_size, width=img_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ToTensorV2(),
])


In [51]:
# Dataset 정의
trn_dataset = ImageDataset(
    "data/train.csv",
    "data/train/",
    transform=trn_transform
)
tst_dataset = ImageDataset(
    "data/sample_submission.csv",
    "data/test/",
    transform=tst_transform
)
print(len(trn_dataset), len(tst_dataset))

1570 3140


In [52]:
# DataLoader 정의
trn_loader = DataLoader(
    trn_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
    drop_last=False
)
tst_loader = DataLoader(
    tst_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

## 5. Train Model
* 모델을 로드하고, 학습을 진행합니다.

In [53]:
# load model
model = timm.create_model(
    model_name,
    pretrained=True,
    num_classes=17
).to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)

In [None]:
for epoch in range(EPOCHS):
    ret = train_one_epoch(trn_loader, model, optimizer, loss_fn, device=device)
    ret['epoch'] = epoch

    log = ""
    for k, v in ret.items():
      log += f"{k}: {v:.4f}\n"
    print(log)

Loss: 1.3532: 100%|██████████| 99/99 [00:31<00:00,  3.16it/s]


train_loss: 1.6414
train_acc: 0.5178
train_f1: 0.4821
epoch: 0.0000



Loss: 1.3008: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.6130
train_acc: 0.8070
train_f1: 0.7854
epoch: 1.0000



Loss: 0.7521: 100%|██████████| 99/99 [00:29<00:00,  3.37it/s]


train_loss: 0.3661
train_acc: 0.8771
train_f1: 0.8603
epoch: 2.0000



Loss: 1.0236: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.2921
train_acc: 0.9045
train_f1: 0.8965
epoch: 3.0000



Loss: 1.1784: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.2220
train_acc: 0.9280
train_f1: 0.9206
epoch: 4.0000



Loss: 1.7269: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.2053
train_acc: 0.9369
train_f1: 0.9325
epoch: 5.0000



Loss: 3.1618: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.1803
train_acc: 0.9497
train_f1: 0.9449
epoch: 6.0000



Loss: 0.2248: 100%|██████████| 99/99 [00:29<00:00,  3.33it/s]


train_loss: 0.1414
train_acc: 0.9529
train_f1: 0.9479
epoch: 7.0000



Loss: 0.3038: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.1038
train_acc: 0.9688
train_f1: 0.9654
epoch: 8.0000



Loss: 1.4279: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.1140
train_acc: 0.9637
train_f1: 0.9618
epoch: 9.0000



Loss: 0.0899: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.0740
train_acc: 0.9752
train_f1: 0.9735
epoch: 10.0000



Loss: 1.7015: 100%|██████████| 99/99 [00:29<00:00,  3.32it/s]


train_loss: 0.0845
train_acc: 0.9783
train_f1: 0.9772
epoch: 11.0000



Loss: 0.3139: 100%|██████████| 99/99 [00:29<00:00,  3.37it/s]


train_loss: 0.0591
train_acc: 0.9803
train_f1: 0.9773
epoch: 12.0000



Loss: 1.7098: 100%|██████████| 99/99 [00:29<00:00,  3.33it/s]


train_loss: 0.0858
train_acc: 0.9739
train_f1: 0.9720
epoch: 13.0000



Loss: 0.1860: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.0488
train_acc: 0.9822
train_f1: 0.9812
epoch: 14.0000



Loss: 3.6091: 100%|██████████| 99/99 [00:29<00:00,  3.34it/s]


train_loss: 0.0669
train_acc: 0.9904
train_f1: 0.9906
epoch: 15.0000



Loss: 0.0760: 100%|██████████| 99/99 [00:29<00:00,  3.33it/s]


train_loss: 0.0326
train_acc: 0.9911
train_f1: 0.9906
epoch: 16.0000



Loss: 0.3464: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.0441
train_acc: 0.9917
train_f1: 0.9909
epoch: 17.0000



Loss: 0.1360: 100%|██████████| 99/99 [00:29<00:00,  3.38it/s]


train_loss: 0.0350
train_acc: 0.9911
train_f1: 0.9906
epoch: 18.0000



Loss: 0.2796: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.0244
train_acc: 0.9943
train_f1: 0.9938
epoch: 19.0000



Loss: 1.2123: 100%|██████████| 99/99 [00:29<00:00,  3.33it/s]


train_loss: 0.0427
train_acc: 0.9892
train_f1: 0.9878
epoch: 20.0000



Loss: 1.0203: 100%|██████████| 99/99 [00:29<00:00,  3.35it/s]


train_loss: 0.0433
train_acc: 0.9892
train_f1: 0.9890
epoch: 21.0000



Loss: 0.0336: 100%|██████████| 99/99 [00:29<00:00,  3.36it/s]


train_loss: 0.0360
train_acc: 0.9892
train_f1: 0.9882
epoch: 22.0000



Loss: 0.1764: 100%|██████████| 99/99 [00:30<00:00,  3.29it/s]


train_loss: 0.0345
train_acc: 0.9885
train_f1: 0.9869
epoch: 23.0000



Loss: 0.2488: 100%|██████████| 99/99 [00:29<00:00,  3.34it/s]


train_loss: 0.0330
train_acc: 0.9866
train_f1: 0.9859
epoch: 24.0000



Loss: 0.0501: 100%|██████████| 99/99 [00:29<00:00,  3.34it/s]


train_loss: 0.0164
train_acc: 0.9955
train_f1: 0.9955
epoch: 25.0000



Loss: 1.3075: 100%|██████████| 99/99 [00:29<00:00,  3.37it/s]


train_loss: 0.0358
train_acc: 0.9917
train_f1: 0.9917
epoch: 26.0000



Loss: 0.0123:  32%|███▏      | 32/99 [00:09<00:20,  3.26it/s]
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.sen

KeyboardInterrupt: 

socket.send() raised exception.
socket.send() raised exception.


socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.send() raised exception.
socket.s

# 6. Inference & Save File
* 테스트 이미지에 대한 추론을 진행하고, 결과 파일을 저장합니다.

In [61]:
preds_list = []

model.eval()
for image, _ in tqdm(tst_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = model(image)
    preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

100%|██████████| 197/197 [00:14<00:00, 13.35it/s]


In [62]:
pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list

In [63]:
sample_submission_df = pd.read_csv("data/sample_submission.csv")
assert (sample_submission_df['ID'] == pred_df['ID']).all()

In [64]:
pred_df.to_csv("pred.csv", index=False)

In [68]:
pred_df.head()

Unnamed: 0,ID,target
0,0008fdb22ddce0ce.jpg,2
1,00091bffdffd83de.jpg,6
2,00396fbc1f6cc21d.jpg,8
3,00471f8038d9c4b6.jpg,13
4,00901f504008d884.jpg,2
