In [1]:
# Jupyter Environment Setup
%matplotlib inline

import json
import logging
import numpy as np
import matplotlib.pyplot as plt
from shapely.geometry import Polygon, box, mapping, shape, MultiPolygon
from shapely.strtree import STRtree
from shapely.ops import unary_union
from shapely.errors import ShapelyDeprecationWarning
from shapely.geometry import box
import warnings
from tqdm.notebook import tqdm  # Better for Jupyter
import os

In [2]:
# iou = 20% 0.2; confidence = 0.9

# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')

# --- Load GeoJSON ---
def load_geojson(file_path):
    try:
        with open(file_path, 'r') as f:
            data = json.load(f)
        logging.info(f"Successfully loaded GeoJSON: {file_path}")
        return data
    except Exception as e:
        logging.error(f"Error loading GeoJSON: {e}")
        return None

# --- Geometry Helpers ---
def calculate_overlap(p1, p2):
    if not p1.intersects(p2):
        return 0
    inter = p1.intersection(p2).area
    union = p1.area + p2.area - inter
    return inter / union if union > 0 else 0

# --- Extract polygons from features (handle Polygon and MultiPolygon) ---
def extract_polygons(features, conf_thresh=0.9): # CHANGED TO 0.9
    valid_polys = []
    skipped = []

    for f in features:
        props = f.get("properties", {})
        cls_name = props.get("classification", {}).get("name", "").lower()
        confidence = props.get("confidence", 1)
        if cls_name != "eos" or confidence < conf_thresh:
            skipped.append(f)
            continue

        try:
            geom = shape(f["geometry"])
            if isinstance(geom, Polygon):
                if geom.is_valid:
                    valid_polys.append(geom)
                else:
                    skipped.append(f)
            elif isinstance(geom, MultiPolygon):
                valid_polys.extend([g for g in geom.geoms if g.is_valid])
            else:
                skipped.append(f)
        except Exception as e:
            logging.warning(f"Invalid geometry skipped: {e}")
            skipped.append(f)

    logging.info(f"Extracted {len(valid_polys)} valid eos polygons, skipped {len(skipped)} features.")
    return valid_polys, skipped

# --- Cluster eos polygons by DFS using IoU threshold ---
def cluster_eos(polygons, iou_thresh=0.2): # # CHANGED TO 0.2
    if not polygons:
        return []

    n = len(polygons)
    adjacency = [[] for _ in range(n)]

    for i in range(n):
        for j in range(i + 1, n):
            if calculate_overlap(polygons[i], polygons[j]) > iou_thresh:
                adjacency[i].append(j)
                adjacency[j].append(i)

    visited = set()
    clusters = []

    def dfs(node):
        stack = [node]
        cluster = []
        while stack:
            curr = stack.pop()
            if curr not in visited:
                visited.add(curr)
                cluster.append(curr)
                stack.extend([nbr for nbr in adjacency[curr] if nbr not in visited])
        return cluster

    for i in range(n):
        if i not in visited:
            clusters.append(dfs(i))

    merged_polys = []
    for c in clusters:
        merged = unary_union([polygons[i] for i in c])
        if isinstance(merged, MultiPolygon):
            merged = max(merged.geoms, key=lambda p: p.area)
        merged_polys.append(merged)

    logging.info(f"Formed {len(clusters)} clusters from {n} polygons.")
    return merged_polys

# --- Analyze GeoJSON to find peak HPF and eos clusters ---
def analyze_geojson(data, hpf_size, iou_thresh=0.2, conf_thresh=0.9): # CHANGED TO 0.2 and 0.9 
    features = data['features']
    all_polys, skipped = extract_polygons(features, conf_thresh=conf_thresh)

    if not all_polys:
        raise ValueError("No valid eos polygons found in the GeoJSON.")

    # Calculate bounding box extents
    all_coords = [pt for poly in all_polys for pt in poly.exterior.coords]
    min_x, max_x = int(min(p[0] for p in all_coords)), int(max(p[0] for p in all_coords))
    min_y, max_y = int(min(p[1] for p in all_coords)), int(max(p[1] for p in all_coords))

    max_eos = 0
    hpf_count = 0
    peak_box, peak_merged_polys = None, []

    logging.info("Scanning HPFs for peak eos count...")
    for x in tqdm(range(min_x, max_x, hpf_size), desc="HPF cols"):
        for y in range(min_y, max_y, hpf_size):
            hpf = box(x, y, x + hpf_size, y + hpf_size)
            # Polygons that intersect HPF
            local_polys = [p for p in all_polys if hpf.intersects(p)]
            merged_polys = cluster_eos(local_polys, iou_thresh=iou_thresh)
            eos_count = len(merged_polys)
            if eos_count > max_eos:
                max_eos = eos_count
                peak_box = hpf
                peak_merged_polys = merged_polys
            hpf_count += 1

    logging.info(f"Total HPFs analyzed: {hpf_count}")
    logging.info(f"Peak eos count in one HPF (clustered): {max_eos}")

    return hpf_count, max_eos, peak_box, peak_merged_polys, min_x, min_y, max_x, max_y

# --- Save peak HPF and merged eos polygons as GeoJSON ---
def save_peak_to_geojson(hpf_box, merged_polys, out_path):
    geojson = {
        "type": "FeatureCollection",
        "features": []
    }
    geojson["features"].append({
        "type": "Feature",
        "geometry": mapping(hpf_box),
        "properties": {"name": "Peak HPF"}
    })

    for i, poly in enumerate(merged_polys):
        geojson["features"].append({
            "type": "Feature",
            "geometry": mapping(poly),
            "properties": {"id": i, "name": ""}
        })

    with open(out_path, 'w') as f:
        json.dump(geojson, f, indent=2)
    logging.info(f"Saved merged eos GeoJSON to {out_path}")

# --- Generate heatmap of clustered eos counts per HPF ---
def generate_heatmap_clustered(data, hpf_size, iou_thresh=0.2, conf_thresh=0.9): # CHANGED TO 0.2 and 0.9
    features = data['features']
    all_polys, _ = extract_polygons(features, conf_thresh=conf_thresh)

    if not all_polys:
        logging.error("No valid polygons for heatmap generation.")
        return

    all_coords = [pt for poly in all_polys for pt in poly.exterior.coords]
    min_x, max_x = int(min(p[0] for p in all_coords)), int(max(p[0] for p in all_coords))
    min_y, max_y = int(min(p[1] for p in all_coords)), int(max(p[1] for p in all_coords))

    cols = (max_x - min_x) // hpf_size + 1
    rows = (max_y - min_y) // hpf_size + 1
    heatmap = np.zeros((rows, cols))

    logging.info("Generating heatmap from clustered eos counts...")
    for x_idx, x in enumerate(tqdm(range(min_x, max_x, hpf_size), desc="Heatmap cols")):
        for y_idx, y in enumerate(range(min_y, max_y, hpf_size)):
            hpf = box(x, y, x + hpf_size, y + hpf_size)
            local_polys = [p for p in all_polys if hpf.intersects(p)]
            merged_polys = cluster_eos(local_polys, iou_thresh=iou_thresh)
            heatmap[y_idx, x_idx] = len(merged_polys)

    plt.figure(figsize=(10, 8))
    plt.imshow(heatmap, cmap='hot', interpolation='nearest', origin='upper')
    plt.colorbar(label='Eosinophil Count per HPF')
    plt.title('Eosinophil Density Heatmap (Clustered)')
    plt.axis('off')
    plt.show()



In [None]:
# --- Main Execution ---
if __name__ == "__main__":
    file_path = '1007401_detections_merged.geojson'
    output_geojson = '1007401_peak_hpf_output14_20_90_dfs_poly_merged.geojson'
    hpf_size = 2144  # 548μm @ 0.2555 μm/pixel

    geojson = load_geojson(file_path)
    if geojson:
        hpf_total, peak_eos, peak_box, peak_merged_polys, min_x, min_y, max_x, max_y = analyze_geojson(
            geojson, hpf_size, iou_thresh=0.2, conf_thresh=0.9) # CHANGED TO 0.2 and 0.9

        # Plot peak HPF with merged eos polygons
        import matplotlib.patches as patches
        fig, ax = plt.subplots(figsize=(8, 8))
        x, y = peak_box.exterior.xy
        ax.plot(x, y, color='red', linewidth=2, label='Peak HPF')

        for poly in peak_merged_polys:
            x, y = poly.exterior.xy
            ax.fill(x, y, color='blue', alpha=0.5)

        ax.set_title("Peak HPF with Merged Eosinophils")
        ax.set_aspect('equal')
        ax.legend()
        ax.invert_xaxis()  # Flip 180°
        ax.invert_yaxis()
        plt.grid(True)
        plt.show()

        save_peak_to_geojson(peak_box, peak_merged_polys, output_geojson)
        generate_heatmap_clustered(geojson, hpf_size, iou_thresh=0.2, conf_thresh=0.9) # CHANGED TO 0.2 and 0.9
    else:
        logging.error("GeoJSON load failed. Terminating.")
