In [1]:
import os
import pandas as pd
import numpy as np
import sys
from pathlib import Path
import pickle
import warnings
from tqdm_joblib import tqdm_joblib
from joblib import Parallel, delayed
import pyvista as pv
from scipy.ndimage import gaussian_filter
import tkinter as tk
from scipy.interpolate import griddata
from pyvistaqt import BackgroundPlotter

warnings.filterwarnings("ignore")

# external packages
from mne_connectivity import spectral_connectivity_epochs, phase_slope_index

# project paths and helpers
sys.path.insert(0, './lib')
sys.path.insert(0, './utils/')

import utils_io
from lib_data import DATA_IO

  from tqdm.autonotebook import tqdm


In [2]:
PATH_CURR = os.path.abspath(os.curdir)
PATH      = (str(Path(PATH_CURR).parent))
fs        = 2048

### 1. Load Events

In [None]:
EVENTS_ECOG      = utils_io.load_ECoG_events(event_category="tapping", fs=fs)
EVENTS_ECOG_LID  = EVENTS_ECOG['controlateral']['LID']
EVENTS_ECOG_noLID= EVENTS_ECOG['controlateral']['noLID']

EVENTS_LFP             = utils_io.load_LFP_events(event_category="tapping", stn_areas=["motor"], fs=fs)
EVENTS_LFP_MOTOR_LID   = EVENTS_LFP['motor']['controlateral']['LID']
EVENTS_LFP_MOTOR_noLID = EVENTS_LFP['motor']['controlateral']['noLID']

### 2. Define connectivity computation

In [None]:
def surrogate_iteration(seed, data, sfreq, fmin, fmax):
    rng = np.random.default_rng(seed)
    shuffled = data.copy()
    rng.shuffle(shuffled[:, 1, :], axis=0)

    conn_surr = spectral_connectivity_epochs(
        shuffled,
        method="wpli2_debiased",
        mode="multitaper",
        sfreq=sfreq,
        fmin=fmin, fmax=fmax,
        faverage=True,
        mt_adaptive=True,
        mt_low_bias=True,
        verbose=False,
    )
    return conn_surr.get_data()[2,:]

def compute_wpli_psi(data, sfreq=2048, n_perm=1000, n_jobs=-1):
    """
    Compute debiased wPLI and PSI for LFP–ECoG pairs across trials.
    lfp_trials, ecog_trials: arrays (n_trials, n_times)
    """

    # Calculate and average within the frequency bands
    # bands = {"theta"     : np.where((freq >= 4) & (freq <= 8))[0],
    #         "alpha"     : np.where((freq >= 8) & (freq <= 12))[0],
    #         "beta_low"  : np.where((freq >= 12) & (freq <= 20))[0],
    #         "beta_high" : np.where((freq >= 20) & (freq <= 35))[0],
    #         "gamma"     : np.where((freq >= 60) & (freq <= 90))[0],
    #         "gamma_III" : np.where((freq >= 80) & (freq <= 90))[0]}
    fmin=[4, 8, 12, 20, 60, 80],    # theta, alpha, beta_low, beta_high, gamma, gamma_III
    fmax=[8, 12, 20, 35, 90, 90],   # alpha ends at 12, gamma ends at 90

    conn_PLI = spectral_connectivity_epochs(
        data,
        method=['wpli2_debiased'],
        mode='multitaper',
        sfreq=sfreq,
        fmin=fmin, fmax=fmax,
        faverage=True,   # average inside each band
        mt_adaptive=True,
        mt_low_bias=True,
        verbose=False,
    )
    #.get_data()[:, 2]
    
    freqs = conn_PLI.freqs
    PLI_observed = conn_PLI.get_data()[2,:]  # (n_freqs,)

    # Surrogate iterations (parallelized)
    with tqdm_joblib(desc="Running surrogate iterations", total=n_perm) as progress_bar:
        results = Parallel(n_jobs=n_jobs)(
            delayed(surrogate_iteration)(seed, data, sfreq, fmin, fmax)
            for seed in range(n_perm)
        )

    surrogates = np.vstack(results)

    # --- 3. Empirical p-values ---
    # two-sided test
    p_values = (np.sum(np.abs(surrogates) >= np.abs(PLI_observed), axis=0) + 1) / (n_perm + 1)
    
    
    # Phase slope index computed separately
    conn_PSI = phase_slope_index(
        data,
        sfreq=sfreq,
        mode='multitaper',
        fmin=fmin, fmax=fmax,
        mt_adaptive=True,
        mt_low_bias=True,
        verbose=False,
    )

    results = dict(freqs=np.vstack([fmin,fmax]), wpli=conn_PLI.get_data()[2,:], p_values=p_values, psi=conn_PSI.get_data()[2,:])
    return results


### 3. Pairwise wPLI & PSI

#### 3.1 Pre-Event Calculation

In [None]:
preEventResults = []

for (patient, lfp_ch), group_lfp in EVENTS_LFP_MOTOR_LID.groupby(["patient", "LFP_channel"]):

    # if patient=='008' or patient=='009':
    #     continue
        
    # loop through each ECoG channel for this patient
    group_ecog_patient = EVENTS_ECOG_LID[EVENTS_ECOG_LID["patient"] == patient]
    for ecog_ch, group_ecog in group_ecog_patient.groupby("ECoG_channel"):
        # stack all tapping events for this patient + LFP channel
        # Make sure to only use events with the same laterality (i.e., ipsilateral) that occur in both event lists
        group_lfp = group_lfp[group_lfp["event_no"].isin(group_ecog["event_no"])]
        group_ecog = group_ecog[group_ecog["event_no"].isin(group_lfp["event_no"])]

        # Skip if either group is empty after filtering
        if group_lfp.empty or group_ecog.empty:
            print(f"Skipping patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}: no matching events.")
            continue
        
        lfp_events = np.stack(group_lfp["pre_event_recording"].to_numpy())   # shape (n_events, n_times)
        ecog_events = np.stack(group_ecog["pre_event_recording"].to_numpy()) # shape (n_events, n_times)
        
        # ensure same number of epochs
        assert(lfp_events.shape[0]==ecog_events.shape[0], f"Number of epochs do not match for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}")

        # build data array (n_epochs, n_signals=2, n_times)
        data = np.stack([lfp_events, ecog_events], axis=1)
        
        print(f"Computing wPLI and PSI pre events for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}, no tapping events={data.shape[0]}")
        out = compute_wpli_psi(data, sfreq=fs, n_perm=1000)

        preEventResults.append(dict(
            patient=patient,
            event_no=group_lfp["event_no"].to_numpy(),
            LFP_channel=lfp_ch,
            ECoG_channel=ecog_ch,
            freqs=out['freqs'],
            wpli=out['wpli'],
            p_values=out['p_values'],
            psi=out['psi'],
        ))

# Save results
with open(DATA_IO.path_events + "coherence/STN_MOTOR_vs_Cortex_wPLI_LID.pkl", "wb") as f:
    pickle.dump(preEventResults, f)

In [3]:
# Helper functions 
def select_band_from_list(options):
    root = tk.Tk()
    root.title("Choose Frequency Band")
    
    var = tk.StringVar()

    # create listbox and insert items manually
    listbox = tk.Listbox(root, height=len(options))
    for opt in options:
        listbox.insert(tk.END, opt)
    listbox.pack(padx=20, pady=10)

    def confirm(event=None):
        selection = listbox.get(listbox.curselection())
        var.set(selection)
        root.quit()
    
    listbox.bind("<Double-1>", confirm)
    btn = tk.Button(root, text="Confirm", command=confirm)
    btn.pack(pady=5)

    root.mainloop()
    choice = var.get()
    root.destroy()
    return choice

In [None]:
# Step 1: Aggregate wPLI results per channel pair

df_list = pd.read_pickle(DATA_IO.path_events + "coherence/STN_MOTOR_vs_Cortex_wPLI_LID.pkl")
df = pd.DataFrame(df_list)

# Calculate and average within the frequency bands
bands = {
    "theta"     : (4, 8),
    "alpha"     : (8, 12),
    "beta_low"  : (12, 20),
    "beta_high" : (20, 35),
    "gamma"     : (60, 90),
    "gamma_III" : (80, 90)
}

band_results = []
for (patient, ecog_ch), group in df.groupby(["patient", "ECoG_channel"]):
    freqs = group.iloc[0]["freqs"]  # all rows have the same freqs
    wplis = np.stack(group["wpli"].to_numpy())  # shape (n_LFP, n_freqs)
    mean_wpli = wplis.mean(axis=0)
    band_results.append(dict(patient=patient, ECoG_channel=ecog_ch, theta=mean_wpli[0], alpha=mean_wpli[1], beta_low=mean_wpli[2], beta_high=mean_wpli[3], gamma=mean_wpli[4], gamma_III=mean_wpli[5]))

band_df = pd.DataFrame(band_results)

# --- Step 2: Load ECoG channel coordinates ---
MNI_coordinates      = pd.read_csv(DATA_IO.path_coordinates + "contact_coordinates.csv")
MNI_ECoG_coordinates = MNI_coordinates[MNI_coordinates.recording_type == "ecog"]


# insert x, y, z coordinates right after ECoG_channel
col_idx = band_df.columns.get_loc("ECoG_channel")
band_df.insert(col_idx + 1, "x", np.nan)
band_df.insert(col_idx + 2, "y", np.nan)
band_df.insert(col_idx + 3, "z", np.nan)

# Merge band_df with coordinates
for idx, row in band_df.iterrows():
    patient = row["patient"]
    bipolar = row["ECoG_channel"]

    # split the bipolar reference into two ints
    try:
        c1, c2 = bipolar.split("-")
        c1, c2 = int(c1), int(c2)
    except Exception:
        print(f"Skipping {bipolar} (invalid format)")
        continue

    # select coordinates for both contacts from this patient
    coords_patient = MNI_ECoG_coordinates[MNI_ECoG_coordinates["patient"] == int(patient)]
    c1_coords = coords_patient[coords_patient["contact"] == c1][["x", "y", "z"]].values
    c2_coords = coords_patient[coords_patient["contact"] == c2][["x", "y", "z"]].values

    if len(c1_coords) == 0 or len(c2_coords) == 0:
        print(f"Coordinates missing for patient {patient}, contacts {c1}-{c2}")
        continue

    # compute mean (Euclidean midpoint)
    mean_coords = np.mean(np.vstack([c1_coords, c2_coords]), axis=0)

    # assign to band_df
    band_df.at[idx, "x"] = mean_coords[0]
    band_df.at[idx, "y"] = mean_coords[1]
    band_df.at[idx, "z"] = mean_coords[2]

# Pick which band to plot   
# selected_band = select_band_from_list(list(bands.keys()))
selected_band = 'beta_high'
coords = band_df[["x", "y", "z"]].to_numpy()
values = band_df[selected_band].to_numpy()

# --- Step 3: Prepare PyVista cortical plot ---
cortex_mesh  = utils_io.load_cortical_atlas_meshes()
cortex_right = cortex_mesh["right_hemisphere"]
cortex_left  = cortex_mesh["left_hemisphere"]

# Merge both hemispheres into one for interpolation
mesh_combined = cortex_right.merge(cortex_left)
mesh_vertices = mesh_combined.points

# Interpolate electrode values onto mesh vertices
values_interp = griddata(
    points=coords, 
    values=values, 
    xi=mesh_vertices, 
    method="linear", 
    fill_value=np.nan
)

# Smooth the interpolated map
values_interp = gaussian_filter(np.nan_to_num(values_interp), sigma=200)

# --- Step 4: Add interpolated values as scalars to mesh ---
mesh_combined["connectivity"] = values_interp

# --- Step 5: Plot ---
plotter = BackgroundPlotter()
plotter.add_mesh(
    mesh_combined, 
    scalars="connectivity", 
    cmap="viridis", 
    opacity=1.0, 
    smooth_shading=True, 
    scalar_bar_args={"title": f"{selected_band} wPLI"}
)

# Overlay electrodes as spheres
for idx, row in band_df[["x", "y", "z"]].iterrows():
    sphere = pv.Sphere(radius=1.5, center=[row.x, row.y, row.z])
    plotter.add_mesh(sphere, color="red", smooth_shading=True)

plotter.background_color = "white"
plotter.view_vector((1, 0, 0))   # lateral view
plotter.add_light(pv.Light(position=(0, 0, 1), color="white", intensity=0.6))
plotter.show()


#### 3.2 During Event Calculation

In [None]:
EventResults = []

for (patient, lfp_ch), group_lfp in EVENTS_LFP_MOTOR_LID.groupby(["patient", "LFP_channel"]):

    # loop through each ECoG channel for this patient
    group_ecog_patient = EVENTS_ECOG_LID[EVENTS_ECOG_LID["patient"] == patient]
    for ecog_ch, group_ecog in group_ecog_patient.groupby("ECoG_channel"):
        # stack all tapping events for this patient + LFP channel
        # Make sure to only use events with the same laterality (i.e., ipsilateral) that occur in both event lists
        group_lfp = group_lfp[group_lfp["event_no"].isin(group_ecog["event_no"])]
        group_ecog = group_ecog[group_ecog["event_no"].isin(group_lfp["event_no"])]
        
        # Skip if either group is empty after filtering
        if group_lfp.empty or group_ecog.empty:
            print(f"Skipping patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}: no matching events.")
            continue
        

        # Pad each array with NaN up to max_len
        arrays = group_lfp["event_recording"].to_numpy()
        arrays = [np.array(a) for a in arrays]        
        max_len = max(len(a) for a in arrays) # Find the maximum length
        padded = [np.pad(a, (0, max_len - len(a)), constant_values=np.nan) for a in arrays]
        lfp_events = np.vstack(padded) # shape (n_events, n_times)

        arrays = group_ecog["event_recording"].to_numpy()
        arrays = [np.array(a) for a in arrays]        
        max_len = max(len(a) for a in arrays) # Find the maximum length
        padded = [np.pad(a, (0, max_len - len(a)), constant_values=np.nan) for a in arrays]
        ecog_events = np.vstack(padded) # shape (n_events, n_times)
        
        # ensure same number of epochs
        assert(lfp_events.shape[0]==ecog_events.shape[0], f"Number of epochs do not match for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}")

        # build data array (n_epochs, n_signals=2, n_times)
        data = np.stack([lfp_events, ecog_events], axis=1)
        
        print(f"Computing wPLI and PSI during events for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}, no tapping events={data.shape[0]}")
        out = compute_wpli_psi(data, sfreq=fs, n_perm=1000)

        EventResults.append(dict(
            patient=patient,
            LFP_channel=lfp_ch,
            ECoG_channel=ecog_ch,
            freqs=out['freqs'],
            wpli=out['wpli'],
            p_values=out['p_values'],
            psi=out['psi'],
        ))

# Save results
with open(DATA_IO.path_events + "coherence/STN_MOTOR_vs_Cortex_wPLI.pkl", "wb") as f:
    pickle.dump(preEventResults, f)

#### 3.3 Post-Event Calculation

In [None]:
postEventResults = []

for (patient, lfp_ch), group_lfp in EVENTS_LFP_MOTOR_LID.groupby(["patient", "LFP_channel"]):

    # loop through each ECoG channel for this patient
    group_ecog_patient = EVENTS_ECOG_LID[EVENTS_ECOG_LID["patient"] == patient]
    for ecog_ch, group_ecog in group_ecog_patient.groupby("ECoG_channel"):
        # stack all tapping events for this patient + LFP channel
        # Make sure to only use events with the same laterality (i.e., ipsilateral) that occur in both event lists
        group_lfp = group_lfp[group_lfp["event_no"].isin(group_ecog["event_no"])]
        group_ecog = group_ecog[group_ecog["event_no"].isin(group_lfp["event_no"])]
        
        # Skip if either group is empty after filtering
        if group_lfp.empty or group_ecog.empty:
            print(f"Skipping patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}: no matching events.")
            continue
        
        lfp_events = np.stack(group_lfp["post_event_recording"].to_numpy())   # shape (n_events, n_times)
        ecog_events = np.stack(group_ecog["post_event_recording"].to_numpy()) # shape (n_events, n_times)
        
        # ensure same number of epochs
        assert(lfp_events.shape[0]==ecog_events.shape[0], f"Number of epochs do not match for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}")

        # build data array (n_epochs, n_signals=2, n_times)
        data = np.stack([lfp_events, ecog_events], axis=1)
        
        print(f"Computing wPLI and PSI post events for patient {patient}, LFP {lfp_ch}, ECoG {ecog_ch}, no tapping events={data.shape[0]}")
        out = compute_wpli_psi(data, sfreq=fs, n_perm=1000)

        postEventResults.append(dict(
            patient=patient,
            LFP_channel=lfp_ch,
            ECoG_channel=ecog_ch,
            freqs=out['freqs'],
            wpli=out['wpli'],
            p_values=out['p_values'],
            psi=out['psi'],
        ))

In [None]:
# Print unique patient names
test = EVENTS_LFP_MOTOR_LID["patient"].unique()
print(test)

### 4. Save

In [None]:
results_df = pd.DataFrame(preEventResults)
results_df.to_pickle(DATA_IO.path_events + "coherence/LFP_ECoG_wPLI_PSI.pkl")

print("Saved:", DATA_IO.path_events + "coherence/LFP_ECoG_wPLI_PSI.pkl")