# Extract adjacency matrices per session (KSG)
From each `*_combined_ksg.pkl` (one per subject/session), extract adjacency in three forms:
- binary adjacency (edge present/absent),
- lag at max p-value,
- lag at max TE value.  
Save each as `.npy` next to the combined `.pkl`.


In [None]:
import pickle
from pathlib import Path
import numpy as np

ksg_dir = Path('/lustre/majlepy2/myproject/Results/ksg_results')
combined_files = sorted(ksg_dir.glob("*_combined_ksg.pkl"))
print(f"Found {len(combined_files)} combined session files.")

USE_FDR = False  # False = raw; True = FDR-corrected

for file in combined_files:
    with open(file, 'rb') as f:
        res = pickle.load(f)

    adj_binary     = np.array(res.get_adjacency_matrix('binary',     fdr=USE_FDR)).astype(np.uint8)
    adj_max_p_lag  = np.array(res.get_adjacency_matrix('max_p_lag',  fdr=USE_FDR)).astype(float)
    adj_max_te_lag = np.array(res.get_adjacency_matrix('max_te_lag', fdr=USE_FDR)).astype(float)

    base = file.with_suffix('')
    np.save(base.with_name(base.name + '_binary.npy'),     adj_binary)
    np.save(base.with_name(base.name + '_max_p_lag.npy'),  adj_max_p_lag)
    np.save(base.with_name(base.name + '_max_te_lag.npy'), adj_max_te_lag)

    print(f"Saved matrices for {file.name}")

print("All session adjacency matrices extracted and saved as .npy (binary, max_p_lag, max_te_lag).")


# Build group-level summaries
Use `subject_session_metadata.csv` to map sessions to groups (Healthy, PD-off, PD-on).
- Stack binary matrices → compute fraction of sessions with edge (edge presence).
- Stack lag matrices → take most common lag where edge exists.
Save group-level summaries as `.npy`.


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter

ksg_dir = Path('/lustre/majlepy2/myproject/Results/ksg_results')
meta = pd.read_csv('/lustre/majlepy2/myproject/subject_session_metadata.csv')
meta['sub_ses'] = meta['subject'] + '_' + meta['session']

group_dict = dict(zip(meta['sub_ses'], meta['group']))

binary_files = sorted(ksg_dir.glob("*_ksg_binary.npy"))
lag_files = sorted(ksg_dir.glob("*_ksg_max_p_lag.npy"))
assert len(binary_files) == len(lag_files)

group_to_binary, group_to_lags = {}, {}

for bin_file, lag_file in zip(binary_files, lag_files):
    base = bin_file.name.replace('_ksg_binary.npy', '')
    if base.endswith('_combined'):
        base = base[:-9]
    sub_ses = base
    group = group_dict.get(sub_ses)
    if group is None:
        print(f"WARNING: {sub_ses} not found in metadata.")
        continue

    adj_binary = np.load(bin_file)
    adj_lag = np.load(lag_file)

    group_to_binary.setdefault(group, []).append(adj_binary)
    group_to_lags.setdefault(group, []).append(adj_lag)

results = {}
for group in group_to_binary:
    binaries = np.stack(group_to_binary[group], axis=0)
    lags = np.stack(group_to_lags[group], axis=0)
    n_sessions = binaries.shape[0]

    edge_presence = binaries.mean(axis=0)
    lag_summary = np.full(binaries.shape[1:], np.nan)
    for i in range(binaries.shape[1]):
        for j in range(binaries.shape[2]):
            lags_present = lags[:, i, j][binaries[:, i, j] == 1]
            if len(lags_present) > 0:
                lag_summary[i, j] = Counter(lags_present).most_common(1)[0][0]

    results[group] = {
        "edge_presence": edge_presence,
        "lag_summary": lag_summary,
        "n_sessions": n_sessions,
    }
    print(f"KSG {group}: {n_sessions} sessions, edge_presence matrix shape: {edge_presence.shape}")

for group, group_results in results.items():
    np.save(ksg_dir / f"{group}_ksg_edge_presence.npy", group_results["edge_presence"])
    np.save(ksg_dir / f"{group}_ksg_lag_summary.npy", group_results["lag_summary"])
    print(f"Saved KSG summary arrays for {group}")

print("KSG group-level aggregation complete.")


# Visualize group-level heatmaps
For each group:
- Heatmap 1: edge presence fraction (0–1).  
- Heatmap 2: most common lag (masked if presence < threshold).  
Save figures as PNG.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

ksg_dir = Path('/lustre/majlepy2/myproject/Results/ksg_results')
out_dir = Path('/home/majlepy2/myproject')
out_dir.mkdir(parents=True, exist_ok=True)

groups = ['healthy', 'PD-on', 'PD-off']
est_label = "KSG"
presence_threshold = 0.90

for group in groups:
    print(f"\n=== {group} ===")
    edge_file = ksg_dir / f"{group}_edge_presence.npy"
    lag_file  = ksg_dir / f"{group}_lag_summary.npy"

    if not edge_file.exists() or not lag_file.exists():
        print(f"Missing files for {group}")
        continue

    edge_presence = np.load(edge_file)
    lag_summary   = np.load(lag_file)

    lag_display = np.array(lag_summary, dtype=float)
    lag_display[edge_presence < presence_threshold] = np.nan

    fig, axs = plt.subplots(1, 2, figsize=(12, 5))
    sns.heatmap(edge_presence, ax=axs[0], cmap="Blues", vmin=0, vmax=1,
                cbar_kws={'label': 'Fraction of sessions (edge present)'})
    axs[0].set_title(f"{group} – {est_label}: Edge Presence Fraction")
    axs[0].set_xlabel("Target Node")
    axs[0].set_ylabel("Source Node")

    sns.heatmap(lag_display, ax=axs[1], cmap="viridis",
                cbar_kws={'label': f"Most Common Lag (≥{int(presence_threshold*100)}% presence)"})
    axs[1].set_title(f"{group} – {est_label}: Most Common Lag")
    axs[1].set_xlabel("Target Node")
    axs[1].set_ylabel("Source Node")

    fig.tight_layout()
    out_png = out_dir / f"{group}_{est_label.lower()}_adj_{int(presence_threshold*100)}.png"
    fig.savefig(out_png, dpi=200, bbox_inches='tight')
    plt.close(fig)
    print(f"Saved: {out_png}")
