In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
print(f"현재 사용 가능한 GPU 개수: {torch.cuda.device_count()}")
print(f"현재 GPU 이름: {torch.cuda.get_device_name(0)}")

In [None]:
import os, random
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, confusion_matrix
from sklearn.model_selection import train_test_split


In [None]:
LABEL_CSV = Path("/home/khdp-user/workspace/dataset/CSV/GT_label.csv")
PATCH_ROOT = Path("/home/khdp-user/workspace/MIL_patch")
PATCH_SIZE = 256
BATCH_SIZE = 1
EPOCHS = 50
LR = 1e-4
OUT_DIR = Path("/home/khdp-user/workspace/MIL_run_cls")
OUT_DIR.mkdir(exist_ok=True, parents=True)
CSV_PATH = os.path.join(OUT_DIR, "dataset.csv")
BEST_MODEL_PATH = OUT_DIR / "best_model.pt"
SPLIT_CSV_PATH  = OUT_DIR / "dataset.csv"

SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def resplit_mil_dataframe(
    df: pd.DataFrame,
    train_ratio: float = 0.8,
    val_ratio: float = 0.1,
    test_ratio: float = 0.1,
    random_state: int = 42,
    gt_col: str = "GT",
    split_col: str = "split",
):
    assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \
        "Ratios must sum to 1."

    mil_df = df[
        (df[split_col] == "train") &
        (df[gt_col].notna())
    ].copy()

    if len(mil_df) == 0:
        raise ValueError("No samples found with split=='train' and GT not null.")


    train_df, temp_df = train_test_split(
        mil_df,
        test_size=(1 - train_ratio),
        random_state=random_state,
        shuffle=True,
        stratify=mil_df[gt_col] if mil_df[gt_col].nunique() > 1 else None
    )

    val_size = val_ratio / (val_ratio + test_ratio)

    val_df, test_df = train_test_split(
        temp_df,
        test_size=(1 - val_size),
        random_state=random_state,
        shuffle=True,
        stratify=temp_df[gt_col] if temp_df[gt_col].nunique() > 1 else None
    )

    mil_df.loc[train_df.index, split_col] = "train"
    mil_df.loc[val_df.index,   split_col] = "val"
    mil_df.loc[test_df.index,  split_col] = "test"

    mil_df.to_csv(SPLIT_CSV_PATH, index=False)

    print("[MIL Resplit Summary]")
    print("Total used:", len(mil_df))
    print("Train:", len(train_df))
    print("Val  :", len(val_df))
    print("Test :", len(test_df))
    return mil_df


In [None]:
label_df = pd.read_csv(LABEL_CSV)
slide_split = resplit_mil_dataframe(label_df)

In [None]:
def get_transforms():
    train_tf = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    val_tf = A.Compose([
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    return train_tf, val_tf

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode="min"):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def step(self, score):
        if self.best_score is None:
            self.best_score = score
            return True

        improved = (
            score < self.best_score - self.min_delta
            if self.mode == "min"
            else score > self.best_score + self.min_delta
        )

        if improved:
            self.best_score = score
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False


# ============================================================
# MIL Dataset: Bag = slide, Instance = patch
# Folder: PATCH_ROOT / SlideName / images / *.png
# ============================================================
class MILDataset(Dataset):
    def __init__(self, df_slide, patch_root: Path, transform=None,
                 max_patches=None, sample_mode="random"):
        """
        df_slide: SlideName, GT, split columns
        max_patches: 한 slide에서 사용할 patch 최대 수 (None=전부)
        sample_mode: "random" | "first"
        """
        self.df = df_slide.reset_index(drop=True)
        self.patch_root = Path(patch_root)
        self.transform = transform
        self.max_patches = max_patches
        self.sample_mode = sample_mode

        # sanity check columns
        for c in ["SlideName", "GT", "split"]:
            if c not in self.df.columns:
                raise ValueError(f"Missing required column: {c}")

    def __len__(self):
        return len(self.df)

    def _collect_patch_paths(self, slide_name: str):
        img_dir = self.patch_root / slide_name / "images"
        if not img_dir.exists():
            return []
        return sorted(img_dir.glob("*.png"))

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        slide_name = str(row["SlideName"])
        gt_val = row["GT"]

        # GT 예외 처리: NaN이면 사용 불가
        if pd.isna(gt_val):
            raise RuntimeError(f"GT is NaN for slide {slide_name}. This should not be in MIL split set.")

        y = torch.tensor(int(gt_val)).long()

        patch_paths = self._collect_patch_paths(slide_name)
        if len(patch_paths) == 0:
            # MIL에서 patch 없는 slide는 학습 불가 -> 명확히 터뜨림
            raise RuntimeError(f"No patches found for slide: {slide_name} at {self.patch_root/slide_name/'images'}")

        # sampling
        if self.max_patches is not None and len(patch_paths) > self.max_patches:
            if self.sample_mode == "first":
                patch_paths = patch_paths[:self.max_patches]
            else:
                patch_paths = list(np.random.choice(patch_paths, self.max_patches, replace=False))

        imgs = []
        for p in patch_paths:
            img = cv2.imread(str(p))
            if img is None:
                continue
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            if self.transform:
                img = self.transform(image=img)["image"]
            else:
                # fallback
                img = torch.from_numpy(img).permute(2,0,1).float() / 255.0

            imgs.append(img)

        if len(imgs) == 0:
            raise RuntimeError(f"All patches failed to load for slide: {slide_name}")

        x = torch.stack(imgs, dim=0)  # (N,3,H,W)
        return x, y, slide_name


class MILTopKMean(nn.Module):
    def __init__(self, backbone="resnet50", num_classes=2, top_k=10):
        super().__init__()
        self.top_k = int(top_k)

        self.backbone = timm.create_model(
            backbone,
            pretrained=True,
            num_classes=0,
            global_pool="avg",
        )
        feat_dim = self.backbone.num_features
        self.classifier = nn.Linear(feat_dim, num_classes)

    def forward(self, x):
        feats = self.backbone(x)         # (N,D)
        inst_logits = self.classifier(feats)  # (N,2)
        pos = inst_logits[:, 1]  # (N,)

        k = min(self.top_k, pos.size(0))
        topk_pos, _ = torch.topk(pos, k=k)
        slide_pos = topk_pos.mean()  # scalar

        slide_logits = torch.stack(
            [torch.tensor(0.0, device=x.device), slide_pos]
        ).unsqueeze(0)  # (1,2)

        return slide_logits

@torch.no_grad()
def evaluate_mil(model, loader):
    model.eval()

    y_true = []
    y_prob = []
    y_pred = []

    for x, y, _slide in tqdm(loader, desc="Eval", leave=False):
        # loader batch_size=1 => x: (1,N,3,H,W)
        x = x.squeeze(0).to(DEVICE)
        y = y.to(DEVICE)

        logits = model(x)  # (1,2)
        prob = torch.softmax(logits, dim=1)[:, 1]  # (1,)
        pred = torch.argmax(logits, dim=1)        # (1,)

        y_true.append(int(y.item()))
        y_prob.append(float(prob.item()))
        y_pred.append(int(pred.item()))

    acc = accuracy_score(y_true, y_pred)
    macro_f1 = f1_score(y_true, y_pred, average="macro")

    # AUC는 단일 클래스만 있으면 계산 불가
    auc = None
    if len(set(y_true)) > 1:
        auc = roc_auc_score(y_true, y_prob)

    return {
        "acc": acc,
        "macro_f1": macro_f1,
        "auc": auc,
        "cm": confusion_matrix(y_true, y_pred),
    }

# ============================================================
# Train
# ============================================================
def train_mil(slide_split_df, top_k=10, max_patches=300, sample_mode="random"):
    train_tf, val_tf = get_transforms()

    df_tr = slide_split_df[slide_split_df["split"] == "train"].reset_index(drop=True)
    df_va = slide_split_df[slide_split_df["split"] == "val"].reset_index(drop=True)

    # DataLoader: slide 단위로 배치 처리 => batch_size=1
    SLIDE_BATCH_SIZE = 1

    dl_tr = DataLoader(
        MILDataset(df_tr, PATCH_ROOT, transform=train_tf, max_patches=max_patches, sample_mode=sample_mode),
        batch_size=SLIDE_BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True,
    )
    dl_va = DataLoader(
        MILDataset(df_va, PATCH_ROOT, transform=val_tf, max_patches=max_patches, sample_mode=sample_mode),
        batch_size=SLIDE_BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True,
    )

    model = MILTopKMean(backbone="resnet50", num_classes=2, top_k=top_k).to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    early = EarlyStopping(patience=5, min_delta=1e-4, mode="min")

    for epoch in range(EPOCHS):
        model.train()
        train_losses = []

        pbar = tqdm(dl_tr, desc=f"[Train] Epoch {epoch+1}/{EPOCHS}")
        for x, y, _slide in pbar:
            # x: (1,N,3,H,W) -> (N,3,H,W)
            x = x.squeeze(0).to(DEVICE)
            y = y.to(DEVICE)  # (1,)

            logits = model(x)               # (1,2)
            loss = criterion(logits, y)     # y shape ok (1,)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            pbar.set_postfix(loss=f"{loss.item():.4f}", n_patches=x.size(0))

        train_loss = float(np.mean(train_losses))

        # Validation loss/metrics
        model.eval()
        val_losses = []
        for x, y, _slide in dl_va:
            x = x.squeeze(0).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)
            loss = criterion(logits, y)
            val_losses.append(loss.item())
        val_loss = float(np.mean(val_losses)) if val_losses else 0.0

        is_best = early.step(val_loss)
        if is_best:
            torch.save(model.state_dict(), BEST_MODEL_PATH)

        val_metrics = evaluate_mil(model, dl_va)
        auc_str = "None" if val_metrics["auc"] is None else f"{val_metrics['auc']:.4f}"

        print(
            f"Epoch {epoch+1:02d} | "
            f"train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | "
            f"val_acc={val_metrics['acc']:.4f} | val_macro_f1={val_metrics['macro_f1']:.4f} | val_auc={auc_str} | "
            f"best_val_loss={early.best_score:.4f} | patience={early.counter}/{early.patience}"
        )

        if early.early_stop:
            print("[Early Stop] Training stopped.")
            break

    # load best
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
    print(f"[DONE] Best model saved: {BEST_MODEL_PATH}")
    return model

In [None]:
model = train_mil(
    slide_split_df=slide_split,
    top_k=10,          
    max_patches=300,   
    sample_mode="random"
)