In [None]:
import pandas as pd
import numpy as np
import networkx as nx
from collections import defaultdict, Counter
from pathlib import Path

# ===================== CONFIG =====================
structures_path = r"Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data\structures.csv"
input_excel     = r"Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data\per_mouse_sheets.xlsx"
output_excel    = r"Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data\den_collapsed_matrix.xlsx"
output_log      = r"Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data\den_collapsing_log.csv"
audit_dir       = r"Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data"

TARGET_N        = 160
ROOT_ID         = "997"   # Allen mouse root
BAN_IDS         = {"8"}   # ban shallow hub: "Basic cell groups and regions"

# Category roots: nothing ABOVE these IDs will ever be selected
category_map = {
    "Medulla": "/997/8/343/1065/354/",
    "Pons": "/997/8/343/1065/771/",
    "Hypothalamus": "/997/8/343/1129/1097/",
    "Thalamus": "/997/8/343/1129/549/",
    "Midbrain": "/997/8/343/313/",
    "Cerebellum": "/997/8/512/",
    "Cortical plate": "/997/8/567/688/695/",
    "Cortical subplate": "/997/8/567/688/703/",
    "Pallidum": "/997/8/567/623/803/",
    "Striatum": "/997/8/567/623/477/",
}
# Selection behavior inside categories
MIN_OFFSET_BELOW_CATEGORY = 1     # require at least 1 level below the category root
ALLOW_CATEGORY_ROOTS      = False # forbid selecting the category root itself initially
RELAX_DEPTH_IF_N_SHORT    = True  # carefully relax toward the root if needed
# ==================================================
def canon(p) -> str:
    """Ensure leading and trailing slash: '/.../'."""
    p = "" if p is None else str(p)
    return "/" + p.strip("/") + "/"


# ------------------ Structures & Graph ------------------
def load_structures(structures_csv: str):
    S = pd.read_csv(structures_path)
    S["id"] = S["id"].astype(str)
    S["structure_id_path"] = S["structure_id_path"].apply(canon)

    id_to_acronym = S.set_index("id")["acronym"].astype(str).to_dict()
    id_to_path    = S.set_index("id")["structure_id_path"].to_dict()  # canonical now
    id_to_name    = S.set_index("id")["name"].astype(str).to_dict()
    return S, id_to_name, id_to_path, id_to_acronym

def build_graph_from_paths(S: pd.DataFrame) -> nx.DiGraph:
    G = nx.DiGraph()
    for _, row in S.iterrows():
        # row["structure_id_path"] is already canonical '/.../'
        parts = row["structure_id_path"].strip("/").split("/")
        for i in range(len(parts) - 1):
            G.add_edge(parts[i], parts[i + 1])  # IDs are strings
    return G


def node_depth_from_path(sid: str, id_to_path: dict) -> int:
    p = id_to_path.get(sid, "")
    return max(len(str(p).strip("/").split("/")) - 1, 0)

# ------------------ Diagnostics ------------------
def diagnose_categories(id_to_path: dict, category_map: dict):
    print("\n[DIAG] Category coverage:")
    for name, pref in category_map.items():
        pref_s = canon(pref)                     # <-- canonicalize the prefix
        depths, count = [], 0
        for sid, pth in id_to_path.items():      # pth already canonical
            if pth.startswith(pref_s):
                count += 1
                depths.append(len(pth.strip("/").split("/")) - 1)
        if count == 0:
            print(f"  - {name}: 0 nodes under prefix {pref_s}  <<< CHECK THIS PREFIX")
        else:
            print(f"  - {name}: {count} nodes; depth range {min(depths)}–{max(depths)} (prefix {pref_s})")
    print("")


# --------- Category-bounded antichain selection ----------
def select_frontier_antichain_within_categories(
    G,
    id_to_path: dict,
    category_map: dict,
    target_n: int,
    root_id: str = "997",
    *,
    min_offset: int = 1,     # 1 = start frontier at children of category roots
    max_extra_depth: int = 3, # don’t go deeper than (root_depth + 3)
    ban_ids: set | None = None,
):
    """
    Coverage-preserving frontier selector:
    - Works inside the given category subtrees (never above them).
    - Builds an antichain 'frontier' that covers all leaves: start from each
      category root (or its children if min_offset=1), then iteratively replace
      the frontier node with the largest subtree by ALL of its children, while:
        * respecting a per-category max depth (root_depth + max_extra_depth),
        * never selecting banned IDs,
        * keeping an antichain (no ancestor/descendant pairs).

    Returns a list[str] of selected region IDs (strings).
    """
    ban_ids = set(map(str, ban_ids or set()))
    root_id = str(root_id)

    # --- helpers ---
    def canon(p) -> str:
        return "/" + str(p).strip("/") + "/"

    cat_prefixes = {k: canon(v) for k, v in category_map.items()}
    cat_roots = {pref.strip("/").split("/")[-1] for pref in cat_prefixes.values()}

    def depth(sid: str) -> int:
        p = id_to_path.get(sid, "")
        return max(len(str(p).strip("/").split("/")) - 1, 0)

    # Map each node to its (longest) matching category root id, or None
    node_catroot = {}
    for sid, pth in id_to_path.items():
        best = None; best_len = -1
        for pref in cat_prefixes.values():
            if pth.startswith(pref):
                L = len(pref)
                if L > best_len:
                    best_len = L
                    best = pref.strip("/").split("/")[-1]
        node_catroot[sid] = best

    # Per-category depth limit
    cat_root_depth = {r: depth(r) for r in cat_roots}

    def in_cat(sid: str):
        cr = node_catroot.get(sid)
        return (cr is not None), cr

    def within_limit(sid: str, cr: str):
        return depth(sid) <= cat_root_depth[cr] + max_extra_depth

    def children_in_cat(sid: str, cr: str):
        return [c for c in G.successors(sid) if node_catroot.get(c) == cr]

    # Count descendant leaves within the same category (memoized)
    from functools import lru_cache
    @lru_cache(maxsize=None)
    def leaf_count_in_cat(sid: str, cr: str) -> int:
        kids = children_in_cat(sid, cr)
        if not kids or not within_limit(sid, cr):
            return 1  # treat as terminal for our frontier decision
        return sum(leaf_count_in_cat(k, cr) for k in kids)

    # --- initial frontier: for each category, roots (offset=0) or their children (offset=1) ---
    frontier = []
    for name, pref in cat_prefixes.items():
        cr = pref.strip("/").split("/")[-1]
        if cr in ban_ids:  # skip banned roots entirely
            continue
        if min_offset <= 0:
            frontier.append(cr)
        else:
            kids = [c for c in G.successors(cr) if node_catroot.get(c) == cr]
            if not kids:  # degenerate branch: keep the root if nothing else exists
                frontier.append(cr)
            else:
                frontier.extend(kids)

    # Keep only valid frontier nodes (inside categories, not banned, within depth limit)
    clean = []
    for v in frontier:
        inside, cr = in_cat(v)
        if not inside or v in ban_ids:
            continue
        if not within_limit(v, cr):
            continue
        clean.append(v)
    frontier = list(dict.fromkeys(clean))  # unique, stable order

    # Enforce antichain on the starting frontier
    sel = []
    for v in frontier:
        if any((v in nx.descendants(G, s)) or (s in nx.descendants(G, v)) for s in sel):
            continue
        sel.append(v)

    # --- refine frontier until ~target_n, by expanding largest subtree ---
    # We only ever replace a selected node by ALL of its children inside the same category.
    def expandable(v):
        inside, cr = in_cat(v)
        return inside and any(children_in_cat(v, cr)) and depth(v) < cat_root_depth[cr] + max_extra_depth

    # Greedy expansion with fit-to-budget (don’t exceed target_n)
    while len(sel) < target_n:
        # pick expandable node with largest leaf_count
        candidates = [(leaf_count_in_cat(v, in_cat(v)[1]), v) for v in sel if expandable(v)]
        if not candidates:
            break
        candidates.sort(reverse=True)
        expanded = False
        for _, v in candidates:
            cr = in_cat(v)[1]
            kids = [k for k in children_in_cat(v, cr) if within_limit(k, cr) and k not in ban_ids]
            if not kids:
                continue
            # Check budget if we replace v with all its kids
            new_count = len(sel) - 1 + len(kids)
            if new_count > target_n:
                # try a smaller expandable node if available
                continue
            # do the replacement (preserves antichain and coverage)
            sel.remove(v)
            # ensure kids don't conflict with existing selection (they shouldn't, but be safe)
            for k in kids:
                if any((k in nx.descendants(G, s)) or (s in nx.descendants(G, k)) for s in sel):
                    # if conflict, skip this whole expansion to preserve antichain
                    sel.insert(0, v)  # put v back (order not important)
                    break
            else:
                sel.extend(kids)
                expanded = True
                break
        if not expanded:
            break  # no expansion fits in budget

    # Final antichain check
    S = set(sel)
    for u in list(S):
        if S & set(nx.descendants(G, u)):
            raise AssertionError(f"Antichain violation: {u} has selected descendant(s).")

    return sel

# -------------- Mapping & Aggregation --------------
def map_child_to_collapsed(id_to_path: dict, selected_ids: list[str]) -> dict:
    """Map every atlas region id -> nearest selected ancestor (if any)."""
    selected = set(selected_ids)
    mapping = {}
    for sid, pth in id_to_path.items():
        parts = [p for p in str(pth).strip("/").split("/") if p]
        mapped = next((p for p in reversed(parts) if p in selected), None)
        if mapped:
            mapping[sid] = mapped
    return mapping

def aggregate_per_mouse_hemi(input_excel: str,
                             structure_to_collapsed: dict,
                             id_to_name: dict,
                             id_to_path: dict):
    """
    Read per_mouse_sheets.xlsx and aggregate mean/std/sem across child regions, per mouse & hemisphere.
    Accepts L/R or any *_L / *_R columns.
    Reconstructs structure_id_path from region_id if needed.
    """
    df_by_mouse = pd.read_excel(input_excel, sheet_name=None)
    mean_L, std_L, sem_L = defaultdict(dict), defaultdict(dict), defaultdict(dict)
    mean_R, std_R, sem_R = defaultdict(dict), defaultdict(dict), defaultdict(dict)
    log_rows = []

    for sheet_name, df_mouse in df_by_mouse.items():
        if str(sheet_name).lower() == "summary":
            continue

        # Ensure structure_id_path exists (reconstruct from region_id if available)
        if "structure_id_path" not in df_mouse.columns:
            if "region_id" in df_mouse.columns:
                df_mouse = df_mouse.copy()
                df_mouse["structure_id_path"] = df_mouse["region_id"].astype(str).map(id_to_path).fillna("")
            else:
                print(f"[WARN] {sheet_name}: missing structure_id_path; skipping.")
                continue

        # Detect hemisphere columns (L/R). Fallback: any *_L / *_R
        has_L = "L" in df_mouse.columns
        has_R = "R" in df_mouse.columns
        if not (has_L or has_R):
            for c in df_mouse.columns:
                if c.endswith("_L") and not has_L:
                    df_mouse["L"] = df_mouse[c]; has_L = True
                if c.endswith("_R") and not has_R:
                    df_mouse["R"] = df_mouse[c]; has_R = True
        if not (has_L or has_R):
            print(f"[WARN] {sheet_name}: no L/R columns; skipping.")
            continue

        mouse_id = str(sheet_name)
        assign_L, assign_R = defaultdict(list), defaultdict(list)

        for _, row in df_mouse.iterrows():
            path = str(row["structure_id_path"]).strip("/")
            if not path:
                continue
            region_id = path.split("/")[-1]  # leaf id as string
            cid = structure_to_collapsed.get(region_id)
            if not cid:
                continue

            if has_L and pd.notna(row["L"]):
                vL = float(row["L"])
                assign_L[cid].append(vL)
                log_rows.append({
                    "mouse": mouse_id, "hemisphere": "L",
                    "collapsed_region_id": cid,
                    "collapsed_region_name": id_to_name.get(cid, ""),
                    "child_region_id": region_id,
                    "child_structure_id_path": row["structure_id_path"],
                    "cells_per_mm3_contribution": vL
                })
            if has_R and pd.notna(row["R"]):
                vR = float(row["R"])
                assign_R[cid].append(vR)
                log_rows.append({
                    "mouse": mouse_id, "hemisphere": "R",
                    "collapsed_region_id": cid,
                    "collapsed_region_name": id_to_name.get(cid, ""),
                    "child_region_id": region_id,
                    "child_structure_id_path": row["structure_id_path"],
                    "cells_per_mm3_contribution": vR
                })

        # reduce to mean/std/sem per collapsed region
        for cid, vals in assign_L.items():
            arr = np.asarray(vals, float)
            m = float(np.mean(arr))
            s = float(np.std(arr, ddof=1)) if arr.size > 1 else 0.0
            e = float(s / np.sqrt(arr.size)) if arr.size > 1 else 0.0
            mean_L[mouse_id][cid] = m; std_L[mouse_id][cid] = s; sem_L[mouse_id][cid] = e

        for cid, vals in assign_R.items():
            arr = np.asarray(vals, float)
            m = float(np.mean(arr))
            s = float(np.std(arr, ddof=1)) if arr.size > 1 else 0.0
            e = float(s / np.sqrt(arr.size)) if arr.size > 1 else 0.0
            mean_R[mouse_id][cid] = m; std_R[mouse_id][cid] = s; sem_R[mouse_id][cid] = e

    return (mean_L, std_L, sem_L), (mean_R, std_R, sem_R), log_rows

def build_wide_with_meta(mean_L, mean_R, id_to_name, id_to_path, id_to_acronym):
    """Build a wide table: meta columns + one column per mouse hemisphere (<mouse>_L / <mouse>_R)."""
    cols = {}
    all_mice = sorted(set(mean_L.keys()) | set(mean_R.keys()))
    for m in all_mice:
        if m in mean_L and mean_L[m]:
            cols[f"{m}_L"] = mean_L[m]   # dict: collapsed_id -> value
        if m in mean_R and mean_R[m]:
            cols[f"{m}_R"] = mean_R[m]

    df = pd.DataFrame(cols)
    df.index.name = "collapsed_region_id"

    # attach metadata
    df["region_id"] = df.index.astype(str)  # keep as string key for maps
    df["acronym"] = df["region_id"].map(lambda x: id_to_acronym.get(x, ""))
    df["name"] = df["region_id"].map(lambda x: id_to_name.get(x, ""))
    df["structure_id_path"] = df["region_id"].map(lambda x: id_to_path.get(x, ""))
    df["depth"] = df["structure_id_path"].map(lambda p: len(str(p).strip("/").split("/")))

    # if you prefer integer region_id in output, coerce safely
    def to_int_maybe(x):
        try: return int(x)
        except: return x
    df["region_id"] = df["region_id"].map(to_int_maybe)

    meta_cols = ["region_id", "acronym", "name", "structure_id_path", "depth"]
    ordered = meta_cols + [c for c in df.columns if c not in meta_cols]
    return df[ordered].reset_index(drop=True)


# -------------------- Audit Exports (safe) --------------------
def export_audit_files(collapsed_region_ids, id_to_name, id_to_acronym, id_to_path,
                       structure_to_collapsed, audit_dir: str):
    """Write: collapsed_selection.csv, child_to_collapsed_map.csv, unmapped_regions.csv. Safe if empty selection."""
    out_dir = Path(audit_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    sel_csv  = out_dir / "collapsed_selection.csv"
    map_csv  = out_dir / "child_to_collapsed_map.csv"
    unmap_csv= out_dir / "unmapped_regions.csv"

    # 1) collapsed_selection.csv
    rows = []
    for cid in collapsed_region_ids:
        sid = str(cid)
        rows.append({
            "region_id": int(sid) if sid.isdigit() else sid,
            "acronym":   id_to_acronym.get(sid, ""),
            "name":      id_to_name.get(sid, ""),
            "structure_id_path": id_to_path.get(sid, ""),
            "depth":     node_depth_from_path(sid, id_to_path)
        })
    sel_df = pd.DataFrame(rows)
    if not sel_df.empty:
        sel_df.sort_values(["depth","region_id"]).to_csv(sel_csv, index=False)
    else:
        # write headers so downstream tools don't choke
        pd.DataFrame(columns=["region_id","acronym","name","structure_id_path","depth"]).to_csv(sel_csv, index=False)

    # 2) child_to_collapsed_map.csv
    map_rows, unmapped = [], []
    selected_set = set(map(str, collapsed_region_ids))
    for sid, pth in id_to_path.items():
        parts = [p for p in str(pth).strip("/").split("/") if p]
        mapped = next((p for p in reversed(parts) if p in selected_set), None)
        if mapped is None:
            unmapped.append({
                "child_region_id": int(sid) if sid.isdigit() else sid,
                "child_name": id_to_name.get(sid, ""),
                "child_acronym": id_to_acronym.get(sid, ""),
                "child_structure_id_path": pth
            })
        else:
            map_rows.append({
                "child_region_id": int(sid) if sid.isdigit() else sid,
                "child_name": id_to_name.get(sid, ""),
                "child_acronym": id_to_acronym.get(sid, ""),
                "child_structure_id_path": pth,
                "collapsed_region_id": int(mapped) if mapped.isdigit() else mapped,
                "collapsed_name": id_to_name.get(mapped, ""),
                "collapsed_acronym": id_to_acronym.get(mapped, ""),
                "collapsed_structure_id_path": id_to_path.get(mapped, ""),
                "collapsed_depth": node_depth_from_path(mapped, id_to_path)
            })
    pd.DataFrame(map_rows).to_csv(map_csv, index=False)
    pd.DataFrame(unmapped).to_csv(unmap_csv, index=False)

    print(f"[AUDIT] collapsed_selection → {sel_csv} ({len(sel_df)} rows)")
    print(f"[AUDIT] child_to_collapsed_map → {map_csv} ({len(map_rows)} rows)")
    if unmapped:
        print(f"[AUDIT] unmapped_regions → {unmap_csv}  (count={len(unmapped)})")
    else:
        print("[AUDIT] All regions mapped.")


# ------------------------- MAIN -------------------------
def main():
    # 1) Load structures & build graph
    S, id_to_name, id_to_path, id_to_acronym = load_structures(structures_path)
    G = build_graph_from_paths(S)
    if ROOT_ID not in G:
        raise ValueError(f"Root id {ROOT_ID} not in graph.")

    # Diagnostics: ensure your prefixes actually match nodes
    diagnose_categories(id_to_path, category_map)

    # 2) Select collapsed regions **within categories** (never above them)
    # BAN_IDS should include "8" to avoid the shallow hub
    BAN_IDS = {"8"}

    collapsed_region_ids = select_frontier_antichain_within_categories(
        G,
        id_to_path=id_to_path,
        category_map=category_map,   # your prefixes, canonicalized elsewhere
        target_n=TARGET_N,
        root_id="997",
        min_offset=1,                # start one level below category root
        max_extra_depth=3,           # cap depth to root_depth + 3  (tune: 2..4)
        ban_ids=BAN_IDS,
    )

    # Sanity: these should be mid-depth ancestors, not last leaves
    depths = [max(len(id_to_path[s].strip("/").split("/")) - 1, 0) for s in collapsed_region_ids]
    print(f"[SELECT] {len(collapsed_region_ids)} nodes, depth {min(depths)}–{max(depths)}")


    # Fallback if empty or too small: allow category roots
    #if len(collapsed_region_ids) == 0:
    #    print("[WARN] No nodes selected under current constraints. Retrying allowing category roots...")
    #    collapsed_region_ids = select_collapsed_antichain_within_categories(
    #        G, id_to_path, category_map, TARGET_N, root_id=ROOT_ID,
    #        min_offset=0, allow_category_roots=True, ban_ids=BAN_IDS, relax=True
    #    )

    depths = [node_depth_from_path(s, id_to_path) for s in collapsed_region_ids] if collapsed_region_ids else []
    if collapsed_region_ids:
        print(f"[SELECT] Selected {len(collapsed_region_ids)} collapsed nodes within categories. "
              f"Depth range: {min(depths)}–{max(depths)}")
    else:
        print("[SELECT] Still selected 0 nodes — check category prefixes in the DIAG output above.")

    # 3) Map every atlas region to its nearest selected ancestor (will be empty map if selection empty)
    structure_to_collapsed = map_child_to_collapsed(id_to_path, collapsed_region_ids)

    # 4) Aggregate per mouse & hemisphere (cells_per_mm3)
    (mean_L, std_L, sem_L), (mean_R, std_R, sem_R), log_rows = aggregate_per_mouse_hemi(
        input_excel, structure_to_collapsed, id_to_name, id_to_path
    )

    # 5) Build wide tables with metadata
    if collapsed_region_ids:
        df_mean = build_wide_with_meta(mean_L, mean_R, id_to_name, id_to_path, id_to_acronym)
        df_std  = build_wide_with_meta(std_L,  std_R,  id_to_name, id_to_path, id_to_acronym)
        df_sem  = build_wide_with_meta(sem_L,  sem_R,  id_to_name, id_to_path, id_to_acronym)
    else:
        # empty frames with just meta headers
        df_mean = pd.DataFrame(columns=["region_id","acronym","name","structure_id_path","depth"])
        df_std  = df_mean.copy(); df_sem = df_mean.copy()

    # 6) Save outputs
    with pd.ExcelWriter(output_excel) as writer:
        df_mean.to_excel(writer, sheet_name="mean_cells_per_mm3", index=False)
        df_std.to_excel(writer,  sheet_name="std_cells_per_mm3",  index=False)
        df_sem.to_excel(writer,  sheet_name="sem_cells_per_mm3",  index=False)
    print(f"[OK] Collapsed matrices written → {output_excel}")

    pd.DataFrame(log_rows).to_csv(output_log, index=False)
    print(f"[OK] Contribution log written → {output_log} (rows={len(log_rows)})")

    # 7) Audit CSVs (safe if empty)
    export_audit_files(collapsed_region_ids, id_to_name, id_to_acronym, id_to_path,
                       structure_to_collapsed, audit_dir)


if __name__ == "__main__":
    main()



[DIAG] Category coverage:
  - Medulla: 58 nodes; depth range 4–7 (prefix /997/8/343/1065/354/)
  - Pons: 34 nodes; depth range 4–7 (prefix /997/8/343/1065/771/)
  - Hypothalamus: 57 nodes; depth range 4–8 (prefix /997/8/343/1129/1097/)
  - Thalamus: 67 nodes; depth range 4–8 (prefix /997/8/343/1129/549/)
  - Midbrain: 68 nodes; depth range 3–7 (prefix /997/8/343/313/)
  - Cerebellum: 28 nodes; depth range 2–6 (prefix /997/8/512/)
  - Cortical plate: 352 nodes; depth range 4–9 (prefix /997/8/567/688/695/)
  - Cortical subplate: 14 nodes; depth range 4–6 (prefix /997/8/567/688/703/)
  - Pallidum: 15 nodes; depth range 4–7 (prefix /997/8/567/623/803/)
  - Striatum: 23 nodes; depth range 4–7 (prefix /997/8/567/623/477/)

[SELECT] 160 nodes, depth 3–7
[SELECT] Selected 160 collapsed nodes within categories. Depth range: 3–7
[OK] Collapsed matrices written → Y:\public\projects\AnAl_20240405_Neuromod_PE\PE_mapping\processed_data\den_collapsed_matrix.xlsx
[OK] Contribution log written → Y:\pu