In [None]:
#PTE code credit to https://github.com/patrk/pyPTE

from typing import Tuple
import numpy as np
import numpy.typing as npt
from scipy.signal import hilbert

def get_delay(phase: npt.NDArray) -> int:
    phase = phase
    m, n = phase.shape
    c1 = n * m
    r_phase = np.roll(phase, 1, axis=1)
    phase_product = np.multiply(phase, r_phase)
    c2 = (phase_product < 0).sum()
    delay = int(np.round(c1 / c2))
    print(f"delay = ",delay)
    return delay

def get_phase(time_series: npt.ArrayLike) -> npt.NDArray:
    complex_series = hilbert(time_series, axis=1)
    phase = np.angle(complex_series)
    return phase

def get_discretized_phase(phase: npt.NDArray, binsize: float) -> npt.NDArray:
    d_phase = np.ceil(phase / binsize).astype(np.int32)
    return d_phase

def get_binsize(phase: npt.NDArray, c: float = 3.49) -> float:
    m, n = phase.shape
    binsize = c * np.mean(np.std(phase, axis=1, ddof=1)) * n ** (-1.0 / 3)
    return binsize

def get_bincount(binsize: float) -> int:
    bins_w = np.arange(0, 2 * np.pi, binsize)
    bincount = len(bins_w)
    return bincount

import numpy as np
from scipy.signal import hilbert
from typing import Tuple
import numpy.typing as npt

def compute_PTE(phase: npt.NDArray, delay: int) -> npt.NDArray:
    m, n = phase.shape
    PTE = np.zeros((m, m), dtype=float)
    eps = 1e-12  # small constant to avoid log(0)

    length = n - delay

    for i in range(m):
        for j in range(m):
            ypr = phase[j, delay:]
            y   = phase[j, : -delay]
            x   = phase[i, : -delay]

            max_val = max(ypr.max(), y.max(), x.max()) + 1

            # Compute distributions using histogram functions
            # P_y (1D)
            P_y = np.bincount(y, minlength=max_val)

            # P_ypr_y (2D)
            # We use histogram2d with range set based on max_val
            P_ypr_y, _, _ = np.histogram2d(ypr, y, bins=(max_val, max_val),
                                           range=[[0, max_val], [0, max_val]])

            # P_y_x (2D)
            P_y_x, _, _ = np.histogram2d(y, x, bins=(max_val, max_val),
                                         range=[[0, max_val], [0, max_val]])

            # P_ypr_y_x (3D)
            # histogramdd for 3D histogram
            P_ypr_y_x, _ = np.histogramdd((ypr, y, x),
                                          bins=(max_val, max_val, max_val),
                                          range=((0, max_val), (0, max_val), (0, max_val)))

            # Normalize probabilities
            P_y       = P_y / length
            P_ypr_y   = P_ypr_y / length
            P_y_x     = P_y_x / length
            P_ypr_y_x = P_ypr_y_x / length

            # Compute entropies
            # Add eps to avoid log2(0)
            def entropy(P):
                P = P[P > 0]
                return -np.sum(P * np.log2(P + eps))

            Hy         = entropy(P_y)
            Hypr_y     = entropy(P_ypr_y)
            Hy_x       = entropy(P_y_x)
            Hypr_y_x   = entropy(P_ypr_y_x)

            # Compute Partial Transfer Entropy
            PTE[i, j] = Hypr_y + Hy_x - Hy - Hypr_y_x

    return PTE


def compute_dPTE_rawPTE(
    phase: npt.NDArray, delay: int
) -> Tuple[npt.NDArray, npt.NDArray]:
    raw_PTE = compute_PTE(phase, delay)
    tmp = np.triu(raw_PTE) + np.tril(raw_PTE).T
    with np.errstate(divide="ignore", invalid="ignore"):
        dPTE = np.triu(raw_PTE / tmp, 1) + np.tril(raw_PTE / tmp.T, -1)
    return dPTE, raw_PTE

def PTE(d_phase, delay):
    return compute_dPTE_rawPTE(d_phase, delay)

In [None]:
import mne
import glob
import numpy as np
import os
from scipy.signal import hilbert
# Gather all file paths
file_paths = glob.glob('/teamspace/studios/this_studio/dataset/derivatives/sub-*/eeg/*.set')

# Function to extract subject ID from path
def get_subject_id(filepath):
    for part in filepath.split(os.sep):
        if 'sub-' in part:
            return int(part.replace('sub-', '').strip())
    return None

# Separate file paths by group
alz_file_paths = []
ctrl_file_paths = []
ftd_file_paths = []

for fpath in file_paths:
    subj_id = get_subject_id(fpath)
    if subj_id is None:
        continue
    if 1 <= subj_id <= 36:
        alz_file_paths.append(fpath)
    elif 37 <= subj_id <= 65:
        ctrl_file_paths.append(fpath)
    elif subj_id >= 66:
        ftd_file_paths.append(fpath)

# Define frequency bands
freq_bands = {
    "delta": (0.5, 4),
    "theta": (4, 8),
    "alpha": (8, 12),
    "beta": (12, 30),
    "gamma": (30, 45)
}

def compute_dPTE_and_raw_PTE_all_subjects(file_paths, win_length=10.0, overlap=0.5):
    all_results = {}

    for band_name, (fmin, fmax) in freq_bands.items():
        all_subject_means_dPTE = []
        all_subject_means_raw_PTE = []
        band_delays = {}
        channel_names = None

        for file_path in file_paths:
            subj_id = get_subject_id(file_path)
            print(f'Processing file: {file_path}, Band: {band_name}')
            raw = mne.io.read_raw_eeglab(file_path, preload=True)
            raw = raw.filter(fmin, fmax)
            fs = raw.info['sfreq']
            n_samples = int(win_length * fs)
            step_size = int(n_samples * (1 - overlap))

            data = raw.get_data()  # shape: (n_channels, n_samples_total)
            phase_matrix = np.angle(hilbert(data, axis=1))
            delay = get_delay(phase_matrix)  # Calculate delay using the entire signal
            band_delays[subj_id] = delay

            # Calculate binsize and discretized phase globally
            binsize = get_binsize(phase_matrix)
            d_phase = get_discretized_phase(phase_matrix + np.pi, binsize)
            
            n_channels, n_samples_total = data.shape
            ind_dPTE = []
            ind_raw_PTE = []

            for start in range(0, n_samples_total - n_samples + 1, step_size):
                end = start + n_samples
                window_phase = d_phase[:, start:end]
                dPTE_vals, raw_PTE_vals = PTE(window_phase, delay)
                ind_dPTE.append(dPTE_vals)
                ind_raw_PTE.append(raw_PTE_vals)

            if channel_names is None:
                channel_names = raw.ch_names

            # Compute mean for this subject across all its windows
            ind_mean_dPTE = np.mean(ind_dPTE, axis=0)       # shape (n_channels, n_channels)
            ind_mean_raw_PTE = np.mean(ind_raw_PTE, axis=0) # shape (n_channels, n_channels)

            all_subject_means_dPTE.append(ind_mean_dPTE)
            all_subject_means_raw_PTE.append(ind_mean_raw_PTE)

        # Convert lists to arrays of shape (n_subjects, n_channels, n_channels)
        if len(all_subject_means_dPTE) > 0:
            all_subject_means_dPTE = np.stack(all_subject_means_dPTE, axis=0)
            all_subject_means_raw_PTE = np.stack(all_subject_means_raw_PTE, axis=0)
        else:
            all_subject_means_dPTE = None
            all_subject_means_raw_PTE = None

        all_results[band_name] = {
            "dPTE": all_subject_means_dPTE,
            "raw_PTE": all_subject_means_raw_PTE,
            "delays": band_delays,
            "channel_names": channel_names
        }

    return all_results


# Compute for each group
dPTE_alz = compute_dPTE_and_raw_PTE_all_subjects(alz_file_paths)
dPTE_ctrl = compute_dPTE_and_raw_PTE_all_subjects(ctrl_file_paths)
dPTE_ftd = compute_dPTE_and_raw_PTE_all_subjects(ftd_file_paths)

# Save results
np.savez('./dPTE_results/dPTE_alz_results.npz', **dPTE_alz)
np.savez('./dPTE_results/dPTE_ctrl_results.npz', **dPTE_ctrl)
np.savez('./dPTE_results/dPTE_ftd_results.npz', **dPTE_ftd)
