# Temporal lag structure (Gaussian) — **50%, 70%, 90%** _(complementary)_

**Goal.** Mirror the KSG lag analysis for the **Gaussian** estimator: summarise the **modal TE lag** (50–250 ms in 50 ms steps) for **robust edges** within each group.

**Inputs**
- `Results/gauss_results/sub-XXX_ses-YYY_combined_gauss_binary.npy`
- `Results/gauss_results/sub-XXX_ses-YYY_combined_gauss_max_te_lag.npy`
- `subject_session_metadata.csv`

**Method**
1. For each group and each edge, collect **per-session `max_te_lag`** values **only when the edge is significant** in that session (`binary == 1`).
2. Compute the **modal lag (ms)** per edge **across sessions** in that group (ties resolved deterministically by first occurrence).
3. Using **group presence** (fraction of sessions with the edge), define **robust sets** at **50%**, **70%**, and **90%**.
4. Within each robust set, summarise the **lag repertoire**:
   - counts per bin (50, 100, 150, 200, 250 ms),
   - repertoire width (number of occupied bins),
   - median bin and IQR (in ms).
5. Plot group histograms for **50%**, **70%**, and **90%** (saved as PNGs).

**Key choices**
- **No across-edge FDR** at extraction (`fdr=False`), diagonal excluded.
- EEG downsampled to **20 Hz** → 5 lag bins (50–250 ms).
- This Gaussian section is **complementary/appendix**; your Results narrative emphasises KSG.

**Outputs (for Appendix)**
- CSVs: `gauss_lag_repertoire_thr50.csv`, `gauss_lag_repertoire_thr70.csv`, `gauss_lag_repertoire_thr90.csv`
- Figures: `lag_histograms_gauss_thr50.png`, `lag_histograms_gauss_thr70.png`, `lag_histograms_gauss_thr90.png`


In [1]:
# ============================
# STEP 3 (Gaussian): Temporal lag structure (max_te_lag)
# Mirrors the KSG Step 3, now for the Gaussian estimator.
# Outputs 50%, 70% and 90% summaries + plots (complementary/appendix).
# ============================
from pathlib import Path
import re
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import Counter

# -------- Paths --------
BASE      = Path("/lustre/majlepy2/myproject")
RESULTS   = BASE / "Results" / "gauss_results"
META_CSV  = BASE / "subject_session_metadata.csv"

OUTROOT   = Path("/home/majlepy2/myproject/Step-wise")
FIGDIR    = OUTROOT / "figs"
FIGDIR.mkdir(parents=True, exist_ok=True)

# -------- Metadata: session -> group --------
meta = pd.read_csv(META_CSV)
meta["sub_ses"] = meta["subject"] + "_" + meta["session"]
subses_to_group = dict(zip(meta["sub_ses"], meta["group"]))

# -------- Load group presence (reused if saved); else compute quickly --------
def load_or_compute_presence():
    pres = {}
    for g in ["healthy", "PD-off", "PD-on"]:
        f = RESULTS / f"{g}_gauss_edge_presence.npy"
        if f.exists():
            pres[g] = np.load(f)
        else:
            # compute presence by averaging binaries in this group
            mats = []
            for npy in sorted(RESULTS.glob("sub-*_*_combined_gauss_binary.npy")):
                stem = npy.name.replace("_binary.npy", "")
                m = re.match(r"(sub-\d+)_([^_]*)_combined_gauss", stem)
                if not m:
                    continue
                sub_ses = f"{m.group(1)}_{m.group(2)}"
                if subses_to_group.get(sub_ses) != g:
                    continue
                A = np.load(npy).astype(np.uint8)
                mats.append(A)
            if not mats:
                raise RuntimeError(f"No binaries found for group '{g}' to compute presence.")
            M = np.stack(mats, axis=0)
            P = M.mean(axis=0)
            np.fill_diagonal(P, 0.0)
            pres[g] = P
            np.save(f, P)
            print(f"[INFO] Presence computed and saved for {g}: shape={P.shape}")
    return pres

presence = load_or_compute_presence()
P_H, P_OFF, P_ON = presence["healthy"], presence["PD-off"], presence["PD-on"]
N = P_H.shape[0]
diag = np.eye(N, dtype=bool)

def robust_mask(P, thr):
    return (P >= thr) & (~diag)

# -------- Collect per-session binaries and max_te_lag arrays --------
def load_group_sessions():
    groups = {"healthy": [], "PD-off": [], "PD-on": []}
    for bin_path in sorted(RESULTS.glob("sub-*_*_combined_gauss_binary.npy")):
        stem = bin_path.name.replace("_binary.npy", "")  # sub-XXX_ses-YY_combined_gauss
        m = re.match(r"(sub-\d+)_([^_]*)_combined_gauss", stem)
        if not m:
            continue
        sub, ses = m.group(1), m.group(2)
        sub_ses = f"{sub}_{ses}"
        g = subses_to_group.get(sub_ses)
        if g not in groups:
            continue
        lag_path = RESULTS / f"{stem}_max_te_lag.npy"
        if not lag_path.exists():
            raise FileNotFoundError(f"Missing max_te_lag for {stem}")
        A = np.load(bin_path).astype(np.uint8)
        L = np.load(lag_path).astype(float)
        groups[g].append((A, L))
    for g, lst in groups.items():
        print(f"{g}: {len(lst)} sessions")
    return groups

group_sessions = load_group_sessions()

# -------- Utility: convert lag coding to ms (1..5 -> 50..250, else pass-through) --------
def to_ms(arr):
    a = arr.copy()
    finite = np.isfinite(a)
    if np.nanmax(a[finite]) <= 10:  # looks like bin indices
        a[finite] = a[finite] * 50.0
    return a

# -------- Per-group modal lag per edge, masked by significance in that session --------
def modal_lag_per_edge(session_list):
    """
    session_list: list of (A, L) where A is binary (N,N) and L is lag array (N,N)
    Returns: (N,N) array of modal lag in ms, np.nan where absent.
    """
    if not session_list:
        raise RuntimeError("Empty session list.")
    N = session_list[0][0].shape[0]
    lag_mode_ms = np.full((N, N), np.nan, dtype=float)

    # accumulate lags per edge
    acc = [[[] for _ in range(N)] for __ in range(N)]
    for A, L in session_list:
        L_ms = to_ms(L)
        present = (A == 1)
        for i in range(N):
            for j in range(N):
                if i == j:
                    continue
                if present[i, j] and np.isfinite(L_ms[i, j]) and L_ms[i, j] != 0:
                    acc[i][j].append(int(L_ms[i, j]))

    # modal lag (ties: Counter.most_common first-encounter rule; deterministic with fixed order)
    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            vals = acc[i][j]
            if vals:
                lag_mode_ms[i, j] = Counter(vals).most_common(1)[0][0]
    return lag_mode_ms

L_mode_H_ms   = modal_lag_per_edge(group_sessions["healthy"])
L_mode_OFF_ms = modal_lag_per_edge(group_sessions["PD-off"])
L_mode_ON_ms  = modal_lag_per_edge(group_sessions["PD-on"])

# -------- Summaries for 50%, 70%, 90% --------
BINS_MS = np.array([50, 100, 150, 200, 250], dtype=int)

def robust_modal_values(P, L_mode_ms, thr):
    R = robust_mask(P, thr)
    vals = L_mode_ms[R]
    vals = vals[np.isfinite(vals)]
    return vals.astype(int)

def counts_width_median_iqr(vals):
    counts = {int(b): int((vals == b).sum()) for b in BINS_MS}
    width = sum(1 for b in BINS_MS if counts[b] > 0)
    if vals.size:
        med = float(np.median(vals))
        q25, q75 = np.percentile(vals, [25, 75])
        iqr_str = f"[{int(q25)}–{int(q75)}]"
    else:
        med, iqr_str = np.nan, "NA"
    return counts, width, med, iqr_str

def summarize_and_plot(thr, tag):
    vals_H   = robust_modal_values(P_H,   L_mode_H_ms,   thr)
    vals_OFF = robust_modal_values(P_OFF, L_mode_OFF_ms, thr)
    vals_ON  = robust_modal_values(P_ON,  L_mode_ON_ms,  thr)

    C_H,   W_H,   MED_H,   IQR_H   = counts_width_median_iqr(vals_H)
    C_OFF, W_OFF, MED_OFF, IQR_OFF = counts_width_median_iqr(vals_OFF)
    C_ON,  W_ON,  MED_ON,  IQR_ON  = counts_width_median_iqr(vals_ON)

    # CSV summary
    rows = []
    def row(name, counts, width, med, iqr):
        return {
            "group": name,
            "n_robust_edges": int(sum(counts.values())),
            "count_50": counts[50], "count_100": counts[100],
            "count_150": counts[150], "count_200": counts[200], "count_250": counts[250],
            "repertoire_width_bins": width,
            "median_lag_ms": med,
            "iqr_lag_ms": iqr
        }
    rows.append(row("Healthy", C_H, W_H, MED_H, IQR_H))
    rows.append(row("PD-off",  C_OFF, W_OFF, MED_OFF, IQR_OFF))
    rows.append(row("PD-on",   C_ON, W_ON, MED_ON, IQR_ON))
    df = pd.DataFrame(rows)
    csv_out = OUTROOT / f"gauss_lag_repertoire_thr{int(thr*100)}.csv"
    df.to_csv(csv_out, index=False)
    print("Saved:", csv_out)
    print(df.to_string(index=False))

    # Figure
    fig, axes = plt.subplots(1, 3, figsize=(11, 3.5), dpi=200, sharey=True)
    for ax, counts, title in zip(
        axes,
        [C_H, C_OFF, C_ON],
        ["Healthy", "PD-off", "PD-on"]
    ):
        heights = [counts[int(b)] for b in BINS_MS]
        ax.bar(range(len(BINS_MS)), heights)
        ax.set_xticks(range(len(BINS_MS)))
        ax.set_xticklabels([f"{b}" for b in BINS_MS])
        ax.set_xlabel("Modal lag (ms)")
        ax.set_title(title)
    axes[0].set_ylabel(f"Count of robust edges (≥{int(thr*100)}%)")
    plt.tight_layout()
    fig_out = FIGDIR / f"lag_histograms_gauss_thr{int(thr*100)}.png"
    plt.savefig(fig_out, bbox_inches="tight"); plt.close(fig)
    print("Saved:", fig_out)

# Run for 50%, 70%, 90% (added 50% — logic unchanged)
for thr, tag in [(0.50, "thr50"), (0.70, "thr70"), (0.90, "thr90")]:
    summarize_and_plot(thr, tag)


healthy: 12 sessions
PD-off: 12 sessions
PD-on: 12 sessions
Saved: /home/majlepy2/myproject/Step-wise/gauss_lag_repertoire_thr50.csv
  group  n_robust_edges  count_50  count_100  count_150  count_200  count_250  repertoire_width_bins  median_lag_ms iqr_lag_ms
Healthy              81        70         10          1          0          0                      3           50.0    [50–50]
 PD-off             105        90         13          1          1          0                      4           50.0    [50–50]
  PD-on             110        93         10          1          3          3                      5           50.0    [50–50]
Saved: /home/majlepy2/myproject/Step-wise/figs/lag_histograms_gauss_thr50.png
Saved: /home/majlepy2/myproject/Step-wise/gauss_lag_repertoire_thr70.csv
  group  n_robust_edges  count_50  count_100  count_150  count_200  count_250  repertoire_width_bins  median_lag_ms iqr_lag_ms
Healthy               6         6          0          0          0          0    