In [2]:
import os
import time

import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn

from albumentations.pytorch import ToTensorV2
from torch.optim import Adam
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import CosineAnnealingLR

from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

import matplotlib.pyplot as plt
import seaborn as sns

import wandb

In [10]:
# device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# PATH
TRAIN_AUG_CSV_PATH = '/upstage-cv-classification-cv2/data/train_aug_gaussian.csv'
TRAIN_AUG_IMAGE_PATH = '/upstage-cv-classification-cv2/data/train_aug_gaussian'

VALID_CSV_PATH = '/upstage-cv-classification-cv2/data/valid.csv'
VALID_IMAGE_PATH = '/upstage-cv-classification-cv2/data/valid'

TEST_CSV_PATH = '/upstage-cv-classification-cv2/data/sample_submission.csv'
TEST_IMAGE_PATH = '/upstage-cv-classification-cv2/data/test'

RESULT_CSV_PATH = '/upstage-cv-classification-cv2'

WANDB_PROJECT_NAME = 'cv_competition_batchval'

# HyperParameter

In [11]:
# training config
img_size = 380
LR = 1e-3
EPOCHS = 100
BATCH_SIZE = 32
num_workers = 0

patience = 5
min_delta = 0.001 # 성능 개선의 최소 변화량

# 1. DATA LOAD

In [12]:
# test image 변환
data_transform = A.Compose([
    A.Resize(height = img_size, width = img_size),
    A.Normalize(mean=[0.485, 0.456, 0.406], std = [0.229, 0.224, 0.225]),
    ToTensorV2()
])

class ImageDataset(Dataset):
    def __init__(self, csv, path, transform=None):
        self.df = pd.read_csv(csv).values
        self.path = path
        self.transform = transform

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        name, target = self.df[idx]
        img = np.array(Image.open(os.path.join(self.path, name)))
        if self.transform:
            img = self.transform(image = img)['image']
    
        return img, target

    def get_labels(self):
        return self.df[:, 1] 

trn_dataset = ImageDataset(
    TRAIN_AUG_CSV_PATH,
    TRAIN_AUG_IMAGE_PATH,
    transform = data_transform
)

val_dataset = ImageDataset(
    VALID_CSV_PATH,
    VALID_IMAGE_PATH,
    transform = data_transform
)

tst_dataset = ImageDataset(
    TEST_CSV_PATH,
    TEST_IMAGE_PATH,
    transform = data_transform
)

labels = trn_dataset.get_labels()
labels = labels.astype(int)

# DataLoader
trn_loader = DataLoader(
    trn_dataset,
    batch_size = BATCH_SIZE,
    shuffle = True,
    num_workers = num_workers,
    pin_memory = True,
    drop_last = False
)

val_loader = DataLoader(
    val_dataset,
    batch_size = BATCH_SIZE,
    num_workers = 0,
    pin_memory = True,
    drop_last = False
)

tst_loader = DataLoader(
    tst_dataset,
    batch_size = BATCH_SIZE,
    shuffle = False,
    num_workers = 0,
    pin_memory = True
)

print(len(trn_dataset), len(tst_dataset))

37680 3140


# 2. Model Train

In [19]:
# model
model = timm.create_model('efficientnet_b4',
                        pretrained=True,
                        num_classes = 17).to(device)

loss_fn = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr = LR)


INFO:timm.models._builder:Loading pretrained weights from Hugging Face hub (timm/efficientnet_b4.ra2_in1k)
INFO:timm.models._hub:[timm/efficientnet_b4.ra2_in1k] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.
INFO:timm.models._builder:Missing keys (classifier.weight, classifier.bias) discovered while loading pretrained weights. This is expected if model is being adapted.


In [20]:
def valid_one_epoch(loader, model, loss_fn, device, epoch):
    model.eval()
    valid_loss = 0

    preds_list =[]
    targets_list = []

    with torch.no_grad():
        pbar = tqdm(loader)
        for step, (image, targets) in enumerate(pbar):
            image = image.to(device)
            targets = targets.to(device)

            preds = model(image)
            loss = loss_fn(preds, targets)
       
            valid_loss += loss.item()
        
            preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
            targets_list.extend(targets.detach().cpu().numpy())

            pbar.set_description(f"Loss: {loss.item():.4f}")

            wandb.log({
                "valid_step" : epoch * len(loader) + step,
                "valid_loss_step" : loss.item()
            })

    valid_loss /= len(loader)
    valid_acc = accuracy_score(targets_list, preds_list)
    valid_f1 = f1_score(targets_list, preds_list, average = 'macro')

    ret = {
        "epoch" : epoch,
        "valid_loss" : valid_loss,
        "valid_acc" : valid_acc,
        "valid_f1" : valid_f1
    }

    wandb.log({
        "valid_epoch" : epoch,
        "val_loss_epoch" : valid_loss,
        "val_acc" : valid_acc,
        "val_f1" : valid_f1
    })

    return ret

In [21]:
# one epoch 학습
def train_one_epoch(train_loader, valid_loader, model, optimizer, loss_fn, device, epoch):
    global patience_counter, best_valid_loss, f1_scores, valid_losses, trained_models
    model.train()

    train_loss = 0
    preds_list =[]
    targets_list = []

    is_earlystop = False

    pbar = tqdm(train_loader)
    for step, (image, targets) in enumerate(pbar):
        image = image.to(device)
        targets = targets.to(device)

        model.zero_grad(set_to_none = True)

        preds = model(image)
        loss = loss_fn(preds, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        
        preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())
        targets_list.extend(targets.detach().cpu().numpy())

        pbar.set_description(f"Loss: {loss.item():.4f}")

        wandb.log({
            "train_step" : epoch * len(train_loader) + step,
            "train_loss_step" : loss.item()
        })

        # 100 step 마다 validation 하기
        if (step + 1) % 100 == 0:
            print(f"-------------- step : { epoch * len(train_loader) + step} --------------")
            val_ret =  valid_one_epoch(val_loader, model, loss_fn, device, epoch)

            f1_scores.append(val_ret['valid_f1'])
            valid_losses.append(val_ret['valid_loss'])
            trained_models.append(model)

            print(f"valid loss : {val_ret['valid_loss']}")
            print(f"valid f1 : {val_ret['valid_f1']}")

            # 성능 개선 됨
            if val_ret['valid_loss'] < best_valid_loss - min_delta:
                best_valid_loss = val_ret['valid_loss']
                patience_counter = 0  
                
            # 성능 개선 되지 않음
            else:
                patience_counter += 1  

            # 성능 개선이 patience 만큼 안되면 학습 중단
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                is_earlystop = True
                break
        
    train_loss /= len(train_loader)
    train_acc = accuracy_score(targets_list, preds_list)
    train_f1 = f1_score(targets_list, preds_list, average = 'macro')

    ret = {
        "isEarlyStop" : is_earlystop,
        "model" : model,
        "train_epoch" : epoch,
        "train_loss" : train_loss,
        "tarin_acc" : train_acc,
        "train_f1" : train_f1
    }

    wandb.log({
        "train_epoch" : epoch,
        "train_loss_epoch" : train_loss,
        "train_acc" : train_acc,
        "train_f1" : train_f1
    })

    return ret

In [22]:
os.environ['WANDB_SILENT'] = 'true'

f1_scores = []
valid_losses = []
trained_models = []

patient_counter = 0
best_valid_loss = float('inf')

wandb.init(project=WANDB_PROJECT_NAME, name="gaussian")

for epoch in range(EPOCHS):
    print(f"{epoch} epoch")

    trn_ret = train_one_epoch(trn_loader, val_loader, model, optimizer, loss_fn, device, epoch)
    is_earlystop = trn_ret['isEarlyStop']

    if is_earlystop:
        break
    

best_model_idx = np.argmin(np.array(valid_losses))
best_model = trained_models[best_model_idx]

wandb.finish()

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112243102656471, max=1.0…

0 epoch


Loss: 0.2951:   8%|▊         | 99/1178 [00:39<07:01,  2.56it/s]

-------------- step : 99 --------------


Loss: 0.5840: 100%|██████████| 10/10 [00:01<00:00,  6.44it/s]
Loss: 0.2951:   8%|▊         | 100/1178 [00:40<15:24,  1.17it/s]

valid loss : 0.5221376433968544
valid f1 : 0.8052011986217086


Loss: 0.3149:  17%|█▋        | 199/1178 [01:18<06:11,  2.64it/s]

-------------- step : 199 --------------


Loss: 0.3598: 100%|██████████| 10/10 [00:01<00:00,  6.50it/s]
Loss: 0.3149:  17%|█▋        | 200/1178 [01:20<13:43,  1.19it/s]

valid loss : 0.5180131688714027
valid f1 : 0.8282333615209465


Loss: 0.2019:  25%|██▌       | 299/1178 [01:57<05:30,  2.66it/s]

-------------- step : 299 --------------


Loss: 1.2534: 100%|██████████| 10/10 [00:01<00:00,  6.33it/s]
Loss: 0.2019:  25%|██▌       | 300/1178 [01:59<12:30,  1.17it/s]

valid loss : 0.6410200595855713
valid f1 : 0.8227452789452518


Loss: 0.3793:  34%|███▍      | 399/1178 [02:37<04:53,  2.65it/s]

-------------- step : 399 --------------


Loss: 0.4563: 100%|██████████| 10/10 [00:01<00:00,  6.53it/s]
Loss: 0.3793:  34%|███▍      | 400/1178 [02:38<10:52,  1.19it/s]

valid loss : 0.518449530005455
valid f1 : 0.8292597277506627


Loss: 0.0342:  42%|████▏     | 499/1178 [03:16<04:17,  2.64it/s]

-------------- step : 499 --------------


Loss: 0.9336: 100%|██████████| 10/10 [00:01<00:00,  6.44it/s]
Loss: 0.0342:  42%|████▏     | 500/1178 [03:17<09:33,  1.18it/s]

valid loss : 0.47978816032409666
valid f1 : 0.8478118753098304


Loss: 0.0205:  51%|█████     | 599/1178 [03:55<03:38,  2.66it/s]

-------------- step : 599 --------------


Loss: 0.3530: 100%|██████████| 10/10 [00:01<00:00,  6.60it/s]
Loss: 0.0205:  51%|█████     | 600/1178 [03:57<08:02,  1.20it/s]

valid loss : 0.5024144098162651
valid f1 : 0.8823394159177549


Loss: 0.3234:  59%|█████▉    | 699/1178 [04:34<03:02,  2.62it/s]

-------------- step : 699 --------------


Loss: 0.4726: 100%|██████████| 10/10 [00:01<00:00,  6.62it/s]
Loss: 0.3234:  59%|█████▉    | 700/1178 [04:36<06:39,  1.20it/s]

valid loss : 0.5111183792352676
valid f1 : 0.8612082297865067


Loss: 0.2608:  68%|██████▊   | 799/1178 [05:14<02:22,  2.67it/s]

-------------- step : 799 --------------


Loss: 0.0783: 100%|██████████| 10/10 [00:01<00:00,  6.43it/s]
Loss: 0.2608:  68%|██████▊   | 800/1178 [05:15<05:19,  1.18it/s]

valid loss : 0.37972619459033014
valid f1 : 0.8776853531244752


Loss: 0.0605:  76%|███████▋  | 899/1178 [05:53<01:46,  2.63it/s]

-------------- step : 899 --------------


Loss: 0.3123: 100%|██████████| 10/10 [00:01<00:00,  6.25it/s]
Loss: 0.0605:  76%|███████▋  | 900/1178 [05:55<03:59,  1.16it/s]

valid loss : 0.6480358809232711
valid f1 : 0.821977716659528


Loss: 0.1843:  85%|████████▍ | 999/1178 [06:32<01:07,  2.65it/s]

-------------- step : 999 --------------


Loss: 0.0286: 100%|██████████| 10/10 [00:01<00:00,  6.22it/s]
Loss: 0.1843:  85%|████████▍ | 1000/1178 [06:34<02:33,  1.16it/s]

valid loss : 0.48168935514986516
valid f1 : 0.8179085487027761


Loss: 0.0677:  93%|█████████▎| 1099/1178 [07:12<00:29,  2.67it/s]

-------------- step : 1099 --------------


Loss: 0.1016: 100%|██████████| 10/10 [00:01<00:00,  6.35it/s]
Loss: 0.0677:  93%|█████████▎| 1100/1178 [07:13<01:06,  1.18it/s]

valid loss : 0.6673544995486737
valid f1 : 0.8331898957873983


Loss: 0.2374: 100%|██████████| 1178/1178 [07:43<00:00,  2.54it/s]


1 epoch


Loss: 0.1732:   8%|▊         | 99/1178 [00:39<07:00,  2.56it/s]

-------------- step : 1277 --------------


Loss: 0.4323: 100%|██████████| 10/10 [00:01<00:00,  6.57it/s]
Loss: 0.1732:   8%|▊         | 100/1178 [00:40<15:16,  1.18it/s]

valid loss : 0.4134209968149662
valid f1 : 0.9063142432019478


Loss: 0.0894:  17%|█▋        | 199/1178 [01:18<06:07,  2.67it/s]

-------------- step : 1377 --------------


Loss: 0.0446: 100%|██████████| 10/10 [00:01<00:00,  6.60it/s]
Loss: 0.0894:  17%|█▋        | 199/1178 [01:19<06:32,  2.49it/s]


valid loss : 0.38251737877726555
valid f1 : 0.9025657436670329
Early stopping at epoch 1


# TEST

In [24]:
preds_list = []

best_model.eval()

for image, _ in tqdm(tst_loader):
    image = image.to(device)

    with torch.no_grad():
        preds = best_model(image)
        
    preds_list.extend(preds.argmax(dim=1).detach().cpu().numpy())

pred_df = pd.DataFrame(tst_dataset.df, columns=['ID', 'target'])
pred_df['target'] = preds_list
pred_df.to_csv(f"{RESULT_CSV_PATH}/base_batch_valid.csv", index=False)

100%|██████████| 99/99 [00:15<00:00,  6.27it/s]
