In [None]:
import os
import cv2
import torch
import numpy as np
import pandas as pd
import openslide
import albumentations as A
import segmentation_models_pytorch as smp
from tqdm import tqdm
from pathlib import Path
from albumentations.pytorch import ToTensorV2

In [None]:
CSV_PATH    = "/home/khdp-user/workspace/dataset/Models/Infla_run_seg/dataset.csv"
WEIGHT_PATH = "/home/khdp-user/workspace/dataset/Models/Infla_run_seg/best_model.pt"
SVS_DIR     = Path("/home/khdp-user/workspace/dataset/Slide")

LAYER_IDS   = [1]          # binary: len==1, multiclass: 여러 개
TASK_TYPE   = "binary"     # "binary" | "multiclass" 
PATCH_SIZE  = 512
TARGET_MAG  = 10.0
THRESH      = 0.5          # binary only

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# ============================================================
# Utils
# ============================================================
def parse_xy_from_name(patch_name: str):
    # ..._X0Y0_034816_010240.png
    stem = Path(patch_name).stem
    parts = stem.split("_")
    x0 = int(parts[-2])
    y0 = int(parts[-1])
    return x0, y0


def dice_iou(pred01: np.ndarray, gt01: np.ndarray, eps=1e-6):
    pred01 = pred01.astype(bool)
    gt01   = gt01.astype(bool)

    inter = np.logical_and(pred01, gt01).sum()
    union = np.logical_or(pred01, gt01).sum()

    dice = (2 * inter + eps) / (pred01.sum() + gt01.sum() + eps)
    iou  = (inter + eps) / (union + eps)
    return float(dice), float(iou)


def safe_img_path(row):
    """CSV의 path가 파일일 수도, 디렉토리일 수도 있는 케이스 처리."""
    p = Path(row["path"])
    if p.is_file():
        return p
    return p / row["name"]


def get_slide_id_from_name(patch_name: str):
    # "SLIDEID_PAS_..." -> "SLIDEID"
    return patch_name.split("_PAS")[0]


def infer_task_type(task_type: str, layer_ids):
    if task_type in ("binary", "multiclass"):
        return task_type
    # auto
    return "binary" if len(layer_ids) == 1 else "multiclass"


# ============================================================
# Transform (MUST match training/val)
# ============================================================
def get_infer_transform():
    return A.Compose([
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])


# ============================================================
# Model
# ============================================================
def build_model(layer_ids, task_type: str):
    task_type = infer_task_type(task_type, layer_ids)
    is_binary = (task_type == "binary")

    if is_binary:
        classes = 1
    else:
        # background + len(layer_ids)
        classes = 1 + len(layer_ids)

    model = smp.Unet(
        encoder_name="resnet50",
        encoder_weights=None,   # weight는 state_dict로 로드됨
        in_channels=3,
        classes=classes,
        activation=None,
    )
    model.load_state_dict(torch.load(WEIGHT_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    return model, is_binary, task_type


# ============================================================
# Patch-level eval
# ============================================================
@torch.no_grad()
def eval_patch_level(model, task_type: str, df_test: pd.DataFrame):
    tf = get_infer_transform()
    rows = []

    is_binary = (task_type == "binary")

    for _, row in tqdm(df_test.iterrows(), total=len(df_test), desc="Patch-level eval"):
        img_path   = safe_img_path(row)
        patch_name = row["name"]
        slide_id   = get_slide_id_from_name(patch_name)

        img = cv2.imread(str(img_path))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        out = tf(image=img, mask=np.zeros(img.shape[:2], np.uint8))
        x = out["image"].unsqueeze(0).to(DEVICE)  # (1,3,H,W)

        logits = model(x)

        slide_root = img_path.parents[1]         # .../SLIDE_ID_PAS
        mask_root  = slide_root / "masks"

        if is_binary:
            gt_path = mask_root / f"layer{LAYER_IDS[0]}" / patch_name
            gt = (cv2.imread(str(gt_path), 0) > 127) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), bool)

            prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
            pred = (prob > THRESH)

            d, i = dice_iou(pred, gt)
            rows.append({"slide_id": slide_id, "layer": LAYER_IDS[0], "dice": d, "iou": i})

        else:
            # multiclass: logits (1,C,H,W) -> argmax (H,W), class 0..K
            probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
            pred_cls = np.argmax(probs, axis=0)

            for k, lid in enumerate(LAYER_IDS, start=1):
                gt_path = mask_root / f"layer{lid}" / patch_name
                gt = (cv2.imread(str(gt_path), 0) > 127) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), bool)

                pred = (pred_cls == k)
                d, i = dice_iou(pred, gt)
                rows.append({"slide_id": slide_id, "layer": lid, "dice": d, "iou": i})

    return pd.DataFrame(rows)


# ============================================================
# Slide-level stitching eval (TARGET_MAG canvas)
# ============================================================
@torch.no_grad()
def eval_slide_level_stitching(model, task_type: str, df_slide: pd.DataFrame):
    tf = get_infer_transform()
    is_binary = (task_type == "binary")

    any_name = df_slide.iloc[0]["name"]
    slide_id = get_slide_id_from_name(any_name)
    svs_path = SVS_DIR / f"{slide_id}_PAS.svs"
    if not svs_path.exists():
        return None

    slide = openslide.OpenSlide(str(svs_path))

    # base mag from Aperio property (fallback 40)
    BASE_MAG = float(slide.properties.get("aperio.AppMag", 40.0))
    down = BASE_MAG / float(TARGET_MAG)

    W0, H0 = slide.level_dimensions[0]
    Wt, Ht = int(W0 / down), int(H0 / down)

    # patch downscale to target mag canvas
    scale = 1.0 / down
    tw = max(1, int(PATCH_SIZE * scale))
    th = max(1, int(PATCH_SIZE * scale))

    if is_binary:
        pred_canvas = np.zeros((Ht, Wt), np.uint8)
        gt_canvas   = np.zeros((Ht, Wt), np.uint8)
    else:
        pred_canvas = {lid: np.zeros((Ht, Wt), np.uint8) for lid in LAYER_IDS}
        gt_canvas   = {lid: np.zeros((Ht, Wt), np.uint8) for lid in LAYER_IDS}

    for _, row in df_slide.iterrows():
        img_path   = safe_img_path(row)
        patch_name = row["name"]

        img = cv2.imread(str(img_path))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        out = tf(image=img, mask=np.zeros(img.shape[:2], np.uint8))
        x = out["image"].unsqueeze(0).to(DEVICE)

        logits = model(x)

        x0, y0 = parse_xy_from_name(patch_name)
        xt, yt = int(x0 / down), int(y0 / down)

        slide_root = img_path.parents[1]
        mask_root  = slide_root / "masks"

        if is_binary:
            prob = torch.sigmoid(logits)[0, 0].cpu().numpy()
            pred_patch = (prob > THRESH).astype(np.uint8)

            gt_path = mask_root / f"layer{LAYER_IDS[0]}" / patch_name
            gt_patch = (cv2.imread(str(gt_path), 0) > 127).astype(np.uint8) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), np.uint8)

            pred_small = cv2.resize(pred_patch, (tw, th), interpolation=cv2.INTER_NEAREST)
            gt_small   = cv2.resize(gt_patch,   (tw, th), interpolation=cv2.INTER_NEAREST)

            if yt + th > Ht or xt + tw > Wt:
                continue

            pred_canvas[yt:yt+th, xt:xt+tw] = np.maximum(pred_canvas[yt:yt+th, xt:xt+tw], pred_small)
            gt_canvas[yt:yt+th, xt:xt+tw]   = np.maximum(gt_canvas[yt:yt+th, xt:xt+tw],   gt_small)

        else:
            probs = torch.softmax(logits, dim=1)[0].cpu().numpy()
            pred_cls = np.argmax(probs, axis=0).astype(np.uint8)  # 0..K

            if yt + th > Ht or xt + tw > Wt:
                continue

            for k, lid in enumerate(LAYER_IDS, start=1):
                pred_patch = (pred_cls == k).astype(np.uint8)

                gt_path = mask_root / f"layer{lid}" / patch_name
                gt_patch = (cv2.imread(str(gt_path), 0) > 127).astype(np.uint8) if gt_path.exists() else np.zeros((PATCH_SIZE, PATCH_SIZE), np.uint8)

                pred_small = cv2.resize(pred_patch, (tw, th), interpolation=cv2.INTER_NEAREST)
                gt_small   = cv2.resize(gt_patch,   (tw, th), interpolation=cv2.INTER_NEAREST)

                pred_canvas[lid][yt:yt+th, xt:xt+tw] = np.maximum(pred_canvas[lid][yt:yt+th, xt:xt+tw], pred_small)
                gt_canvas[lid][yt:yt+th, xt:xt+tw]   = np.maximum(gt_canvas[lid][yt:yt+th, xt:xt+tw],   gt_small)

    slide.close()

    if is_binary:
        d, i = dice_iou(pred_canvas, gt_canvas)
        return {"slide_id": slide_id, "layer": LAYER_IDS[0], "dice": d, "iou": i}
    else:
        out_rows = []
        for lid in LAYER_IDS:
            d, i = dice_iou(pred_canvas[lid], gt_canvas[lid])
            out_rows.append({"slide_id": slide_id, "layer": lid, "dice": d, "iou": i})
        return out_rows


# ============================================================
# Main (mode 선택 실행)
# ============================================================
def run(mode: str):
    df = pd.read_csv(CSV_PATH)
    df_test = df[df["split"].astype(str).str.lower() == "test"].reset_index(drop=True)
    if len(df_test) == 0:
        raise RuntimeError("No test rows found in CSV (split == test).")

    model, is_binary, task_type = build_model(LAYER_IDS, TASK_TYPE)

    print(f"[INFO] TASK: {task_type} | LAYER_IDS={LAYER_IDS} | THRESH={THRESH if task_type=='binary' else 'N/A'}")
    print(f"[INFO] DEVICE: {DEVICE} | #test patches: {len(df_test)}")

    if mode in ("patch", "both"):
        patch_df = eval_patch_level(model, task_type, df_test)

        print("\n[PATCH-LEVEL MEAN]")
        if not patch_df.empty:
            print(patch_df.groupby("layer")[["dice", "iou"]].mean())
        else:
            print("No patch metrics computed (check paths).")

    if mode in ("slide", "both"):
        slide_rows = []
        group_key = df_test["name"].astype(str).apply(get_slide_id_from_name)

        for slide_id, df_slide in tqdm(df_test.groupby(group_key), desc="Slide stitching"):
            res = eval_slide_level_stitching(model, task_type, df_slide)
            if res is None:
                continue
            if task_type == "binary":
                slide_rows.append(res)
            else:
                slide_rows.extend(res)

        slide_df = pd.DataFrame(slide_rows)

        print("\n[SLIDE-LEVEL (STITCHING) MEAN]")
        if not slide_df.empty:
            print(slide_df.groupby("layer")[["dice", "iou"]].mean())
            if task_type != "binary":
                macro = slide_df.groupby("slide_id")[["dice", "iou"]].mean().mean()
                print("\n[SLIDE-LEVEL MACRO (mean over layers, then over slides)]")
                print(macro)
        else:
            print("No slide metrics computed (check SVS paths / naming).")

    print("\n[OK] Done.")

In [None]:
run(mode="both") # patch, slide, both