In [14]:
import os
import cv2
import random
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import platform
import json

# Datasets split

In [3]:
# ============================================================
# CONFIG
# ============================================================
ROOT_DIR = "/home/khdp-user/workspace/Infla_patch"
TASK_TYPE = "binary"      # "binary" | "multiclass"
LAYER_IDS = [1]           # binary면 Len == 1
PATCH_SIZE = 512
BATCH_SIZE = 16
EPOCHS = 50
LR = 1e-4
TEST_RATIO = 0.2
VAL_RATIO = 0.2   # train 중 val 비율
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
OUT_DIR = "/home/khdp-user/workspace/Infla_run_seg"
CSV_PATH = os.path.join(OUT_DIR, "dataset.csv")
os.makedirs(OUT_DIR, exist_ok=True)

assert TASK_TYPE in ["binary", "multiclass"]
if TASK_TYPE == "binary":
    assert len(LAYER_IDS) == 1

In [4]:
# ============================================================
# DATASET SPLIT
# ============================================================
def build_dataset_csv():
    root = Path(ROOT_DIR)
    slides = sorted([p for p in root.iterdir() if p.is_dir()])
    slide_ids = [s.name for s in slides]

    random.shuffle(slide_ids)

    n = len(slide_ids)
    n_test = int(n * TEST_RATIO)
    n_val = int((n - n_test) * VAL_RATIO)

    test_slides = set(slide_ids[:n_test])
    val_slides = set(slide_ids[n_test:n_test+n_val])
    train_slides = set(slide_ids[n_test+n_val:])

    def split_of(slide):
        if slide in train_slides: return "train"
        if slide in val_slides: return "val"
        return "test"

    rows = []

    for slide in slides:
        slide_id = slide.name
        split = split_of(slide_id)

        img_dir = slide / "images"
        mask_root = slide / "masks"

        for img_path in img_dir.glob("*.png"):
            # LAYER_IDS 중 하나라도 mask 있으면 포함
            valid = False
            for lid in LAYER_IDS:
                mp = mask_root / f"layer{lid}" / img_path.name
                if mp.exists():
                    valid = True
                    break
            if not valid:
                continue

            rows.append({
                "name": img_path.name,
                "path": str(img_path),
                "split": split,
            })

    df = pd.DataFrame(rows)
    df.to_csv(CSV_PATH, index=False)
    print(f"[OK] CSV saved: {CSV_PATH}  (patches={len(df)})")
    return df

In [5]:
df = build_dataset_csv()

[OK] CSV saved: /home/khdp-user/workspace/Infla_run_seg/dataset.csv  (patches=328)


In [7]:
pd.read_csv(CSV_PATH).value_counts('split')

split
train    232
test      70
val       26
Name: count, dtype: int64

# Training loop

In [8]:
# ============================================================
# DATASET
# ============================================================
class SegDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img = cv2.imread(row["path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        slide_dir = Path(row["path"]).parents[1]
        mask_root = slide_dir / "masks"

        masks = []
        for lid in LAYER_IDS:
            mp = mask_root / f"layer{lid}" / row["name"]
            if mp.exists():
                m = cv2.imread(str(mp), 0) > 127
            else:
                m = np.zeros(img.shape[:2], bool)
            masks.append(m)

        if TASK_TYPE == "binary":
            mask = masks[0].astype(np.float32)
        else:
            # multiclass (background=0)
            mask = np.zeros(img.shape[:2], np.int64)
            for i, m in enumerate(masks):
                mask[m] = i + 1

        if self.transform:
            out = self.transform(image=img, mask=mask)
            img, mask = out["image"], out["mask"]
        
        img = img.float()
        if TASK_TYPE == "binary":
            mask = mask.float()     
        else:
            mask = mask.long()
        return img, mask

    
# ============================================================
# TRANSFORMS
# ============================================================
def get_transforms():
    train_tf = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(),
        A.Resize(PATCH_SIZE,PATCH_SIZE),
        ToTensorV2(),
    ])
    val_tf = A.Compose([
        A.Resize(PATCH_SIZE,PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
    return train_tf, val_tf

# ============================================================
# TRAINING
# ============================================================
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.0, mode="min"):
        """
        patience : 개선 없이 기다릴 epoch 수
        min_delta: 개선으로 인정할 최소 변화량
        mode     : 'min' (loss) or 'max' (metric)
        """
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode

        self.best_score = None
        self.counter = 0
        self.early_stop = False

    def step(self, score):
        if self.best_score is None:
            self.best_score = score
            return True  # best 갱신

        improved = (
            score < self.best_score - self.min_delta
            if self.mode == "min"
            else score > self.best_score + self.min_delta
        )

        if improved:
            self.best_score = score
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
            return False

@torch.no_grad()
def validate(model, loader, criterion):
    model.eval()

    total_loss = 0.0
    n = 0

    dice_list = []
    iou_list = []

    for img, mask in loader:
        img, mask = img.to(DEVICE), mask.to(DEVICE)
        pred = model(img)

        if TASK_TYPE == "binary":
            loss = criterion(pred.squeeze(1), mask)
            prob = torch.sigmoid(pred)
            pred_bin = (prob > 0.5).float()
            gt = mask.unsqueeze(1)
        else:
            loss = criterion(pred, mask)
            prob = torch.softmax(pred, dim=1)
            pred_bin = torch.argmax(prob, dim=1)
            gt = mask

        total_loss += loss.item() * img.size(0)
        n += img.size(0)

        # ---- Dice / IoU (batch 평균)
        if TASK_TYPE == "binary":
            p = pred_bin.view(pred_bin.size(0), -1)
            g = gt.view(gt.size(0), -1)

            inter = (p * g).sum(dim=1)
            union = p.sum(dim=1) + g.sum(dim=1)

            dice = (2 * inter + 1e-6) / (union + 1e-6)
            iou = (inter + 1e-6) / (p.sum(dim=1) + g.sum(dim=1) - inter + 1e-6)

            dice_list.append(dice.mean().item())
            iou_list.append(iou.mean().item())
        else:
            num_classes = pred.shape[1]
            dice_per_class = []
            iou_per_class = []

            for c in range(1, num_classes):  # background 제외
                p = (pred_bin == c).float().view(pred_bin.size(0), -1)
                g = (gt == c).float().view(gt.size(0), -1)

                inter = (p * g).sum(dim=1)
                union = p.sum(dim=1) + g.sum(dim=1)

                dice = (2 * inter + 1e-6) / (union + 1e-6)
                iou  = (inter + 1e-6) / (p.sum(dim=1) + g.sum(dim=1) - inter + 1e-6)

                dice_per_class.append(dice.mean())
                iou_per_class.append(iou.mean())
            dice_list.append(torch.stack(dice_per_class).mean().item())
            iou_list.append(torch.stack(iou_per_class).mean().item())



    val_loss = total_loss / max(n, 1)
    val_dice = float(np.mean(dice_list)) if dice_list else 0.0
    val_iou  = float(np.mean(iou_list)) if iou_list else 0.0

    return val_loss, val_dice, val_iou


def train_model(df):
    train_tf, val_tf = get_transforms()

    df_tr = df[df.split == "train"]
    df_va = df[df.split == "val"]

    ds_tr = SegDataset(df_tr, train_tf)
    ds_va = SegDataset(df_va, val_tf)

    dl_tr = DataLoader(
        ds_tr, batch_size=BATCH_SIZE, shuffle=True,
        num_workers=4, pin_memory=True
    )
    dl_va = DataLoader(
        ds_va, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=4, pin_memory=True
    )

    # model / loss
    if TASK_TYPE == "binary":
        model = smp.Unet(
            encoder_name="resnet50",
            encoder_weights="imagenet",
            in_channels=3,
            classes=1,
            activation=None,
        )
        criterion = smp.losses.DiceLoss(mode="binary", from_logits=True)
    else:
        model = smp.Unet(
            encoder_name="resnet50",
            encoder_weights="imagenet",
            in_channels=3,
            classes=len(LAYER_IDS) + 1,
            activation=None,
        )
        criterion = smp.losses.DiceLoss(mode="multiclass", from_logits=True)

    model.to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
    early_stopper = EarlyStopping(patience=5, min_delta=1e-5, mode="min")
    best_path = os.path.join(OUT_DIR, "best_model.pt")
    for epoch in range(EPOCHS):
        # ==================
        # Train
        # ==================
        model.train()
        train_losses = []
        pbar = tqdm(dl_tr, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for img, mask in pbar:
            img, mask = img.to(DEVICE), mask.to(DEVICE)

            pred = model(img)
            if TASK_TYPE == "binary":
                loss = criterion(pred.squeeze(1), mask)
            else:
                loss = criterion(pred, mask)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_losses.append(loss.item())
            pbar.set_postfix(train_loss=f"{loss.item():.4f}")

        train_loss = float(np.mean(train_losses))
        # ==================
        # Validation
        # ==================
        val_loss, val_dice, val_iou = validate(model, dl_va, criterion)
        # ==================
        # Early Stopping
        # ==================
        is_best = early_stopper.step(val_loss)
        if is_best:
            torch.save(model.state_dict(), best_path)

        # ==================
        # Logging (핵심)
        # ==================
        print(
            f"Epoch {epoch+1}/{EPOCHS} | "
            f"train_loss={train_loss:.4f} | "
            f"val_loss={val_loss:.4f} | "
            f"val_dice={val_dice:.4f} | "
            f"val_iou={val_iou:.4f} | "
            f"best={early_stopper.best_score:.4f} | "
            f"patience={early_stopper.counter}/{early_stopper.patience}"
        )

        if early_stopper.early_stop:
            print("[Early Stop] Training stopped.")
            break


    print(f"[DONE] Best model saved to {best_path}")
    model.load_state_dict(torch.load(best_path))
    return model

In [9]:
model = train_model(df)

Epoch 1/50: 100%|██████████| 15/15 [00:16<00:00,  1.07s/it, train_loss=0.6762]


Epoch 1/50 | train_loss=0.6407 | val_loss=0.6762 | val_dice=0.4515 | val_iou=0.3049 | best=0.6762 | patience=0/5


Epoch 2/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.5026]


Epoch 2/50 | train_loss=0.5302 | val_loss=0.5913 | val_dice=0.5588 | val_iou=0.4072 | best=0.5913 | patience=0/5


Epoch 3/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.3818]


Epoch 3/50 | train_loss=0.4677 | val_loss=0.5521 | val_dice=0.6208 | val_iou=0.4695 | best=0.5521 | patience=0/5


Epoch 4/50: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s, train_loss=0.4231]


Epoch 4/50 | train_loss=0.4326 | val_loss=0.5573 | val_dice=0.5660 | val_iou=0.4127 | best=0.5521 | patience=1/5


Epoch 5/50: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s, train_loss=0.5151]


Epoch 5/50 | train_loss=0.4179 | val_loss=0.5194 | val_dice=0.6717 | val_iou=0.5332 | best=0.5194 | patience=0/5


Epoch 6/50: 100%|██████████| 15/15 [00:13<00:00,  1.08it/s, train_loss=0.3963]


Epoch 6/50 | train_loss=0.3881 | val_loss=0.5501 | val_dice=0.5786 | val_iou=0.4384 | best=0.5194 | patience=1/5


Epoch 7/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.3971]


Epoch 7/50 | train_loss=0.3623 | val_loss=0.5297 | val_dice=0.6166 | val_iou=0.4638 | best=0.5194 | patience=2/5


Epoch 8/50: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s, train_loss=0.3031]


Epoch 8/50 | train_loss=0.3420 | val_loss=0.5142 | val_dice=0.6428 | val_iou=0.4973 | best=0.5142 | patience=0/5


Epoch 9/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2825]


Epoch 9/50 | train_loss=0.3215 | val_loss=0.5049 | val_dice=0.6202 | val_iou=0.4677 | best=0.5049 | patience=0/5


Epoch 10/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2339]


Epoch 10/50 | train_loss=0.3017 | val_loss=0.4844 | val_dice=0.6196 | val_iou=0.4688 | best=0.4844 | patience=0/5


Epoch 11/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.3752]


Epoch 11/50 | train_loss=0.2926 | val_loss=0.4862 | val_dice=0.6260 | val_iou=0.4845 | best=0.4844 | patience=1/5


Epoch 12/50: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s, train_loss=0.2416]


Epoch 12/50 | train_loss=0.2705 | val_loss=0.5043 | val_dice=0.6055 | val_iou=0.4564 | best=0.4844 | patience=2/5


Epoch 13/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2638]


Epoch 13/50 | train_loss=0.2671 | val_loss=0.4777 | val_dice=0.6263 | val_iou=0.4860 | best=0.4777 | patience=0/5


Epoch 14/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2148]


Epoch 14/50 | train_loss=0.2460 | val_loss=0.4634 | val_dice=0.6360 | val_iou=0.4843 | best=0.4634 | patience=0/5


Epoch 15/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2305]


Epoch 15/50 | train_loss=0.2468 | val_loss=0.4755 | val_dice=0.5944 | val_iou=0.4441 | best=0.4634 | patience=1/5


Epoch 16/50: 100%|██████████| 15/15 [00:14<00:00,  1.07it/s, train_loss=0.2584]


Epoch 16/50 | train_loss=0.2320 | val_loss=0.4680 | val_dice=0.6002 | val_iou=0.4518 | best=0.4634 | patience=2/5


Epoch 17/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1692]


Epoch 17/50 | train_loss=0.2067 | val_loss=0.4549 | val_dice=0.6240 | val_iou=0.5019 | best=0.4549 | patience=0/5


Epoch 18/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2084]


Epoch 18/50 | train_loss=0.1979 | val_loss=0.4611 | val_dice=0.6002 | val_iou=0.4604 | best=0.4549 | patience=1/5


Epoch 19/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.3750]


Epoch 19/50 | train_loss=0.1990 | val_loss=0.4365 | val_dice=0.6316 | val_iou=0.5018 | best=0.4365 | patience=0/5


Epoch 20/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2378]


Epoch 20/50 | train_loss=0.1871 | val_loss=0.4715 | val_dice=0.5876 | val_iou=0.4623 | best=0.4365 | patience=1/5


Epoch 21/50: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s, train_loss=0.1629]


Epoch 21/50 | train_loss=0.1806 | val_loss=0.4420 | val_dice=0.6161 | val_iou=0.4654 | best=0.4365 | patience=2/5


Epoch 22/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1487]


Epoch 22/50 | train_loss=0.1706 | val_loss=0.4220 | val_dice=0.6301 | val_iou=0.5172 | best=0.4220 | patience=0/5


Epoch 23/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.2270]


Epoch 23/50 | train_loss=0.1689 | val_loss=0.4054 | val_dice=0.6598 | val_iou=0.5349 | best=0.4054 | patience=0/5


Epoch 24/50: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s, train_loss=0.1423]


Epoch 24/50 | train_loss=0.1602 | val_loss=0.4394 | val_dice=0.6112 | val_iou=0.4605 | best=0.4054 | patience=1/5


Epoch 25/50: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s, train_loss=0.1112]


Epoch 25/50 | train_loss=0.1533 | val_loss=0.4270 | val_dice=0.6321 | val_iou=0.5088 | best=0.4054 | patience=2/5


Epoch 26/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1278]


Epoch 26/50 | train_loss=0.1521 | val_loss=0.4144 | val_dice=0.6393 | val_iou=0.5046 | best=0.4054 | patience=3/5


Epoch 27/50: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s, train_loss=0.1435]


Epoch 27/50 | train_loss=0.1462 | val_loss=0.4010 | val_dice=0.6499 | val_iou=0.5215 | best=0.4010 | patience=0/5


Epoch 28/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1144]


Epoch 28/50 | train_loss=0.1376 | val_loss=0.4247 | val_dice=0.6065 | val_iou=0.4521 | best=0.4010 | patience=1/5


Epoch 29/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1504]


Epoch 29/50 | train_loss=0.1326 | val_loss=0.4229 | val_dice=0.6093 | val_iou=0.4819 | best=0.4010 | patience=2/5


Epoch 30/50: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s, train_loss=0.0983]


Epoch 30/50 | train_loss=0.1289 | val_loss=0.4331 | val_dice=0.5927 | val_iou=0.4471 | best=0.4010 | patience=3/5


Epoch 31/50: 100%|██████████| 15/15 [00:14<00:00,  1.05it/s, train_loss=0.1124]


Epoch 31/50 | train_loss=0.1222 | val_loss=0.3879 | val_dice=0.6473 | val_iou=0.5051 | best=0.3879 | patience=0/5


Epoch 32/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1394]


Epoch 32/50 | train_loss=0.1157 | val_loss=0.3671 | val_dice=0.6718 | val_iou=0.5323 | best=0.3671 | patience=0/5


Epoch 33/50: 100%|██████████| 15/15 [00:13<00:00,  1.09it/s, train_loss=0.1874]


Epoch 33/50 | train_loss=0.1173 | val_loss=0.3950 | val_dice=0.6380 | val_iou=0.4954 | best=0.3671 | patience=1/5


Epoch 34/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1053]


Epoch 34/50 | train_loss=0.1136 | val_loss=0.3744 | val_dice=0.6689 | val_iou=0.5252 | best=0.3671 | patience=2/5


Epoch 35/50: 100%|██████████| 15/15 [00:14<00:00,  1.01it/s, train_loss=0.1383]


Epoch 35/50 | train_loss=0.1134 | val_loss=0.3872 | val_dice=0.6544 | val_iou=0.5182 | best=0.3671 | patience=3/5


Epoch 36/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1115]


Epoch 36/50 | train_loss=0.1029 | val_loss=0.3609 | val_dice=0.6673 | val_iou=0.5276 | best=0.3609 | patience=0/5


Epoch 37/50: 100%|██████████| 15/15 [00:14<00:00,  1.03it/s, train_loss=0.1072]


Epoch 37/50 | train_loss=0.1005 | val_loss=0.3879 | val_dice=0.6421 | val_iou=0.5090 | best=0.3609 | patience=1/5


Epoch 38/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1027]


Epoch 38/50 | train_loss=0.0987 | val_loss=0.3816 | val_dice=0.6500 | val_iou=0.5088 | best=0.3609 | patience=2/5


Epoch 39/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.1182]


Epoch 39/50 | train_loss=0.0910 | val_loss=0.4038 | val_dice=0.6219 | val_iou=0.4890 | best=0.3609 | patience=3/5


Epoch 40/50: 100%|██████████| 15/15 [00:14<00:00,  1.02it/s, train_loss=0.0652]


Epoch 40/50 | train_loss=0.0904 | val_loss=0.3927 | val_dice=0.6288 | val_iou=0.4965 | best=0.3609 | patience=4/5


Epoch 41/50: 100%|██████████| 15/15 [00:14<00:00,  1.04it/s, train_loss=0.0769]


Epoch 41/50 | train_loss=0.0925 | val_loss=0.3752 | val_dice=0.6541 | val_iou=0.5157 | best=0.3609 | patience=5/5
[Early Stop] Training stopped.
[DONE] Best model saved to /home/khdp-user/workspace/Infla_run_seg/best_model.pt


In [15]:
def save_training_env(out_dir):
    env = {
        "TASK_TYPE": TASK_TYPE,
        "LAYER_IDS": LAYER_IDS,
        "PATCH_SIZE": PATCH_SIZE,
        "BATCH_SIZE": BATCH_SIZE,
        "EPOCHS": EPOCHS,
        "LR": LR,
        "TEST_RATIO": TEST_RATIO,
        "VAL_RATIO": VAL_RATIO,
        "target_mag": 10.0,
        "DEVICE": DEVICE,
        "cuda_available": torch.cuda.is_available(),
        "torch_version": torch.__version__,
        "python_version": platform.python_version(),
    }

    save_path = os.path.join(out_dir, "training_env.json")
    with open(save_path, "w") as f:
        json.dump(env, f, indent=2)

    print(f"[OK] Training environment saved: {save_path}")
save_training_env(OUT_DIR)

[OK] Training environment saved: /home/khdp-user/workspace/Infla_run_seg/training_env.json
