# Extract adjacency matrices per session (Gaussian)
From each `*_combined_gauss.pkl` (one per subject/session), extract adjacency in three forms:
- **binary** (edge present/absent),
- **max_p_lag** (lag at max p),
- **max_te_lag** (lag at max TE).  
Save each as `.npy` next to the combined `.pkl`. Toggle FDR if desired.


In [None]:
import pickle
from pathlib import Path
import numpy as np

# --- Paths ---
gauss_dir = Path('/lustre/majlepy2/myproject/Results/gauss_results')

# Session-level combined results (.pkl)
combined_files = sorted(gauss_dir.glob("*_combined_gauss.pkl"))
print(f"Found {len(combined_files)} combined session files.")

# Toggle FDR if needed for adjacency extraction (False = raw IDTxl results)
USE_FDR = False

# --- Extract adjacency matrices from each combined session ---
for file in combined_files:
    with open(file, 'rb') as f:
        res = pickle.load(f)

    # Extract adjacency matrices in different representations
    adj_binary     = np.array(res.get_adjacency_matrix('binary',     fdr=USE_FDR)).astype(np.uint8)
    adj_max_p_lag  = np.array(res.get_adjacency_matrix('max_p_lag',  fdr=USE_FDR)).astype(float)
    adj_max_te_lag = np.array(res.get_adjacency_matrix('max_te_lag', fdr=USE_FDR)).astype(float)

    # Save each adjacency matrix as .npy next to the .pkl
    base = file.with_suffix('')  # drop .pkl
    np.save(base.with_name(base.name + '_binary.npy'),     adj_binary)
    np.save(base.with_name(base.name + '_max_p_lag.npy'),  adj_max_p_lag)
    np.save(base.with_name(base.name + '_max_te_lag.npy'), adj_max_te_lag)

    print(f"Saved matrices for {file.name}")

print("All session adjacency matrices extracted and saved as .npy (binary, max_p_lag, max_te_lag).")


# Build group-level summaries (Gaussian)
Use `subject_session_metadata.csv` to map sessions to groups.  
For each group:
- Stack **binary** matrices → fraction of sessions with edge (**edge_presence**).
- Stack **max_p_lag** matrices → **mode lag** where the edge exists.  
Save per-group arrays.


In [None]:
import numpy as np
import pandas as pd
from pathlib import Path
from collections import Counter

# --- Paths ---
gauss_dir = Path('/lustre/majlepy2/myproject/Results/gauss_results')
meta = pd.read_csv('/lustre/majlepy2/myproject/subject_session_metadata.csv')
meta['sub_ses'] = meta['subject'] + '_' + meta['session']

# Build lookup: sub_ses -> group
group_dict = dict(zip(meta['sub_ses'], meta['group']))

# Find all per-session binary and lag matrices (Gaussian)
binary_files = sorted(gauss_dir.glob("*_gauss_binary.npy"))
lag_files = sorted(gauss_dir.glob("*_gauss_max_p_lag.npy"))
assert len(binary_files) == len(lag_files)

group_to_binary = {}
group_to_lags = {}

# --- Assign each session's matrices to its group ---
for bin_file, lag_file in zip(binary_files, lag_files):
    base = bin_file.name.replace('_gauss_binary.npy', '')
    # Remove trailing '_combined' if present to match metadata key format
    if base.endswith('_combined'):
        base = base[:-9]
    sub_ses = base
    group = group_dict.get(sub_ses)
    if group is None:
        print(f"WARNING: {sub_ses} not found in metadata.")
        continue

    # Load matrices
    adj_binary = np.load(bin_file)
    adj_lag = np.load(lag_file)

    # Append to group collections
    group_to_binary.setdefault(group, []).append(adj_binary)
    group_to_lags.setdefault(group, []).append(adj_lag)

print("Session matrices loaded and grouped by experimental group.")

# --- Compute group-level summaries ---
results = {}
for group in group_to_binary:
    binaries = np.stack(group_to_binary[group], axis=0)  # shape: (n_sessions, n_nodes, n_nodes)
    lags = np.stack(group_to_lags[group], axis=0)        # shape: (n_sessions, n_nodes, n_nodes)
    n_sessions = binaries.shape[0]

    # Edge presence = fraction of sessions with edge present
    edge_presence = binaries.mean(axis=0)

    # Lag summary = most common lag per edge (NaN if edge absent across sessions)
    lag_summary = np.full(binaries.shape[1:], np.nan)
    for i in range(binaries.shape[1]):
        for j in range(binaries.shape[2]):
            lags_present = lags[:, i, j][binaries[:, i, j] == 1]
            if len(lags_present) > 0:
                most_common = Counter(lags_present).most_common(1)[0][0]
                lag_summary[i, j] = most_common

    results[group] = {
        "edge_presence": edge_presence,
        "lag_summary": lag_summary,
        "n_sessions": n_sessions,
    }
    print(f"Gaussian {group}: {n_sessions} sessions, edge_presence matrix shape: {edge_presence.shape}")

# --- Save per-group summaries ---
for group, group_results in results.items():
    np.save(gauss_dir / f"{group}_gauss_edge_presence.npy", group_results["edge_presence"])
    np.save(gauss_dir / f"{group}_gauss_lag_summary.npy", group_results["lag_summary"])
    print(f"Saved Gaussian summary arrays for {group}")

print("Gaussian group-level aggregation complete.")


# Visualize group-level heatmaps (Gaussian)
For each group:
- **Heatmap 1**: edge presence fraction (0–1).  
- **Heatmap 2**: most common lag (masked if presence < threshold).  
Figures are saved as PNGs.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# --- Paths ---
gauss_dir = Path('/lustre/majlepy2/myproject/Results/gauss_results')
out_dir = Path('/home/majlepy2/myproject')
out_dir.mkdir(parents=True, exist_ok=True)

# --- Config ---
groups = ['healthy', 'PD-on', 'PD-off']
est_label = "Gaussian"
presence_threshold = 0.70  # visualization threshold

# --- Generate heatmaps per group ---
for group in groups:
    print(f"\n=== {group} ===")
    edge_file = gauss_dir / f"{group}_edge_presence.npy"
    lag_file  = gauss_dir / f"{group}_lag_summary.npy"

    if not edge_file.exists():
        print(f"Missing: {edge_file}")
        continue
    if not lag_file.exists():
        print(f"Missing: {lag_file}")
        continue

    edge_presence = np.load(edge_file)
    lag_summary   = np.load(lag_file)

    # Mask lag values where edge presence < threshold
    lag_display = np.array(lag_summary, dtype=float)
    lag_display[edge_presence < presence_threshold] = np.nan

    # --- Create figure with 2 panels ---
    fig, axs = plt.subplots(1, 2, figsize=(12, 5))

    # Panel 1 — Edge presence fraction
    sns.heatmap(
        edge_presence, ax=axs[0], cmap="Blues", vmin=0, vmax=1,
        cbar_kws={'label': 'Fraction of sessions (edge present)'}
    )
    axs[0].set_title(f"{group} – {est_label}: Edge Presence Fraction")
    axs[0].set_xlabel("Target Node")
    axs[0].set_ylabel("Source Node")

    # Panel 2 — Most common lag (only if presence ≥ threshold)
    sns.heatmap(
        lag_display, ax=axs[1], cmap="viridis",
        cbar_kws={'label': f"Most Common Lag (shown where presence ≥ {int(presence_threshold*100)}%)"}
    )
    axs[1].set_title(f"{group} – {est_label}: Most Common Lag (edges with ≥{int(presence_threshold*100)}% presence)")
    axs[1].set_xlabel("Target Node")
    axs[1].set_ylabel("Source Node")

    fig.tight_layout()

    # Save to PNG
    out_png = out_dir / f"{group.replace(' ', '_')}_{est_label.lower()}_adj_{int(presence_threshold*100)}.png"
    fig.savefig(out_png, dpi=200, bbox_inches='tight')
    plt.show()  # or close if running headless
    # plt.close(fig)
    print(f"Saved: {out_png}")
