In [107]:
import numpy as np
import scipy.io as sio
from scipy.signal import resample
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from pathlib import Path
from scipy import signal

import os
from gammatone.filters import centre_freqs, make_erb_filters, erb_filterbank

PREPROC_DIR = "/home/naren-root/Documents/FYP/AAD/Notebooks/Dataset/DATA_preproc/"
FIG_DIR = Path("figures"); FIG_DIR.mkdir(parents=True, exist_ok=True)



In [48]:
def _unwrap_1d(x):
    x = np.asarray(x)
    if x.dtype == object:
        for el in x.flat:
            arr = np.asarray(el)
            if arr.size:
                x = arr
                break
        else:
            return np.array([], dtype=float)
    x = np.array(x, dtype=float).squeeze()
    if x.ndim > 1:
        x = x.mean(axis=0)
    return x

def _get_fs(mat):
    return float(np.array(mat["data"].fsample.eeg))  # 64 Hz in your files

def _get_trial_count(mat):
    return int(np.asarray(mat["data"].eeg).shape[0])

def _get_trial_eeg(mat, i):
    arr = np.asarray(mat["data"].eeg)[i]   # (time, n_chan)
    arr = np.array(arr, dtype=float)
    assert arr.ndim == 2
    return arr.T  # -> (n_chan, time)

def _get_trial_labels(mat, i):
    labels = np.asarray(mat["data"].dim.chan.eeg)[i]  # (n_chan,) object
    return [str(x) for x in labels.tolist()]

def _get_trial_env(mat, i, which):
    if not hasattr(mat["data"], which):
        return None
    raw = np.asarray(getattr(mat["data"], which))[i]
    return _unwrap_1d(raw)

def _get_trial_att_events(mat, i):
    """Return (samples, values) from data.event.eeg[i], if present.
       values expected to be 1 or 2 (attended stream id)."""
    ev_all = getattr(mat["data"].event, "eeg", None)
    if ev_all is None:
        return None, None
    ev_i = np.asarray(ev_all)[i]
    if not hasattr(ev_i, "_fieldnames"):
        return None, None
    samples = getattr(ev_i, "sample", None)
    values  = getattr(ev_i, "value", None)
    if samples is None or values is None:
        return None, None
    samples = np.atleast_1d(np.array(samples, dtype=int).squeeze())
    values  = np.atleast_1d(np.array(values, dtype=int).squeeze())
    if samples.size != values.size:
        m = min(samples.size, values.size)
        samples, values = samples[:m], values[:m]
    return samples, values

def _build_att_mask_trial(n_t, samples, values):
    """
    Per-sample attention id for a single trial:
      0 = unknown, 1 = wavA attended, 2 = wavB attended.
    MATLAB 'samples' are 1-based; convert to 0-based.
    """
    att = np.zeros(n_t, dtype=np.uint8)
    if samples is None or values is None or samples.size == 0:
        return att
    samp0 = np.clip(samples - 1, 0, n_t - 1)
    order = np.argsort(samp0)
    samp0, values = samp0[order], values[order]
    for k, s in enumerate(samp0):
        v = values[k] if values[k] in (1, 2) else 0
        e = samp0[k + 1] if k + 1 < samp0.size else n_t
        att[s:e] = v
    return att

In [102]:
def load_subject(preproc_dir, subj_id, drop_exg=True):
    preproc_dir = Path(preproc_dir)
    candidates = [
        preproc_dir / f"S{subj_id}_data_preproc.mat",
        preproc_dir / f"S{subj_id:02d}_data_preproc.mat",
    ]
    for p in candidates:
        if p.exists():
            mat_path = p
            break
    else:
        raise FileNotFoundError(f"No file for subject {subj_id} in {preproc_dir}")

    mat = sio.loadmat(mat_path, squeeze_me=True, struct_as_record=False)
    fs = _get_fs(mat)
    n_trials = _get_trial_count(mat)

    eeg_trials, labels_trials = [], []
    envA_trials, envB_trials = [], []
    attmask_trials, lens = [], []

    for i in range(n_trials):
        eeg_i = _get_trial_eeg(mat, i)               # (n_ch, n_t)
        labs_i = _get_trial_labels(mat, i)
        A_i = _get_trial_env(mat, i, "wavA")
        B_i = _get_trial_env(mat, i, "wavB")
        n_t = eeg_i.shape[1]

        if A_i is not None and A_i.size and A_i.size != n_t:
            A_i = resample(A_i, n_t)
        if B_i is not None and B_i.size and B_i.size != n_t:
            B_i = resample(B_i, n_t)

        samples, values = _get_trial_att_events(mat, i)
        att_i = _build_att_mask_trial(n_t, samples, values)

        eeg_trials.append(eeg_i)
        labels_trials.append(labs_i)
        envA_trials.append(A_i if A_i is not None and A_i.size else None)
        envB_trials.append(B_i if B_i is not None and B_i.size else None)
        attmask_trials.append(att_i)
        lens.append(n_t)

    ch_names = labels_trials[0]
    if drop_exg:
        keep = [not nm.upper().startswith("EXG") for nm in ch_names]
        ch_names = [nm for nm, k in zip(ch_names, keep) if k]
        eeg_trials = [x[keep, :] for x in eeg_trials]

    eeg = np.concatenate(eeg_trials, axis=1)
    T = eeg.shape[1]

    def _concat_env_fixed(env_list, lens):
        parts = []
        for env, n_t in zip(env_list, lens):
            if env is None or env.size == 0:
                parts.append(np.zeros(n_t, dtype=float))
            else:
                parts.append(env)
        return np.concatenate(parts) if parts else None

    envA = _concat_env_fixed(envA_trials, lens)
    envB = _concat_env_fixed(envB_trials, lens)
    attmask = np.concatenate(attmask_trials) if attmask_trials else np.zeros(T, dtype=np.uint8)

    return dict(
        fs=fs,
        ch_names=ch_names,
        eeg=eeg,
        envA=envA,            
        envB=envB,            # Audio 2 (wavB)
        attmask=attmask,      # 0/1/2 per-sample
        lengths=lens,
        subj_id=subj_id,
        has_two_streams=(envA is not None and envB is not None and (envA.any() or envB.any())),
        path=str(mat_path),
    )


In [103]:
def _lp_filter(y, fs, cutoff_hz, order = 4):
    if cutoff_hz is None or cutoff_hz <= 0:
        return y
    nyq = fs / 2.0
    cutoff = min(cutoff_hz, nyq * 0.99)
    b, a = signal.butter(order, cutoff / nyq)
    return signal.filtfilt(b, a, y)

def _resample_to(y, fs_in, fs_out):
    if fs_out is None or fs_out == fs_in:
        return y, fs_in
    # robust poly-phase resampling
    from fractions import Fraction
    frac = Fraction(fs_out, fs_in).limit_denominator(1000)
    up, down = frac.numerator, frac.denominator
    y2 = signal.resample_poly(y, up, down)
    return y2, fs_out


In [104]:

def gammatone_hilbert_envelope(audio, fs, *, num_bands = 32, fmin = 50.0, fmax=None, compress = "pow", compress_exp = 0.6, aggregate = "sum", lowpass_hz = 8.0, target_fs = None, normalize = "unit", return_bands = False):
   
    x = np.asarray(audio, dtype=float).ravel()

    # If `audio` seems to be an envelope already (e.g., 64 Hz), do a simple path
    if fs < 1000 or centre_freqs is None or make_erb_filters is None or erb_filterbank is None:
        env = np.abs(signal.hilbert(x))
        env = _lp_filter(env, fs, lowpass_hz)
        env, fs_env = _resample_to(env, fs, target_fs)
        if normalize == "unit":
            m = np.max(np.abs(env))
            if m > 0:
                env = env / (m + 1e-12)
        elif normalize == "zscore":
            env = (env - env.mean()) / (env.std() + 1e-12)
        meta = {"fs": fs_env, "cf": None, "band_envs": None, "aggregate": "simple", "compress": "none"}
        return env.astype(float), meta

    # Proper gammatone filterbank
    if fmax is None:
        fmax = min(0.45 * fs, fs / 2.0)
    cf = centre_freqs(fs, num_bands, fmin)
    cf = cf[cf <= fmax]
    if cf.size == 0:
        raise ValueError("No center frequencies within [fmin, fmax].")

    fcoefs = make_erb_filters(fs, cf)
    y = erb_filterbank(x, fcoefs)            # shape: (n_bands, T)
    band_envs = np.abs(signal.hilbert(y, axis=-1))

    # Compression
    if compress == "pow":
        band_envs = np.power(band_envs, float(compress_exp))
    elif compress == "log":
        band_envs = np.log1p(band_envs)

    # Aggregate
    if band_envs.ndim == 1:
        agg = band_envs
    else:
        if aggregate == "sum":
            agg = band_envs.sum(axis=0)
        elif aggregate == "rms":
            agg = np.sqrt((band_envs ** 2).mean(axis=0))
        else:  # mean
            agg = band_envs.mean(axis=0)

    # Low-pass smoothing and (optional) resample to target_fs
    agg = _lp_filter(agg, fs, lowpass_hz)
    env, fs_env = _resample_to(agg, fs, target_fs)

    # Final normalization
    if normalize == "unit":
        m = np.max(np.abs(env))
        if m > 0:
            env = env / (m + 1e-12)
    elif normalize == "zscore":
        env = (env - env.mean()) / (env.std() + 1e-12)

    meta = {"fs": fs_env, "cf": cf, "band_envs": band_envs if return_bands else None,
                            "aggregate": aggregate, "compress": compress}
    return env.astype(float), meta


In [106]:


def _unique_path(prefix: str) -> Path:
    p = FIG_DIR / f"{prefix}.png"
    k = 1
    while p.exists():
        p = FIG_DIR / f"{prefix}_{k}.png"
        k += 1
    return p

def plot_env_overlay(audio: np.ndarray,
                     fs_audio: float,
                     env: np.ndarray,
                     meta: dict,
                     t_start: float = 0.0,
                     t_end: float | None = None,
                     audio_norm: str = "none",   # "none" | "zscore"
                     env_norm: str = "unit",     # "unit" | "zscore" | "none"
                     title: str | None = None,
                     out_path: str | None = None) -> str:
    if t_end is None:
        t_end = max(len(audio)/fs_audio, len(env)/float(meta.get("fs", fs_audio)))
    fs_env = float(meta.get("fs", fs_audio))

    i0 = max(0, int(np.floor(t_start * fs_audio)))
    i1 = min(len(audio), int(np.ceil(t_end * fs_audio)))
    j0 = max(0, int(np.floor(t_start * fs_env)))
    j1 = min(len(env), int(np.ceil(t_end * fs_env)))

    a = np.asarray(audio[i0:i1], float)
    e = np.asarray(env[j0:j1], float)
    t_a = np.arange(i0, i1) / fs_audio
    t_e = np.arange(j0, j1) / fs_env

    if audio_norm == "zscore":
        a = (a - a.mean()) / (a.std() + 1e-12)
    if env_norm == "unit":
        m = np.max(np.abs(e));  e = e / (m + 1e-12) if m > 0 else e
    elif env_norm == "zscore":
        e = (e - e.mean()) / (e.std() + 1e-12)

    plt.figure(figsize=(12, 4))
    plt.plot(t_a, a, lw=0.8, label="audio")
    plt.plot(t_e, e, lw=1.2, label="envelope")
    plt.xlabel("Time (s)"); plt.ylabel("Amplitude")
    plt.title(title or "Audio + Envelope")
    plt.legend(); plt.tight_layout()

    if out_path is None:
        out_path = str(_unique_path(f"env_overlay_{int(t_start*1000)}ms_{int(t_end*1000)}ms"))
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=200); plt.close()
    return out_path

def plot_env_bands(meta: dict,
                   t_start: float = 0.0,
                   t_end: float | None = None,
                   vmin: float | None = None,
                   vmax: float | None = None,
                   title: str | None = None,
                   out_path: str | None = None) -> str:
    band_envs = meta.get("band_envs", None)
    fs_env = float(meta.get("fs", 0))
    cf = meta.get("cf", None)
    if band_envs is None or fs_env <= 0:
        raise ValueError("meta must contain 'band_envs' and 'fs' to plot bands")

    B, T = band_envs.shape
    if t_end is None:
        t_end = T / fs_env
    j0 = max(0, int(np.floor(t_start * fs_env)))
    j1 = min(T, int(np.ceil(t_end * fs_env)))
    be = np.asarray(band_envs[:, j0:j1], float)
    t = np.arange(j0, j1) / fs_env

    plt.figure(figsize=(12, 4))
    extent = [t[0], t[-1] if t.size else t_end, (cf[0] if cf is not None else 1), (cf[-1] if cf is not None else B)]
    plt.imshow(be[::-1, :], aspect="auto", extent=[extent[0], extent[1], extent[2], extent[3]],
               vmin=vmin, vmax=vmax)
    plt.colorbar(label="Envelope")
    plt.xlabel("Time (s)")
    plt.ylabel("Center frequency (Hz)" if cf is not None else "Band index")
    plt.title(title or "Per-band envelopes")
    plt.tight_layout()

    if out_path is None:
        out_path = str(_unique_path(f"env_bands_{int(t_start*1000)}ms_{int(t_end*1000)}ms"))
    Path(out_path).parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(out_path, dpi=200); plt.close()
    return out_path


In [None]:
def _clean_to_64(eeg, ch_names=None):
    X = np.asarray(eeg, float)
    n_ch = X.shape[0]
    if n_ch <= 64:
        return X, (ch_names if ch_names is not None else [f"Ch{i+1}" for i in range(n_ch)]), np.arange(n_ch)

    if ch_names is not None:
        bad_prefixes = ("EXG", "MISC", "AUX", "STATUS", "TRIG")
        bad_contains = ("EOG", "ECG", "EMG")
        keep_idx = []
        for i, nm in enumerate(ch_names):
            nm_u = str(nm).upper()
            if any(nm_u.startswith(p) for p in bad_prefixes): continue
            if any(k in nm_u for k in bad_contains): continue
            keep_idx.append(i)
        if len(keep_idx) >= 64:
            keep_idx = keep_idx[:64]
        else:
            rest = [i for i in range(n_ch) if i not in keep_idx]
            keep_idx = (keep_idx + rest)[:64]
    else:
        keep_idx = list(range(64))

    keep_idx = np.asarray(keep_idx, int)
    X64 = X[keep_idx, :]
    ch64 = [ch_names[i] for i in keep_idx] if ch_names is not None else [f"Ch{i+1}" for i in range(64)]
    return X64, ch64, keep_idx


def subject_eeg_audio(preproc_dir,
                       subj_id,
                       num_bands=64,
                       fmin=50.0,
                       fmax=None,
                       lowpass_hz=8.0,
                       normalize="unit"):
    """
    Returns EEG (T x C), attended envelope (T,), fs, and per-sample 'A'/'B' labels.
    Uses your gammatone_hilbert_envelope; trims 66→64 channels cleanly.
    """
    D = load_subject(preproc_dir, subj_id, drop_exg=True)

    fs = float(D["fs"])
    eeg = np.asarray(D["eeg"], float)            # (n_ch, T)
    ch_names = D.get("ch_names", None)

    eeg, ch_names64, idx64 = _clean_to_64(eeg, ch_names)
    
    T = eeg.shape[1]
    att = np.asarray(D["attmask"], np.uint8)
    wavA = np.asarray(D["envA"], float) if D["envA"] is not None else np.zeros(T, float)
    wavB = np.asarray(D["envB"], float) if D["envB"] is not None else np.zeros(T, float)

    def _env(x):
        env, meta = gammatone_hilbert_envelope(
            x, fs,
            num_bands=num_bands, fmin=fmin, fmax=fmax,
            compress="pow", compress_exp=0.6,
            aggregate="sum", lowpass_hz=lowpass_hz,
            target_fs=fs, normalize=normalize, return_bands=False,
        )
        if env.size != T:
            env = signal.resample(env, T)
        return env.astype(float)

    envA = _env(wavA)
    envB = _env(wavB)

    env_att = np.where(att == 1, envA, np.where(att == 2, envB, 0.0)).astype(float)
    att_AB  = np.where(att == 1, "A", np.where(att == 2, "B", "U")).astype("<U1")

    eeg_TxC = eeg.T
    return eeg_TxC, env_att, fs, att_AB

In [114]:
subject_eeg_audio(PREPROC_DIR, 1, num_bands = 64)

['Fp1', 'AF7', 'AF3', 'F1', 'F3', 'F5', 'F7', 'FT7', 'FC5', 'FC3', 'FC1', 'C1', 'C3', 'C5', 'T7', 'TP7', 'CP5', 'CP3', 'CP1', 'P1', 'P3', 'P5', 'P7', 'P9', 'PO7', 'PO3', 'O1', 'Iz', 'Oz', 'POz', 'Pz', 'CPz', 'Fpz', 'Fp2', 'AF8', 'AF4', 'AFz', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT8', 'FC6', 'FC4', 'FC2', 'FCz', 'Cz', 'C2', 'C4', 'C6', 'T8', 'TP8', 'CP6', 'CP4', 'CP2', 'P2', 'P4', 'P6', 'P8', 'P10', 'PO8', 'PO4', 'O2']


(array([[  9.70947347,  27.700268  ,  28.42683385, ...,  -6.03082454,
         -34.35721669, -33.96078393],
        [  4.49293742,  27.2494236 ,  24.95064359, ...,  -4.60977458,
         -29.91686625, -30.8138056 ],
        [ 10.34741619,  31.41075145,  29.21696963, ...,  -6.62789446,
         -33.16222181, -36.59362892],
        ...,
        [ -9.3691818 ,   6.88691982,   1.0286609 , ...,   0.54570865,
           6.39622434,   6.18335204],
        [-12.39941793,   3.87620143,  -3.27677883, ...,  -0.99939222,
           2.42031908,   2.80503614],
        [-11.26664088,   1.86334054,  -2.47354118, ...,   2.25381339,
           7.16669159,   6.71865936]], shape=(192000, 64)),
 array([0.38223552, 0.26540759, 0.16863371, ..., 0.2134771 , 0.21982443,
        0.23109459], shape=(192000,)),
 64.0,
 array(['B', 'B', 'B', ..., 'A', 'A', 'A'], shape=(192000,), dtype='<U1'))