In [1]:
# train_swin_tiny_v3.py
import os
import time
import random
import timm
import torch
import albumentations as A
import pandas as pd
import numpy as np
import torch.nn as nn
from albumentations.pytorch import ToTensorV2
from torch.optim import AdamW
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score
from sklearn.model_selection import train_test_split
import glob
import cv2

# ------------------------
# ‚öôÔ∏è Device & Seed
# ------------------------
SEED = 42
os.environ["PYTHONHASHSEED"] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üî• Using {device} ({torch.cuda.get_device_name(0) if device.type=='cuda' else 'CPU'})")

# ------------------------
# üìÅ Paths
# ------------------------
BASE = "/data/ephemeral/home/data"
DATA_DIR = f"{BASE}/processed/stage0_5_train/"
META_PATH = f"{BASE}/meta_stage0_5_train.csv"
TRAIN_CSV = f"{BASE}/raw/train.csv"
TEST_PATH = f"{BASE}/processed/stage0_5_test/"
SUB_PATH = f"{BASE}/raw/sample_submission.csv"

# ------------------------
# üß© Dataset
# ------------------------
class ImageDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = np.array(Image.open(row["filepath"]).convert("RGB"))
        target = int(row["target"])
        if self.transform:
            img = self.transform(image=img)["image"]
        return img, target

# ------------------------
# ‚öôÔ∏è Config
# ------------------------
model_name = "swin_tiny_patch4_window7_224"
IMG_SIZE = 224
NUM_CLASSES = 17
LR = 3e-4
EPOCHS = 30
BATCH_SIZE = 64
num_workers = 8

# ------------------------
# üß© Albumentations Transforms
# ------------------------
trn_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.OneOf([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.2),
    ], p=0.6),
    A.Rotate(limit=25, border_mode=cv2.BORDER_REFLECT_101, p=0.4),
    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.08, rotate_limit=10, p=0.4),
    A.Perspective(scale=(0.05, 0.10), p=0.2),
    A.OneOf([
        A.MotionBlur(blur_limit=5, p=0.3),
        A.GaussianBlur(blur_limit=3, p=0.3),
        A.GaussNoise(var_limit=(5, 25), p=0.3),
    ], p=0.3),
    A.OneOf([
        A.RandomBrightnessContrast(brightness_limit=0.15, contrast_limit=0.15, p=0.4),
        A.RandomShadow(p=0.2),
    ], p=0.3),
    A.HueSaturationValue(hue_shift_limit=8, sat_shift_limit=10, val_shift_limit=10, p=0.25),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

tst_transform = A.Compose([
    A.Resize(IMG_SIZE, IMG_SIZE),
    A.Normalize(mean=(0.485, 0.456, 0.406),
                std=(0.229, 0.224, 0.225)),
    ToTensorV2(),
])

# ------------------------
# üìà Data Split
# ------------------------
meta = pd.read_csv(META_PATH)
train_csv = pd.read_csv(TRAIN_CSV)

meta["basename"] = meta["filepath"].apply(lambda x: os.path.basename(x).split("_", 1)[-1])
train_csv["basename"] = train_csv["ID"].apply(lambda x: os.path.basename(x))
meta_joined = pd.merge(meta, train_csv[["basename", "target"]], on="basename", how="left")
meta_joined = meta_joined.dropna(subset=["target"]).reset_index(drop=True)

trn_df, val_df = train_test_split(
    meta_joined, test_size=0.2, stratify=meta_joined["target"], random_state=SEED
)
print("‚úÖ Split ÏôÑÎ£å:", trn_df.shape, val_df.shape)

# ------------------------
# üß± Dataloaders
# ------------------------
trn_loader = DataLoader(ImageDataset(trn_df, trn_transform),
                        batch_size=BATCH_SIZE, shuffle=True,
                        num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(ImageDataset(val_df, tst_transform),
                        batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=num_workers, pin_memory=True)

# ------------------------
# üß† Model
# ------------------------
model = timm.create_model(model_name, pretrained=True, num_classes=NUM_CLASSES,
                          in_chans=3, drop_path_rate=0.1).to(device)

loss_fn = nn.CrossEntropyLoss(label_smoothing=0.05)
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=1e-5)
scaler = torch.cuda.amp.GradScaler()

# ------------------------
# üßÆ Train / Validate
# ------------------------
def train_one_epoch(loader, model, optimizer, loss_fn, device, scaler=None):
    model.train()
    total_loss, preds_list, targets_list = 0, [], []
    for images, targets in tqdm(loader, desc="Train", leave=False):
        images, targets = images.to(device), targets.to(device)
        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast("cuda", enabled=(device.type == "cuda")):
            preds = model(images)
            loss = loss_fn(preds, targets)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        total_loss += loss.item()
        preds_list.extend(preds.argmax(1).detach().cpu().numpy())
        targets_list.extend(targets.cpu().numpy())
    return {
        "loss": total_loss / len(loader),
        "acc": accuracy_score(targets_list, preds_list),
        "f1": f1_score(targets_list, preds_list, average="macro"),
    }

@torch.no_grad()
def validate(loader, model, loss_fn, device):
    model.eval()
    total_loss, preds_list, targets_list = 0, [], []
    for images, targets in tqdm(loader, desc="Valid", leave=False):
        images, targets = images.to(device), targets.to(device)
        with torch.amp.autocast("cuda"):
            preds = model(images)
            loss = loss_fn(preds, targets)
        total_loss += loss.item()
        preds_list.extend(preds.argmax(1).cpu().numpy())
        targets_list.extend(targets.cpu().numpy())
    return {
        "loss": total_loss / len(loader),
        "acc": accuracy_score(targets_list, preds_list),
        "f1": f1_score(targets_list, preds_list, average="macro"),
    }

# ------------------------
# üöÄ Training Loop
# ------------------------
best_f1 = 0.0
for epoch in range(EPOCHS):
    train_metrics = train_one_epoch(trn_loader, model, optimizer, loss_fn, device, scaler)
    val_metrics = validate(val_loader, model, loss_fn, device)
    scheduler.step()

    print(f"\n[Epoch {epoch+1}/{EPOCHS}] "
          f"Train | Loss {train_metrics['loss']:.4f} Acc {train_metrics['acc']:.4f} F1 {train_metrics['f1']:.4f} | "
          f"Valid | Loss {val_metrics['loss']:.4f} Acc {val_metrics['acc']:.4f} F1 {val_metrics['f1']:.4f}")

    if val_metrics["f1"] > best_f1:
        best_f1 = val_metrics["f1"]
        torch.save(model.state_dict(), f"./best_swin_tiny_v3_{best_f1:.4f}.pt")
        print(f"‚úÖ Best model saved (F1={best_f1:.4f})")

# ------------------------
# üßæ Inference (TTA)
# ------------------------
print("\nüöÄ Inference with rotation TTA...")
model.load_state_dict(torch.load(f"./best_swin_tiny_v3_{best_f1:.4f}.pt"))
model.eval()

sub = pd.read_csv(SUB_PATH)
preds_list = []

angles = [0, 90, 180, 270]

for name in tqdm(sub["ID"], desc="Inference"):
    matches = glob.glob(os.path.join(TEST_PATH, "**", name), recursive=True)
    if not matches:
        raise FileNotFoundError(f"ÌÖåÏä§Ìä∏ Ïù¥ÎØ∏ÏßÄ ÏóÜÏùå: {name}")
    img = np.array(Image.open(matches[0]).convert("RGB"))

    tta_preds = []
    for ang in angles:
        rotated = cv2.rotate(img, {
            0: cv2.ROTATE_90_CLOCKWISE,
            90: cv2.ROTATE_180,
            180: cv2.ROTATE_90_COUNTERCLOCKWISE,
            270: None
        }[ang]) if ang in [0, 90, 180] else img

        tensor = tst_transform(image=rotated)["image"].unsqueeze(0).to(device)
        with torch.no_grad():
            pred = torch.softmax(model(tensor), dim=1)
        tta_preds.append(pred.cpu().numpy())

    mean_pred = np.mean(tta_preds, axis=0)
    preds_list.append(np.argmax(mean_pred))

sub["target"] = preds_list
sub.to_csv("v3_swin_tiny_split.csv", index=False)
print("üéØ Inference complete! Saved to v3_swin_tiny_split.csv")


üî• Using cuda (NVIDIA GeForce RTX 3090)
‚úÖ Split ÏôÑÎ£å: (1256, 4) (314, 4)


  original_init(self, **validated_kwargs)
  A.GaussNoise(var_limit=(5, 25), p=0.3),
  scaler = torch.cuda.amp.GradScaler()
                                                                                                                                                


[Epoch 1/30] Train | Loss 1.8524 Acc 0.4443 F1 0.4057 | Valid | Loss 1.0766 Acc 0.6879 F1 0.6227
‚úÖ Best model saved (F1=0.6227)


                                                                                                                                                


[Epoch 2/30] Train | Loss 1.0728 Acc 0.7110 F1 0.6830 | Valid | Loss 0.7481 Acc 0.8248 F1 0.7832
‚úÖ Best model saved (F1=0.7832)


                                                                                                                                                


[Epoch 3/30] Train | Loss 0.8732 Acc 0.7914 F1 0.7625 | Valid | Loss 0.7052 Acc 0.8376 F1 0.7981
‚úÖ Best model saved (F1=0.7981)


                                                                                                                                                


[Epoch 4/30] Train | Loss 0.7720 Acc 0.8153 F1 0.7926 | Valid | Loss 0.6354 Acc 0.8662 F1 0.8310
‚úÖ Best model saved (F1=0.8310)


                                                                                                                                                


[Epoch 5/30] Train | Loss 0.7071 Acc 0.8527 F1 0.8317 | Valid | Loss 0.6611 Acc 0.8631 F1 0.8167


                                                                                                                                                


[Epoch 6/30] Train | Loss 0.6565 Acc 0.8662 F1 0.8472 | Valid | Loss 0.5771 Acc 0.8854 F1 0.8784
‚úÖ Best model saved (F1=0.8784)


                                                                                                                                                


[Epoch 7/30] Train | Loss 0.6229 Acc 0.8830 F1 0.8698 | Valid | Loss 0.5655 Acc 0.9013 F1 0.8849
‚úÖ Best model saved (F1=0.8849)


                                                                                                                                                


[Epoch 8/30] Train | Loss 0.5965 Acc 0.8973 F1 0.8857 | Valid | Loss 0.5953 Acc 0.8949 F1 0.8840


                                                                                                                                                


[Epoch 9/30] Train | Loss 0.5649 Acc 0.9021 F1 0.8911 | Valid | Loss 0.5516 Acc 0.9045 F1 0.8915
‚úÖ Best model saved (F1=0.8915)


                                                                                                                                                


[Epoch 10/30] Train | Loss 0.5439 Acc 0.9100 F1 0.9015 | Valid | Loss 0.5328 Acc 0.9076 F1 0.8995
‚úÖ Best model saved (F1=0.8995)


                                                                                                                                                


[Epoch 11/30] Train | Loss 0.5452 Acc 0.9116 F1 0.9021 | Valid | Loss 0.5511 Acc 0.8949 F1 0.8813


                                                                                                                                                


[Epoch 12/30] Train | Loss 0.5233 Acc 0.9252 F1 0.9158 | Valid | Loss 0.5534 Acc 0.9013 F1 0.8905


                                                                                                                                                


[Epoch 13/30] Train | Loss 0.5314 Acc 0.9108 F1 0.9025 | Valid | Loss 0.4924 Acc 0.9236 F1 0.9111
‚úÖ Best model saved (F1=0.9111)


                                                                                                                                                


[Epoch 14/30] Train | Loss 0.5042 Acc 0.9244 F1 0.9180 | Valid | Loss 0.5248 Acc 0.9172 F1 0.9068


                                                                                                                                                


[Epoch 15/30] Train | Loss 0.4979 Acc 0.9283 F1 0.9205 | Valid | Loss 0.5234 Acc 0.9204 F1 0.9152
‚úÖ Best model saved (F1=0.9152)


                                                                                                                                                


[Epoch 16/30] Train | Loss 0.4691 Acc 0.9371 F1 0.9339 | Valid | Loss 0.5650 Acc 0.9108 F1 0.9015


                                                                                                                                                


[Epoch 17/30] Train | Loss 0.4550 Acc 0.9411 F1 0.9335 | Valid | Loss 0.5068 Acc 0.9204 F1 0.9137


                                                                                                                                                


[Epoch 18/30] Train | Loss 0.4399 Acc 0.9554 F1 0.9521 | Valid | Loss 0.5462 Acc 0.9076 F1 0.8980


                                                                                                                                                


[Epoch 19/30] Train | Loss 0.4422 Acc 0.9634 F1 0.9606 | Valid | Loss 0.4925 Acc 0.9268 F1 0.9249
‚úÖ Best model saved (F1=0.9249)


                                                                                                                                                


[Epoch 20/30] Train | Loss 0.4467 Acc 0.9522 F1 0.9489 | Valid | Loss 0.5119 Acc 0.9172 F1 0.9102


                                                                                                                                                


[Epoch 21/30] Train | Loss 0.4234 Acc 0.9626 F1 0.9596 | Valid | Loss 0.5146 Acc 0.9172 F1 0.9094


                                                                                                                                                


[Epoch 22/30] Train | Loss 0.4342 Acc 0.9538 F1 0.9499 | Valid | Loss 0.5023 Acc 0.9140 F1 0.9065


                                                                                                                                                


[Epoch 23/30] Train | Loss 0.3904 Acc 0.9801 F1 0.9792 | Valid | Loss 0.5158 Acc 0.9108 F1 0.9039


                                                                                                                                                


[Epoch 24/30] Train | Loss 0.4059 Acc 0.9705 F1 0.9684 | Valid | Loss 0.5157 Acc 0.9236 F1 0.9185


                                                                                                                                                


[Epoch 25/30] Train | Loss 0.4054 Acc 0.9689 F1 0.9668 | Valid | Loss 0.5054 Acc 0.9268 F1 0.9237


                                                                                                                                                


[Epoch 26/30] Train | Loss 0.3916 Acc 0.9785 F1 0.9785 | Valid | Loss 0.5070 Acc 0.9140 F1 0.9092


                                                                                                                                                


[Epoch 27/30] Train | Loss 0.3700 Acc 0.9865 F1 0.9859 | Valid | Loss 0.5080 Acc 0.9236 F1 0.9211


                                                                                                                                                


[Epoch 28/30] Train | Loss 0.3885 Acc 0.9777 F1 0.9743 | Valid | Loss 0.5111 Acc 0.9204 F1 0.9171


                                                                                                                                                


[Epoch 29/30] Train | Loss 0.3889 Acc 0.9777 F1 0.9768 | Valid | Loss 0.5128 Acc 0.9268 F1 0.9237


                                                                                                                                                


[Epoch 30/30] Train | Loss 0.4085 Acc 0.9729 F1 0.9708 | Valid | Loss 0.5132 Acc 0.9204 F1 0.9178

üöÄ Inference with rotation TTA...


Inference: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3140/3140 [02:06<00:00, 24.78it/s]

üéØ Inference complete! Saved to v3_swin_tiny_split.csv



