In [1]:
# Minimal deps for MedSAM-based zero-shot labeling
%pip install -q --no-deps "git+https://github.com/facebookresearch/segment-anything.git" scikit-image opencv-python seaborn huggingface_hub

import torch
print("Torch:", torch.__version__)

Note: you may need to restart the kernel to use updated packages.
Torch: 2.8.0+cu128


In [2]:
# Run this ONCE per kernel, before other tqdm imports
import os, sys
os.environ["TQDM_NOTEBOOK"] = "0"   # tell tqdm to not use notebook widgets
os.environ["TQDM_DISABLE"] = ""     # ensure tqdm is enabled (not disabled)

from tqdm import tqdm

# Optional: clear any previous widget placeholders
try:
    from IPython.display import clear_output
    clear_output(wait=True)
except Exception:
    pass

print("tqdm configured for text output. Type:", type(tqdm))

tqdm configured for text output. Type: <class 'type'>


In [3]:
from pathlib import Path

class CFG:
    # Input images: directly under each class folder
    ORIG_ROOT = "/home/philipdt/IKT-project/data/EndoscopicBladderTissue"
    CLASSES   = ["HGC", "LGC", "NST", "NTL"]

    # Choose class set: "basic" or "extended"
    CLASS_SET = "extended"

    CLASS_SETS = {
        "basic":    {0:"background", 1:"vessel", 2:"mucosa"},
        "extended": {0:"background", 1:"vessel", 2:"mucosa", 3:"tumor", 4:"instrument", 5:"specular"}
    }
    CLASS_COLORS = {
        "basic":    {0:(0,0,0), 1:(0,255,255), 2:(255,165,0)},
        "extended": {0:(0,0,0), 1:(0,255,255), 2:(255,165,0), 3:(255,0,0), 4:(128,128,128), 5:(255,255,255)}
    }

    # MedSAM settings (keep as is)
    SAM_TYPE = "vit_b"
    MEDSAM_CKPT = "/home/philipdt/IKT-project/weights/medsam_vit_b.pth"
    AUTO_DOWNLOAD = True
    OUTPUT_LABEL_ROOT = "/home/philipdt/IKT-project/auto_labels_medsam"

    # SAM generator knobs (keep as is)
    POINTS_PER_SIDE = 24
    PRED_IOU_THRESH = 0.88
    STABILITY_THRESH = 0.90
    CROP_N_LAYERS = 1
    CROP_DOWNSCALE = 2
    MIN_MASK_REGION_AREA = 256
    COVERAGE_LIMIT = 0.70

    # THESE ARE THE ONES TO CHANGE - make them more permissive:
    VESSEL_SIGMA_RANGE = (1.0, 3.0)   # keep same
    VESSEL_PCT = 95.0                 # CHANGED from 98.0
    SPECULAR_V_MIN = 200              # CHANGED from 230
    SPECULAR_S_MAX = 50               # CHANGED from 40
    INSTR_EDGE_PCT = 0.08             # CHANGED from 0.12
    TUMOR_RED_RATIO = 1.15            # CHANGED from 1.25
    MIN_REGION_AREA = 256             # keep same

# Validate paths
if not Path(CFG.ORIG_ROOT).is_dir():
    raise SystemExit(f"ORIG_ROOT not found: {CFG.ORIG_ROOT}")
Path(CFG.OUTPUT_LABEL_ROOT).mkdir(parents=True, exist_ok=True)

print("Input:", CFG.ORIG_ROOT)
print("Output labels:", CFG.OUTPUT_LABEL_ROOT)
print("Class set:", CFG.CLASS_SET, CFG.CLASS_SETS[CFG.CLASS_SET])

Input: /home/philipdt/IKT-project/data/EndoscopicBladderTissue
Output labels: /home/philipdt/IKT-project/auto_labels_medsam
Class set: extended {0: 'background', 1: 'vessel', 2: 'mucosa', 3: 'tumor', 4: 'instrument', 5: 'specular'}


In [4]:
# Utilities (force console tqdm to avoid "Loading widget..." hang)
import os, numpy as np, cv2
from pathlib import Path
from typing import Dict, List, Tuple
from skimage.filters import frangi

# Force non-widget tqdm
os.environ["TQDM_NOTEBOOK"] = "0"
try:
    from tqdm.std import tqdm  # text mode
except ImportError:
    from tqdm import tqdm

IMG_EXTS = {".png",".jpg",".jpeg",".bmp",".tif",".tiff",".PNG",".JPG"}

def list_images(d: Path) -> List[Path]:
    if not d.is_dir(): return []
    return sorted([p for p in d.iterdir() if p.is_file() and p.suffix in IMG_EXTS])

def ensure_dir(p: Path):
    p.mkdir(parents=True, exist_ok=True)

def colorize_mask(mask: np.ndarray, colors: Dict[int, Tuple[int,int,int]]) -> np.ndarray:
    out = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for k, c in colors.items():
        out[mask == k] = c
    return out

def vesselness(gray_u8: np.ndarray, sig_range=(1.0,3.0)) -> np.ndarray:
    g = gray_u8.astype(np.float32) / 255.0
    v = frangi(g, sigmas=np.linspace(sig_range[0], sig_range[1], 4), black_ridges=False)
    v = np.nan_to_num(v)
    v = (v - v.min()) / (v.max() - v.min() + 1e-6)
    return v

def specular_mask(bgr: np.ndarray, v_min=230, s_max=40) -> np.ndarray:
    hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
    h,s,v = cv2.split(hsv)
    return ((v >= v_min) & (s <= s_max)).astype(np.uint8)

def instrument_mask(bgr: np.ndarray, edge_pct=0.12) -> np.ndarray:
    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
    edges = cv2.Canny(gray, 80, 160)
    return edges, edges.mean() >= edge_pct*255

def tumor_likelihood(bgr: np.ndarray, red_ratio=1.25) -> np.ndarray:
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) + 1e-6
    R, G, B = rgb[...,0], rgb[...,1], rgb[...,2]
    return ((R / ((G+B)/2.0)) >= red_ratio).astype(np.uint8)

def remove_small_components(mask_bin: np.ndarray, min_area: int) -> np.ndarray:
    num, labels, stats, _ = cv2.connectedComponentsWithStats(mask_bin.astype(np.uint8), 8)
    keep = np.zeros_like(mask_bin, dtype=np.uint8)
    for i in range(1, num):
        if stats[i, cv2.CC_STAT_AREA] >= min_area:
            keep[labels == i] = 1
    return keep

In [5]:
# MedSAM load cell: (1) optional Google Drive download → (2) load weights → (3) init mask generator

from pathlib import Path
import os, sys, re, subprocess
import torch

# 1) CONFIG: set your local destination and (optionally) paste your Google Drive share URL
DEST = "/home/philipdt/IKT-project/weights/medsam_vit_b.pth"  # where to store/find the checkpoint

# Segment-Anything backbone type for MedSAM
SAM_TYPE = "vit_b"  # ViT-B checkpoint

# Mask generator knobs (tweak as you like)
POINTS_PER_SIDE = 24
PRED_IOU_THRESH = 0.88
STABILITY_THRESH = 0.90
CROP_N_LAYERS = 0
CROP_DOWNSCALE = 2
MIN_MASK_REGION_AREA = 64

# 2) Ensure segment-anything is installed and import
try:
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
except ImportError:
    print("[INFO] Installing segment-anything...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "git+https://github.com/facebookresearch/segment-anything.git"])
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator


ckpt_path = Path(DEST)

# 4) Load MedSAM weights into the SAM ViT-B architecture
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Build the architecture first
sam = sam_model_registry[SAM_TYPE](checkpoint=None)

# Load weights (handle different dict formats)
state = torch.load(str(ckpt_path), map_location="cpu")
if isinstance(state, dict) and "state_dict" in state:
    state = state["state_dict"]
missing, unexpected = sam.load_state_dict(state, strict=False)
print(f"[INFO] MedSAM weights loaded. Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")

sam.to(device=DEVICE)

# 5) Create the automatic mask generator
mask_gen = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=POINTS_PER_SIDE,
    pred_iou_thresh=PRED_IOU_THRESH,
    stability_score_thresh=STABILITY_THRESH,
    crop_n_layers=CROP_N_LAYERS,
    crop_n_points_downscale_factor=CROP_DOWNSCALE,
    min_mask_region_area=MIN_MASK_REGION_AREA,
)

print("[READY] MedSAM loaded and mask generator initialized.")

Using device: cuda
[INFO] MedSAM weights loaded. Missing keys: 0, Unexpected keys: 0
[READY] MedSAM loaded and mask generator initialized.


In [6]:
import matplotlib.pyplot as plt
import sys, time

# def classify_masks(bgr: np.ndarray, masks: List[np.ndarray], class_set="extended", debug=False) -> np.ndarray:
#     H, W = bgr.shape[:2]
#     labels = np.zeros((H, W), dtype=np.uint8)
    
#     # Precompute detection maps
#     gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
#     vmap = vesselness(gray, sig_range=CFG.VESSEL_SIGMA_RANGE)
#     spec = specular_mask(bgr, v_min=CFG.SPECULAR_V_MIN, s_max=CFG.SPECULAR_S_MAX)
#     edges, _ = instrument_mask(bgr, edge_pct=CFG.INSTR_EDGE_PCT)
#     edge_density = edges.astype(np.float32) / 255.0
    
#     # Improved tumor detection
#     rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) + 1e-6
#     R, G, B = rgb[...,0], rgb[...,1], rgb[...,2]
#     red_ratio_mask = ((R / ((G+B)/2.0)) >= CFG.TUMOR_RED_RATIO).astype(np.uint8)
    
#     # Use adaptive threshold for vesselness (try multiple percentiles)
#     vessel_thresholds = [np.percentile(vmap, p) for p in [90, 95, 98]]
    
#     if debug:
#         debug_classification(bgr, masks[:5])
    
#     # Sort masks by area (largest first) 
#     masks_with_area = [(m, m.sum()) for m in masks if m.sum() >= CFG.MIN_REGION_AREA]
#     masks_with_area.sort(key=lambda x: x[1], reverse=True)
    
#     covered = np.zeros((H, W), dtype=np.uint8)
    
#     for m, area in masks_with_area:
#         # Remove overlap
#         m_clean = (m & (~covered.astype(bool))).astype(np.uint8)
#         if m_clean.sum() == 0:
#             continue
            
#         m_bool = m_clean.astype(bool)
        
#         # Calculate ratios
#         spec_ratio = float(spec[m_bool].mean()) if m_bool.sum() > 0 else 0
#         edge_ratio = float(edge_density[m_bool].mean()) if m_bool.sum() > 0 else 0
#         tumor_ratio = float(red_ratio_mask[m_bool].mean()) if m_bool.sum() > 0 else 0
        
#         # Try multiple vessel thresholds
#         vessel_ratios = [float((vmap[m_bool] > thr).mean()) for thr in vessel_thresholds]
#         vessel_ratio = max(vessel_ratios)  # Use the most permissive
        
#         if class_set == "basic":
#             cls_id = 1 if vessel_ratio > 0.05 else 2
#         else:
#             if spec_ratio > 0.02:       # CHANGED: more sensitive specular detection
#                 cls_id = 5
#             elif edge_ratio > 0.05:     # CHANGED: more sensitive instrument detection
#                 cls_id = 4
#             elif vessel_ratio > 0.06:   # CHANGED: more sensitive vessel detection
#                 cls_id = 1
#             elif tumor_ratio > 0.08:    # CHANGED: more sensitive tumor detection
#                 cls_id = 3
        
#         labels[m_bool] = cls_id
#         covered[m_bool] = 1
        
#         if covered.mean() >= CFG.COVERAGE_LIMIT:
#             break
            
#     return labels

def classify_masks(bgr: np.ndarray, masks: List[np.ndarray], class_set="extended") -> np.ndarray:
    H, W = bgr.shape[:2]
    labels = np.zeros((H, W), dtype=np.uint8)

    # Precompute helper maps
    gray = cv2.cvtColor(bgr, cv2.COLOR_BGR2GRAY)
    vmap = vesselness(gray, sig_range=CFG.VESSEL_SIGMA_RANGE)
    spec = specular_mask(bgr, v_min=CFG.SPECULAR_V_MIN, s_max=CFG.SPECULAR_S_MAX).astype(np.uint8)
    edges, _ = instrument_mask(bgr, edge_pct=CFG.INSTR_EDGE_PCT)
    edge_density = (edges.astype(np.float32) / 255.0)
    
    # Tumor detection
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB).astype(np.float32) + 1e-6
    R, G, B = rgb[...,0], rgb[...,1], rgb[...,2]
    red_ratio_mask = ((R / ((G+B)/2.0)) >= CFG.TUMOR_RED_RATIO).astype(np.uint8)
    lab = cv2.cvtColor(bgr, cv2.COLOR_BGR2LAB).astype(np.float32)
    a_star = lab[...,1]
    a_thresh = np.percentile(a_star, 75)
    red_lab_mask = (a_star >= a_thresh).astype(np.uint8)
    tumor_like = np.maximum(red_ratio_mask, red_lab_mask)

    # Global vesselness threshold
    thr_vessel_img = np.percentile(vmap, CFG.VESSEL_PCT)
    
    print(f"\n=== DEBUG INFO ===")
    print(f"Image size: {H}x{W}")
    print(f"Number of masks: {len(masks)}")
    print(f"Vesselness range: {vmap.min():.4f} to {vmap.max():.4f}")
    print(f"Vesselness threshold ({CFG.VESSEL_PCT}th percentile): {thr_vessel_img:.4f}")
    print(f"Specular pixels: {spec.sum()} ({100*spec.mean():.2f}%)")
    print(f"High-edge pixels: {(edge_density > 0.05).sum()} ({100*(edge_density > 0.05).mean():.2f}%)")
    print(f"Tumor-like pixels: {tumor_like.sum()} ({100*tumor_like.mean():.2f}%)")
    print(f"Lab a* threshold: {a_thresh:.2f}")

    # Filter masks by size
    valid_masks = [m for m in masks if m.sum() >= CFG.MIN_REGION_AREA]
    print(f"Valid masks (>= {CFG.MIN_REGION_AREA} pixels): {len(valid_masks)}")
    
    if len(valid_masks) == 0:
        print("No valid masks found!")
        return labels

    # Priority scoring and sorting
    mask_scores = []
    for i, m in enumerate(valid_masks):
        m = m.astype(np.uint8)
        m_bool = m.astype(bool)
        area = m.sum()
        
        spec_r = float(spec[m_bool].mean()) if m_bool.sum() > 0 else 0
        edge_r = float(edge_density[m_bool].mean()) if m_bool.sum() > 0 else 0
        vess_r = float((vmap[m_bool] > thr_vessel_img).mean()) if m_bool.sum() > 0 else 0
        tum_r = float(tumor_like[m_bool].mean()) if m_bool.sum() > 0 else 0
        
        priority = 3.0*spec_r + 2.0*edge_r + 1.5*vess_r + 1.0*tum_r
        mask_scores.append((priority, m, i, area, spec_r, edge_r, vess_r, tum_r))

    # Sort by priority then area
    mask_scores.sort(key=lambda t: (t[0], t[3]), reverse=True)
    covered = np.zeros((H, W), dtype=np.uint8)

    print(f"\nProcessing top {min(10, len(mask_scores))} masks:")
    assigned_classes = {}
    
    for priority, m, orig_idx, area, spec_r, edge_r, vess_r, tum_r in mask_scores[:10]:
        # Remove overlap
        m_clean = (m & (~covered.astype(bool))).astype(np.uint8)
        remaining_area = m_clean.sum()
        
        if remaining_area == 0:
            print(f"Mask {orig_idx}: SKIPPED (fully covered)")
            continue

        # Recalculate ratios for clean mask
        m_bool = m_clean.astype(bool)
        spec_ratio = float(spec[m_bool].mean()) if m_bool.sum() > 0 else 0
        edge_ratio = float(edge_density[m_bool].mean()) if m_bool.sum() > 0 else 0
        vessel_ratio = float((vmap[m_bool] > thr_vessel_img).mean()) if m_bool.sum() > 0 else 0
        tumor_ratio = float(tumor_like[m_bool].mean()) if m_bool.sum() > 0 else 0

        # Classify with very permissive thresholds for debugging
        if class_set == "basic":
            cls_id = 1 if vessel_ratio > 0.01 else 2  # Very low threshold
        else:
            # EXTREMELY permissive thresholds for debugging
            if spec_ratio > 0.01:  # 1% specular pixels
                cls_id = 5
            elif edge_ratio > 0.02:  # 2% high-edge pixels  
                cls_id = 4
            elif vessel_ratio > 0.02:  # 2% vessel pixels
                cls_id = 1
            elif tumor_ratio > 0.03:  # 3% tumor-like pixels
                cls_id = 3
            else:
                cls_id = 2

        print(f"Mask {orig_idx}: area={remaining_area}, spec={spec_ratio:.3f}, edge={edge_ratio:.3f}, "
              f"vessel={vessel_ratio:.3f}, tumor={tumor_ratio:.3f} → class {cls_id}")
        
        labels[m_bool] = cls_id
        covered[m_bool] = 1
        
        assigned_classes[cls_id] = assigned_classes.get(cls_id, 0) + 1

        if covered.mean() >= CFG.COVERAGE_LIMIT:
            break

    print(f"\nFinal class distribution: {dict(sorted(assigned_classes.items()))}")
    print(f"Unique classes in output: {np.unique(labels).tolist()}")
    return labels

def generate_labels_for_dir(in_dir: Path, out_lbl: Path, out_vis: Path, class_set="basic", max_vis=8):
    ensure_dir(out_lbl); ensure_dir(out_vis)
    paths = list_images(in_dir)
    total = len(paths)

    skipped = 0
    fallbacks = 0

    pbar = tqdm(
        total=total,
        desc=in_dir.name,   # keep static
        unit="img",
        ncols=80,
        dynamic_ncols=False,
        ascii=True,
        leave=False,
        file=sys.stderr     # same stream as default tqdm
    )

    for i, p in enumerate(paths, 1):
        bgr = cv2.imread(str(p), cv2.IMREAD_COLOR)
        if bgr is None:
            skipped += 1
            pbar.update(1)
            continue

        rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
        try:
            anns = mask_gen.generate(rgb)
        except IndexError as e:
            if "too many indices" in str(e) or "box_area" in str(e):
                from segment_anything import SamAutomaticMaskGenerator
                mask_gen_fallback = SamAutomaticMaskGenerator(
                    model=sam,
                    points_per_side=CFG.POINTS_PER_SIDE,
                    pred_iou_thresh=CFG.PRED_IOU_THRESH,
                    stability_score_thresh=CFG.STABILITY_THRESH,
                    crop_n_layers=0,
                    crop_n_points_downscale_factor=CFG.CROP_DOWNSCALE,
                    min_mask_region_area=CFG.MIN_MASK_REGION_AREA,
                )
                anns = mask_gen_fallback.generate(rgb)
                fallbacks += 1
            else:
                pbar.close()
                raise

        masks = [a["segmentation"].astype(np.uint8) for a in anns if a.get("area", 0) >= CFG.MIN_MASK_REGION_AREA]
        if not masks:
            hsv = cv2.cvtColor(bgr, cv2.COLOR_BGR2HSV)
            Hch, Sch, Vch = cv2.split(hsv)
            tissue = ((Vch > 40) & (Sch > 20)).astype(np.uint8)
            tissue = cv2.morphologyEx(tissue, cv2.MORPH_OPEN, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3)))
            labels = np.zeros(bgr.shape[:2], dtype=np.uint8)
            labels[tissue.astype(bool)] = 2
        else:
            labels = classify_masks(bgr, masks, class_set=class_set)

        cv2.imwrite(str(out_lbl / (p.stem + ".png")), labels)
        if i <= max_vis:
            colors = CFG.CLASS_COLORS[class_set]
            overlay = cv2.addWeighted(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), 1.0, colorize_mask(labels, colors), 0.45, 0.0)
            cv2.imwrite(str(out_vis / (p.stem + "_overlay.jpg")), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))

        pbar.update(1)  # no set_postfix/desc updates here

    pbar.close()
    print(f"[DONE] {in_dir.name}: {total} imgs | skipped={skipped}, fallbacks={fallbacks} → labels: {out_lbl} | vis: {out_vis}")

In [7]:
in_root = Path(CFG.ORIG_ROOT)
out_root = Path(CFG.OUTPUT_LABEL_ROOT)
class_set = CFG.CLASS_SET

for cls in CFG.CLASSES:  # no outer tqdm
    src_dir = in_root / cls
    if not src_dir.is_dir():
        # Avoid printing inside the progress run; skip quietly or log after
        continue
    out_lbl = out_root / cls / "labels"
    out_vis = out_root / cls / "vis"
    generate_labels_for_dir(src_dir, out_lbl, out_vis, class_set=class_set, max_vis=8)
print("[ALL DONE] Labels saved to:", CFG.OUTPUT_LABEL_ROOT)

HGC:   0%|                                             | 0/469 [00:00<?, ?img/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 768.00 MiB. GPU 0 has a total capacity of 31.73 GiB of which 647.00 MiB is free. Process 456913 has 3.42 GiB memory in use. Process 1906032 has 17.66 GiB memory in use. Process 2699547 has 8.38 GiB memory in use. Process 3830224 has 1.62 GiB memory in use. Of the allocated memory 1.21 GiB is allocated by PyTorch, and 48.79 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import random, matplotlib.pyplot as plt, numpy as np, cv2
from pathlib import Path


def _colorize_mask(mask, colors):
    out = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for k, c in colors.items():
        out[mask == k] = c
    return out

def show_examples(cls="HGC", n=6):
    vis_dir = Path(CFG.OUTPUT_LABEL_ROOT) / cls / "vis"
    lbl_dir = Path(CFG.OUTPUT_LABEL_ROOT) / cls / "labels"
    img_dir = Path(CFG.ORIG_ROOT) / cls

    if not vis_dir.is_dir():
        print("No vis folder yet:", vis_dir); return
    files = sorted([p for p in vis_dir.iterdir() if p.suffix.lower() in [".jpg",".png"]])
    if not files:
        print("No images in", vis_dir); return

    picks = random.sample(files, k=min(n, len(files)))
    rows = len(picks)   # one row per sample
    cols = 3            # Original | Label | Vis

    fig, axs = plt.subplots(rows, cols, figsize=(14, 3.8*rows))
    if rows == 1:
        axs = np.array([axs])  # make it 2D for uniform indexing

    # Choose color map from config
    colors = CFG.CLASS_COLORS[CFG.CLASS_SET] if hasattr(CFG, "CLASS_SET") else CFG.CLASS_COLORS
    IMG_EXTS = [".png",".jpg",".jpeg",".bmp",".tif",".tiff",".PNG",".JPG"]

    for row_idx, vis_path in enumerate(picks):
        base_stem = vis_path.stem[:-8] if vis_path.stem.endswith("_overlay") else vis_path.stem

        # Find original by stem
        orig_path = None
        for ext in IMG_EXTS:
            cand = img_dir / f"{base_stem}{ext}"
            if cand.exists():
                orig_path = cand; break

        # Load original
        orig_rgb = None
        if orig_path is not None:
            bgr = cv2.imread(str(orig_path), cv2.IMREAD_COLOR)
            if bgr is not None:
                orig_rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        # Load label (indexed PNG)
        lbl_path = lbl_dir / f"{base_stem}.png"
        lbl = cv2.imread(str(lbl_path), cv2.IMREAD_UNCHANGED) if lbl_path.exists() else None
        lbl_ids = np.unique(lbl) if isinstance(lbl, np.ndarray) else np.array([])

        # Load existing vis overlay
        vis_bgr = cv2.imread(str(vis_path))
        vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB) if vis_bgr is not None else None

        # Original
        ax = axs[row_idx, 0]
        if orig_rgb is not None:
            ax.imshow(orig_rgb); ax.set_title(f"Original\n{orig_path.name}", fontsize=9)
        else:
            ax.text(0.5, 0.5, "Original not found", ha="center", va="center")
        ax.axis("off")

        # Colorized Label
        ax = axs[row_idx, 1]
        if isinstance(lbl, np.ndarray):
            col = _colorize_mask(lbl, colors)
            ax.imshow(col); ax.set_title(f"Label IDs: {lbl_ids.tolist()}", fontsize=9)
        else:
            ax.text(0.5, 0.5, "Label not found", ha="center", va="center")
        ax.axis("off")

        # Vis overlay
        ax = axs[row_idx, 2]
        if vis_rgb is not None:
            ax.imshow(vis_rgb); ax.set_title(vis_path.name, fontsize=9)
        else:
            ax.text(0.5, 0.5, "Vis not found", ha="center", va="center")
        ax.axis("off")

    plt.tight_layout()
    plt.show()

show_examples("HGC", n=6)
show_examples("LGC", n=6)
show_examples("NST", n=6)
show_examples("NTL", n=6)

In [None]:
import random, matplotlib.pyplot as plt, numpy as np, cv2
from pathlib import Path
from collections import defaultdict, Counter

IMG_EXTS = [".png",".jpg",".jpeg",".bmp",".tif",".tiff",".PNG",".JPG"]

def _find_original(img_dir: Path, stem: str):
    for ext in IMG_EXTS:
        cand = img_dir / f"{stem}{ext}"
        if cand.exists():
            return cand
    return None

def _colorize_mask(mask, colors):
    out = np.zeros((*mask.shape, 3), dtype=np.uint8)
    for k, c in colors.items():
        out[mask == k] = c
    return out

def dataset_class_summary(cls="HGC"):
    lbl_dir = Path(CFG.OUTPUT_LABEL_ROOT) / cls / "labels"
    if not lbl_dir.is_dir():
        print("No labels folder:", lbl_dir); return
    counts = Counter()
    file_counts = Counter()
    for lp in sorted(lbl_dir.glob("*.png")):
        lbl = cv2.imread(str(lp), cv2.IMREAD_UNCHANGED)
        if lbl is None: 
            continue
        ids = set(np.unique(lbl).tolist())
        for cid in ids:
            counts[cid] += int((lbl == cid).sum())
        for cid in ids:
            file_counts[cid] += 1
    print(f"[SUMMARY] Files with class (by presence): {dict(sorted(file_counts.items()))}")
    print(f"[SUMMARY] Pixel counts by class: {dict(sorted(counts.items()))}")

def show_examples_diverse(cls="HGC", n=6, prefer_ids=(1,3,4,5,2), seed=None):
    if seed is not None:
        random.seed(seed)
        np.random.seed(seed)

    lbl_dir = Path(CFG.OUTPUT_LABEL_ROOT) / cls / "labels"
    vis_dir = Path(CFG.OUTPUT_LABEL_ROOT) / cls / "vis"
    img_dir = Path(CFG.ORIG_ROOT) / cls

    if not lbl_dir.is_dir():
        print("No labels folder:", lbl_dir); return

    # 1) Index labels -> which classes each image contains (exclude 0 from selection)
    idx = []  # list of (stem, present_ids_set_without_bg)
    present_ids_overall = set()
    for lp in sorted(lbl_dir.glob("*.png")):
        stem = lp.stem
        lbl = cv2.imread(str(lp), cv2.IMREAD_UNCHANGED)
        if lbl is None:
            continue
        ids = set(np.unique(lbl).tolist())
        ids_nobg = {cid for cid in ids if cid != 0}
        if ids_nobg:
            present_ids_overall |= ids_nobg
        idx.append((stem, ids_nobg))

    if not idx:
        print("No readable label PNGs found in:", lbl_dir)
        return

    # 2) Build candidates per class
    by_class = defaultdict(list)
    for stem, s in idx:
        for cid in s:
            by_class[cid].append(stem)

    # 3) Greedy selection to maximize diversity: pick one per preferred class if available
    chosen = []
    chosen_set = set()

    # Ensure we only request classes that actually exist
    effective_prefer = [cid for cid in prefer_ids if cid in present_ids_overall]

    for cid in effective_prefer:
        candidates = [st for st in by_class[cid] if st not in chosen_set]
        if candidates:
            pick = random.choice(candidates)
            chosen.append(pick)
            chosen_set.add(pick)
            if len(chosen) >= n:
                break

    # 4) If still need more, fill with images that add the most new preferred classes
    def score(stem):
        s = next((s for st, s in idx if st == stem), set())
        return len({cid for cid in s if cid in effective_prefer})
    if len(chosen) < n:
        remaining = [st for st, s in idx if st not in chosen_set]
        remaining.sort(key=score, reverse=True)
        for st in remaining:
            if len(chosen) >= n:
                break
            chosen.append(st); chosen_set.add(st)

    # 5) Plot: Original | Colorized Label | Vis overlay (compute overlay if missing)
    rows, cols = len(chosen), 3
    colors = CFG.CLASS_COLORS[CFG.CLASS_SET] if hasattr(CFG, "CLASS_SET") else CFG.CLASS_COLORS

    fig, axs = plt.subplots(rows, cols, figsize=(14, 3.8*rows))
    if rows == 1:
        axs = np.array([axs])

    for row_idx, stem in enumerate(chosen):
        lbl_path = lbl_dir / f"{stem}.png"
        vis_path = vis_dir / f"{stem}_overlay.jpg"
        orig_path = _find_original(img_dir, stem)

        # Load label
        lbl = cv2.imread(str(lbl_path), cv2.IMREAD_UNCHANGED)
        lbl_ids = np.unique(lbl) if isinstance(lbl, np.ndarray) else np.array([])

        # Load/create display images
        # Original
        orig_rgb = None
        if orig_path is not None:
            bgr = cv2.imread(str(orig_path), cv2.IMREAD_COLOR)
            if bgr is not None:
                orig_rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

        # Colorized label
        col = _colorize_mask(lbl, colors) if isinstance(lbl, np.ndarray) else None

        # Vis overlay: use saved one if exists, else build on-the-fly
        vis_rgb = None
        if vis_path.exists():
            vis_bgr = cv2.imread(str(vis_path))
            if vis_bgr is not None:
                vis_rgb = cv2.cvtColor(vis_bgr, cv2.COLOR_BGR2RGB)
        elif orig_rgb is not None and col is not None:
            vis_rgb = cv2.addWeighted(orig_rgb, 1.0, col, 0.45, 0.0)

        # Plot
        ax = axs[row_idx, 0]
        if orig_rgb is not None:
            ax.imshow(orig_rgb); ax.set_title(f"Original\n{orig_path.name}", fontsize=9)
        else:
            ax.text(0.5, 0.5, "Original not found", ha="center", va="center")
        ax.axis("off")

        ax = axs[row_idx, 1]
        if col is not None:
            ax.imshow(col); ax.set_title(f"Label IDs: {lbl_ids.tolist()}", fontsize=9)
        else:
            ax.text(0.5, 0.5, "Label not found", ha="center", va="center")
        ax.axis("off")

        ax = axs[row_idx, 2]
        if vis_rgb is not None:
            ax.imshow(vis_rgb); ax.set_title(f"{stem}_overlay.jpg", fontsize=9)
        else:
            ax.text(0.5, 0.5, "Vis overlay not found", ha="center", va="center")
        ax.axis("off")

    plt.tight_layout()
    plt.show()

# Usage:
# Show up to 6 examples prioritizing vessel(1), tumor(3), instrument(4), specular(5), then mucosa(2).
# If some classes aren't present in HGC, it will skip them and still fill the grid.
dataset_class_summary("HGC")
show_examples_diverse("HGC", n=6, prefer_ids=(1,3,4,5,2), seed=0)

In [None]:
from pathlib import Path
import numpy as np, cv2, os
from PIL import Image
from tqdm.auto import tqdm

# Originals and MedSAM outputs
ORIG_ROOT = Path("/home/philipdt/IKT-project/data/EndoscopicBladderTissue")  # adjust if needed
SEG_ROOT  = Path("/home/philipdt/IKT-project/auto_labels_medsam")

# Classes and label IDs (update if your mapping differs)
CLASSES = ["HGC", "LGC", "NST", "NTL"]
# 0:bg, 1:vessel, 2:mucosa, 3:tumor, 4:instrument, 5:specular
KEEP_SETS = {
    "masked_all":        {1,2,3,4,5},
    "masked_tumor_only": {3},
}
GEN_OVERLAY_ALL = True  # set False if you don’t want full overlays for every image

# Overlay palette (BGR)
PALETTE = {
    1: (255, 64, 64),    # vessel - blue-ish (OpenCV uses BGR)
    2: (64, 255, 64),    # mucosa - green
    3: (64, 64, 255),    # tumor - red-ish
    4: (255, 255, 64),   # instrument - cyan-ish
    5: (255, 64, 255),   # specular - magenta
}
OVERLAY_ALPHA = 0.45  # blending factor (0..1)
IMG_EXTS = {".png",".jpg",".jpeg",".bmp",".tif",".tiff",".PNG",".JPG",".JPEG",".BMP",".TIF",".TIFF"}

In [None]:
def find_orig_by_stem(root: Path, cls: str, stem: str):
    d = root / cls
    for e in IMG_EXTS:
        p = d / f"{stem}{e}"
        if p.exists(): return p
    return None

def load_label(path: Path):
    # Expect indexed PNG (uint8/16)
    lbl = cv2.imread(str(path), cv2.IMREAD_UNCHANGED)
    if lbl is None:
        raise FileNotFoundError(f"Could not read label: {path}")
    return lbl

def ensure_label_size(lbl, w, h):
    if lbl.shape[1] != w or lbl.shape[0] != h:
        lbl = cv2.resize(lbl, (w, h), interpolation=cv2.INTER_NEAREST)
    return lbl

def apply_mask(img_rgb: np.ndarray, lbl: np.ndarray, keep_ids: set):
    mask = np.isin(lbl, list(keep_ids)).astype(np.uint8)
    out = img_rgb.copy()
    out[mask == 0] = 0
    return out

def make_overlay(img_rgb: np.ndarray, lbl: np.ndarray, alpha=0.45):
    overlay = img_rgb.copy()
    base = img_rgb.copy()
    for cls_id, color_bgr in PALETTE.items():
        if cls_id == 0: continue
        m = (lbl == cls_id)
        if not np.any(m): continue
        overlay[m] = color_bgr
    blended = cv2.addWeighted(overlay, alpha, base, 1-alpha, 0)
    return blended[..., ::-1]  # convert BGR->RGB appearance if needed

In [None]:
def export_masked_variants():
    total = 0
    made = {k: 0 for k in KEEP_SETS.keys()}
    ovl_made = 0

    for cls in CLASSES:
        lbl_dir = SEG_ROOT / cls / "labels"
        if not lbl_dir.is_dir():
            print(f"[WARN] Missing labels dir: {lbl_dir}")
            continue

        # Prepare out dirs
        out_dirs = {name: SEG_ROOT/cls/name for name in KEEP_SETS.keys()}
        for d in out_dirs.values(): d.mkdir(parents=True, exist_ok=True)
        ovl_dir = SEG_ROOT/cls/"overlay_all"
        if GEN_OVERLAY_ALL: ovl_dir.mkdir(parents=True, exist_ok=True)

        label_paths = sorted(lbl_dir.glob("*.png"))
        pbar = tqdm(label_paths, desc=f"[{cls}] exporting", leave=False)
        for lp in pbar:
            stem = lp.stem
            # locate original
            op = find_orig_by_stem(ORIG_ROOT, cls, stem)
            if op is None:
                pbar.set_postfix_str("orig-missing")
                continue

            # load original + label
            img = Image.open(op).convert("RGB")
            img_np = np.array(img)  # HxWx3 RGB
            h, w = img_np.shape[:2]

            lbl = load_label(lp)
            lbl = ensure_label_size(lbl, w, h)

            # masked variants
            for name, keep_ids in KEEP_SETS.items():
                out = apply_mask(img_np, lbl, keep_ids)
                Image.fromarray(out).save(out_dirs[name]/f"{stem}.png")
                made[name] += 1

            # overlay (optional)
            if GEN_OVERLAY_ALL:
                ov = make_overlay(img_np[..., ::-1], lbl, alpha=OVERLAY_ALPHA)  # ensure proper color mapping
                Image.fromarray(ov).save(ovl_dir/f"{stem}.png")
                ovl_made += 1

            total += 1

    print(f"Processed labels: {total}")
    for k,v in made.items(): print(f"  {k}: {v} images")
    if GEN_OVERLAY_ALL: print(f"  overlay_all: {ovl_made} images")

export_masked_variants()