In [None]:
import os
import cv2
import random
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import platform
import json

# Datasets split

In [None]:
# ============================================================
# CONFIG
# ============================================================
ROOT_DIR = "/home/khdp-user/workspace/Infla_patch"
TASK_TYPE = "binary"      # "binary" | "multiclass"
LAYER_IDS = [1]           # binary면 Len == 1
PATCH_SIZE = 512
BATCH_SIZE = 16
EPOCHS = 50
LR = 1e-4
TEST_RATIO = 0.2
VAL_RATIO = 0.2   # train 중 val 비율
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_DIR = "/home/khdp-user/workspace/Infla_run_seg"
CSV_PATH = os.path.join(OUT_DIR, "dataset.csv")
os.makedirs(OUT_DIR, exist_ok=True)

assert TASK_TYPE in ["binary", "multiclass"]
if TASK_TYPE == "binary":
    assert len(LAYER_IDS) == 1

In [None]:
# ============================================================
# DATASET SPLIT
# ============================================================
def build_dataset_csv():
    root = Path(ROOT_DIR)
    slides = sorted([p for p in root.iterdir() if p.is_dir()])
    slide_ids = [s.name for s in slides]

    random.shuffle(slide_ids)

    n = len(slide_ids)
    n_test = int(n * TEST_RATIO)
    n_val = int((n - n_test) * VAL_RATIO)

    test_slides = set(slide_ids[:n_test])
    val_slides = set(slide_ids[n_test:n_test+n_val])
    train_slides = set(slide_ids[n_test+n_val:])

    def split_of(slide):
        if slide in train_slides: return "train"
        if slide in val_slides: return "val"
        return "test"

    rows = []

    for slide in slides:
        slide_id = slide.name
        split = split_of(slide_id)

        img_dir = slide / "images"
        mask_root = slide / "masks"

        for img_path in img_dir.glob("*.png"):
            # LAYER_IDS 중 하나라도 mask 있으면 포함
            valid = False
            for lid in LAYER_IDS:
                mp = mask_root / f"layer{lid}" / img_path.name
                if mp.exists():
                    valid = True
                    break
            if not valid:
                continue

            rows.append({
                "name": img_path.name,
                "path": str(img_path),
                "split": split,
            })

    df = pd.DataFrame(rows)
    df.to_csv(CSV_PATH, index=False)
    print(f"[OK] CSV saved: {CSV_PATH}  (patches={len(df)})")
    return df

In [None]:
df = build_dataset_csv()

In [None]:
pd.read_csv(CSV_PATH).value_counts('split')

# Training loop

In [None]:
# ============================================================
# DATASET
# ============================================================
class SegDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row["path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        slide_dir = Path(row["path"]).parents[1]
        mask_root = slide_dir / "masks"

        masks = []
        for lid in LAYER_IDS:
            mp = mask_root / f"layer{lid}" / row["name"]
            if mp.exists():
                m = cv2.imread(str(mp), 0) > 127
            else:
                m = np.zeros(img.shape[:2], bool)
            masks.append(m)

        if TASK_TYPE == "binary":
            mask = masks[0].astype(np.float32)
        else:
            # multiclass (background=0)
            mask = np.zeros(img.shape[:2], np.int64)
            for i, m in enumerate(masks):
                mask[m] = i + 1

        if self.transform:
            out = self.transform(image=img, mask=mask)
            img, mask = out["image"], out["mask"]
        
        img = img.float()
        if TASK_TYPE == "binary":
            mask = mask.float()     
        else:
            mask = mask.long()
        return img, mask

    
# ============================================================
# TRANSFORMS
# ============================================================
def get_transforms():
    train_tf = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(),
        A.Resize(PATCH_SIZE,PATCH_SIZE),
        ToTensorV2(),
    ])
    val_tf = A.Compose([
        A.Resize(PATCH_SIZE,PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    return train_tf, val_tf

# ============================================================
# TRAINING
# ============================================================
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode="min"):
        """
        patience : 개선 없이 기다릴 epoch 수
        min_delta: 개선으로 인정할 최소 변화량
        mode     : 'min' (loss) or 'max' (metric)
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode

        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def step(self, score):
        if self.best_score is None:
            self.best_score = score
            return True  # best 갱신

        improved = (
            score < self.best_score - self.min_delta
            if self.mode == "min"
            else score > self.best_score + self.min_delta
        )

        if improved:
            self.best_score = score
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()

    total_loss = 0.0
    n = 0

    dice_list = []
    iou_list = []

    for img, mask in loader:
        img, mask = img.to(DEVICE), mask.to(DEVICE)
        pred = model(img)

        if TASK_TYPE == "binary":
            loss = criterion(pred.squeeze(1), mask)
            prob = torch.sigmoid(pred)
            pred_bin = (prob > 0.5).float()
            gt = mask.unsqueeze(1)
        else:
            loss = criterion(pred, mask)
            prob = torch.softmax(pred, dim=1)
            pred_bin = torch.argmax(prob, dim=1)
            gt = mask

        total_loss += loss.item() * img.size(0)
        n += img.size(0)

        # ---- Dice / IoU (batch 평균)
        if TASK_TYPE == "binary":
            p = pred_bin.view(pred_bin.size(0), -1)
            g = gt.view(gt.size(0), -1)

            inter = (p * g).sum(dim=1)
            union = p.sum(dim=1) + g.sum(dim=1)

            dice = (2 * inter + 1e-6) / (union + 1e-6)
            iou = (inter + 1e-6) / (p.sum(dim=1) + g.sum(dim=1) - inter + 1e-6)

            dice_list.append(dice.mean().item())
            iou_list.append(iou.mean().item())
        else:
            num_classes = pred.shape[1]
            dice_per_class = []
            iou_per_class = []

            for c in range(1, num_classes):  # background 제외
                p = (pred_bin == c).float().view(pred_bin.size(0), -1)
                g = (gt == c).float().view(gt.size(0), -1)

                inter = (p * g).sum(dim=1)
                union = p.sum(dim=1) + g.sum(dim=1)

                dice = (2 * inter + 1e-6) / (union + 1e-6)
                iou  = (inter + 1e-6) / (p.sum(dim=1) + g.sum(dim=1) - inter + 1e-6)

                dice_per_class.append(dice.mean())
                iou_per_class.append(iou.mean())
            dice_list.append(torch.stack(dice_per_class).mean().item())
            iou_list.append(torch.stack(iou_per_class).mean().item())



    val_loss = total_loss / max(n, 1)
    val_dice = float(np.mean(dice_list)) if dice_list else 0.0
    val_iou  = float(np.mean(iou_list)) if iou_list else 0.0

    return val_loss, val_dice, val_iou


def train_model(df):
    train_tf, val_tf = get_transforms()

    df_tr = df[df.split == "train"]
    df_va = df[df.split == "val"]

    ds_tr = SegDataset(df_tr, train_tf)
    ds_va = SegDataset(df_va, val_tf)

    dl_tr = DataLoader(
        ds_tr, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=4, pin_memory=True
    )
    dl_va = DataLoader(
        ds_va, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=4, pin_memory=True
    )

    # model / loss
    if TASK_TYPE == "binary":
        model = smp.Unet(
            encoder_name="resnet50",
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )
        criterion = smp.losses.DiceLoss(mode="binary", from_logits=True)
    else:
        model = smp.Unet(
            encoder_name="resnet50",
            encoder_weights="imagenet",
            in_channels=3,
            classes=len(LAYER_IDS) + 1,
            activation=None,
        )
        criterion = smp.losses.DiceLoss(mode="multiclass", from_logits=True)

    model.to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    early_stopper = EarlyStopping(patience=5, min_delta=1e-5, mode="min")
    best_path = os.path.join(OUT_DIR, "best_model.pt")
    for epoch in range(EPOCHS):
        # ==================
        # Train
        # ==================
        model.train()
        train_losses = []
        pbar = tqdm(dl_tr, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for img, mask in pbar:
            img, mask = img.to(DEVICE), mask.to(DEVICE)

            pred = model(img)
            if TASK_TYPE == "binary":
                loss = criterion(pred.squeeze(1), mask)
            else:
                loss = criterion(pred, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            pbar.set_postfix(train_loss=f"{loss.item():.4f}")

        train_loss = float(np.mean(train_losses))
        # ==================
        # Validation
        # ==================
        val_loss, val_dice, val_iou = validate(model, dl_va, criterion)
        # ==================
        # Early Stopping
        # ==================
        is_best = early_stopper.step(val_loss)
        if is_best:
            torch.save(model.state_dict(), best_path)

        # ==================
        # Logging (핵심)
        # ==================
        print(
            f"Epoch {epoch+1}/{EPOCHS} | "
            f"train_loss={train_loss:.4f} | "
            f"val_loss={val_loss:.4f} | "
            f"val_dice={val_dice:.4f} | "
            f"val_iou={val_iou:.4f} | "
            f"best={early_stopper.best_score:.4f} | "
            f"patience={early_stopper.counter}/{early_stopper.patience}"
        )

        if early_stopper.early_stop:
            print("[Early Stop] Training stopped.")
            break


    print(f"[DONE] Best model saved to {best_path}")
    model.load_state_dict(torch.load(best_path))
    return model

In [None]:
model = train_model(df)

In [None]:
def save_training_env(out_dir):
    env = {
        "TASK_TYPE": TASK_TYPE,
        "LAYER_IDS": LAYER_IDS,
        "PATCH_SIZE": PATCH_SIZE,
        "BATCH_SIZE": BATCH_SIZE,
        "EPOCHS": EPOCHS,
        "LR": LR,
        "TEST_RATIO": TEST_RATIO,
        "VAL_RATIO": VAL_RATIO,
        "target_mag": 10.0,
        "DEVICE": DEVICE,
        "cuda_available": torch.cuda.is_available(),
        "torch_version": torch.__version__,
        "python_version": platform.python_version(),
    }

    save_path = os.path.join(out_dir, "training_env.json")
    with open(save_path, "w") as f:
        json.dump(env, f, indent=2)

    print(f"[OK] Training environment saved: {save_path}")
save_training_env(OUT_DIR)