In [None]:
import cv2
import numpy as np
import pandas as pd
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import segmentation_models_pytorch as smp

In [None]:
CSV_PATH   = "/home/khdp-user/workspace/IFTA_run_seg/dataset.csv"     
WEIGHT_PATH = "/home/khdp-user/workspace/IFTA_run_seg/best_model.pt"
SVS_DIR = Path("/home/khdp-user/workspace/dataset/Slide")
LAYER_IDS = [1,2,3]                  # binary 기준
TASK_TYPE = "multiclass" 
PATCH_SIZE = 256
TARGET_MAG = 10.0
BATCH_SIZE = 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
THRESH = 0.5
AGG = "max"

In [None]:
# -------------------------
# Utils
# -------------------------
def parse_xy_from_name(patch_name: str):
    # ..._X0Y0_034816_010240.png
    stem = Path(patch_name).stem
    parts = stem.split("_")
    x0 = int(parts[-2])
    y0 = int(parts[-1])
    return x0, y0


def dice_iou(pred01: np.ndarray, gt01: np.ndarray, eps=1e-6):
    pred01 = pred01.astype(bool)
    gt01   = gt01.astype(bool)

    inter = np.logical_and(pred01, gt01).sum()
    union = np.logical_or(pred01, gt01).sum()

    dice = (2 * inter + eps) / (pred01.sum() + gt01.sum() + eps)
    iou  = (inter + eps) / (union + eps)
    return float(dice), float(iou)


def safe_img_path(row):
    """CSV의 path가 디렉토리일 수도, 파일일 수도 있는 케이스 둘 다 처리."""
    p = Path(row["path"])
    if p.is_file():
        return p
    return p / row["name"]


def get_slide_id_from_name(patch_name: str):
    # 11_01_0144_PAS_... -> 11_01_0144
    return patch_name.split("_PAS")[0]


# -------------------------
# Model
# -------------------------
def build_model(layer_ids):
    is_binary = (len(layer_ids) == 1)

    if is_binary:
        classes = 1
    else:
        # background + len(layer_ids)
        classes = 1 + len(layer_ids)

    model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights=None,   # 테스트 시엔 None (학습 때 imagenet 써도 weight는 로드됨)
        in_channels=3,
        classes=classes,
        activation=None,
    )
    model.load_state_dict(torch.load(WEIGHT_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model, is_binary


# -------------------------
# Patch-level eval (optional)
# -------------------------
@torch.no_grad()
def eval_patch_level(model, is_binary, df_test):
    rows = []

    for _, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Patch-level eval"):
        img_path = safe_img_path(row)
        patch_name = row["name"]
        slide_id = get_slide_id_from_name(patch_name)

        img = cv2.imread(str(img_path))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        x = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.
        x = x.to(DEVICE)

        logits = model(x)

        # GT 로드 (각 layer별 mask png 존재한다고 가정)
        slide_root = img_path.parents[1]  # .../SLIDE_ID_PAS
        mask_root = slide_root / "masks"

        if is_binary:
            gt_path = mask_root / f"layer{LAYER_IDS[0]}" / patch_name
            gt = (cv2.imread(str(gt_path), 0) > 127) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), bool)

            prob = torch.sigmoid(logits)[0,0].cpu().numpy()
            pred = prob > THRESH

            d, i = dice_iou(pred, gt)
            rows.append({"slide_id": slide_id, "layer": LAYER_IDS[0], "dice": d, "iou": i})
        else:
            gt_stack = []
            for lid in LAYER_IDS:
                gp = mask_root / f"layer{lid}" / patch_name
                g = (cv2.imread(str(gp), 0) > 127) if gp.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), bool)
                gt_stack.append(g)

            pred_cls = torch.argmax(torch.softmax(logits, dim=1), dim=1)[0].cpu().numpy()  # 0..K
            # layer index: 1..K corresponds to LAYER_IDS order
            for k, lid in enumerate(LAYER_IDS, start=1):
                pred = (pred_cls == k)
                gt   = gt_stack[k-1]
                d, i = dice_iou(pred, gt)
                rows.append({"slide_id": slide_id, "layer": lid, "dice": d, "iou": i})

    return pd.DataFrame(rows)


# -------------------------
# Slide-level stitching eval (true WSI-level at TARGET_MAG)
# -------------------------
@torch.no_grad()
def eval_slide_level_stitching(model, is_binary, df_slide):
    # slide open
    any_name = df_slide.iloc[0]["name"]
    slide_id = get_slide_id_from_name(any_name)
    svs_path = SVS_DIR / f"{slide_id}_PAS.svs"
    if not svs_path.exists():
        return None

    slide = openslide.OpenSlide(str(svs_path))

    BASE_MAG = float(slide.properties.get("aperio.AppMag", 40.0))
    down = BASE_MAG / float(TARGET_MAG)

    W0, H0 = slide.level_dimensions[0]
    Wt, Ht = int(W0 / down), int(H0 / down)

    if is_binary:
        pred_canvas = np.zeros((Ht, Wt), np.uint8)
        gt_canvas   = np.zeros((Ht, Wt), np.uint8)
    else:
        # per-layer canvas
        pred_canvas = {lid: np.zeros((Ht, Wt), np.uint8) for lid in LAYER_IDS}
        gt_canvas   = {lid: np.zeros((Ht, Wt), np.uint8) for lid in LAYER_IDS}

    # stitch
    for _, row in df_slide.iterrows():
        img_path = safe_img_path(row)
        patch_name = row["name"]

        img = cv2.imread(str(img_path))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        x = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.
        x = x.to(DEVICE)

        logits = model(x)

        x0, y0 = parse_xy_from_name(patch_name)
        xt, yt = int(x0 / down), int(y0 / down)

        # GT root
        slide_root = img_path.parents[1]
        mask_root = slide_root / "masks"

        if is_binary:
            prob = torch.sigmoid(logits)[0,0].cpu().numpy()
            pred_patch = (prob > THRESH).astype(np.uint8)

            gt_path = mask_root / f"layer{LAYER_IDS[0]}" / patch_name
            gt_patch = (cv2.imread(str(gt_path), 0) > 127).astype(np.uint8) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), np.uint8)

            h, w = pred_patch.shape
            if yt+h > Ht or xt+w > Wt:
                continue

            pred_canvas[yt:yt+h, xt:xt+w] = np.maximum(pred_canvas[yt:yt+h, xt:xt+w], pred_patch)
            gt_canvas[yt:yt+h, xt:xt+w]   = np.maximum(gt_canvas[yt:yt+h, xt:xt+w], gt_patch)
        else:
            probs = torch.softmax(logits, dim=1)[0].cpu().numpy()  # (1+K,H,W)
            pred_cls = np.argmax(probs, axis=0).astype(np.uint8)   # 0..K

            h, w = pred_cls.shape
            if yt+h > Ht or xt+w > Wt:
                continue

            for k, lid in enumerate(LAYER_IDS, start=1):
                pred_patch = (pred_cls == k).astype(np.uint8)

                gt_path = mask_root / f"layer{lid}" / patch_name
                gt_patch = (cv2.imread(str(gt_path), 0) > 127).astype(np.uint8) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), np.uint8)

                pred_canvas[lid][yt:yt+h, xt:xt+w] = np.maximum(pred_canvas[lid][yt:yt+h, xt:xt+w], pred_patch)
                gt_canvas[lid][yt:yt+h, xt:xt+w]   = np.maximum(gt_canvas[lid][yt:yt+h, xt:xt+w], gt_patch)

    # compute metrics
    if is_binary:
        d, i = dice_iou(pred_canvas, gt_canvas)
        slide.close()
        return {
            "slide_id": slide_id,
            "layer": LAYER_IDS[0],
            "dice": d,
            "iou": i,
        }
    else:
        out = []
        for lid in LAYER_IDS:
            d, i = dice_iou(pred_canvas[lid], gt_canvas[lid])
            out.append({"slide_id": slide_id, "layer": lid, "dice": d, "iou": i})
        slide.close()
        return out




In [None]:
# -------------------------
# Main
# -------------------------
def main():
    df = pd.read_csv(CSV_PATH)
    df_test = df[df["split"].astype(str).str.lower() == "test"].reset_index(drop=True)
    if len(df_test) == 0:
        raise RuntimeError("No test rows found in CSV (split == test).")
    model, is_binary = build_model(LAYER_IDS)
    print(f"[INFO] TASK: {'binary(sigmoid)' if is_binary else 'multiclass(softmax)'} | LAYER_IDS={LAYER_IDS}")
    print(f"[INFO] DEVICE: {DEVICE} | #test patches: {len(df_test)}")
    patch_df = eval_patch_level(model, is_binary, df_test)
#     patch_df.to_csv("test_patch_metrics.csv", index=False)
    print("\n[PATCH-LEVEL MEAN]")
    if not patch_df.empty:
        print(patch_df.groupby("layer")[["dice","iou"]].mean())
    else:
        print("No patch metrics computed (check paths).")
    slide_rows = []
    group_key = df_test["name"].astype(str).apply(get_slide_id_from_name)
    for slide_id, df_slide in tqdm(df_test.groupby(group_key), desc="Slide stitching"):
        res = eval_slide_level_stitching(model, is_binary, df_slide)
        if res is None:
            continue
        if is_binary:
            slide_rows.append(res)
        else:
            slide_rows.extend(res)

    slide_df = pd.DataFrame(slide_rows)
#     slide_df.to_csv("test_slide_stitching_metrics.csv", index=False)

    print("\n[SLIDE-LEVEL (STITCHING) MEAN]")
    if not slide_df.empty:
        print(slide_df.groupby("layer")[["dice","iou"]].mean())
        # multiclass면 macro 평균도 출력
        if not is_binary:
            macro = slide_df.groupby("slide_id")[["dice","iou"]].mean().mean()
            print("\n[SLIDE-LEVEL MACRO (mean over layers, then over slides)]")
            print(macro)
    else:
        print("No slide metrics computed (check SVS paths / naming).")

    print("\n[OK] Done.")


if __name__ == "__main__":
    main()