In [17]:
import os
import time

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn

from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import CosineAnnealingLR

from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

import matplotlib.pyplot as plt
import seaborn as sns

import wandb

In [27]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# PATH
TRAIN_AUG_CSV_PATH = '/upstage-cv-classification-cv2/data/train_aug.csv'
TRAIN_AUG_IMAGE_PATH = '/upstage-cv-classification-cv2/data/train_aug'

VALID_CSV_PATH = '/upstage-cv-classification-cv2/data/valid.csv'
VALID_IMAGE_PATH = '/upstage-cv-classification-cv2/data/valid'

TEST_CSV_PATH = '/upstage-cv-classification-cv2/data/sample_submission.csv'
TEST_IMAGE_PATH = '/upstage-cv-classification-cv2/data/test'

RESULT_CSV_PATH = '/upstage-cv-classification-cv2'

WANDB_PROJECT_NAME = 'cv_competition_base'

# HyperParameter

In [28]:
# training config
img_size = 380
LR = 1e-3
EPOCHS = 100
BATCH_SIZE = 32
num_workers = 0

patience = 5
min_delta = 0.001 # 성능 개선의 최소 변화량

# 1. DATA LOAD

In [29]:
# test image 변환
data_transform = A.Compose([
    A.Resize(height = img_size, width = img_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
    ToTensorV2()
])

class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image = img)['image']
    
        return img, target

    def get_labels(self):
        return self.df[:, 1] 

trn_dataset = ImageDataset(
    TRAIN_AUG_CSV_PATH,
    TRAIN_AUG_IMAGE_PATH,
    transform = data_transform
)

val_dataset = ImageDataset(
    VALID_CSV_PATH,
    VALID_IMAGE_PATH,
    transform = data_transform
)

tst_dataset = ImageDataset(
    TEST_CSV_PATH,
    TEST_IMAGE_PATH,
    transform = data_transform
)

labels = trn_dataset.get_labels()
labels = labels.astype(int)

# DataLoader
trn_loader = DataLoader(
    trn_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False
)

val_loader = DataLoader(
    val_dataset,
    batch_size = BATCH_SIZE,
    num_workers = 0,
    pin_memory = True,
    drop_last = False
)

tst_loader = DataLoader(
    tst_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = 0,
    pin_memory = True
)

print(len(trn_dataset), len(tst_dataset))

26690 3140


# 2. Model Train

In [22]:
# model
model = timm.create_model('efficientnet_b4',
                        pretrained=True,
                        num_classes = 17).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr = LR)


INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b4.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b4.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


In [23]:
# one epoch 학습
def train_one_epoch(loader, model, optimizer, loss_fn, device, epoch):
    model.train()
    train_loss = 0
    preds_list =[]
    targets_list = []

    pbar = tqdm(loader)
    for step, (image, targets) in enumerate(pbar):
        image = image.to(device)
        targets = targets.to(device)

        model.zero_grad(set_to_none = True)

        preds = model(image)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())

        pbar.set_description(f"Loss: {loss.item():.4f}")

        wandb.log({
            "train_step" : epoch * len(loader) + step,
            "train_loss_step" : loss.item()
        })
        
    train_loss /= len(loader)
    train_acc = accuracy_score(targets_list, preds_list)
    train_f1 = f1_score(targets_list, preds_list, average = 'macro')

    ret = {
        "model" : model,
        "train_epoch" : epoch,
        "train_loss" : train_loss,
        "tarin_acc" : train_acc,
        "train_f1" : train_f1
    }

    wandb.log({
        "train_epoch" : epoch,
        "train_loss_epoch" : train_loss,
        "train_acc" : train_acc,
        "train_f1" : train_f1
    })

    return ret

In [24]:
def valid_one_epoch(loader, model, loss_fn, device, epoch):
    model.eval()
    valid_loss = 0

    preds_list =[]
    targets_list = []

    with torch.no_grad():
        pbar = tqdm(loader)
        for step, (image, targets) in enumerate(pbar):
            image = image.to(device)
            targets = targets.to(device)

            preds = model(image)
            loss = loss_fn(preds, targets)
       
            valid_loss += loss.item()
        
            preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
            targets_list.extend(targets.detach().cpu().numpy())

            pbar.set_description(f"Loss: {loss.item():.4f}")

            wandb.log({
                "valid_step" : epoch * len(loader) + step,
                "valid_loss_step" : loss.item()
            })

    valid_loss /= len(loader)
    valid_acc = accuracy_score(targets_list, preds_list)
    valid_f1 = f1_score(targets_list, preds_list, average = 'macro')

    ret = {
        "epoch" : epoch,
        "valid_loss" : valid_loss,
        "valid_acc" : valid_acc,
        "valid_f1" : valid_f1
    }

    wandb.log({
        "valid_epoch" : epoch,
        "val_loss_epoch" : valid_loss,
        "val_acc" : valid_acc,
        "val_f1" : valid_f1
    })

    return ret

In [25]:
os.environ['WANDB_SILENT'] = 'true'

f1_scores = []
valid_losses = []
trained_models = []
patient_counter = 0
best_valid_loss = float('inf')


wandb.init(project=WANDB_PROJECT_NAME, name="efficientenet_b4_base")

for epoch in range(EPOCHS):
    print(f"{epoch} epoch")
    trn_ret = train_one_epoch(trn_loader, model, optimizer, loss_fn, device, epoch)
    val_ret =  valid_one_epoch(val_loader, model, loss_fn, device, epoch)

    f1_scores.append(val_ret['valid_f1'])
    valid_losses.append(val_ret['valid_loss'])
    trained_models.append(trn_ret['model'])

    print(f"valid loss : {val_ret['valid_loss']}")
    print(f"valid f1 : {val_ret['valid_f1']}")

    # 성능 개선 됨
    if val_ret['valid_loss'] < best_valid_loss - min_delta:
        best_valid_loss = val_ret['valid_loss']
        patience_counter = 0  
        
    # 성능 개선 되지 않음
    else:
        patience_counter += 1  

    # 성능 개선이 patience 만큼 안되면 학습 중단
    if patience_counter >= patience:
        print(f"Early stopping at epoch {epoch}")
        break

best_model_idx = np.argmin(np.array(valid_losses))
best_model = trained_models[best_model_idx]

wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112665178047286, max=1.0…

0 epoch


Loss: 0.1951: 100%|██████████| 835/835 [02:25<00:00,  5.75it/s]
Loss: 0.8145: 100%|██████████| 10/10 [00:00<00:00, 11.73it/s]


valid loss : 0.9209675282239914
valid f1 : 0.7973231943020411
1 epoch


Loss: 3.9228: 100%|██████████| 835/835 [02:25<00:00,  5.73it/s]
Loss: 0.0254: 100%|██████████| 10/10 [00:00<00:00, 11.80it/s]


valid loss : 0.5328848786652088
valid f1 : 0.8515226187326626
2 epoch


Loss: 0.1151: 100%|██████████| 835/835 [02:25<00:00,  5.74it/s]
Loss: 0.4211: 100%|██████████| 10/10 [00:00<00:00, 12.06it/s]


valid loss : 0.574249729514122
valid f1 : 0.8638497228897944
3 epoch


Loss: 1.0336: 100%|██████████| 835/835 [02:26<00:00,  5.69it/s]
Loss: 1.4203: 100%|██████████| 10/10 [00:00<00:00, 11.46it/s]


valid loss : 5.000928801298142
valid f1 : 0.8413033567983912
4 epoch


Loss: 0.0041: 100%|██████████| 835/835 [02:25<00:00,  5.74it/s]
Loss: 0.4983: 100%|██████████| 10/10 [00:00<00:00, 11.87it/s]


valid loss : 0.7800617843866349
valid f1 : 0.8546008669597089
5 epoch


Loss: 0.0206: 100%|██████████| 835/835 [02:26<00:00,  5.69it/s]
Loss: 1.8408: 100%|██████████| 10/10 [00:00<00:00, 11.40it/s]


valid loss : 1.0815544307231904
valid f1 : 0.8488437379635189
6 epoch


Loss: 2.4601: 100%|██████████| 835/835 [02:28<00:00,  5.62it/s]
Loss: 1.4670: 100%|██████████| 10/10 [00:00<00:00, 12.03it/s]


valid loss : 1.1022712230682372
valid f1 : 0.8122842810393381
Early stopping at epoch 6


# TEST

In [30]:
preds_list = []

best_model.eval()

for image, _ in tqdm(tst_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = best_model(image)
        
    preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list
pred_df.to_csv(f"{RESULT_CSV_PATH}/base.csv", index=False)

100%|██████████| 99/99 [00:08<00:00, 11.10it/s]


# 결과 분석

In [31]:
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.metrics import confusion_matrix

# 예측 결과 생성
valid_preds_list = []

best_model.eval()

for image, _ in tqdm(val_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = best_model(image)
        
    valid_preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

# 실제 레이블과 예측 레이블 준비
true_labels = val_dataset.get_labels()
pred_labels = np.array(valid_preds_list)

# Confusion Matrix 생성
cm = confusion_matrix(true_labels, pred_labels)

# 클래스의 최대값 확인 (히트맵의 크기를 결정하기 위해)
n_classes = max(cm.shape)

# 히트맵 생성
plt.figure(figsize=(12, 10))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=range(n_classes), 
            yticklabels=range(n_classes))

plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('True')

# x축과 y축의 눈금을 1단위로 설정
plt.xticks(np.arange(0, n_classes, 1))
plt.yticks(np.arange(0, n_classes, 1))

plt.tight_layout()
plt.show()


100%|██████████| 10/10 [00:00<00:00, 11.86it/s]


ValueError: Classification metrics can't handle a mix of unknown and multiclass targets