In [None]:
#!/usr/bin/env python3
"""
Pipeline:   EEG → 9‑level DWT → per‑band windows → PTE/dPTE (Numba‑JIT, parallel)
Computes per-subject mean dPTE and raw PTE across windows, returning the structure:
  all_results[band] = { 'dPTE': ..., 'raw_PTE': ..., 'delays': ..., 'channel_names': ... }
and saves NPZ files: dPTE_alz_results.npz, dPTE_ctrl_results.npz, dPTE_ftd_results.npz in OUT_DIR.
"""

import glob
import os
import numpy as np
import mne
import pywt
from scipy.signal import hilbert
from joblib import Parallel, delayed
import numba as nb
import numpy.typing as npt

# ───────────── CONFIGURATION ─────────────────────────────────────────────────
MAX_LVL   = 8               # DWT levels
WAVELET   = 'db4'
band2levels = {
    'delta': [1, 2, 3],   # D8–D6 → 0.5–4 Hz
    'theta': [4],         # D5 → 4–8 Hz
    'alpha': [5],         # D4 → 8–16 Hz
    'beta':  [6],         # D3 → 16–32 Hz
    'gamma': [7]          # D2 → 32–64 Hz
}

WIN_LEN   = 10            # seconds
OVERLAP   = 0.5           # 50% overlap
N_JOBS    = 64
DATA_DIR  = "/home/s.dharia-ra/Shyamal/EEG_Phase_Project/major-revisions/dataset"
OUT_DIR   = "./dPTE_results"

# ───────────── SUBJECT GROUPING ───────────────────────────────────────────────
def get_subject_id(fp: str) -> int:
    for p in fp.split(os.sep):
        if p.startswith('sub-'):
            return int(p.replace('sub-',''))
    return None

all_paths = glob.glob(f"{DATA_DIR}/sub-*/eeg/*.set")
groups = {'alz':[], 'ctrl':[], 'ftd':[]}
for fp in all_paths:
    sid = get_subject_id(fp)
    if sid is None: continue
    if sid <=36: groups['alz'].append(fp)
    elif sid <=65: groups['ctrl'].append(fp)
    else: groups['ftd'].append(fp)

# ───────────── NUMBA-JIT PTE/dPTE ─────────────────────────────────────────────
@nb.njit(fastmath=True, cache=True)
def _entropy(counts, length):
    H=0.0
    for c in counts:
        if c:
            p=c/length
            H-=p*np.log2(p)
    return H

@nb.njit(fastmath=True, cache=True)
def compute_PTE_numba(phase, delay):
    m,n = phase.shape
    raw = np.zeros((m,m), np.float64)
    L = n - delay
    for i in range(m):
        x = phase[i,:L]
        for j in range(m):
            y   = phase[j,:L]
            ypr = phase[j,delay:]
            vmax = int(max(ypr.max(),y.max(),x.max())+1)
            cnt_y     = np.bincount(y, minlength=vmax)
            idx_ypr_y = ypr + vmax*y
            cnt_ypr_y = np.bincount(idx_ypr_y, minlength=vmax*vmax)
            idx_y_x   = y + vmax*x
            cnt_y_x   = np.bincount(idx_y_x, minlength=vmax*vmax)
            idx_3d    = ypr + vmax*(y + vmax*x)
            cnt_3d    = np.bincount(idx_3d, minlength=vmax*vmax*vmax)
            Hy     = _entropy(cnt_y, L)
            Hypr   = _entropy(cnt_ypr_y, L)
            Hyx    = _entropy(cnt_y_x, L)
            Hyprx  = _entropy(cnt_3d, L)
            raw[i,j] = Hypr + Hyx - Hy - Hyprx
    return raw

@nb.njit(fastmath=True, cache=True)
def dPTE_from_raw(raw):
    sym = raw + raw.T
    dp  = np.triu(raw/sym,1) + np.tril((raw/sym).T,-1)
    return dp

# ───────────── UTILS ─────────────────────────────────────────────────────────
def reconstruct_band_dwt(data: np.ndarray, levels: list[int]) -> np.ndarray:
    coeffs = pywt.wavedec(data, WAVELET, axis=1, level=MAX_LVL)
    kept = [np.zeros_like(c) for c in coeffs]
    for lv in levels:
        kept[lv] = coeffs[lv]
    return pywt.waverec(kept, WAVELET, axis=1)

def get_delay(phase: npt.NDArray) -> int:
    m,n = phase.shape
    c1 = m*n
    c2 = (phase * np.roll(phase,1,axis=1) < 0).sum()
    return int(round(c1/c2))


def get_binsize(phase: npt.NDArray, c: float=3.49) -> float:
    m,n = phase.shape
    return c * np.mean(np.std(phase,axis=1,ddof=1)) * n**(-1/3)

def discretize_phase(phase: npt.NDArray, binsize: float) -> npt.NDArray:
    return np.ceil(phase/binsize).astype(np.int32)

# ───────────── PROCESS ONE SUBJECT & BAND ─────────────────────────────────────
def process_one(fp: str, levels: list[int]):
    raw = mne.io.read_raw_eeglab(fp, preload=True, verbose='ERROR')
    raw.resample(256)
    fs = raw.info['sfreq']
    data = reconstruct_band_dwt(raw.get_data(), levels)
    phase = np.angle(hilbert(data, axis=1))
    delay = get_delay(phase)
    binsz = get_binsize(phase)
    dph = discretize_phase(phase+np.pi, binsz)
    win = int(WIN_LEN * fs)
    step= int(win*(1-OVERLAP))
    dp_wins, raw_wins = [], []
    for st in range(0, dph.shape[1]-win+1, step):
        blk = dph[:, st:st+win]
        rawP = compute_PTE_numba(blk, delay)
        dp   = dPTE_from_raw(rawP)
        raw_wins.append(rawP); dp_wins.append(dp)
    sid = get_subject_id(fp)
    return sid, np.mean(dp_wins,axis=0), np.mean(raw_wins,axis=0), delay, raw.ch_names

# ───────────── COMPUTE ALL SUBJECTS ───────────────────────────────────────────
def compute_dPTE_and_raw_PTE_all_subjects(paths: list[str]) -> dict:
    all_results = {}
    for band, levels in band2levels.items():
        print(f"Processing band: {band} ({len(paths)} subs)")
        out = Parallel(n_jobs=N_JOBS)(delayed(process_one)(fp, levels) for fp in paths)
        if out:
            sids, dps, raws, delays, chs = zip(*out)
            all_results[band] = {
                'dPTE': np.stack(dps,axis=0),
                'raw_PTE': np.stack(raws,axis=0),
                'delays': dict(zip(sids,delays)),
                'channel_names': chs[0]
            }
        else:
            all_results[band] = {'dPTE':None,'raw_PTE':None,'delays':{},'channel_names':None}
    return all_results

# ───────────── MAIN & SAVE ──────────────────────────────────────────────────
if __name__ == '__main__':
    os.makedirs(OUT_DIR, exist_ok=True)
    for grp, paths in groups.items():
        results = compute_dPTE_and_raw_PTE_all_subjects(paths)
        np.savez(f"{OUT_DIR}/dPTE_{grp}_results.npz", **results)
        print(f"Saved: dPTE_{grp}_results.npz")
