# Build Gaussian results, compute group presence, counts, lag summaries, and overlap (recall)

In [2]:
# ============================
# Gaussian pipeline + KSG vs Gaussian comparison
# Reproducible analysis used in "Nonlinear vs. linear estimators"
#
# What this cell does:
#   1) (Optional) Combines per-target Gaussian IDTxl Results (23 .pkl per session) into
#      one session-level Results object per session.
#   2) Extracts session-level Gaussian adjacency matrices:
#        - binary (edge present per session)
#        - max_te_lag (modal TE delay per edge per session)
#      and saves them as .npy.
#   3) Builds group-level presence matrices (fraction of sessions with the edge).
#   4) Loads previously produced KSG group presence matrices.
#   5) Derives robust-edge counts at 50/70/90% for KSG and Gaussian; saves CSV and plots.
#   6) Summarizes Gaussian lag repertoires for robust edges at 70% and 90%; saves CSVs.
#   7) Computes estimator overlap at 70% and 90% (CONSENSUS and Gaussian RECALL of KSG).
#
# Notes:
#   - No across-edge FDR is applied here (fdr=False), matching the manuscript.
#   - Lags are assumed to be coded as bins 1..5. These are converted to ms by ×50.
#   - KSG presence matrices must already exist in KSG_RESULTS (produced earlier in the project).
#   - Figure outputs are written under OUTBASE; copy/symlink to your LaTeX figure folder as needed.
# ============================

from pathlib import Path
import re, pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

# ---------- Paths (EDIT GAUSS_SOURCE_DIR to your local location if needed) ----------
BASE            = Path("/lustre/majlepy2/myproject")
RESULTS         = BASE / "Results"
KSG_RESULTS     = RESULTS / "ksg_results"
GAUSS_SOURCE_DIR= BASE / "gaussian_mte_50"         # <-- location of per-target Gaussian Results .pkl files (23 per session)
GAUSS_RESULTS   = RESULTS / "gauss_results"        # derived Gaussian arrays and combined pickles are saved here
GAUSS_RESULTS.mkdir(parents=True, exist_ok=True)

META_CSV        = BASE / "subject_session_metadata.csv"

OUTBASE         = Path("/home/majlepy2/myproject/Step-wise")
FIGS            = OUTBASE / "figs" / "Comparison"  # robust-edge count figures (per group)
FIGS.mkdir(parents=True, exist_ok=True)
OUTCSV          = OUTBASE / "Comparison"           # CSV outputs for counts, lags, overlap
OUTCSV.mkdir(parents=True, exist_ok=True)

# ---------- Metadata: map 'subject_session' -> group (healthy / PD-off / PD-on) ----------
meta = pd.read_csv(META_CSV)
meta["sub_ses"] = meta["subject"] + "_" + meta["session"]
subses_to_group = dict(zip(meta["sub_ses"], meta["group"]))

# ---------- Helper: combine per-target Gaussian Results into one session-level Results ----------
def combine_gaussian_sessions():
    """
    For each session directory under GAUSS_SOURCE_DIR (e.g., sub-XXX/ses-YY) that contains
    exactly 23 IDTxl Results .pkl files (one per target), load them and combine into a single
    ResultsNetworkInference object, then save as *_combined_gauss.pkl in GAUSS_RESULTS.

    Skips sessions with a different number of .pkl files and prints a note.
    """
    session_dirs = []
    for subses_dir in sorted(GAUSS_SOURCE_DIR.glob("sub-*/ses-*")):
        pkl_files = list(subses_dir.glob("*.pkl"))
        if len(pkl_files) == 23:
            session_dirs.append(subses_dir)
        else:
            print(f"[skip] {subses_dir} has {len(pkl_files)} pkls (need 23)")
    print(f"Gaussian: found {len(session_dirs)} sessions with 23 targets.")

    for subses_dir in session_dirs:
        subj = subses_dir.parts[-2]
        sess = subses_dir.parts[-1]
        out_pkl = GAUSS_RESULTS / f"{subj}_{sess}_combined_gauss.pkl"
        if out_pkl.exists():
            continue
        from idtxl.results import ResultsNetworkInference
        results = []
        for pf in sorted(subses_dir.glob("*.pkl")):
            with open(pf, "rb") as f:
                obj = pickle.load(f)
                # prune non-essential settings keys to keep combined Results clean
                for k in ['target','filename_ckp','write_ckp','loglevel']:
                    obj.settings.pop(k, None)
                results.append(obj)
        example = results[0]
        comb = ResultsNetworkInference(
            n_nodes=example.data_properties['n_nodes'],
            n_realisations=example.data_properties['n_realisations'],
            normalised=example.data_properties['normalised'],
        )
        comb.combine_results(*results)
        with open(out_pkl, "wb") as f:
            pickle.dump(comb, f)
        print(f"[combined] {out_pkl}")

# Combine (idempotent). If GAUSS_SOURCE_DIR is absent, assume combined pickles exist already.
if GAUSS_SOURCE_DIR.exists():
    combine_gaussian_sessions()
else:
    print(f"[WARN] GAUSS_SOURCE_DIR not found ({GAUSS_SOURCE_DIR}); assuming combined pickles exist.")

# ---------- Extract Gaussian adjacencies to .npy (binary + max_te_lag), no FDR ----------
def get_adj(res, weights, fdr=False):
    # Thin wrapper around IDTxl API to return a numpy array.
    return np.array(res.get_adjacency_matrix(weights, fdr=fdr))

gauss_pkls = sorted(GAUSS_RESULTS.glob("sub-*_*_combined_gauss.pkl"))
print(f"Gaussian: {len(gauss_pkls)} combined session files.")

for f in gauss_pkls:
    stem = f.stem.replace("_combined_gauss", "")  # sub-XXX_ses-YY
    out_bin = GAUSS_RESULTS / f"{stem}_gauss_binary.npy"
    out_lag = GAUSS_RESULTS / f"{stem}_gauss_max_te_lag.npy"
    if out_bin.exists() and out_lag.exists():
        continue
    with open(f, "rb") as pf:
        res = pickle.load(pf)
    # No across-edge FDR; per-edge inference threshold; see manuscript
    A = get_adj(res, "binary",    fdr=False).astype(np.uint8)
    L = get_adj(res, "max_te_lag",fdr=False).astype(float)
    np.fill_diagonal(A, 0)
    np.fill_diagonal(L, np.nan)
    np.save(out_bin, A)
    np.save(out_lag, L)
    print(f"[saved] {out_bin.name}, {out_lag.name}")

# ---------- Group-level presence (Gaussian): fraction of sessions with edge present ----------
def group_presence_gauss():
    """
    Loads all *_gauss_binary.npy per session, groups them by cohort, and computes the mean
    across sessions to obtain group-level edge-presence matrices (values in [0,1]).
    """
    groups = {"healthy": [], "PD-off": [], "PD-on": []}
    for bin_path in sorted(GAUSS_RESULTS.glob("sub-*_*_gauss_binary.npy")):
        stem = bin_path.name.replace("_gauss_binary.npy","")  # sub-XXX_ses-YY
        m = re.match(r"(sub-\d+)_([^_]+)", stem)
        if not m: 
            continue
        sub_ses = f"{m.group(1)}_{m.group(2)}"
        g = subses_to_group.get(sub_ses)
        if g not in groups:
            continue
        A = np.load(bin_path).astype(np.uint8)
        groups[g].append(A)
    pres = {}
    for g, mats in groups.items():
        if not mats:
            raise RuntimeError(f"No Gaussian binaries for group {g}")
        M = np.stack(mats, axis=0)
        P = M.mean(axis=0)          # fraction of sessions with edge present
        np.fill_diagonal(P, 0.0)
        np.save(GAUSS_RESULTS / f"{g}_gauss_edge_presence.npy", P)
        pres[g] = P
        print(f"[presence] {g}: shape={P.shape}")
    return pres

G_presence = group_presence_gauss()

# ---------- Load KSG presence (must be produced earlier in the pipeline) ----------
K_presence = {
    "healthy": np.load(KSG_RESULTS / "healthy_ksg_edge_presence.npy"),
    "PD-off":  np.load(KSG_RESULTS / "PD-off_ksg_edge_presence.npy"),
    "PD-on":   np.load(KSG_RESULTS / "PD-on_ksg_edge_presence.npy"),
}

# ---------- Robust-edge counts vs threshold, per estimator (50/70/90%) ----------
THRS = [0.50, 0.70, 0.90]
rows = []
for est, pres in [("KSG", K_presence), ("Gaussian", G_presence)]:
    for g in ["healthy","PD-off","PD-on"]:
        P = pres[g]
        N = P.shape[0]
        diag = np.eye(N, dtype=bool)
        for thr in THRS:
            R = (P >= thr) & (~diag)
            rows.append({
                "estimator": est,
                "group": g,
                "threshold": int(thr*100),
                "robust_edges": int(R.sum())
            })
df_counts = pd.DataFrame(rows)
df_counts.to_csv(OUTCSV / "robust_edge_counts_ksg_vs_gaussian.csv", index=False)
print(df_counts)

# ---------- Plot: robust edges vs threshold (one PNG per group; KSG vs Gaussian) ----------
def plot_counts(df):
    """
    Saves bar plots for each group showing # robust edges for KSG vs Gaussian at 50/70/90% presence.
    """
    for g in ["healthy","PD-off","PD-on"]:
        sub = df[df.group==g].pivot(index="threshold", columns="estimator", values="robust_edges")
        sub = sub.loc[sorted(sub.index)]
        fig, ax = plt.subplots(figsize=(5,3), dpi=150)
        sub.plot(kind="bar", ax=ax)
        ax.set_title(f"Robust edges vs threshold — {g}")
        ax.set_ylabel("# robust edges")
        ax.set_xlabel("Presence threshold (%)")
        fig.tight_layout()
        fig.savefig(FIGS / f"robust_edges_thresholds_{g}.png", bbox_inches="tight")
        plt.close(fig)
        print(f"[fig] {FIGS / f'robust_edges_thresholds_{g}.png'}")

plot_counts(df_counts)

# ---------- Lag histograms for Gaussian robust edges (parallel to the KSG timing section) ----------
def to_ms(arr):
    """
    Convert lag bins (1..5) to ms by ×50 if they look like small integers; otherwise leave as-is.
    """
    a = arr.copy()
    finite = np.isfinite(a)
    if finite.any() and np.nanmax(a[finite]) <= 10:
        a[finite] = a[finite] * 50.0
    return a

def modal_lag_per_edge(session_npy_prefix="_gauss"):
    """
    For each group, compute the per-edge modal lag (ms) across sessions, counting only sessions
    where the edge is present (binary==1). Diagonals are ignored. Returns dict[group] -> (N x N).
    """
    groups = {"healthy": [], "PD-off": [], "PD-on": []}
    for bin_path in sorted(GAUSS_RESULTS.glob(f"sub-*_*{session_npy_prefix}_binary.npy")):
        stem = bin_path.name.replace(f"{session_npy_prefix}_binary.npy","").rstrip("_")
        parts = stem.split("_")
        sub_ses = "_".join(parts[:2])
        g = subses_to_group.get(sub_ses)
        if g not in groups: 
            continue
        A = np.load(bin_path).astype(np.uint8)
        L = np.load(GAUSS_RESULTS / f"{stem}{session_npy_prefix}_max_te_lag.npy").astype(float)
        groups[g].append((A,L))
    out = {}
    for g, items in groups.items():
        if not items:
            raise RuntimeError(f"No sessions for {g} (Gaussian).")
        N = items[0][0].shape[0]
        lag_mode_ms = np.full((N,N), np.nan)
        acc = [[[] for _ in range(N)] for __ in range(N)]
        for A,L in items:
            Lms = to_ms(L)
            pres = (A==1)
            for i in range(N):
                for j in range(N):
                    if i==j: continue
                    if pres[i,j] and np.isfinite(Lms[i,j]) and Lms[i,j]!=0:
                        acc[i][j].append(int(Lms[i,j]))
        for i in range(N):
            for j in range(N):
                if i==j: continue
                vals = acc[i][j]
                if vals:
                    lag_mode_ms[i,j] = Counter(vals).most_common(1)[0][0]
        out[g] = lag_mode_ms
    return out

G_modal = modal_lag_per_edge("_gauss")

BINS = [50,100,150,200,250]

def lag_counts_from_presence(P, L_mode_ms, thr):
    """
    Given a group presence matrix P and modal lag matrix (ms), restrict to edges robust at 'thr',
    then count edges per 50ms bin and summarize width, median, and IQR.
    """
    N = P.shape[0]
    diag = np.eye(N, dtype=bool)
    R = (P >= thr) & (~diag)
    vals = L_mode_ms[R]
    vals = vals[np.isfinite(vals)].astype(int)
    counts = {b: int((vals==b).sum()) for b in BINS}
    width = sum(1 for b in BINS if counts[b]>0)
    med = float(np.median(vals)) if vals.size else np.nan
    q25,q75 = (np.percentile(vals,[25,75]) if vals.size else (np.nan,np.nan))
    return counts, width, med, q25, q75

def save_lag_summary(est_name, presence_dict, modal_dict, thr, tag):
    """
    Saves a CSV with counts per lag bin, repertoire width (#occupied bins), median (ms), and IQR,
    for robust edges at threshold 'thr'. Here used for Gaussian at 70% and 90%.
    """
    rows = []
    for g in ["healthy","PD-off","PD-on"]:
        C,W,MED,q25,q75 = lag_counts_from_presence(presence_dict[g], modal_dict[g], thr)
        rows.append({
            "estimator": est_name,
            "group": g,
            "threshold": int(thr*100),
            "n_robust_edges": sum(C.values()),
            "count_50": C[50], "count_100": C[100], "count_150": C[150], "count_200": C[200], "count_250": C[250],
            "width_bins": W,
            "median_ms": MED,
            "iqr": f"[{int(q25)}–{int(q75)}]" if np.isfinite(q25) else "NA"
        })
    df = pd.DataFrame(rows)
    out = OUTCSV / f"lag_repertoire_{est_name}_thr{int(thr*100)}.csv"
    df.to_csv(out, index=False)
    print(df.to_string(index=False))
    print("[csv]", out)

for thr in [0.70, 0.90]:
    save_lag_summary("Gaussian", G_presence, G_modal, thr, f"thr{int(thr*100)}")

# ---------- Overlap (CONSENSUS) between estimators at same threshold; report Gaussian RECALL of KSG ----------
def overlap_counts(thr):
    """
    For each group, compute:
      - consensus_edges: edges robust for BOTH KSG and Gaussian at 'thr'
      - ksg_robust, gauss_robust: sizes of each robust set
      - gauss_recall_of_ksg: consensus / #KSG-robust (recall of KSG by Gaussian)
    Saves combined CSV elsewhere (precision is computed in Cell 5).
    """
    rows = []
    for g in ["healthy","PD-off","PD-on"]:
        Pk = K_presence[g]
        Pg = G_presence[g]
        N = Pk.shape[0]
        diag = np.eye(N, dtype=bool)
        Rk = (Pk >= thr) & (~diag)
        Rg = (Pg >= thr) & (~diag)
        inter = (Rk & Rg).sum()
        k_only = Rk.sum()
        g_only = Rg.sum()
        recall_g_of_k = inter / k_only if k_only>0 else np.nan
        rows.append({
            "group": g, "threshold": int(thr*100),
            "consensus_edges": int(inter),
            "ksg_robust": int(k_only),
            "gauss_robust": int(g_only),
            "gauss_recall_of_ksg": round(recall_g_of_k, 3)
        })
    return pd.DataFrame(rows)

df_ov_70 = overlap_counts(0.70)
df_ov_90 = overlap_counts(0.90)
df_ov   = pd.concat([df_ov_70, df_ov_90], ignore_index=True)
df_ov.to_csv(OUTCSV / "overlap_ksg_vs_gaussian.csv", index=False)
print(df_ov)

# ---------- Optional placeholder for a side-by-side lag figure (not used in the paper here) ----------
def plot_lag_hist_pair(thr=0.90, tag="thr90"):
    pass  # Keep intentionally empty; KSG timing figs are produced in the timing section.


Gaussian: found 36 sessions with 23 targets.
Gaussian: 36 combined session files.
[presence] healthy: shape=(23, 23)
[presence] PD-off: shape=(23, 23)
[presence] PD-on: shape=(23, 23)
   estimator    group  threshold  robust_edges
0        KSG  healthy         50           477
1        KSG  healthy         70           331
2        KSG  healthy         90           117
3        KSG   PD-off         50           486
4        KSG   PD-off         70           296
5        KSG   PD-off         90            84
6        KSG    PD-on         50           477
7        KSG    PD-on         70           331
8        KSG    PD-on         90           140
9   Gaussian  healthy         50            81
10  Gaussian  healthy         70             6
11  Gaussian  healthy         90             1
12  Gaussian   PD-off         50           105
13  Gaussian   PD-off         70             9
14  Gaussian   PD-off         90             1
15  Gaussian    PD-on         50           110
16  Gaussian    P

# Gaussian lag histograms for robust edges (figures for LaTeX) 

In [3]:
# --- Gaussian lag histograms (tight, readable for LaTeX) ---
# What this cell does:
#   - Loads Gaussian per-session binaries and lag arrays
#   - Computes per-edge modal lag (ms) per group, counting only sessions where the edge is present
#   - Builds group presence matrices (fraction present)
#   - Plots 3-panel histograms of modal lags for robust edges at a chosen threshold (70%, 90%)
#   - Saves figures to OUT_FIGS
#
# Notes:
#   - Lags coded as 1..5 are converted to milliseconds by ×50.
#   - No across-edge FDR is applied here (matches the manuscript).
#   - These figures support the "Lag repertoires (Gaussian)" subsection.

from pathlib import Path
import re, pickle
import numpy as np
import matplotlib.pyplot as plt
from collections import Counter

BASE         = Path("/lustre/majlepy2/myproject")
RESULTS      = BASE / "Results"
GAUSS_DIR    = RESULTS / "gauss_results"         # contains sub-*_gauss_binary.npy and *_gauss_max_te_lag.npy
META_CSV     = BASE / "subject_session_metadata.csv"

OUT_FIGS     = Path("/home/majlepy2/myproject/Step-wise/figs/TE_lag")
OUT_FIGS.mkdir(parents=True, exist_ok=True)

# --- metadata: map sub_ses -> group ---
import pandas as pd
meta = pd.read_csv(META_CSV)
meta["sub_ses"] = meta["subject"] + "_" + meta["session"]
subses_to_group = dict(zip(meta["sub_ses"], meta["group"]))

# --- helper to collect per-session binaries and lag arrays per group ---
def collect_sessions_gaussian():
    """
    Returns dict[group] -> list of (A_binary, L_max_te_lag) per session.
    Diagonals are zeroed/NaN for A/L respectively.
    """
    groups = {"healthy": [], "PD-off": [], "PD-on": []}
    for bin_path in sorted(GAUSS_DIR.glob("sub-*_*_gauss_binary.npy")):
        stem = bin_path.name.replace("_gauss_binary.npy","")  # sub-XXX_ses-YY
        parts = stem.split("_")
        sub_ses = "_".join(parts[:2])
        g = subses_to_group.get(sub_ses)
        if g not in groups:
            continue
        L = np.load(GAUSS_DIR / f"{stem}_gauss_max_te_lag.npy").astype(float)
        A = np.load(bin_path).astype(np.uint8)
        # clean diag
        np.fill_diagonal(A, 0)
        np.fill_diagonal(L, np.nan)
        groups[g].append((A, L))
    return groups

G_sessions = collect_sessions_gaussian()
assert all(len(G_sessions[g])>0 for g in G_sessions), "Missing Gaussian sessions for some group."

# --- compute modal lag per edge (ms), counting only when edge present ---
def to_ms(L):
    # If lags are coded as 1..5, convert to ms; otherwise leave as-is
    Lc = L.copy()
    finite = np.isfinite(Lc)
    if finite.any() and np.nanmax(Lc[finite]) <= 10:
        Lc[finite] = Lc[finite] * 50.0
    return Lc

def modal_lag_per_edge_gauss(G_sessions):
    """
    For each group, aggregate session-wise lags for edges present (A==1) and take the mode (ms).
    Returns dict[group] -> (N x N) matrix of modal lags (ms).
    """
    out = {}
    for g, items in G_sessions.items():
        N = items[0][0].shape[0]
        acc = [[[] for _ in range(N)] for __ in range(N)]
        for A, L in items:
            Lms = to_ms(L)
            present = (A == 1)
            for i in range(N):
                for j in range(N):
                    if i == j: 
                        continue
                    if present[i, j] and np.isfinite(Lms[i, j]) and Lms[i, j] != 0:
                        acc[i][j].append(int(Lms[i, j]))
        mode = np.full((N, N), np.nan)
        for i in range(N):
            for j in range(N):
                if i == j: 
                    continue
                if acc[i][j]:
                    mode[i][j] = Counter(acc[i][j]).most_common(1)[0][0]
        out[g] = mode
    return out

G_modal = modal_lag_per_edge_gauss(G_sessions)

# --- group-level presence matrices (fraction present) ---
def group_presence_gauss(G_sessions):
    pres = {}
    for g, items in G_sessions.items():
        A_stack = np.stack([A for A, _ in items], axis=0)  # S x N x N
        P = A_stack.mean(axis=0)
        np.fill_diagonal(P, 0.0)
        pres[g] = P
    return pres

G_presence = group_presence_gauss(G_sessions)

# --- plot helper: one row (Healthy / PD-off / PD-on) per threshold ---
BINS = [50, 100, 150, 200, 250]

def plot_gauss_lag_hist(thr=0.70, outname="lag_histograms_gauss_thr70.png"):
    """
    Plot 3-panel histograms of modal lag bins for robust edges (presence ≥ thr), one panel per group.
    Saves to OUT_FIGS/outname.
    """
    fig, axes = plt.subplots(1, 3, figsize=(8.4, 2.8), dpi=200)  # tight & readable in LaTeX
    groups = ["healthy", "PD-off", "PD-on"]
    maxy = 0
    # precompute counts for harmonized y-limits
    counts_by_g = {}
    for g in groups:
        P = G_presence[g]; Lm = G_modal[g]
        N = P.shape[0]
        R = (P >= thr) & (~np.eye(N, dtype=bool))
        vals = Lm[R]
        vals = vals[np.isfinite(vals)].astype(int)
        C = [int((vals == b).sum()) for b in BINS]
        counts_by_g[g] = C
        maxy = max(maxy, max(C) if C else 0)

    for ax, g in zip(axes, groups):
        C = counts_by_g[g]
        ax.bar(range(len(BINS)), C, width=0.6)
        # annotate counts
        for xi, c in enumerate(C):
            if c > 0:
                ax.text(xi, c + max(1, 0.03*maxy), str(c), ha="center", va="bottom", fontsize=8)
        ax.set_xticks(range(len(BINS)))
        ax.set_xticklabels([f"{b}" for b in BINS], fontsize=9)
        ax.set_ylim(0, max(1, maxy*1.2))
        ax.set_title(g, fontsize=10)
        ax.grid(axis="y", alpha=0.3, linewidth=0.5)
    fig.suptitle(f"Gaussian modal delays (robust edges, ≥ {int(thr*100)}% of sessions)", fontsize=11)
    fig.text(0.5, 0.01, "Delay bin (ms)", ha="center", fontsize=10)
    fig.text(0.01, 0.5, "Count of robust edges", va="center", rotation="vertical", fontsize=10)
    fig.tight_layout(rect=[0.02, 0.04, 1, 0.92])
    fig.savefig(OUT_FIGS / outname, bbox_inches="tight")
    plt.close(fig)
    print("Saved:", OUT_FIGS / outname)

# Figures used for the "Lag repertoires (Gaussian)" subsection:
plot_gauss_lag_hist(0.70, "lag_histograms_gauss_thr70.png")
plot_gauss_lag_hist(0.90, "lag_histograms_gauss_thr90.png")


Saved: /home/majlepy2/myproject/Step-wise/figs/TE_lag/lag_histograms_gauss_thr70.png
Saved: /home/majlepy2/myproject/Step-wise/figs/TE_lag/lag_histograms_gauss_thr90.png


# List overlapping edges (consensus) at 70%

In [4]:
# Report consensus edges (KSG ∩ Gaussian) at 70% presence, by group.
# This is a diagnostic printout to inspect which specific directed edges overlap.
# For summary counts and recall/precision tables, see Cells 1 and 5.

import numpy as np
from pathlib import Path

ksg_dir = Path('/lustre/majlepy2/myproject/Results/ksg_results')
gauss_dir = Path('/lustre/majlepy2/myproject/Results/gauss_results')
groups = ['healthy', 'PD-off', 'PD-on']
thresh = 0.7

for group in groups:
    edge_presence_ksg = np.load(ksg_dir / f"{group}_ksg_edge_presence.npy")
    edge_presence_gauss = np.load(gauss_dir / f"{group}_gauss_edge_presence.npy")
    
    robust_ksg = edge_presence_ksg >= thresh
    robust_gauss = edge_presence_gauss >= thresh
    
    overlap_mask = np.logical_and(robust_ksg, robust_gauss)
    
    sources, targets = np.where(overlap_mask)
    print(f"\nGroup: {group} - Robust edges present in BOTH KSG and Gaussian (threshold {int(thresh*100)}%):")
    if len(sources) == 0:
        print("  No overlapping edges found.")
    else:
        for src, tgt in zip(sources, targets):
            print(f"  Edge: {src} -> {tgt}")



Group: healthy - Robust edges present in BOTH KSG and Gaussian (threshold 70%):
  Edge: 1 -> 20
  Edge: 2 -> 12
  Edge: 5 -> 3
  Edge: 7 -> 6
  Edge: 21 -> 0

Group: PD-off - Robust edges present in BOTH KSG and Gaussian (threshold 70%):
  Edge: 0 -> 11
  Edge: 1 -> 3
  Edge: 8 -> 6
  Edge: 9 -> 18
  Edge: 14 -> 17
  Edge: 17 -> 21
  Edge: 19 -> 18
  Edge: 21 -> 0
  Edge: 21 -> 17

Group: PD-on - Robust edges present in BOTH KSG and Gaussian (threshold 70%):
  Edge: 0 -> 21
  Edge: 4 -> 18
  Edge: 5 -> 0
  Edge: 6 -> 5
  Edge: 8 -> 12
  Edge: 9 -> 19
  Edge: 14 -> 17
  Edge: 17 -> 14
  Edge: 18 -> 1
  Edge: 18 -> 21
  Edge: 21 -> 0


# List overlapping edges (consensus) at 90%

In [5]:
# Same as Cell 3 but at 90% presence. Useful for manual inspection of consensus edges at stricter robustness.

import numpy as np
from pathlib import Path

ksg_dir = Path('/lustre/majlepy2/myproject/Results/ksg_results')
gauss_dir = Path('/lustre/majlepy2/myproject/Results/gauss_results')
groups = ['healthy', 'PD-off', 'PD-on']
thresh = 0.9

for group in groups:
    edge_presence_ksg = np.load(ksg_dir / f"{group}_ksg_edge_presence.npy")
    edge_presence_gauss = np.load(gauss_dir / f"{group}_gauss_edge_presence.npy")
    
    robust_ksg = edge_presence_ksg >= thresh
    robust_gauss = edge_presence_gauss >= thresh
    
    overlap_mask = np.logical_and(robust_ksg, robust_gauss)
    
    sources, targets = np.where(overlap_mask)
    print(f"\nGroup: {group} - Robust edges present in BOTH KSG and Gaussian (threshold {int(thresh*100)}%):")
    if len(sources) == 0:
        print("  No overlapping edges found.")
    else:
        for src, tgt in zip(sources, targets):
            print(f"  Edge: {src} -> {tgt}")



Group: healthy - Robust edges present in BOTH KSG and Gaussian (threshold 90%):
  No overlapping edges found.

Group: PD-off - Robust edges present in BOTH KSG and Gaussian (threshold 90%):
  Edge: 14 -> 17

Group: PD-on - Robust edges present in BOTH KSG and Gaussian (threshold 90%):
  Edge: 0 -> 21
  Edge: 21 -> 0


# Estimator overlap table with Gaussian precision

In [6]:
# Build a CSV summary of estimator overlap including Gaussian precision w.r.t. KSG.
# Definitions (match manuscript):
#   consensus_edges = |KSG_robust ∩ Gaussian_robust|
#   gauss_precision_wrt_ksg = consensus_edges / |Gaussian_robust|
# Recall (consensus / |KSG_robust|) is produced in Cell 1.

import numpy as np
import pandas as pd
from pathlib import Path

# ---- paths ----
KSG_DIR   = Path('/lustre/majlepy2/myproject/Results/ksg_results')
GAUSS_DIR = Path('/lustre/majlepy2/myproject/Results/gauss_results')
OUTCSV    = Path('/home/majlepy2/myproject/Step-wise/Comparison')
OUTCSV.mkdir(parents=True, exist_ok=True)

GROUPS = ['healthy', 'PD-off', 'PD-on']
THRS   = [0.70, 0.90]  # add 0.50 if you want

def robust_mask(P, thr):
    N = P.shape[0]
    diag = np.eye(N, dtype=bool)
    return (P >= thr) & (~diag)

rows = []
for thr in THRS:
    for g in GROUPS:
        # load presence matrices
        Pk = np.load(KSG_DIR   / f"{g}_ksg_edge_presence.npy")
        Pg = np.load(GAUSS_DIR / f"{g}_gauss_edge_presence.npy")
        # robust sets
        Rk = robust_mask(Pk, thr)
        Rg = robust_mask(Pg, thr)
        # consensus
        inter = (Rk & Rg).sum()
        k_only = Rk.sum()
        g_only = Rg.sum()
        # precision: fraction of Gaussian-robust edges also present in KSG
        precision = inter / g_only if g_only > 0 else 0.0

        rows.append({
            "group": g,
            "threshold": int(thr*100),
            "consensus_edges": int(inter),
            "ksg_robust": int(k_only),
            "gauss_robust": int(g_only),
            "gauss_precision_wrt_ksg": round(precision, 3),
        })

df_precision = pd.DataFrame(rows)
out_file = OUTCSV / "precision_gaussian_wrt_ksg.csv"
df_precision.to_csv(out_file, index=False)
print(df_precision.to_string(index=False))
print("[csv]", out_file)


  group  threshold  consensus_edges  ksg_robust  gauss_robust  gauss_precision_wrt_ksg
healthy         70                5         331             6                    0.833
 PD-off         70                9         296             9                    1.000
  PD-on         70               11         331            12                    0.917
healthy         90                0         117             1                    0.000
 PD-off         90                1          84             1                    1.000
  PD-on         90                2         140             2                    1.000
[csv] /home/majlepy2/myproject/Step-wise/Comparison/precision_gaussian_wrt_ksg.csv


# Consensus-edge lag distribution

In [7]:
# ============================
# Consensus lag distribution (KSG ∩ Gaussian) at 70% and 90%
# Outputs:
#   - CSVs with counts per lag bin (50..250 ms) for the consensus set, per group and threshold
#   - Console summary with proportions at 50 ms
#
# Requirements:
#   - KSG group presence matrices exist in: Results/ksg_results/*_ksg_edge_presence.npy
#   - Gaussian per-session arrays exist in: Results/gauss_results/sub-*_*_gauss_binary.npy and *_gauss_max_te_lag.npy
#   - subject_session_metadata.csv maps sessions to groups
#
# Notes:
#   - No across-edge FDR; diagonals excluded
#   - Lag bins coded 1..5 are converted to ms by ×50
# ============================

from pathlib import Path
import numpy as np
import pandas as pd
from collections import Counter

# ---- paths ----
BASE       = Path('/lustre/majlepy2/myproject')
RESULTS    = BASE / 'Results'
KSG_DIR    = RESULTS / 'ksg_results'
GAUSS_DIR  = RESULTS / 'gauss_results'
META_CSV   = BASE / 'subject_session_metadata.csv'

OUTBASE    = Path('/home/majlepy2/myproject/Step-wise')
OUTCSV     = OUTBASE / 'Comparison'
OUTCSV.mkdir(parents=True, exist_ok=True)

# ---- groups / thresholds ----
GROUPS = ['healthy', 'PD-off', 'PD-on']
THRS   = [0.70, 0.90]
BINS   = [50, 100, 150, 200, 250]

# ---- read metadata to map sub_ses -> group ----
meta = pd.read_csv(META_CSV)
meta['sub_ses'] = meta['subject'] + '_' + meta['session']
subses_to_group = dict(zip(meta['sub_ses'], meta['group']))

# ---- helper: convert lag bins to ms if needed ----
def to_ms(arr):
    a = arr.copy()
    finite = np.isfinite(a)
    if finite.any() and np.nanmax(a[finite]) <= 10:
        a[finite] = a[finite] * 50.0
    return a

# ---- build Gaussian per-group modal lag matrices from per-session arrays ----
def gaussian_modal_lag_by_group():
    """
    Returns dict[group] -> (N x N) array of modal lag (ms), counting only sessions with edge present.
    """
    sessions = {'healthy': [], 'PD-off': [], 'PD-on': []}
    # collect per-session arrays
    for bin_path in sorted(GAUSS_DIR.glob('sub-*_*_gauss_binary.npy')):
        stem = bin_path.name.replace('_gauss_binary.npy', '')
        parts = stem.split('_')
        sub_ses = '_'.join(parts[:2])
        g = subses_to_group.get(sub_ses)
        if g not in sessions:
            continue
        A = np.load(bin_path).astype(np.uint8)
        L = np.load(GAUSS_DIR / f'{stem}_gauss_max_te_lag.npy').astype(float)
        np.fill_diagonal(A, 0)
        np.fill_diagonal(L, np.nan)
        sessions[g].append((A, L))

    # compute modal lags
    out = {}
    for g, items in sessions.items():
        if not items:
            raise RuntimeError(f'No Gaussian sessions found for group {g}')
        N = items[0][0].shape[0]
        # accumulator of observed lags per directed edge
        acc = [[[] for _ in range(N)] for __ in range(N)]
        for A, L in items:
            Lms = to_ms(L)
            present = (A == 1)
            for i in range(N):
                for j in range(N):
                    if i == j: 
                        continue
                    if present[i, j] and np.isfinite(Lms[i, j]) and Lms[i, j] != 0:
                        acc[i][j].append(int(Lms[i, j]))
        mode = np.full((N, N), np.nan)
        for i in range(N):
            for j in range(N):
                if i == j:
                    continue
                if acc[i][j]:
                    mode[i, j] = Counter(acc[i][j]).most_common(1)[0][0]
        out[g] = mode
    return out

# ---- load group presence (KSG and Gaussian) ----
K_presence = {
    g: np.load(KSG_DIR / f'{g}_ksg_edge_presence.npy')
    for g in GROUPS
}
G_presence = {
    g: np.load(GAUSS_DIR / f'{g}_gauss_edge_presence.npy')
    for g in GROUPS
}

# ---- compute Gaussian modal lags per group ----
G_modal = gaussian_modal_lag_by_group()

# ---- for each threshold, intersect consensus mask with Gaussian modal lags and summarize ----
def consensus_lag_summary(thr):
    rows = []
    for g in GROUPS:
        Pk = K_presence[g]
        Pg = G_presence[g]
        N  = Pk.shape[0]
        diag = np.eye(N, dtype=bool)

        Rk = (Pk >= thr) & (~diag)     # robust in KSG
        Rg = (Pg >= thr) & (~diag)     # robust in Gaussian
        cons = Rk & Rg                 # consensus mask

        Lm = G_modal[g]
        vals = Lm[cons]
        vals = vals[np.isfinite(vals)].astype(int)

        counts = {b: int((vals == b).sum()) for b in BINS}
        total  = int(vals.size)
        frac50 = (counts[50] / total) if total > 0 else np.nan
        width  = sum(1 for b in BINS if counts[b] > 0)
        med    = float(np.median(vals)) if total > 0 else np.nan

        rows.append({
            'group': g,
            'threshold': int(thr * 100),
            'consensus_total': total,
            'count_50': counts[50],
            'count_100': counts[100],
            'count_150': counts[150],
            'count_200': counts[200],
            'count_250': counts[250],
            'prop_50ms': round(frac50, 3) if np.isfinite(frac50) else 'NA',
            'median_ms': med,
            'width_bins': width
        })
    return pd.DataFrame(rows)

# ---- run and save CSVs for 70% and 90% ----
all_df = []
for thr in THRS:
    df = consensus_lag_summary(thr)
    all_df.append(df)
    out = OUTCSV / f'consensus_lag_distribution_thr{int(thr*100)}.csv'
    df.to_csv(out, index=False)
    print(df.to_string(index=False))
    print('[csv]', out)

consensus_lag_df = pd.concat(all_df, ignore_index=True)


  group  threshold  consensus_total  count_50  count_100  count_150  count_200  count_250  prop_50ms  median_ms  width_bins
healthy         70                5         5          0          0          0          0      1.000       50.0           1
 PD-off         70                9         7          2          0          0          0      0.778       50.0           2
  PD-on         70               11        10          0          0          1          0      0.909       50.0           2
[csv] /home/majlepy2/myproject/Step-wise/Comparison/consensus_lag_distribution_thr70.csv
  group  threshold  consensus_total  count_50  count_100  count_150  count_200  count_250 prop_50ms  median_ms  width_bins
healthy         90                0         0          0          0          0          0        NA        NaN           0
 PD-off         90                1         1          0          0          0          0       1.0       50.0           1
  PD-on         90                2         2 