In [2]:
import os
import random
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import json
import platform

In [3]:
SEED = 42
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

LABEL_CSV = Path("/home/khdp-user/workspace/dataset/CSV/m0m1_label.csv")
PATCH_ROOT = Path("/home/khdp-user/workspace/dataset/Glom_M0M1")
PATCH_SIZE = 512
BATCH_SIZE = 64
EPOCHS = 50
LR = 1e-4
OUT_DIR = Path("/home/khdp-user/workspace/m0m1_run_cls")
OUT_DIR.mkdir(exist_ok=True, parents=True)
CSV_PATH = os.path.join(OUT_DIR, "dataset.csv")
BEST_MODEL_PATH = OUT_DIR / "best_model.pt"
SPLIT_CSV_PATH  = OUT_DIR / "dataset.csv"

In [4]:
def build_patch_index(root: Path, exts=(".png", ".jpg", ".jpeg", ".tif", ".tiff")):
    idx = {}
    for ext in exts:
        for p in root.rglob(f"*{ext}"):
            idx[p.name] = p
    return idx

print("[Index] scanning patches...")
patch_index = build_patch_index(PATCH_ROOT)
print("Total patches indexed:", len(patch_index))

[Index] scanning patches...
Total patches indexed: 1190


In [5]:
label_df = pd.read_csv(LABEL_CSV)

rows = []
missing = 0
for _, r in label_df.iterrows():
    name = r["patch_name"]
    p = patch_index.get(name)
    if p is None:
        missing += 1
        continue

    y = 1 if r["target"].lower() == "m1" else 0
    rows.append({
        "name": name,
        "path": str(p),
        "y": y
    })
df = pd.DataFrame(rows)
print("df shape:", df.shape, "missing:", missing)
print(df["y"].value_counts())


df shape: (1190, 3) missing: 0
y
0    661
1    529
Name: count, dtype: int64


In [6]:
def get_slide_id_from_patch_name(patch_name: str):
    return patch_name.split("_PAS")[0]

def stratified_split_slide(
    slide_df,
    y_col="slide_y",
    train_ratio=0.8,
    val_ratio=0.1,
    seed=42,
):
    rng = np.random.RandomState(seed)
    slide_split = {}

    for cls, sub in slide_df.groupby(y_col):
        slide_ids = sub["slide_id"].values
        rng.shuffle(slide_ids)

        n = len(slide_ids)
        n_tr = int(n * train_ratio)
        n_va = int(n * val_ratio)

        for sid in slide_ids[:n_tr]:
            slide_split[sid] = "train"
        for sid in slide_ids[n_tr:n_tr+n_va]:
            slide_split[sid] = "val"
        for sid in slide_ids[n_tr+n_va:]:
            slide_split[sid] = "test"

    return slide_split


df["slide_id"] = df["name"].apply(get_slide_id_from_patch_name)

slide_df = (
    df.groupby("slide_id")["y"]
      .max()
      .reset_index()
      .rename(columns={"y": "slide_y"})
)

slide_split = stratified_split_slide(slide_df)

df["split"] = df["slide_id"].map(slide_split)
assert df["split"].isna().sum() == 0

print("Patch-level distribution")
print(pd.crosstab(df["split"], df["y"]))

slide_view = (
    df[["slide_id", "split"]]
    .drop_duplicates("slide_id")
)
print("\nSlide-level distribution (#slides)")
print(slide_view["split"].value_counts())

slide_view = slide_df.merge(
    df[["slide_id", "split"]].drop_duplicates("slide_id"),
    on="slide_id"
)
df.to_csv(CSV_PATH, index=False)
print(f"[OK] CSV saved: {CSV_PATH}  (patches={len(df)})")

Patch-level distribution
y        0    1
split          
test    98   51
train  465  413
val     98   65

Slide-level distribution (#slides)
split
train    81
test     12
val       9
Name: count, dtype: int64
[OK] CSV saved: /home/khdp-user/workspace/m0m1_run_cls/dataset.csv  (patches=1190)


In [7]:
class PatchClsDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row["path"])
        if img is None:
            raise RuntimeError(row["path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            img = self.transform(image=img)["image"]

        y = torch.tensor(row["y"]).long()
        return img, y
    
def get_transforms():
    train_tf = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    val_tf = A.Compose([
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    return train_tf, val_tf

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode="min"):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode

        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def step(self, score):
        if self.best_score is None:
            self.best_score = score
            return True

        improved = (
            score < self.best_score - self.min_delta
            if self.mode == "min"
            else score > self.best_score + self.min_delta
        )

        if improved:
            self.best_score = score
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False


def build_model():
    model = timm.create_model(
        "resnet50",
        pretrained=True,
        num_classes=2 
    )
    return model.to(DEVICE)

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()

    total_loss = 0.0
    correct = 0
    n = 0

    for x, y in loader:
        x, y = x.to(DEVICE), y.to(DEVICE)

        logits = model(x)
        loss = criterion(logits, y)

        bs = x.size(0)
        total_loss += loss.item() * bs
        n += bs

        pred = torch.argmax(logits, dim=1)
        correct += (pred == y).sum().item()

    val_loss = total_loss / max(n, 1)
    val_acc  = correct / max(n, 1)
    return val_loss, val_acc


def train(df):
    train_tf, val_tf = get_transforms()

    df_tr = df[df.split == "train"]
    df_va = df[df.split == "val"]

    dl_tr = DataLoader(
        PatchClsDataset(df_tr, train_tf),
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    dl_va = DataLoader(
        PatchClsDataset(df_va, val_tf),
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )

    model = build_model()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)

    early_stopper = EarlyStopping(
        patience=5,
        min_delta=1e-4,
        mode="min"
    )

    for epoch in range(EPOCHS):
        # ==================
        # Train
        # ==================
        model.train()
        train_losses = []

        pbar = tqdm(dl_tr, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for x, y in pbar:
            x, y = x.to(DEVICE), y.to(DEVICE)

            logits = model(x)
            loss = criterion(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            pbar.set_postfix(train_loss=f"{loss.item():.4f}")

        train_loss = float(np.mean(train_losses))

        # ==================
        # Validation
        # ==================
        val_loss, val_acc = validate(model, dl_va, criterion)

        # ==================
        # Early Stopping
        # ==================
        is_best = early_stopper.step(val_loss)
        if is_best:
            torch.save(model.state_dict(), BEST_MODEL_PATH)

        # ==================
        # Logging
        # ==================
        print(
            f"Epoch {epoch+1}/{EPOCHS} | "
            f"train_loss={train_loss:.4f} | "
            f"val_loss={val_loss:.4f} | "
            f"val_acc={val_acc:.4f} | "
            f"best_val_loss={early_stopper.best_score:.4f} | "
            f"patience={early_stopper.counter}/{early_stopper.patience}"
        )

        if early_stopper.early_stop:
            print("[Early Stop] Training stopped.")
            break

    print(f"[DONE] Best model saved to {BEST_MODEL_PATH}")
    model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
    return model

In [8]:
model = train(df)

Epoch 1/50: 100%|██████████| 14/14 [00:17<00:00,  1.26s/it, train_loss=0.6778]


Epoch 1/50 | train_loss=0.6869 | val_loss=0.6752 | val_acc=0.6319 | best_val_loss=0.6752 | patience=0/5


Epoch 2/50: 100%|██████████| 14/14 [00:12<00:00,  1.12it/s, train_loss=0.6861]


Epoch 2/50 | train_loss=0.6754 | val_loss=0.6552 | val_acc=0.7117 | best_val_loss=0.6552 | patience=0/5


Epoch 3/50: 100%|██████████| 14/14 [00:09<00:00,  1.49it/s, train_loss=0.6820]


Epoch 3/50 | train_loss=0.6623 | val_loss=0.6258 | val_acc=0.7669 | best_val_loss=0.6258 | patience=0/5


Epoch 4/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.6203]


Epoch 4/50 | train_loss=0.6449 | val_loss=0.5998 | val_acc=0.7546 | best_val_loss=0.5998 | patience=0/5


Epoch 5/50: 100%|██████████| 14/14 [00:09<00:00,  1.51it/s, train_loss=0.6429]


Epoch 5/50 | train_loss=0.6289 | val_loss=0.5780 | val_acc=0.7546 | best_val_loss=0.5780 | patience=0/5


Epoch 6/50: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, train_loss=0.5897]


Epoch 6/50 | train_loss=0.6104 | val_loss=0.5504 | val_acc=0.7546 | best_val_loss=0.5504 | patience=0/5


Epoch 7/50: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, train_loss=0.5635]


Epoch 7/50 | train_loss=0.5925 | val_loss=0.5460 | val_acc=0.7607 | best_val_loss=0.5460 | patience=0/5


Epoch 8/50: 100%|██████████| 14/14 [00:09<00:00,  1.49it/s, train_loss=0.5705]


Epoch 8/50 | train_loss=0.5685 | val_loss=0.5394 | val_acc=0.7546 | best_val_loss=0.5394 | patience=0/5


Epoch 9/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.6061]


Epoch 9/50 | train_loss=0.5486 | val_loss=0.5320 | val_acc=0.7485 | best_val_loss=0.5320 | patience=0/5


Epoch 10/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.5460]


Epoch 10/50 | train_loss=0.5179 | val_loss=0.5058 | val_acc=0.7423 | best_val_loss=0.5058 | patience=0/5


Epoch 11/50: 100%|██████████| 14/14 [00:09<00:00,  1.51it/s, train_loss=0.5245]


Epoch 11/50 | train_loss=0.4859 | val_loss=0.5132 | val_acc=0.7546 | best_val_loss=0.5058 | patience=1/5


Epoch 12/50: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, train_loss=0.4353]


Epoch 12/50 | train_loss=0.4422 | val_loss=0.4752 | val_acc=0.7853 | best_val_loss=0.4752 | patience=0/5


Epoch 13/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.4964]


Epoch 13/50 | train_loss=0.4034 | val_loss=0.4846 | val_acc=0.7791 | best_val_loss=0.4752 | patience=1/5


Epoch 14/50: 100%|██████████| 14/14 [00:09<00:00,  1.53it/s, train_loss=0.2954]


Epoch 14/50 | train_loss=0.3624 | val_loss=0.4606 | val_acc=0.7607 | best_val_loss=0.4606 | patience=0/5


Epoch 15/50: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, train_loss=0.3350]


Epoch 15/50 | train_loss=0.3210 | val_loss=0.5048 | val_acc=0.7669 | best_val_loss=0.4606 | patience=1/5


Epoch 16/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.3751]


Epoch 16/50 | train_loss=0.2782 | val_loss=0.4559 | val_acc=0.8037 | best_val_loss=0.4559 | patience=0/5


Epoch 17/50: 100%|██████████| 14/14 [00:09<00:00,  1.51it/s, train_loss=0.1873]


Epoch 17/50 | train_loss=0.2269 | val_loss=0.4972 | val_acc=0.7607 | best_val_loss=0.4559 | patience=1/5


Epoch 18/50: 100%|██████████| 14/14 [00:09<00:00,  1.49it/s, train_loss=0.2058]


Epoch 18/50 | train_loss=0.1843 | val_loss=0.5871 | val_acc=0.7546 | best_val_loss=0.4559 | patience=2/5


Epoch 19/50: 100%|██████████| 14/14 [00:09<00:00,  1.50it/s, train_loss=0.1711]


Epoch 19/50 | train_loss=0.1703 | val_loss=0.4917 | val_acc=0.7975 | best_val_loss=0.4559 | patience=3/5


Epoch 20/50: 100%|██████████| 14/14 [00:09<00:00,  1.54it/s, train_loss=0.1680]


Epoch 20/50 | train_loss=0.1559 | val_loss=0.5876 | val_acc=0.7669 | best_val_loss=0.4559 | patience=4/5


Epoch 21/50: 100%|██████████| 14/14 [00:09<00:00,  1.52it/s, train_loss=0.1626]


Epoch 21/50 | train_loss=0.1443 | val_loss=0.6707 | val_acc=0.7362 | best_val_loss=0.4559 | patience=5/5
[Early Stop] Training stopped.
[DONE] Best model saved to /home/khdp-user/workspace/m0m1_run_cls/best_model.pt


In [17]:
def save_training_env(out_dir):
    env = {
        "TASK_TYPE": 'classification',
        "PATCH_SIZE": PATCH_SIZE,
        "BATCH_SIZE": BATCH_SIZE,
        "EPOCHS": EPOCHS,
        "LR": LR,
        "TEST_RATIO": 0.1,
        "VAL_RATIO": 0.1,
        "target_mag": 10.0,
        "DEVICE": DEVICE,
        "cuda_available": torch.cuda.is_available(),
        "torch_version": torch.__version__,
        "python_version": platform.python_version(),
    }

    save_path = os.path.join(out_dir, "training_env.json")
    with open(save_path, "w") as f:
        json.dump(env, f, indent=2)

    print(f"[OK] Training environment saved: {save_path}")
save_training_env(OUT_DIR)

[OK] Training environment saved: /home/khdp-user/workspace/m0m1_run_cls/training_env.json
