<a href="https://colab.research.google.com/github/2025-02-FML-team/WV-Team/blob/train-pipeline/model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
"""
k 브랜치용 학습 파이프라인 (PyTorch)
- 목적: "k" 브랜치에 들어갈 재현 가능한 이미지 분류 학습 코드
- 입력: uploads 폴더(또는 CSV에서 지정한 경로). 예: photos/..., whiskies_recategorized.csv
- 출력: 체크포인트(.pt), 학습 로그, confusion matrix (numpy)

사용법 예시:
$ pip install -r requirements.txt
$ python k_branch_training.py --csv /mnt/data/whiskies_recategorized.csv --img-root /mnt/data/photos --epochs 30

작성자: (자동생성)
"""

import argparse
import os
import random
import math
from collections import Counter

import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
import torchvision.transforms as T
import torchvision.models as models

# -------------- Dataset --------------
class WhiskyImageDataset(Dataset):
    def __init__(self, df, img_root, transform=None, img_col='local_full_path', label_col='category'):
        self.df = df.reset_index(drop=True)
        self.img_root = img_root
        self.transform = transform
        self.img_col = img_col
        self.label_col = label_col
        self.classes = sorted(self.df[self.label_col].unique())
        self.class_to_idx = {c:i for i,c in enumerate(self.classes)}

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.loc[idx]
        rel_path = row[self.img_col]
        img_path = os.path.join(self.img_root, rel_path) if not os.path.isabs(rel_path) else rel_path
        img = Image.open(img_path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        label = self.class_to_idx[row[self.label_col]]
        return img, label

# -------------- Utilities --------------

def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def stratified_split(df, label_col='category', train_frac=0.64, val_frac=0.16, test_frac=0.2, seed=42):
    assert abs(train_frac + val_frac + test_frac - 1.0) < 1e-6
    np.random.seed(seed)
    train_idx, val_idx, test_idx = [], [], []
    for _, grp in df.groupby(label_col):
        n = len(grp)
        idxs = grp.index.to_numpy()
        np.random.shuffle(idxs)
        n_train = int(np.floor(train_frac * n))
        n_val = int(np.floor(val_frac * n))
        train_idx.extend(idxs[:n_train].tolist())
        val_idx.extend(idxs[n_train:n_train+n_val].tolist())
        test_idx.extend(idxs[n_train+n_val:].tolist())
    return df.loc[train_idx].reset_index(drop=True), df.loc[val_idx].reset_index(drop=True), df.loc[test_idx].reset_index(drop=True)


# -------------- Model helpers --------------

def get_model(num_classes, model_name='resnet50', pretrained=True):
    if model_name == 'resnet50':
        m = models.resnet50(pretrained=pretrained)
        in_f = m.fc.in_features
        m.fc = nn.Linear(in_f, num_classes)
    elif model_name == 'resnet18':
        m = models.resnet18(pretrained=pretrained)
        in_f = m.fc.in_features
        m.fc = nn.Linear(in_f, num_classes)
    elif model_name == 'efficientnet_b0':
        m = models.efficientnet_b0(pretrained=pretrained)
        in_f = m.classifier[1].in_features
        m.classifier[1] = nn.Linear(in_f, num_classes)
    else:
        raise ValueError('Unsupported model_name')
    return m


# -------------- Training loop --------------

def train_one_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for imgs, labels in dataloader:
        imgs = imgs.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += imgs.size(0)
    return running_loss / total, correct / total


def validate(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs = imgs.to(device)
            labels = labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * imgs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += imgs.size(0)
            all_preds.extend(preds.cpu().numpy().tolist())
            all_labels.extend(labels.cpu().numpy().tolist())
    avg_loss = running_loss / total
    acc = correct / total
    return avg_loss, acc, np.array(all_preds), np.array(all_labels)


# -------------- Main --------------

def main(args):
    set_seed(args.seed)
    device = torch.device('cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')

    # Load CSV
    df = pd.read_csv(args.csv)
    # Ensure the expected columns exist
    assert args.img_col in df.columns and args.label_col in df.columns, f"CSV must contain columns: {args.img_col}, {args.label_col}"

    # Optional: filter/remove classes with too few samples (commented out)
    if args.min_samples > 0:
        counts = df[args.label_col].value_counts()
        keep = counts[counts >= args.min_samples].index.tolist()
        df = df[df[args.label_col].isin(keep)].reset_index(drop=True)

    # Stratified split
    train_df, val_df, test_df = stratified_split(df, label_col=args.label_col,
                                                 train_frac=args.train_frac, val_frac=args.val_frac, test_frac=args.test_frac,
                                                 seed=args.seed)
    print(f"Split sizes: train={len(train_df)}, val={len(val_df)}, test={len(test_df)}")

    # Transforms
    train_transform = T.Compose([
        T.Resize((args.img_size, args.img_size)),
        T.RandomHorizontalFlip(),
        T.RandomVerticalFlip(),
        T.RandomRotation(20),
        T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.02),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    val_transform = T.Compose([
        T.Resize((args.img_size, args.img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])

    # Datasets
    train_ds = WhiskyImageDataset(train_df, args.img_root, transform=train_transform,
                                  img_col=args.img_col, label_col=args.label_col)
    val_ds = WhiskyImageDataset(val_df, args.img_root, transform=val_transform,
                                img_col=args.img_col, label_col=args.label_col)
    test_ds = WhiskyImageDataset(test_df, args.img_root, transform=val_transform,
                                 img_col=args.img_col, label_col=args.label_col)

    num_classes = len(train_ds.classes)
    print('Classes:', train_ds.classes)

    # Weighted sampler to mitigate class imbalance
    if args.use_weighted_sampler:
        counts = train_df[args.label_col].value_counts().to_dict()
        weights = [1.0 / counts[c] for c in train_df[args.label_col]]
        sampler = WeightedRandomSampler(weights, num_samples=len(weights), replacement=True)
        shuffle = False
    else:
        sampler = None
        shuffle = True

    train_loader = DataLoader(train_ds, batch_size=args.batch_size, sampler=sampler, shuffle=shuffle, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # Model
    model = get_model(num_classes=num_classes, model_name=args.model, pretrained=not args.no_pretrained)
    model = model.to(device)

    # Loss + optimizer
    if args.class_weighting:
        # compute class weights (inverse frequency)
        class_counts = train_df[args.label_col].value_counts().reindex(train_ds.classes).fillna(0).values
        class_weights = torch.tensor(1.0 / (class_counts + 1e-6), dtype=torch.float32)
        class_weights = class_weights / class_weights.sum() * len(class_weights)
        criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)

    best_val_acc = 0.0
    for epoch in range(1, args.epochs + 1):
        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc, _, _ = validate(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Epoch {epoch:02d}/{args.epochs} | train_loss: {train_loss:.4f} acc: {train_acc:.4f} | val_loss: {val_loss:.4f} acc: {val_acc:.4f}")

        # save best
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            os.makedirs(args.save_dir, exist_ok=True)
            save_path = os.path.join(args.save_dir, f'best_model_epoch{epoch}_acc{val_acc:.4f}.pt')
            torch.save({'model_state_dict': model.state_dict(), 'epoch': epoch, 'val_acc': val_acc}, save_path)
            print('Saved', save_path)

    # Final evaluation on test set
    test_loss, test_acc, preds, labels = validate(model, test_loader, criterion, device)
    print(f"Test: loss={test_loss:.4f} acc={test_acc:.4f}")

    # Save predictions
    out_df = test_df.copy()
    out_df['pred_idx'] = preds
    out_df['label_idx'] = labels
    out_df['pred'] = out_df['pred_idx'].map(lambda x: train_ds.classes[x])
    out_df['label'] = out_df['label_idx'].map(lambda x: train_ds.classes[x])
    out_csv = os.path.join(args.save_dir, 'test_predictions.csv')
    out_df.to_csv(out_csv, index=False)
    print('Saved predictions to', out_csv)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--csv', type=str, default='/mnt/data/whiskies_recategorized.csv', help='CSV file with local_full_path and category columns')
    parser.add_argument('--img-root', type=str, default='/mnt/data/photos', help='Root folder that images are relative to')
    parser.add_argument('--img-col', type=str, default='local_full_path')
    parser.add_argument('--label-col', type=str, default='category')
    parser.add_argument('--img-size', type=int, default=256)
    parser.add_argument('--batch-size', type=int, default=32)
    parser.add_argument('--epochs', type=int, default=20)
    parser.add_argument('--lr', type=float, default=3e-4)
    parser.add_argument('--weight-decay', type=float, default=1e-4)
    parser.add_argument('--model', type=str, default='resnet50', choices=['resnet18','resnet50','efficientnet_b0'])
    parser.add_argument('--no-pretrained', action='store_true')
    parser.add_argument('--use-weighted-sampler', action='store_true')
    parser.add_argument('--class-weighting', action='store_true')
    parser.add_argument('--min-samples', type=int, default=0, help='Remove classes with fewer than this many samples')
    parser.add_argument('--train-frac', type=float, default=0.64)
    parser.add_argument('--val-frac', type=float, default=0.16)
    parser.add_argument('--test-frac', type=float, default=0.20)
    parser.add_argument('--save-dir', type=str, default='./k_branch_outputs')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--no-cuda', action='store_true')
    args = parser.parse_args()
    main(args)


usage: colab_kernel_launcher.py [-h] [--csv CSV] [--img-root IMG_ROOT]
                                [--img-col IMG_COL] [--label-col LABEL_COL]
                                [--img-size IMG_SIZE]
                                [--batch-size BATCH_SIZE] [--epochs EPOCHS]
                                [--lr LR] [--weight-decay WEIGHT_DECAY]
                                [--model {resnet18,resnet50,efficientnet_b0}]
                                [--no-pretrained] [--use-weighted-sampler]
                                [--class-weighting]
                                [--min-samples MIN_SAMPLES]
                                [--train-frac TRAIN_FRAC]
                                [--val-frac VAL_FRAC] [--test-frac TEST_FRAC]
                                [--save-dir SAVE_DIR] [--seed SEED]
                                [--no-cuda]
colab_kernel_launcher.py: error: unrecognized arguments: -f /root/.local/share/jupyter/runtime/kernel-8c4b0f9a-9251-4edc-a63b-75773adc84e4

SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
