In [None]:
import os, shutil
from pathlib import Path
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
from matplotlib import cm

from tiatoolbox.wsicore.wsireader import WSIReader
from tiatoolbox.models.engine.patch_predictor import PatchPredictor
from tiatoolbox.utils.visualization import overlay_prediction_mask

import matplotlib.patches as mpatches
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D

import re
import glob


In [None]:
real_he_path    = Path('CRC2/data_CRC37_19510_C37_US_SCAN_OR_001__161733-registered.ome.tif')
virtual_he_path = Path('CRC2/H&E/data_CRC37_P37_S80_Full_A24_C59qX_E15_20220307_235159_333000-zlib.ome.tiff')

# Model-native tile size/stride in OUTPUT pixels
PATCH_SIZE = (224, 224)  # (H, W)
STRIDE     = (224, 224)  # (H, W)

# Read patches ~0.5 mpp; merge/overlay ~4.0 mpp (tuples for anisotropy safety)
INPUT_RES_MPP  = (0.5, 0.5)
MERGE_RES_MPP  = (4.0, 4.0)

# Inference micro-batch size
BATCH_INFER = 256

# Label mapping for Kather100K 9-class CRC model
label_dict = {
    "BACK": 0, "NORM": 1, "DEB":  2, "TUM":  3, "ADI":  4,
    "MUC":  5, "MUS":  6, "STR":  7, "LYM":  8,
}
id_to_name = {v: k for k, v in label_dict.items()}

# Colors for overlay (RGB)
label_info = {
    0: ("Background", (153, 153, 153)),  # Set1 gray   #999999
    1: ("Normal mucosa", (255, 127,   0)),  # Set1 orange #ff7f00  
    2: ("Debris",  (166,  86,  40)),  # Set1 brown  #a65628
    3: ("Tumor epithelium",  (228,  26,  28)),  # Set1 red    #e41a1c
    4: ("Adipose tissue",  (255, 255,  51)),  # Set1 yellow #ffff33
    5: ("Mucus",  ( 77, 175,  74)),  # Set1 green  #4daf4a
    6: ("Smooth muscle",  (247, 129, 191)),  # Set1 pink   #f781bf
    7: ("Stroma",  ( 55, 126, 184)),  # Set1 blue   #377eb8
    8: ("Lymphocytes",  (152,  78, 163)),  # Set1 purple #984ea3
}

# -------------------- MODEL --------------------
predictor = PatchPredictor(
    pretrained_model="resnet34-kather100k",
    # pretrained_weights=weights_path,
    batch_size=64,
    num_loader_workers=4,
    # device="cuda",
)

# -------------------- HELPERS --------------------
def _res_tuple(res):
    if isinstance(res, (int, float)):
        return (float(res), float(res))
    return (float(res[0]), float(res[1]))

def _slide_dims_at_res(wsi_reader, res_mpp_xy):
    """Return (W,H) slide dimensions in pixels at the requested mpp (resolution space)."""
    res_mpp_xy = _res_tuple(res_mpp_xy)
    try:
        return wsi_reader.slide_dimensions(resolution=res_mpp_xy, units="mpp")
    except TypeError:
        # Older builds accept scalar; use x-axis
        return wsi_reader.slide_dimensions(resolution=float(res_mpp_xy[0]), units="mpp")

def generate_coords_resolution_space(img_size_wh_res, patch_hw, stride_hw):
    """Grid of top-left (x,y) in RESOLUTION space (same pixels as PATCH_SIZE/STRIDE)."""
    W_res, H_res = map(int, img_size_wh_res)
    ph, pw = map(int, patch_hw)    # (H,W)
    sh, sw = map(int, stride_hw)   # (H,W)

    if ph <= 0 or pw <= 0:
        return np.zeros((0, 2), dtype=np.int32)

    max_x_start = max(W_res - pw, 0)
    max_y_start = max(H_res - ph, 0)

    xs = np.arange(0, max_x_start + 1, max(1, sw), dtype=np.int32)
    ys = np.arange(0, max_y_start + 1, max(1, sh), dtype=np.int32)
    if xs.size == 0 or ys.size == 0:
        return np.zeros((0, 2), dtype=np.int32)

    gx, gy = np.meshgrid(xs, ys)
    return np.stack([gx.ravel(), gy.ravel()], axis=1).astype(np.int32)

class OnTheFlyWSIPatchDataset:
    """Read patches at a physical resolution using RESOLUTION-SPACE coords."""
    def __init__(self, wsi_reader, coords_res_xy, patch_size_hw, resolution_mpp_xy):
        self.wsi = wsi_reader
        self.coords = np.asarray(coords_res_xy, dtype=np.int32)  # (x,y) in resolution-space
        self.size_wh = (int(patch_size_hw[1]), int(patch_size_hw[0]))  # (W,H)
        self.res_mpp_xy = _res_tuple(resolution_mpp_xy)

    def __len__(self): return int(self.coords.shape[0])

    def __getitem__(self, i):
        x, y = map(int, self.coords[i])
        img = self.wsi.read_rect(
            location=(x, y),
            size=self.size_wh,
            resolution=self.res_mpp_xy,
            units="mpp",
            coord_space="resolution",     # << CRUCIAL
        )
        return {"image": img}

def infer_patches_in_batches(predictor, dataset, batch_size=BATCH_INFER):
    preds_all, batch_imgs = [], []
    N = len(dataset)
    for i in range(N):
        batch_imgs.append(dataset[i]["image"])
        if len(batch_imgs) == batch_size or (i + 1) == N:
            out = predictor.predict(imgs=batch_imgs, mode="patch", return_probabilities=False)
            preds_all.extend(list(out["predictions"]))
            batch_imgs = []
    return np.asarray(preds_all, dtype=np.int16)

# -------------------- CORE PIPELINE --------------------
def run_wsi(img_path: Path, tag: str):
    if not img_path.exists():
        raise FileNotFoundError(f"WSI not found: {img_path.resolve()}")

    # 1) Open reader & get slide size in RESOLUTION space at INPUT_RES_MPP
    wsi_reader = WSIReader.open(img_path)
    slide_wh_res = _slide_dims_at_res(wsi_reader, INPUT_RES_MPP)   # (W_res, H_res)
    W_res, H_res = map(int, slide_wh_res)

    # 2) Compute grid shape in RESOLUTION space
    ph, pw = map(int, PATCH_SIZE)           # (H,W) in output pixels
    sh, sw = map(int, STRIDE)               # (H,W) in output pixels
    Nx = 0 if W_res < pw else ((W_res - pw) // sw + 1)
    Ny = 0 if H_res < ph else ((H_res - ph) // sh + 1)

    # 3) Build grid coordinates (same order as earlier code)
    coords_res = generate_coords_resolution_space(
        img_size_wh_res=slide_wh_res,
        patch_hw=PATCH_SIZE,
        stride_hw=STRIDE,
    )
    assert coords_res.shape[0] == Ny * Nx, (
        f"Grid size mismatch: coords={coords_res.shape[0]} vs Ny*Nx={Ny*Nx}"
    )

    # 4) Dataset & inference
    dataset = OnTheFlyWSIPatchDataset(
        wsi_reader=wsi_reader,
        coords_res_xy=coords_res,
        patch_size_hw=PATCH_SIZE,
        resolution_mpp_xy=INPUT_RES_MPP,
    )
    pred_ids = infer_patches_in_batches(predictor, dataset, batch_size=BATCH_INFER)

    # 5) Boxes for merge in RESOLUTION space (x0,y0,x1,y1)
    x0y0 = coords_res.astype(np.int32)
    x1y1 = x0y0 + np.asarray([pw, ph], dtype=np.int32)  # add (W,H)
    coords_boxes = np.hstack([x0y0, x1y1]).astype(np.int32)

    # 6) Merge to dense map
    out_for_merge = {
        "coordinates": coords_boxes,
        "predictions": pred_ids,
        "resolution": _res_tuple(INPUT_RES_MPP),
        "units": "mpp",
    }
    pred_map = predictor.merge_predictions(
        wsi_reader,
        out_for_merge,
        resolution=_res_tuple(MERGE_RES_MPP),
        units="mpp",
    )

    # 7) Overlay background aligned to pred_map
    h, w = pred_map.shape[:2]
    rgb_for_overlay = wsi_reader.read_rect(
        location=(0, 0),
        size=(int(w), int(h)),
        resolution=_res_tuple(MERGE_RES_MPP),
        units="mpp",
        coord_space="resolution",
    )

    # 8) Quicklooks
    max_side = 1024
    Hs, Ws = rgb_for_overlay.shape[:2]
    scale = max(Hs, Ws) / float(max_side)
    new_w, new_h = (int(round(Ws / scale)), int(round(Hs / scale))) if scale > 1.0 else (Ws, Hs)
    thumb_rgb = cv2.resize(rgb_for_overlay, (new_w, new_h), interpolation=cv2.INTER_AREA)
    pm_small  = cv2.resize(pred_map.astype(np.int16), (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    ax1 = overlay_prediction_mask(
        img=thumb_rgb, prediction=pm_small, alpha=0.45, label_info=label_info, return_ax=True
    )
    fig1 = ax1.get_figure()
    fig1.suptitle(f"{tag} — Thumbnail overlay", y=1.02)
    plt.show()
    plt.close(fig1)

    print(f"[{tag}] tiles: Ny={Ny}, Nx={Nx}")

    # Reshape predictions into a (Ny, Nx) grid using the same row-major order
    pred_grid = pred_ids.reshape(Ny, Nx)

    return {
        "pred_map": pred_map,
        "grid_shape": (Ny, Nx),
        "pred_grid": pred_grid,
    }

# -------------------- RUN ON BOTH WSIs --------------------
virt_res = run_wsi(virtual_he_path, tag="virtual_HE")
real_res = run_wsi(real_he_path,    tag="real_HE")


In [None]:
TITLE_FSIZE  = 16   # titles above each panel
LEGEND_FSIZE = 14   # legend text
SQUARE_SIZE  = 14   # legend square size

# Legend handles with FULL names, drawn as true squares
legend_handles = []
for cid, (name, rgb) in sorted(label_info.items()):
    legend_handles.append(
        Line2D(
            [0], [0],
            marker='s', linestyle='None',
            markersize=SQUARE_SIZE,
            markerfacecolor=np.array(rgb)/255.0,
            markeredgecolor='black',
            markeredgewidth=0.8,
            label=name
        )
    )

def _color_lut_from_label_info(label_info):
    max_id = max(label_info.keys())
    lut = np.zeros((max_id + 1, 3), dtype=np.uint8)
    for cid, (_, rgb) in label_info.items():
        lut[int(cid)] = np.array(rgb, dtype=np.uint8)
    return lut

def _labels_rgb_from_pred(pred_map, label_info):
    """Return a colorized label image (no H&E) from an integer pred_map."""
    lut = _color_lut_from_label_info(label_info)
    pred_clipped = np.clip(pred_map.astype(np.int32), 0, lut.shape[0]-1)
    return lut[pred_clipped]  # [H,W,3] uint8

def _read_rgb_at_merge(wsi_path, target_wh, res_mpp=(4.0, 4.0)):
    from tiatoolbox.wsicore.wsireader import WSIReader
    wsi = WSIReader.open(wsi_path)
    w, h = int(target_wh[0]), int(target_wh[1])
    return wsi.read_rect(location=(0, 0), size=(w, h),
                         resolution=res_mpp, units="mpp", coord_space="resolution")

def _resize_pair(rgb, pred_map, target_wh):
    Wt, Ht = int(target_wh[0]), int(target_wh[1])
    rgb_r  = cv2.resize(rgb, (Wt, Ht), interpolation=cv2.INTER_AREA)
    pm_r   = cv2.resize(pred_map.astype(np.int16), (Wt, Ht), interpolation=cv2.INTER_NEAREST)
    return rgb_r, pm_r

def show_wsi_comparison_clean(
    real_res, virt_res, real_path, virt_path, label_info, merge_res_mpp=(4.0, 4.0), target="min"
):
    # Choose a common canvas size at MERGE resolution
    Hr, Wr = real_res["pred_map"].shape[:2]
    Hv, Wv = virt_res["pred_map"].shape[:2]
    if target == "max":
        Wt, Ht = max(Wr, Wv), max(Hr, Hv)
    else:
        Wt, Ht = min(Wr, Wv), min(Hr, Hv)

    # Read H&E images at the common size
    rgb_real = _read_rgb_at_merge(real_path, (Wt, Ht), res_mpp=merge_res_mpp)
    rgb_virt = _read_rgb_at_merge(virt_path, (Wt, Ht), res_mpp=merge_res_mpp)

    # Resize prediction maps to the same size
    _, real_pm = _resize_pair(rgb_real, real_res["pred_map"], (Wt, Ht))
    _, virt_pm = _resize_pair(rgb_virt, virt_res["pred_map"], (Wt, Ht))

    # Labels-only (colorized) images
    labels_real_rgb = _labels_rgb_from_pred(real_pm, label_info)
    labels_virt_rgb = _labels_rgb_from_pred(virt_pm, label_info)

    # Plot (same compact spacing you liked)
    fig, axes = plt.subplots(
        2, 2, figsize=(10, 8),
        gridspec_kw={'hspace': 0.05, 'wspace': 0.05}
    )

    axes[0,0].imshow(rgb_real);
    axes[0,0].set_title("Real H&E", fontsize=TITLE_FSIZE, pad=4);
    axes[0,0].axis("off")
    axes[0,1].imshow(labels_real_rgb, interpolation="nearest");
    axes[0,1].set_title("Predictions on real H&E",    fontsize=TITLE_FSIZE, pad=4);
    axes[0,1].axis("off")
    
    axes[1,0].imshow(rgb_virt);
    axes[1,0].set_title("Virtual H&E", fontsize=TITLE_FSIZE, pad=4);
    axes[1,0].axis("off")
    axes[1,1].imshow(labels_virt_rgb, interpolation="nearest");
    axes[1,1].set_title("Predictions on virtual H&E", fontsize=TITLE_FSIZE, pad=4);
    axes[1,1].axis("off")


    # After plotting each image:
    for ax in axes.flat:
        ax.set_xticks([])
        ax.set_yticks([])
        # add a rectangle the size of the axes
        ax.add_patch(Rectangle(
            (0, 0), 1, 1, transform=ax.transAxes,
            fill=False, color="black", linewidth=1.5
        ))
        
    # Move legend a bit right and enlarge font
    fig.legend(
        handles=legend_handles,
        loc="upper left",
        bbox_to_anchor=(0.91, 0.88),
        fontsize=LEGEND_FSIZE,
        frameon=False,
        ncol=3,
        handlelength=0.0,
        labelspacing=0.6,
        borderpad=0.2
    )
    plt.show()

show_wsi_comparison_clean(
    real_res=real_res,
    virt_res=virt_res,
    real_path=real_he_path,
    virt_path=virtual_he_path,
    label_info=label_info,
    merge_res_mpp=(4.0, 4.0),
    target="min",  # or "max"
)


In [None]:
# -------- Ensure run_wsi (from your inference block) is available --------
try:
    run_wsi  # noqa: F401
except NameError as _:
    raise RuntimeError("run_wsi(...) not found. Run your TIAToolbox inference block first.")

# -------- Save directory for per-pair outputs --------
SAVE_DIR = Path("evaluation_data")
SAVE_DIR.mkdir(parents=True, exist_ok=True)

# -------- Label mapping (Kather100K, 9 classes) --------
label_dict = {
    "BACK": 0, "NORM": 1, "DEB":  2, "TUM":  3, "ADI":  4,
    "MUC":  5, "MUS":  6, "STR":  7, "LYM":  8,
}
num_classes = max(label_dict.values()) + 1
BACK_ID = label_dict["BACK"]

# Pretty (full) tissue names (maps id -> full string)
full_label_names = {
    "BACK": "Background",
    "NORM": "Normal mucosa",
    "DEB":  "Debris",
    "TUM":  "Tumor epithelium",
    "ADI":  "Adipose tissue",
    "MUC":  "Mucus",
    "MUS":  "Smooth muscle",
    "STR":  "Stroma",
    "LYM":  "Lymphocytes",
}
id_to_full = {label_dict[k]: v for k, v in full_label_names.items()}

# -------- Define the list of image pairs (virtual, real) --------
image_pairs = [
    ('CRC2/H&E/data_CRC01_P37_S29_A24_C59kX_E15_20220106_014304_946511-zlib.ome.tiff', 'CRC2/data_CRC01_18459_LSP10353_US_SCAN_OR_001__093059-registered.ome.tif'),
    ('CRC2/H&E/data_CRC02_P37_S30_A24_C59kX_E15_20220106_014319_409148-zlib.ome.tiff', 'CRC2/data_CRC02_18459_LSP10364_US_SCAN_OR_001__092347-registered.ome.tif'),
    ('CRC2/H&E/data_CRC03_P37_S31_A24_C59kX_E15_20220106_014409_014236-zlib.ome.tiff', 'CRC2/data_CRC03_18459_LSP10375_US_SCAN_OR_001__092147-registered.ome.tif'),
    ('CRC2/H&E/data_CRC04_P37_S32_A24_C59kX_E15_20220106_014630_553652-zlib.ome.tiff', 'CRC2/data_CRC04_18459_LSP10388_US_SCAN_OR_001__091155-registered.ome.tif'),
    ('CRC2/H&E/data_CRC05_P37_S33_A24_C59kX_E15_20220107_180446_881530-zlib.ome.tiff', 'CRC2/data_CRC05_18459_LSP10397_US_SCAN_OR_001__091631-registered.ome.tif'),
    ('CRC2/H&E/data_CRC06_P37_S34_A24_C59kX_E15_20220107_202112_212579-zlib.ome.tiff', 'CRC2/data_CRC06_18459_LSP10408_US_SCAN_OR_001__092559-registered.ome.tif'),
    ('CRC2/H&E/data_CRC07_P37_S35_A24_C59kX_E15_20220108_012037_490594-zlib.ome.tiff', 'CRC2/data_CRC07_18459_LSP10419_US_SCAN_OR_001__090907-registered.ome.tif'),
    ('CRC2/H&E/data_CRC08_P37_S57_Full_A24_C59nX_E15_20220224_011032_774034-zlib.ome.tiff', 'CRC2/data_CRC08_19510_C8_US_SCAN_OR_001__150825-registered.ome.tif'),
    ('CRC2/H&E/data_CRC09_P37_S37_A24_C59kX_E15_20220108_012113_953544-zlib.ome.tiff', 'CRC2/data_CRC09_18459_LSP10441_US_SCAN_OR_001__091844-registered.ome.tif'),
    ('CRC2/H&E/data_CRC10_P37_S38_A24_C59kX_E15_20220108_012130_664519-zlib.ome.tiff', 'CRC2/data_CRC10_18459_LSP10452_US_SCAN_OR_001__091355-registered.ome.tif'),
    ('CRC2/H&E/data_CRC11_P37_S43_Full_A24_C59mX_E15_20220128_171510_544056-zlib.ome.tiff', 'CRC2/data_CRC11_19510_C11_US_SCAN_OR_001__151039-registered.ome.tif'),
    ('CRC2/H&E/data_CRC12_P37_S44_Full_A24_C59mX_E15_20220128_171448_903938-zlib.ome.tiff', 'CRC2/data_CRC12_19510_C12_US_SCAN_OR_001__151249-registered.ome.tif'),
    ('CRC2/H&E/data_CRC13_P37_S45_Full_A24_C59mX_E15_20220128_171409_633341-zlib.ome.tiff', 'CRC2/data_CRC13_19510_C13_US_SCAN_OR_001__151503-registered.ome.tif'),
    ('CRC2/H&E/data_CRC14_P37_S46_Full_A24_C59mX_E15_20220128_013821_398547-zlib.ome.tiff', 'CRC2/data_CRC14_19510_C14_US_SCAN_OR_001__151737-registered.ome.tif'),
    ('CRC2/H&E/data_CRC15_P37_S47_Full_A24_C59mX_E15_20220128_020654_901143-zlib.ome.tiff', 'CRC2/data_CRC15_19510_C15_US_SCAN_OR_001__152234-registered.ome.tif'),
    ('CRC2/H&E/data_CRC16_P37_S48_Full_A24_C59mX_E15_20220129_015105_865195-zlib.ome.tiff', 'CRC2/data_CRC16_19510_C16_US_SCAN_OR_001__152020-registered.ome.tif'),
    ('CRC2/H&E/data_CRC17_P37_S49_Full_A24_C59mX_E15_20220129_015121_911264-zlib.ome.tiff', 'CRC2/data_CRC17_19510_C17_US_SCAN_OR_001__152525-registered.ome.tif'),
    ('CRC2/H&E/data_CRC18_P37_S50_Full_A24_C59mX_E15_20220129_015242_755602-zlib.ome.tiff', 'CRC2/data_CRC18_19510_C18_US_SCAN_OR_001__152757-registered.ome.tif'),
    ('CRC2/H&E/data_CRC19_P37_S51_Full_A24_C59mX_E15_20220129_015300_669681-zlib.ome.tiff', 'CRC2/data_CRC19_19510_C19_US_SCAN_OR_001__153041-registered.ome.tif'),
    ('CRC2/H&E/data_CRC20_P37_S52_Full_A24_C59mX_E15_20220129_015324_574779-zlib.ome.tiff', 'CRC2/data_CRC20_19510_C20_US_SCAN_OR_001__153341-registered.ome.tif'),
    ('CRC2/H&E/data_CRC21_P37_S58_Full_A24_C59nX_E15_20220224_011058_014787-zlib.ome.tiff', 'CRC2/data_CRC21_19510_C21_US_SCAN_OR_001__153607-registered.ome.tif'),
    ('CRC2/H&E/data_CRC22_P37_S59_Full_A24_C59nX_E15_20220224_011113_455637-zlib.ome.tiff', 'CRC2/data_CRC22_19510_C22_US_SCAN_OR_001__092420-registered.ome.tif'),
    ('CRC2/H&E/data_CRC23_P37_S60_Full_A24_C59nX_E15_20220224_011127_971497-zlib.ome.tiff', 'CRC2/data_CRC23_19510_C23_US_SCAN_OR_001__154147-registered.ome.tif'),
    ('CRC2/H&E/data_CRC24_P37_S61_Full_A24_C59nX_E15_20220224_011149_079291-zlib.ome.tiff', 'CRC2/data_CRC24_19510_C24_US_SCAN_OR_001__091904-registered.ome.tif'),
    ('CRC2/H&E/data_CRC25_P37_S62_Full_A24_C59nX_E15_20220224_011204_784145-zlib.ome.tiff', 'CRC2/data_CRC25_19510_C25_US_SCAN_OR_001__154712-registered.ome.tif'),
    ('CRC2/H&E/data_CRC26_P37_S63_Full_A24_C59nX_E15_20220224_011246_458738-zlib.ome.tiff', 'CRC2/data_CRC26_19510_C26_US_SCAN_OR_001__092131-registered.ome.tif'),
    ('CRC2/H&E/data_CRC27_P37_S64_Full_A24_C59nX_E15_20220224_011259_841605-zlib.ome.tiff', 'CRC2/data_CRC27_19510_C27_US_SCAN_OR_001__155205-registered.ome.tif'),
    ('CRC2/H&E/data_CRC28_P37_S65_Full_A24_C59nX_E15_20220224_011333_386280-zlib.ome.tiff', 'CRC2/data_CRC28_19510_C28_US_SCAN_OR_001__155413-registered.ome.tif'),
    ('CRC2/H&E/data_CRC29_P37_S66_Full_A24_C59nX_E15_20220224_011348_519133-zlib.ome.tiff', 'CRC2/data_CRC29_19510_C29_US_SCAN_OR_001__155859-registered.ome.tif'),
    ('CRC2/H&E/data_CRC30_P37_S67_Full_A24_C59nX_E15_20220224_011408_506939-zlib.ome.tiff', 'CRC2/data_CRC30_19510_C30_US_SCAN_OR_001__155702-registered.ome.tif'),
    ('CRC2/H&E/data_CRC31_P37_S74_Full_A24_C59qX_E15_20220302_234837_137590-zlib.ome.tiff', 'CRC2/data_CRC31_19510_C31_US_SCAN_OR_001__160203-registered.ome.tif'),
    ('CRC2/H&E/data_CRC32_P37_S75_Full_A24_C59qX_E15_20220302_235001_586560-zlib.ome.tiff', 'CRC2/data_CRC32_19510_C32_US_SCAN_OR_001__160434-registered.ome.tif'),
    ('CRC2/H&E/data_CRC33_01_P37_S76_01_A24_C59qX_E15_20220302_235136_561323-zlib.ome.tiff', 'CRC2/data_CRC33_01_19510_C33_US_SCAN_OR_001__160715-2-registered.ome.tif'),
    ('CRC2/H&E/data_CRC33_02_P37_S76_02_A24_C59qX_E15_20220302_235158_533766-zlib.ome.tiff', 'CRC2/data_CRC33_02_19510_C33_US_SCAN_OR_001__160715-registered.ome.tif'),
    ('CRC2/H&E/data_CRC34_P37_S77_Full_A24_C59qX_E15_20220302_235222_359806-zlib.ome.tiff', 'CRC2/data_CRC34_19510_C34_US_SCAN_OR_001__160949-registered.ome.tif'),
    ('CRC2/H&E/data_CRC35_P37_S78_Full_A24_C59qX_E15_20220302_235239_498836-zlib.ome.tiff', 'CRC2/data_CRC35_19510_C35_US_SCAN_OR_001__161209-registered.ome.tif'),
    ('CRC2/H&E/data_CRC36_P37_S79_Full_A24_C59qX_E15_20220302_235254_496641-zlib.ome.tiff', 'CRC2/data_CRC36_19510_C36_US_SCAN_OR_001__161442-registered.ome.tif'),
    ('CRC2/H&E/data_CRC37_P37_S80_Full_A24_C59qX_E15_20220307_235159_333000-zlib.ome.tiff', 'CRC2/data_CRC37_19510_C37_US_SCAN_OR_001__161733-registered.ome.tif'),
    ('CRC2/H&E/data_CRC38_P37_S81_Full_A24_C59qX_E15_20220302_235331_704703-zlib.ome.tiff', 'CRC2/data_CRC38_19510_C38_US_SCAN_OR_001__162018-registered.ome.tif'),
    ('CRC2/H&E/data_CRC39_P37_S82_Full_A24_C59qX_E15_20220304_200614_832683-zlib.ome.tiff', 'CRC2/data_CRC39_19510_C39_US_SCAN_OR_001__162343-registered.ome.tif'),
    ('CRC2/H&E/data_CRC40_P37_S83_Full_A24_C59qX_E15_20220304_200429_490805-zlib.ome.tiff', 'CRC2/data_CRC40_19510_P37-S83_C40_US_SCAN_OR_001__163912-registered.ome.tif'),
]

# -------- Helpers --------
def _pair_tag_from_paths(vpath: Path, rpath: Path, idx: int) -> str:
    """Create a short, unique tag for saving files."""
    m = re.search(r"(CRC\d+(_\d+)?)", vpath.name)
    base = m.group(1) if m else f"pair_{idx+1:02d}"
    return base

def compute_counts_all_tiles(real_res, virt_res):
    """
    Compare virtual vs real predictions on the overlapping tile grid.
    Keeps ALL tiles (including BACKGROUND).
    Returns (counts[T,T], T, tissue_names_full, stats_dict).
    """
    Ny_r, Nx_r = real_res["grid_shape"]
    Ny_v, Nx_v = virt_res["grid_shape"]
    Ny_o, Nx_o = min(Ny_r, Ny_v), min(Nx_r, Nx_v)

    grid_real = real_res["pred_grid"][:Ny_o, :Nx_o]
    grid_virt = virt_res["pred_grid"][:Ny_o, :Nx_o]
    flat_real = grid_real.ravel()
    flat_virt = grid_virt.ravel()

    # Keep ALL tiles (incl. BACK)
    mask = np.ones_like(flat_real, dtype=bool)
    virt_t = flat_virt[mask]
    real_t = flat_real[mask]

    # Use natural IDs 0..(num_classes-1)
    tissue_ids = list(range(num_classes))
    id2pos = {cid: i for i, cid in enumerate(tissue_ids)}
    T = len(tissue_ids)
    tissue_names_full = [id_to_full[cid] for cid in tissue_ids]

    if virt_t.size == 0:
        counts = np.zeros((T, T), dtype=int)
    else:
        virt_idx = np.fromiter((id2pos[int(c)] for c in virt_t), dtype=np.int32, count=virt_t.size)
        real_idx = np.fromiter((id2pos[int(c)] for c in real_t), dtype=np.int32, count=real_t.size)
        pair_index = virt_idx * T + real_idx
        counts = np.bincount(pair_index, minlength=T*T).reshape(T, T)

    agree = int(np.trace(counts))
    total = int(counts.sum())
    acc = (agree / total) if total > 0 else float('nan')

    stats = {
        "Ny_r": Ny_r, "Nx_r": Nx_r,
        "Ny_v": Ny_v, "Nx_v": Nx_v,
        "Ny_o": Ny_o, "Nx_o": Nx_o,
        "tiles_total_overlap": Ny_o * Nx_o,
        "tiles_used": int(mask.sum()),
        "agree_all": agree,
        "total_all": total,
        "acc_alltiles": acc,
    }
    return counts, T, tissue_names_full, stats

def save_pair_outputs(pair_tag, counts, tissue_names_full, acc_all):
    """
    Save CSVs and a heatmap (rows=virtual, cols=real) for ALL tiles (incl. BACK).
    """
    T = len(tissue_names_full)

    # Counts CSV
    counts_df = pd.DataFrame(counts, index=tissue_names_full, columns=tissue_names_full)
    counts_csv = SAVE_DIR / f"{pair_tag}_ALL_TILES_virtual_rows_vs_real_cols_COUNTS.csv"
    counts_df.to_csv(counts_csv)

    # Row-normalized integer percent CSV (virtual rows sum to 1)
    row_sums = counts.sum(axis=1, keepdims=True)
    with np.errstate(divide="ignore", invalid="ignore"):
        row_norm = counts / np.clip(row_sums, 1, None)
    row_pct_int = np.rint(row_norm * 100).astype(int)
    rowpct_df = pd.DataFrame(row_pct_int, index=tissue_names_full, columns=tissue_names_full)
    rowpct_csv = SAVE_DIR / f"{pair_tag}_ALL_TILES_virtual_to_real_ROW_PERCENT_INT.csv"
    rowpct_df.to_csv(rowpct_csv)

    # Heatmap figure
    CMAP = "viridis"
    FIGSIZE = (8, 8)
    TITLE_FS, LABEL_FS, TICK_FS, ANN_FS, CB_FS = 16, 14, 12, 14, 12

    fig, ax = plt.subplots(figsize=FIGSIZE)
    im = ax.imshow(row_norm, vmin=0.0, vmax=1.0, cmap=CMAP, aspect="equal", interpolation="nearest")

    ax.set_title(
        f"{pair_tag} — Virtual (rows) → Real (cols) — ALL tiles (incl. BACK) — "
        f"Overall accuracy: {acc_all*100:.1f}",
        fontsize=TITLE_FS
    )
    ax.set_xlabel("Real H&E classes", fontsize=LABEL_FS)
    ax.set_ylabel("Virtual H&E classes", fontsize=LABEL_FS)
    ax.set_xticks(np.arange(T)); ax.set_yticks(np.arange(T))
    ax.set_xticklabels(tissue_names_full, rotation=45, ha="right", fontsize=TICK_FS)
    ax.set_yticklabels(tissue_names_full, fontsize=TICK_FS)

    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Percent of virtual class tiles", fontsize=CB_FS)
    cb_ticks = np.linspace(0, 1, 6)
    cbar.set_ticks(cb_ticks)
    cbar.set_ticklabels([f"{int(t*100)}" for t in cb_ticks])
    for t in cbar.ax.get_yticklabels(): t.set_fontsize(CB_FS)

    # Annotate cells with integer percent; bold diagonal
    for i in range(T):
        for j in range(T):
            v = row_norm[i, j]
            txt_color = "white" if v < 0.5 else "black"
            ax.text(j, i, f"{row_pct_int[i, j]}",
                    ha="center", va="center",
                    fontsize=ANN_FS, fontweight=("bold" if i == j else "normal"),
                    color=txt_color)

    fig.tight_layout()
    fig.savefig(SAVE_DIR / f"{pair_tag}_ALL_TILES_virtual_to_real_row_percent_heatmap.pdf")
    fig.savefig(SAVE_DIR / f"{pair_tag}_ALL_TILES_virtual_to_real_row_percent_heatmap.png", dpi=300)
    plt.close(fig)

    return {"counts_csv": counts_csv, "rowpct_csv": rowpct_csv}

# -------- Run all pairs --------
pair_results = []
for idx, (virt_path_str, real_path_str) in enumerate(image_pairs):
    vpath, rpath = Path(virt_path_str), Path(real_path_str)
    pair_tag = _pair_tag_from_paths(vpath, rpath, idx)

    try:
        print(f"\n=== Processing {pair_tag} ===")
        virt_res = run_wsi(vpath, tag=f"{pair_tag}_virtual")
        real_res = run_wsi(rpath, tag=f"{pair_tag}_real")

        counts, T, tissue_names_full, stats = compute_counts_all_tiles(real_res, virt_res)
        outs = save_pair_outputs(pair_tag, counts, tissue_names_full, stats["acc_alltiles"])

        print(f"Tiles (real Ny×Nx): {stats['Ny_r']}×{stats['Nx_r']} | (virt Ny×Nx): {stats['Ny_v']}×{stats['Nx_v']}")
        if (stats['Ny_r'], stats['Nx_r']) != (stats['Ny_v'], stats['Nx_v']):
            print(f"Overlap used: {stats['Ny_o']}×{stats['Nx_o']}")
        print(f"Tiles used (ALL, incl. BACK): {stats['tiles_used']} / {stats['tiles_total_overlap']}")
        print(f"Accuracy (diag/total, ALL): {stats['agree_all']}/{stats['total_all']} = {stats['acc_alltiles']*100:.1f}")

        pair_results.append({
            "pair_tag": pair_tag,
            "counts": counts,
            "tissue_names_full": tissue_names_full,
            "stats": stats,
            "files": outs,
        })
    except Exception as e:
        print(f"[ERROR] {pair_tag}: {e}")
        pair_results.append({
            "pair_tag": pair_tag,
            "error": str(e),
        })

print(f"\nDone. Processed {len(pair_results)} pairs.")
# ==================== end batch block ====================


In [None]:

OUT_PDF = "figures/Heatmap.pdf"
OUT_CSV = "PerClass_Recall_column_normalized.csv"  # diagonal (recall per real class, %)
CMAP = "Blues"
FIGSIZE = (8, 8)
TITLE_FS, LABEL_FS, TICK_FS, ANN_FS, CB_FS = 18, 16, 14, 16, 14

def _plot_heatmap(mat, class_names, title, save_pdf_path=None, show=False):
    T = len(class_names)
    pct_int = np.rint(mat * 100).astype(int)

    fig, ax = plt.subplots(figsize=FIGSIZE)
    im = ax.imshow(mat, vmin=0.0, vmax=1.0, cmap=CMAP, aspect="equal", interpolation="nearest")
    ax.set_title(title, fontsize=TITLE_FS)
    ax.set_xlabel("Predictions on real H&E", fontsize=LABEL_FS)
    ax.set_ylabel("Predictions on virtual H&E", fontsize=LABEL_FS)
    ax.set_xticks(np.arange(T)); ax.set_yticks(np.arange(T))
    ax.set_xticklabels(class_names, rotation=45, ha="right", fontsize=TICK_FS)
    ax.set_yticklabels(class_names, fontsize=TICK_FS)

    cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label("Percent of real H&E prediction", fontsize=CB_FS)  # column-normalized
    cb_ticks = np.linspace(0, 1, 6)
    cbar.set_ticks(cb_ticks)
    cbar.set_ticklabels([f"{int(t*100)}" for t in cb_ticks])
    for t in cbar.ax.get_yticklabels():
        t.set_fontsize(CB_FS)

    # annotate cells with integer %
    for i in range(T):
        for j in range(T):
            v = mat[i, j]
            txt_color = "black" if v < 0.5 else "white"
            ax.text(j, i, f"{pct_int[i, j]}",
                    ha="center", va="center",
                    fontsize=ANN_FS, fontweight=("bold" if i == j else "normal"),
                    color=txt_color)

    fig.tight_layout()

    if save_pdf_path is not None:
        # ensure directory of the PDF exists (current working dir or nested path)
        Path(save_pdf_path).parent.mkdir(parents=True, exist_ok=True)
        fig.savefig(save_pdf_path, format="pdf", bbox_inches="tight")
        print(f"Saved {save_pdf_path}")

    if show:
        plt.show()
    else:
        plt.close(fig)

# Plot & save
_plot_heatmap(
    col_norm_micro, class_names,
    title="Column-normalized correspondence",
    save_pdf_path=OUT_PDF,
    show=True,  # set False if you don't want inline display
)


In [None]:
from pypdf import PdfReader, PdfWriter, Transformation

reader1 = PdfReader("figures/wsi_comparison.pdf")
reader2 = PdfReader("figures/Heatmap.pdf")

page1 = reader1.pages[0]   # base
page2 = reader2.pages[0]   # overlay

# Define a transformation: shrink to 50%, move right and down
transformation = (
    Transformation()
    .scale(sx=0.75, sy=0.75)      # scale down
    .translate(tx=580, ty=0)  # shift right (x) and down (y)
)

# Apply the transformation to page2 before merging
page1.merge_transformed_page(page2, transformation)

# Write output
writer = PdfWriter()
writer.add_page(page1)

with open("overlay.pdf", "wb") as f:
    writer.write(f)