### Imports & Configuration

In [16]:
# --- Imports ---
from pathlib import Path
import os, re, csv, numpy as np
from skimage import io
import torch
from torch.utils.data import Dataset
import torchvision.transforms as T

# --- Paths & constants (edit these) ---
PATCH_ROOT = Path("/home/user01/MS-RGCN-Plus/data/VPC/multiscale_patches_Train")
MAP_DIRS = [
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps1_T"),
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps2_T"),
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps3_T"),
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps4_T"),
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps5_T"),
    Path("/home/user01/MS-RGCN-Plus/data/VPC/Maps/Maps6_T"),
]

MASK_BASENAME = "{scid}_classimg_nonconvex.png"
OUT_CSV = "/home/user01/MS-RGCN-Plus/data/VPC/patch_labels_majority.csv"

BACKGROUND_VALUES = {255}
PATCH_SIZE = 512
MAGS = (10, 20, 40)
# --- Remap (dataset class ids -> training class ids) ---
# Your mapping:
#  '1'->0, '3'->1, '4'->2, '5'->3, '0'->4, '6'->5
MAP_VPC = {'1': 0, '3': 1, '4': 2, '5': 3, '0': 4, '6': 5}

# Target number of classes (after remap)
NUM_CLASSES = 6

# Maximum original class id that might appear in masks (because of '6' in the mapping)
NUM_CLASSES_ORIG = max(int(k) for k in MAP_VPC.keys()) + 1  # -> 7


# Set this if your masks are RGB instead of single-channel ints
COLOR_TO_CLASS = None


### Helpers for Path Parsing & Mask Loading


In [17]:
# --- Regex parsers for slide/core and coordinates ---
slide_core_re = re.compile(r"slide(\d{3})_core(\d{3})", re.IGNORECASE)
xy_re = re.compile(r'(?P<x>\d+)_(?P<y>\d+)\.png$', re.IGNORECASE)

def slide_core_id_from_path(p: str) -> str:
    m = slide_core_re.search(p)
    if not m:
        raise ValueError(f"slide/core id not found in: {p}")
    return f"slide{m.group(1)}_core{m.group(2)}"

def coords_from_name(name: str):
    m = xy_re.search(name)
    if not m:
        raise ValueError(f"cannot parse coords from {name}")
    return int(m.group('x')), int(m.group('y'))

def find_annotator_masks_for_core(scid: str):
    name = MASK_BASENAME.format(scid=scid)
    return [root / name for root in MAP_DIRS if (root / name).exists()]

def mask_to_ids(arr):
    if arr.ndim == 2:
        return arr
    if arr.ndim == 3 and COLOR_TO_CLASS is not None:
        h, w, _ = arr.shape
        out = np.full((h, w), 255, dtype=np.uint8)
        rgb = arr.reshape(-1, 3)
        out_flat = out.reshape(-1)
        lut = {(r<<16)+(g<<8)+b: c for (r,g,b), c in COLOR_TO_CLASS.items()}
        keys = (rgb[:,0].astype(np.int64)<<16) + (rgb[:,1].astype(np.int64)<<8) + rgb[:,2].astype(np.int64)
        unique_keys, inv = np.unique(keys, return_inverse=True)
        map_vals = np.full(unique_keys.shape, 255, dtype=np.uint16)
        for i, k in enumerate(unique_keys):
            if k in lut:
                map_vals[i] = lut[k]
        out_flat[:] = map_vals[inv]
        return out
    raise ValueError("Mask appears RGB but COLOR_TO_CLASS is None.")


### Histogram & Patch Label Computation

In [18]:
def valid_hist(arr, num_classes=NUM_CLASSES_ORIG, ignore=BACKGROUND_VALUES):
    vals, counts = np.unique(arr, return_counts=True)
    h = np.zeros(num_classes, dtype=np.float64)  # original-id histogram
    for v, c in zip(vals, counts):
        if v in ignore:
            continue
        if 0 <= v < num_classes:
            h[v] += c
    s = h.sum()
    if s > 0:
        h /= s
    return h  # original-id probs


def patch_label_probs_from_maps(patch_path: str):
    scid = slide_core_id_from_path(patch_path)
    x, y = coords_from_name(os.path.basename(patch_path))
    expert_masks = find_annotator_masks_for_core(scid)
    if len(expert_masks) == 0:
        return None, None, 0

    acc_orig = np.zeros(NUM_CLASSES_ORIG, dtype=np.float64)
    used = 0
    for mp in expert_masks:
        m = io.imread(str(mp))
        m = mask_to_ids(m)
        crop = m[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
        if crop.shape[:2] != (PATCH_SIZE, PATCH_SIZE):
            continue
        h_orig = valid_hist(crop, num_classes=NUM_CLASSES_ORIG)
        if h_orig.sum() > 0:
            acc_orig += h_orig
            used += 1

    if used == 0:
        return None, None, 0

    # average across experts in ORIGINAL space
    probs_orig = acc_orig / used

    # remap to your target class order
    probs_tgt = np.zeros(NUM_CLASSES, dtype=np.float64)
    for orig_id in range(NUM_CLASSES_ORIG):
        key = str(orig_id)
        if key in MAP_VPC:
            probs_tgt[MAP_VPC[key]] += probs_orig[orig_id]

    # hard label after remap (tie → higher target id)
    hard = int(np.flatnonzero(probs_tgt == probs_tgt.max()).max())
    return probs_tgt, hard, used



### Crawling Patches & Building the CSV

In [19]:
def iter_patches(root=PATCH_ROOT, mags=MAGS):
    for slide_dir in sorted(root.glob("slide*_core*")):
        for mag in mags:
            leaf = slide_dir / "512" / str(mag)
            if not leaf.exists():
                continue
            for p in sorted(leaf.glob("*.png")):
                yield str(p)

def build_csv(out_csv=OUT_CSV):
    import tqdm
    rows = []
    all_patches = list(iter_patches(PATCH_ROOT, MAGS))
    for p in tqdm.tqdm(all_patches, total=len(all_patches), desc="Labeling patches"):
        probs, hard, used = patch_label_probs_from_maps(p)
        if probs is None:
            continue
        rows.append([p, used, hard] + list(probs.astype(np.float32)))

    Path(out_csv).parent.mkdir(parents=True, exist_ok=True)
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        header = ["path", "n_experts_used", "hard_label"] + [f"p{c}" for c in range(NUM_CLASSES)]
        w.writerow(header)
        w.writerows(rows)
    print(f"[OK] Wrote {len(rows)} rows to {out_csv}")
    return out_csv


### GPU‑accelerated labeling (per‑core caching)

In [24]:
import torch
import torch.nn.functional as F
from collections import defaultdict
from skimage import io
from tqdm import tqdm

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Labeling device:", DEVICE)

def _bucket_patches_by_core(patch_paths):
    """Group patch paths by slide_core id and cache their (x,y)."""
    buckets = defaultdict(list)
    for p in patch_paths:
        scid = slide_core_id_from_path(p)     # 'slideNNN_coreMMM'
        x, y = coords_from_name(os.path.basename(p))
        buckets[scid].append((p, x, y))
    # keep deterministic order
    for k in buckets:
        buckets[k].sort(key=lambda t: (t[2], t[1]))  # sort by y,x
    return buckets

def _load_masks_for_core_to_device(scid):
    """Load all available expert masks for this core, return tensor [E,H,W] on DEVICE."""
    paths = find_annotator_masks_for_core(scid)
    if len(paths) == 0:
        return None, 0
    arrs = []
    for mp in paths:
        m = io.imread(str(mp))
        m = mask_to_ids(m)    # -> uint8 single-channel, classes 0..NUM_CLASSES-1 or 255 background
        arrs.append(m)
    # stack -> [E,H,W]
    m_np = np.stack(arrs, axis=0)
    m_t = torch.from_numpy(m_np).to(DEVICE, non_blocking=True)
    return m_t, m_np.shape[0]

def _patch_probs_from_masks_gpu(masks_EHW, x, y):
    """
    masks_EHW: torch uint8 [E,H,W] on DEVICE
    returns probs[NUM_CLASSES], hard_label, used_experts
    """
    E, H, W = masks_EHW.shape
    if y+PATCH_SIZE > H or x+PATCH_SIZE > W:
        return None, None, 0

    crop = masks_EHW[:, y:y+PATCH_SIZE, x:x+PATCH_SIZE]            # [E,Ps,Ps]
    valid = (crop != 255)                                          # background mask
    if not valid.any():
        return None, None, 0

    # one-hot over classes 0..NUM_CLASSES-1 (ignore where invalid)
    # clamp to avoid accidental 255 indexing
    crop_clamped = torch.clamp(crop.long(), 0, NUM_CLASSES_ORIG-1)
    oh = F.one_hot(crop_clamped, num_classes=NUM_CLASSES_ORIG)  # [E,Ps,Ps,C_orig]
    oh = oh.permute(0,3,1,2).float()                               # [E,C,Ps,Ps]
    oh = oh * valid.unsqueeze(1).float()                            # zero out background

    # per-expert class counts
    counts_EC = oh.sum(dim=(2,3))                                  # [E,C]

    # normalize per expert to probs; filter experts that had any valid pixel
    pix_per_exp = counts_EC.sum(dim=1)                             # [E]
    good = pix_per_exp > 0
    if not torch.any(good):
        return None, None, 0
    probs_EC = torch.zeros_like(counts_EC, dtype=torch.float32)
    probs_EC[good] = counts_EC[good] / pix_per_exp[good].unsqueeze(1)

    # average across used experts → final probs
    used = int(good.sum().item())
    probs_orig = probs_EC[good].mean(dim=0)  # [C_orig]

    # remap to target order
    probs_tgt = torch.zeros(NUM_CLASSES, device=probs_orig.device, dtype=probs_orig.dtype)
    for key, tgt in MAP_VPC.items():
        orig = int(key)
        if orig < probs_orig.numel():
            probs_tgt[tgt] += probs_orig[orig]

    # tie → higher target id
    hard = int(torch.nonzero(probs_tgt == probs_tgt.max(), as_tuple=False).max().item())
    return probs_tgt.detach().cpu().numpy(), hard, used

def build_csv_gpu(out_csv=OUT_CSV, mags=MAGS, flush_cuda_each_core=True, max_cores=None):
    """
    Faster CSV builder (with tqdm):
      - buckets patches per core
      - loads each core's masks once to GPU
      - computes per-patch soft labels on GPU
      - optionally limit the number of cores via max_cores
      - progress bars for cores and per-core patches
    """
    # gather all patch paths
    all_patches = list(iter_patches(PATCH_ROOT, mags=mags))
    buckets = _bucket_patches_by_core(all_patches)

    # keep only a subset of cores (deterministic order)
    core_ids = sorted(buckets.keys())
    if max_cores is not None:
        core_ids = core_ids[:int(max_cores)]

    print(f"Cores to process: {len(core_ids)}")

    rows = []
    processed = 0

    # ---- tqdm over cores ----
    for scid in tqdm(core_ids, desc="Processing cores", unit="core"):
        items = buckets[scid]

        masks_EHW, E = _load_masks_for_core_to_device(scid)
        if masks_EHW is None:
            # no masks for this core; skip its patches
            processed += len(items)
            continue

        # ---- tqdm over patches within this core ----
        for (ppath, x, y) in tqdm(items, desc=f"{scid} patches", unit="patch", leave=False):
            probs, hard, used = _patch_probs_from_masks_gpu(masks_EHW, x, y)
            if probs is None:
                continue
            rows.append([ppath, used, hard] + [float(p) for p in probs])
            processed += 1

        # free GPU memory per core
        del masks_EHW
        if DEVICE == "cuda" and flush_cuda_each_core:
            torch.cuda.empty_cache()

    Path(out_csv).parent.mkdir(parents=True, exist_ok=True)
    with open(out_csv, "w", newline="") as f:
        w = csv.writer(f)
        header = ["path", "n_experts_used", "hard_label"] + [f"p{c}" for c in range(NUM_CLASSES)]
        w.writerow(header)
        w.writerows(rows)

    print(f"[OK] Wrote {len(rows)} labeled patches to {out_csv}")
    return out_csv


Labeling device: cuda


### Dataset Class with Soft or Hard Labels

In [25]:
class PatchCSVWithSoftLabels(Dataset):
    def __init__(self, csv_file, transform=None, target_kind="soft"):
        self.items = []
        with open(csv_file) as f:
            r = csv.DictReader(f)
            for row in r:
                used = int(row["n_experts_used"])
                if used == 0:
                    continue
                p = row["path"]
                probs = torch.tensor([float(row[f"p{c}"]) for c in range(NUM_CLASSES)], dtype=torch.float32)
                hard  = int(row["hard_label"])
                self.items.append((p, probs, hard))
        self.tf = transform or T.ToTensor()
        assert target_kind in ("soft","hard")
        self.target_kind = target_kind

    def __len__(self): return len(self.items)

    def __getitem__(self, i):
        p, probs, hard = self.items[i]
        from PIL import Image
        img = Image.open(p).convert("RGB")
        img = self.tf(img)
        if self.target_kind == "soft":
            y = probs
        else:
            y = torch.tensor(hard, dtype=torch.long)
        return {"img": img, "label": y, "path": p}


### Run Label Building & Quick Dataset Test

In [31]:
# Build CSV
# csv_path = build_csv(OUT_CSV)
csv_path = build_csv_gpu(OUT_CSV, mags=(10,20,40))
# csv_path = build_csv_gpu(OUT_CSV, mags=(40,))

# Example usage
AUG = T.Compose([
    T.ToTensor(),
    # T.RandomVerticalFlip(),
    # T.RandomHorizontalFlip(),
    # T.RandomRotation([0,90,180,270]),
])
train_ds = PatchCSVWithSoftLabels(csv_path, transform=AUG, target_kind="soft")
val_ds   = PatchCSVWithSoftLabels(csv_path, transform=T.ToTensor(), target_kind="soft")

print("Train samples:", len(train_ds))
if len(train_ds) > 0:
    s = train_ds[0]
    print("Example shapes:", s["img"].shape, s["label"].shape, s["path"])


Cores to process: 244


Processing cores: 100%|██████████| 244/244 [03:20<00:00,  1.22core/s]


[OK] Wrote 71244 labeled patches to /home/user01/MS-RGCN-Plus/data/VPC/patch_labels_majority.csv
Train samples: 71244
Example shapes: torch.Size([3, 512, 512]) torch.Size([6]) /home/user01/MS-RGCN-Plus/data/VPC/multiscale_patches_Train/slide001_core003/512/10/0_0.png


### 📄 Patch Labels CSV – Column Description (with `map_vpc` Applied)

This CSV contains **patch-level labels** generated by majority voting across expert annotation masks, then **remapped** to match the classification scheme used in the ResNet training code.

| Column Name        | Type     | Description |
|--------------------|----------|-------------|
| **path**           | `str`    | Absolute path to the patch image file (e.g., `/home/user01/MS-RGCN-Plus/data/VPC/multiscale_patches_Train/slide001_core003/512/40/0_1024.png`). Identifies the specific patch at a given magnification. |
| **n_experts_used** | `int`    | Number of pathologists’ maps that were successfully loaded for this core/patch and used in voting. This may be less than the total available maps if some masks are missing. |
| **hard_label**     | `int`    | **Remapped class index** based on `map_vpc` (`{'1':0, '3':1, '4':2, '5':3, '0':4, '6':5}`). This is the integer class ID you should use for training. It corresponds to Gleason grades but in the custom index order used by your model. |
| **p0 ... p5**      | `float`  | Soft-label probabilities for each remapped class index (`0`–`5`). These are normalized to sum to 1.0 and represent the proportion of experts voting for each class. |

---

#### 🔄 **Mapping Logic**
The `map_vpc` remap converts original Gleason grade labels into model-specific class IDs:

| Original Gleason Grade | Mapped Class ID |
|------------------------|-----------------|
| 1                      | 0 |
| 3                      | 1 |
| 4                      | 2 |
| 5                      | 3 |
| 0 (normal)             | 4 |
| 6                      | 5 |

> Any missing Gleason grades in the dataset will simply not appear in the CSV.

---

#### ✅ **Usage Notes**
- `hard_label` is the **final class ID** you should feed into your training pipeline.
- `p0...p5` can be used for **soft-label training** or as confidence scores.
- If you filter patches by magnification, ensure you only train with patches matching your intended `magnification` value.
- The CSV can be regenerated at any time by running the GPU CSV builder with `map_vpc` applied inside the loop.


In [None]:
import pandas as pd

# Path to the CSV you just generated
csv_path = "/home/user01/MS-RGCN-Plus/data/VPC/patch_labels_majority.csv"

# Load the CSV
df = pd.read_csv(csv_path)

# Show first few rows and shape
df_shape = df.shape
df.head()


Unnamed: 0,path,n_experts_used,hard_label,p0,p1,p2,p3,p4,p5
0,/home/user01/MS-RGCN-Plus/data/VPC/multiscale_...,4,4,0.0,0.0,0.0,0.0,1.0,0.0
1,/home/user01/MS-RGCN-Plus/data/VPC/multiscale_...,4,4,0.0,0.0,0.0,0.0,1.0,0.0
2,/home/user01/MS-RGCN-Plus/data/VPC/multiscale_...,4,4,0.0,0.0,0.067184,0.0,0.932816,0.0
3,/home/user01/MS-RGCN-Plus/data/VPC/multiscale_...,4,4,0.0,0.0,0.137395,0.0,0.862605,0.0
4,/home/user01/MS-RGCN-Plus/data/VPC/multiscale_...,4,4,0.0,0.0,0.128264,0.0,0.871736,0.0


In [30]:
df['hard_label'].value_counts()

hard_label
4    13175
2     5693
1     3651
0     1060
3      169
Name: count, dtype: int64

### Patch & Majority‑Vote Mask Visualization

In [1]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap, BoundaryNorm

# ---- discrete colors for classes 0..5 (edit if you like)
# 0: Benign, 1: G3, 2: G4, 3: G5, 4: (unused in paper), 5: (unused)  <-- adjust to your scheme if needed
CLASS_NAMES = {0: "0/Benign", 1: "1", 2: "2", 3: "3", 4: "4", 5: "5"}
PALETTE = np.array([
    [  0, 170,   0],  # 0
    [ 30, 144, 255],  # 1
    [255, 140,   0],  # 2
    [220,  20,  60],  # 3
    [148,   0, 211],  # 4
    [128, 128, 128],  # 5
], dtype=np.uint8)

cmap = ListedColormap(PALETTE / 255.0)
norm = BoundaryNorm(list(range(7)), cmap.N)  # bins [0..6)

def _majority_vote_crop(expert_mask_paths, x, y):
    """Return per-pixel majority-vote mask over experts for this patch window.
       Pixels with no votes stay 255 (ignored). Ties -> higher class id."""
    H = W = PATCH_SIZE
    counts = np.zeros((NUM_CLASSES, H, W), dtype=np.uint16)
    any_vote = np.zeros((H, W), dtype=bool)

    for mp in expert_mask_paths:
        m = io.imread(str(mp))
        m = mask_to_ids(m)
        crop = m[y:y+H, x:x+W]
        if crop.shape[:2] != (H, W):
            continue
        for c in range(NUM_CLASSES):
            votes = (crop == c)
            counts[c] += votes
            any_vote |= votes

    if not any_vote.any():
        return None  # no label coverage for this window

    # argmax across classes; for ties, take the highest class id
    # trick: add tiny class-index-based epsilon so ties prefer higher
    eps = np.arange(NUM_CLASSES, dtype=np.float32)[:, None, None] * 1e-6
    maj = np.argmax(counts.astype(np.float32) + eps, axis=0).astype(np.uint8)

    # mark pixels with no votes as 255 (ignored)
    maj[~any_vote] = 255
    return maj

def _colorize_mask(mask_uint8):
    """map class ids to RGB; 255 -> transparent mask"""
    H, W = mask_uint8.shape
    rgb = np.zeros((H, W, 3), dtype=np.uint8)
    valid = mask_uint8 != 255
    rgb[valid] = PALETTE[mask_uint8[valid]]
    return rgb, valid

def visualize_patch_with_vote(patch_path, alpha=0.45, show_probs=True):
    """Show: image, majority-vote mask, and overlay. Prints patch-level probs too."""
    patch_img = io.imread(patch_path)
    scid = slide_core_id_from_path(patch_path)
    x, y = coords_from_name(os.path.basename(patch_path))
    expert_masks = find_annotator_masks_for_core(scid)

    if len(expert_masks) == 0:
        print(f"[WARN] No expert masks found for {scid}")
        return

    vote = _majority_vote_crop(expert_masks, x, y)
    probs, hard, used = patch_label_probs_from_maps(patch_path)

    fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    axs[0].imshow(patch_img)
    axs[0].set_title("Patch")
    axs[0].axis('off')

    if vote is None:
        axs[1].text(0.5, 0.5, "No label coverage\nfor this window", ha='center', va='center', fontsize=12)
        axs[1].axis('off')
        axs[2].imshow(patch_img)
        axs[2].axis('off')
    else:
        axs[1].imshow(vote, cmap=cmap, norm=norm, vmin=0, vmax=5)
        axs[1].set_title("Majority‑vote mask")
        axs[1].axis('off')

        color_mask, valid = _colorize_mask(vote)
        # build RGBA overlay where valid=1 has alpha, others 0
        overlay = np.dstack([color_mask, (valid.astype(np.float32) * alpha * 255).astype(np.uint8)])
        axs[2].imshow(patch_img)
        axs[2].imshow(overlay)
        axs[2].set_title("Overlay")
        axs[2].axis('off')

    # simple legend
    handles = []
    import matplotlib.patches as mpatches
    for cid, name in CLASS_NAMES.items():
        swatch = mpatches.Patch(color=PALETTE[cid]/255.0, label=f"{cid}: {name}")
        handles.append(swatch)
    fig.legend(handles=handles, loc="lower center", ncol=min(6, len(handles)), bbox_to_anchor=(0.5, -0.02))
    plt.tight_layout()
    plt.show()

    if show_probs and probs is not None:
        print(f"Experts used: {used}   Hard label: {hard}")
        for c, p in enumerate(probs):
            print(f"  p[{c}] = {p:.3f}")
    elif show_probs:
        print("No per‑patch probabilities (no valid expert crops).")

# ---- Example usage: pick a patch and visualize
# p = next(iter_patches(PATCH_ROOT, mags=(40,)))   # e.g., only 40×
# visualize_patch_with_vote(p)
