<a href="https://colab.research.google.com/github/a2m-dotcom/DLBCL_Pub/blob/main/Threezone%20Creation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# reprocess_missing_wsis.py
# Usage: run in same env where your previous pipeline ran (so dependencies exist).
# Adjust ROOT and out_dir paths below if required.

import json, math, traceback
from pathlib import Path
from collections import defaultdict
import numpy as np
from PIL import Image, ImageDraw
from scipy.ndimage import gaussian_filter, binary_dilation, label, generate_binary_structure
import pandas as pd
import os

# ---------------- CONFIG - adjust if needed ----------------
ROOT = Path("/content/drivee/MyDrive/DLBCLMORPH/Results")   # top-level WSI folder
out_dir = ROOT / "zone_counts_batch_outputs"         # same run folder used earlier
per_wsi_csv_folder = out_dir / "per_wsi_csvs"
per_wsi_csv_folder.mkdir(parents=True, exist_ok=True)

# processing params (match what you used)
downscale = 100
gaussian_sigma = 1.0
dilation_iters = 10
min_component_size = 100
zone2_dilation_size = 50
restrict_zone3_to_tissue = False
save_qc_images = True

# helper: robustly find geojson/json file for a wsi (nested or flat or "Copy of")
def find_geojson_for_wsi(wsi_folder: Path):
    # prefer nested: wsi/wsi/cells.geojson
    nested = wsi_folder / wsi_folder.name / "cells.geojson"
    if nested.exists(): return nested
    # try nested with other names
    nested_other = list((wsi_folder / wsi_folder.name).glob("*.geojson")) + list((wsi_folder / wsi_folder.name).glob("*.json"))
    if nested_other:
        return nested_other[0]
    # flat path
    flat = wsi_folder / "cells.geojson"
    if flat.exists(): return flat
    # fallback patterns: any file containing wsi id or "cells" in filename
    for p in sorted(wsi_folder.glob("*cells*.geojson")):
        return p
    for p in sorted(wsi_folder.glob(f"*{wsi_folder.name}*.geojson")):
        return p
    for p in sorted(wsi_folder.glob("*.geojson")):
        return p
    for p in sorted(wsi_folder.glob("*.json")):
        return p
    # also try files whose names start with "Copy of {wsi}"
    for p in sorted(wsi_folder.glob("Copy*")):
        if p.suffix.lower() in (".geojson", ".json"):
            return p
    return None

# small utility functions copied/adapted from your pipeline
def to_mask_coord_floor(x_full, y_full, downscale):
    return int(math.floor(float(x_full) / downscale)), int(math.floor(float(y_full) / downscale))

def centroid_of_ring(ring):
    xs = [p[0] for p in ring if p is not None]
    ys = [p[1] for p in ring if p is not None]
    if len(xs) == 0:
        return None, None
    return sum(xs)/len(xs), sum(ys)/len(ys)

def rasterize_polygons_to_canvas(polygons_by_class, canvas_w, canvas_h, downscale):
    canvas = Image.new("RGB", (canvas_w, canvas_h), (0,0,0))
    draw = ImageDraw.Draw(canvas)
    tissue_mask = np.zeros((canvas_h, canvas_w), dtype=np.uint8)
    for cls, polylist in polygons_by_class.items():
        color = (200,200,200)
        for poly in polylist:
            if not poly: continue
            outer = poly[0] if isinstance(poly, list) and len(poly)>0 else poly
            pts = [ to_mask_coord_floor(x,y, downscale) for x,y in outer ]
            pts_clamped = [ (max(0,min(canvas_w-1,xx)), max(0,min(canvas_h-1,yy))) for xx,yy in pts ]
            if len(pts_clamped) >= 3:
                try:
                    draw.polygon(pts_clamped, fill=color)
                    mask_temp = Image.new("L", (canvas_w, canvas_h), 0)
                    ImageDraw.Draw(mask_temp).polygon(pts_clamped, outline=1, fill=1)
                    tissue_mask = np.maximum(tissue_mask, np.array(mask_temp, dtype=np.uint8))
                except Exception:
                    continue
    return canvas, tissue_mask

def make_downscaled_tumor_mask(tumor_polygons, downscale, canvas_w, canvas_h):
    mask_img = Image.new("L", (canvas_w, canvas_h), 0)
    draw = ImageDraw.Draw(mask_img)
    for poly in tumor_polygons:
        if not poly: continue
        outer = poly[0] if isinstance(poly, list) and len(poly)>0 else poly
        pts = [ to_mask_coord_floor(x,y, downscale) for x,y in outer ]
        pts_clamped = [ (max(0,min(canvas_w-1,xx)), max(0,min(canvas_h-1,yy))) for xx,yy in pts ]
        if len(pts_clamped) >= 3:
            try:
                draw.polygon(pts_clamped, outline=1, fill=1)
            except Exception:
                continue
    return np.array(mask_img, dtype=np.uint8)

def postprocess_mask(binary_mask, gaussian_sigma=1.0, dilation_iters=5, min_component_size=100):
    smooth = gaussian_filter(binary_mask.astype(float), sigma=gaussian_sigma)
    smooth = (smooth > 0.5).astype(np.uint8)
    struct = generate_binary_structure(2, 2)
    dil = binary_dilation(smooth, structure=struct, iterations=dilation_iters).astype(np.uint8)
    inv = binary_dilation((dil==0).astype(np.uint8), structure=struct, iterations=dilation_iters).astype(np.uint8)
    smooth = (1 - inv).astype(np.uint8)
    labeled_array, num_features = label(smooth, structure=struct)
    if num_features > 0:
        for lab in range(1, num_features+1):
            if np.sum(labeled_array == lab) < min_component_size:
                smooth[labeled_array == lab] = 0
        labeled_array, _ = label(smooth, structure=struct)
    else:
        labeled_array = np.zeros_like(smooth)
    return smooth, labeled_array

def get_zone_two(arr, dilation_size=50):
    structuring_element = np.ones((dilation_size, dilation_size), dtype=bool)
    dilated_array = binary_dilation(arr.astype(bool), structure=structuring_element).astype(arr.dtype)
    zone_two = dilated_array - arr
    return (zone_two > 0).astype(np.uint8)

# ---------- helper: safe WSI list discovery ----------
def list_input_wsis(root: Path):
    # consider directories whose names look like "<digits>_<digit>" but be permissive
    return sorted([p.name for p in root.iterdir() if p.is_dir()])

def list_output_wsis(out_folder: Path):
    if not out_folder.exists(): return []
    return sorted([p.name for p in out_folder.iterdir() if p.is_dir()])

# --------- single-WSI small runner (a trimmed version of your main loop) ----------
def run_one_wsi_by_id(wsi_id: str):
    wsi_folder = ROOT / wsi_id
    geojson_path = find_geojson_for_wsi(wsi_folder)
    if geojson_path is None:
        return dict(status="missing_geojson", msg=f"no geojson found in {wsi_folder}")

    try:
        with open(geojson_path, "r") as f:
            gj = json.load(f)
    except Exception as e:
        return dict(status="read_error", msg=str(e))

    features = gj if isinstance(gj, list) else gj.get("features", [])
    if not features:
        return dict(status="no_features", msg="no features in file")

    classes = defaultdict(list)
    global_max_x = 0.0
    global_max_y = 0.0

    for feat in features:
        if not isinstance(feat, dict): continue
        props = feat.get("properties", {}) or {}
        cls_field = props.get("classification", props.get("class"))
        if isinstance(cls_field, dict):
            cls_name = cls_field.get("name", "Unknown")
        else:
            cls_name = str(cls_field) if cls_field is not None else "Unknown"
        geom = feat.get("geometry", {}) or {}
        gtype = geom.get("type")
        coords = geom.get("coordinates") or []
        if gtype == "Polygon":
            polys = [coords]
        elif gtype == "MultiPolygon":
            polys = coords
        else:
            polys = []
        classes[cls_name].extend(polys)
        for poly in polys:
            for ring in poly:
                for x,y in ring:
                    if x is None or y is None: continue
                    global_max_x = max(global_max_x, float(x))
                    global_max_y = max(global_max_y, float(y))

    if len(classes) == 0:
        return dict(status="no_polygon_classes", msg="no polygon classes parsed")

    canvas_w = int(global_max_x // downscale) + 1
    canvas_h = int(global_max_y // downscale) + 1

    # rasterize
    thumbnail_img, tissue_mask = rasterize_polygons_to_canvas(classes, canvas_w, canvas_h, downscale)

    # find tumor class (case-insensitive startswith 'tumor')
    tumor_key = next((k for k in classes.keys() if k and k.lower().startswith("tumor")), None)
    if tumor_key is None:
        # fallback: pick largest class by pixel area (approx)
        # approximate area by rasterizing each class quickly
        class_areas = {}
        for k, polys in classes.items():
            m = Image.new("L", (canvas_w, canvas_h), 0)
            d = ImageDraw.Draw(m)
            for poly in polys:
                if not poly: continue
                outer = poly[0] if isinstance(poly, list) and len(poly)>0 else poly
                pts = [to_mask_coord_floor(x,y,downscale) for x,y in outer]
                pts_clamped = [ (max(0,min(canvas_w-1,xx)), max(0,min(canvas_h-1,yy))) for xx,yy in pts ]
                if len(pts_clamped)>=3:
                    try:
                        d.polygon(pts_clamped, outline=1, fill=1)
                    except Exception:
                        pass
            class_areas[k] = np.array(m).sum()
        # choose largest non-zero class
        sorted_cls = sorted(class_areas.items(), key=lambda kv: kv[1], reverse=True)
        if sorted_cls and sorted_cls[0][1] > 0:
            tumor_key = sorted_cls[0][0]
        else:
            return dict(status="no_tumor_polygons", msg=f"no tumor-like class found; class_areas={class_areas}")

    tumor_polys = classes[tumor_key]

    # create masks
    tumor_mask_raw = make_downscaled_tumor_mask(tumor_polys, downscale, canvas_w, canvas_h)
    tumor_mask_proc, labeled_mask = postprocess_mask(tumor_mask_raw,
                                                     gaussian_sigma=gaussian_sigma,
                                                     dilation_iters=dilation_iters,
                                                     min_component_size=min_component_size)

    zone2 = get_zone_two(tumor_mask_proc, dilation_size=zone2_dilation_size)
    zone3_raw = (1 - np.clip(tumor_mask_proc + zone2, 0, 1)).astype(np.uint8)
    if restrict_zone3_to_tissue:
        zone3 = (zone3_raw & (tissue_mask > 0).astype(np.uint8)).astype(np.uint8)
    else:
        zone3 = zone3_raw

    # save outputs into per_wsi_csv_folder/wsi_id
    wsi_out_dir = per_wsi_csv_folder / wsi_id
    wsi_out_dir.mkdir(parents=True, exist_ok=True)
    np.save(wsi_out_dir / "tumor_mask_raw.npy", tumor_mask_raw)
    np.save(wsi_out_dir / "tumor_mask_processed.npy", tumor_mask_proc)
    np.save(wsi_out_dir / "zone2.npy", zone2)
    np.save(wsi_out_dir / "zone3.npy", zone3)
    try:
        thumbnail_img.save(wsi_out_dir / "thumbnail_from_geojson.png")
    except Exception:
        pass

    # centroid assignment & counts (same as before)
    zone_counts = defaultdict(int)
    zone_class_counts = defaultdict(int)
    total_polygons = 0
    oob = 0
    for feat in features:
        if not isinstance(feat, dict): continue
        props = feat.get("properties", {}) or {}
        cls_field = props.get("classification", props.get("class"))
        if isinstance(cls_field, dict):
            class_name = cls_field.get("name", "Unknown")
        else:
            class_name = str(cls_field) if cls_field is not None else "Unknown"
        geom = feat.get("geometry", {}) or {}
        gtype = geom.get("type")
        coords = geom.get("coordinates") or []
        if gtype == "Polygon":
            poly_list = [coords]
        elif gtype == "MultiPolygon":
            poly_list = coords
        else:
            poly_list = []
        for poly in poly_list:
            total_polygons += 1
            if not poly: continue
            outer = poly[0] if isinstance(poly, list) and len(poly)>0 else poly
            cx_full, cy_full = centroid_of_ring(outer)
            if cx_full is None:
                continue
            mx, my = to_mask_coord_floor(cx_full, cy_full, downscale)
            if not (0 <= mx < canvas_w and 0 <= my < canvas_h):
                oob += 1
                continue
            if tumor_mask_proc[my, mx] == 1:
                zone = "zone1"
            elif zone2[my, mx] == 1:
                zone = "zone2"
            elif zone3[my, mx] == 1:
                zone = "zone3"
            else:
                zone = "unassigned"
            zone_counts[zone] += 1
            zone_class_counts[(class_name, zone)] += 1

    # build per-wsi rows and save small CSV
    classes_seen = sorted({c for (c,z) in zone_class_counts.keys()})
    per_wsi_rows = []
    for cls in classes_seen:
        r = {
            "wsi_id": wsi_id,
            "class": cls,
            "zone1_count": int(zone_class_counts.get((cls, "zone1"), 0)),
            "zone2_count": int(zone_class_counts.get((cls, "zone2"), 0)),
            "zone3_count": int(zone_class_counts.get((cls, "zone3"), 0)),
            "total_polygons_assigned": int(sum(zone_class_counts.get((cls, z), 0) for z in ["zone1","zone2","zone3"]))
        }
        per_wsi_rows.append(r)
    if per_wsi_rows:
        pd.DataFrame(per_wsi_rows).to_csv(wsi_out_dir / f"{wsi_id}_zone_counts_dil5.csv", index=False)

    # optional QC overlay (skip failures)
    if save_qc_images:
        try:
            W,H = canvas_w, canvas_h
            base = thumbnail_img.convert("RGBA").resize((canvas_w, canvas_h), Image.LANCZOS)
            overlay = Image.new("RGBA", (canvas_w, canvas_h), (0,0,0,0))
            def paste_mask(m, color_rgb, alpha=0.45):
                if m.sum() == 0: return
                mask_img = Image.fromarray((m*255).astype("uint8")).convert("L")
                color_img = Image.new("RGBA", (canvas_w, canvas_h), color_rgb + (int(alpha*255),))
                overlay.paste(color_img, (0,0), mask_img)
            paste_mask(zone3, (0,180,0), 0.35)
            paste_mask(zone2, (255,210,0), 0.45)
            paste_mask(tumor_mask_proc, (220,20,60), 0.5)
            composite = Image.alpha_composite(base, overlay)
            composite.save(wsi_out_dir / f"{wsi_id}_zones_overlay_dil5.png")
        except Exception:
            pass

    return dict(status="ok", msg=f"processed {wsi_id}")

# ------------- main: compare and re-run for missing WSIs -------------
if __name__ == "__main__":
    input_wsis = list_input_wsis(ROOT)
    output_wsis = list_output_wsis(per_wsi_csv_folder)
    missing = sorted([w for w in input_wsis if w not in output_wsis])
    print(f"Input WSI count: {len(input_wsis)}")
    print(f"Output WSI count (folders found under {per_wsi_csv_folder}): {len(output_wsis)}")
    print(f"Missing (to process): {len(missing)}")

    # save missing list
    pd.DataFrame({"missing_wsi": missing}).to_csv(out_dir / "missing_wsis.csv", index=False)

    # process each missing WSI serially (you can parallelize later)
    errors = []
    for i, wsi in enumerate(missing, start=1):
        print(f"\n[{i}/{len(missing)}] processing missing WSI: {wsi}")
        try:
            res = run_one_wsi_by_id(wsi)
            if res.get("status") != "ok":
                print(" SKIPPED:", res)
                errors.append({"wsi": wsi, "status": res.get("status"), "msg": res.get("msg")})
            else:
                print(" DONE:", res.get("msg"))
        except Exception as e:
            tb = traceback.format_exc()
            print(" ERROR running:", e)
            errors.append({"wsi": wsi, "status": "exception", "msg": str(e), "trace": tb})

    if errors:
        pd.DataFrame(errors).to_csv(out_dir / "faulty_wsis_missing_run.csv", index=False)
        print("Some WSIs failed; saved report to faulty_wsis_missing_run.csv")
    else:
        print("All missing WSIs processed successfully.")

    print("Finished.")