In [9]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import roc_curve, precision_recall_curve, auc
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, confusion_matrix
import joblib
import mne
import antropy as ant
import scipy.stats as sp_stats
import mne_features

### Preprocessing

In [ ]:
# General vars
sr = 1000  # Hz
window_size = 250 # samples
mtl_path = 'P%s_mtl_clean.fif'
all_subjects = ['01', '02', '03'] # example ids
depth_model = joblib.load(r'depth_model.pkl')
depth_channels = ['RAH1', 'LAH1', 'RA1', 'LA1', 'LEC1', 'REC1', 'RPHG1', 'LPHG1', 'RMH1', 'LMH1', 'RAH2', 'LAH2', 'RA2', 'LA2', 'LEC2', 'REC2', 'RPHG2', 'LPHG2', 'RMH2', 'LMH2']

In [ ]:
def extract_depth_features(epochs, subj):
    """
    Extracts features for the depth model from epoched data.

    Args:
        epochs (np.ndarray or List): 2D array-like structure (n_epochs, n_samples)
                                     containing the signal windows.
        subj (str): The subject identifier.

    Returns:
        pd.DataFrame: DataFrame with one row per epoch and one column per feature.
    """
    mobility, complexity = ant.hjorth_params(epochs, axis=1)
    feat = {
        'subj': np.full(len(epochs), subj),
        'epoch_id': np.arange(len(epochs)),
        'kurtosis': sp_stats.kurtosis(epochs, axis=1),
        'hjorth_mobility': mobility,
        'hjorth_complexity': complexity,
        'ptp_amp': np.ptp(epochs, axis=1),
        'samp_entropy': np.apply_along_axis(ant.sample_entropy, axis=1, arr=epochs)
    }
    
    # Convert to dataframe
    feat = pd.DataFrame(feat)
    
    # Teager-Kaiser energy
    kaiser = mne_features.univariate.compute_teager_kaiser_energy(np.array(epochs))
    reshaped_list = np.array(kaiser).reshape(-1, 12)
    X_kaiser = pd.DataFrame(reshaped_list)
    # rename columns
    X_kaiser.columns = [
        f'teager_kaiser_energy_{i}_mean' if j % 2 == 0 else f'teager_kaiser_energy_{i}_std'
        for i in range(6) for j in range(2)
    ]

    feat = pd.concat([feat, X_kaiser], axis=1)
    return feat

def raw_chan_to_feat(raw, chan, subj, depth):
    """
    Processes a single channel from a raw file, epochs it, and extracts features.

    This function normalizes the channel, segments it into 250ms windows,
    and calls the appropriate feature extraction function (zEEG or depth).
    Finally, it adds global channel-level features to all epochs.

    Args:
        raw (mne.io.Raw): The MNE raw object.
        chan (str): The name of the channel to process.
        subj (str): The subject identifier.
        depth (bool): Flag. If True, extract depth features.
                      If False, extract zEEG features.

    Returns:
        pd.DataFrame: A DataFrame of features, one row per valid epoch.
    """
    epochs = []
    chan_raw = raw.copy().pick([chan]).get_data(reject_by_annotation='NaN').flatten()
    
    # normalize chan
    chan_norm = (chan_raw - np.nanmean(chan_raw)) / np.nanstd(chan_raw)
    
    # run over all 250ms epochs (exclude last second)
    for i in range(0, len(chan_norm) - sr, window_size):
        if not np.isnan(chan_norm[i: i + window_size]).any():
            epochs.append(chan_norm[i: i + window_size])

    if depth:
        curr_feat = extract_depth_features(epochs, subj)
    else: # zeeg
        curr_feat = extract_zeeg_features(epochs, subj, raw.info['sfreq'])
    
    # add channel features for all epochs
    chan_feat = {
        'chan_name': chan,
        'chan_ptp': np.ptp(chan_norm[~np.isnan(chan_norm)]),
        'chan_skew': sp_stats.skew(chan_norm[~np.isnan(chan_norm)]),
        'chan_kurt': sp_stats.kurtosis(chan_norm[~np.isnan(chan_norm)]),
    }

    for feat in chan_feat.keys():
        curr_feat[feat] = chan_feat[feat]

    return curr_feat

def get_depth_pred(subjects, threshold=0.8, min_channels=2):
    """
    Generates depth model predictions (labels) for a list of subjects.

    This function iterates through subjects, loads data, extracts features
    from target channels, and generates a binary prediction for each epoch
    based on a consensus of channels.

    Args:
        subjects: List of subject IDs.
        threshold (float): Probability threshold for a single channel to be considered "active".
        min_channels (int): The minimum number of active channels required to
                            label an epoch as positive (1).

    Returns:
        dict: A dictionary where keys are subject IDs and values are np.ndarrays
              of binary predictions (one per epoch).
    """
    y_all = {}
    for subj in subjects:
        raw = mne.io.read_raw(mtl_path % subj)
        # Find channels that exist in *both* the raw file and our target list
        curr_chans = [chan for chan in raw.ch_names if chan in depth_channels]
        y_curr = None
        # Predict and sum over channels
        for chan in curr_chans:
            curr_feat = raw_chan_to_feat(raw, chan, subj, depth=True)
            predictions = depth_model.predict_proba(curr_feat[depth_model.get_booster().feature_names])
            if y_curr is None:
                y_curr = (predictions[:, 1] >= threshold).astype(int)
            else:
                y_curr += (predictions[:, 1] >= threshold).astype(int)

        # at least X channels should be above threshold
        y_curr[y_curr <= min_channels - 1] = 0
        y_curr[y_curr > min_channels - 1] = 1
        y_all[subj] = y_curr

    return y_all

In [ ]:
def extract_zeeg_features(epochs, subj, sr):
    """
    Extracts a comprehensive set of zEEG features from epoched data.

    Calculates features across several domains:
    - Basic statistical (mean, std, ptp, etc.)
    - Fractal / nonlinear (Higuchi, Katz, Hjorth)
    - Entropy (Sample, Spectral, SVD)
    - Power (Absolute, Normalized band power, band ratios)
    - Energy (Band energy, band ratios)
    - Wavelet energy
    - Teager-Kaiser energy

    Args:
        epochs (np.ndarray or List): 2D array-like (n_epochs, n_samples).
        subj (str): The subject identifier.
        sr (int): The sampling rate in Hz.

    Returns:
        pd.DataFrame: DataFrame with one row per epoch and one column per feature.
    """
    data = np.array(epochs)
    uni = mne_features.univariate

    # Frequency bands configuration
    bands = {
        'theta': (4, 8), 'alpha': (8, 12), 'sigma': (12, 16),
        'beta': (16, 30), 'gamma': (30, 100), 'fast': (100, 300)
    }
    band_list = ['theta', 'alpha', 'sigma', 'beta', 'gamma', 'fast']
    psd_params = {'psd_method': 'welch', 'psd_params': None}

    # === Basic statistical & complexity features ===
    basic_features = {
        'ptp_amp': uni.compute_ptp_amp(data),
        'mean': uni.compute_mean(data),
        'std': uni.compute_std(data),
        'variance': uni.compute_variance(data),
        'skewness': uni.compute_skewness(data),
        'kurtosis': uni.compute_kurtosis(data),
        'quantile': uni.compute_quantile(data, q=0.75),
        'rms': uni.compute_rms(data),
        'line_length': uni.compute_line_length(data),
        'zero_crossings': uni.compute_zero_crossings(data, threshold=np.finfo(float).eps),
    }

    # === Fractal & nonlinear features ===
    complexity_features = {
        'higuchi_fd': uni.compute_higuchi_fd(data, kmax=10),
        'katz_fd': uni.compute_katz_fd(data),
        'hurst_exp': uni.compute_hurst_exp(data),
        'hjorth_mobility': uni.compute_hjorth_mobility(data),
        'hjorth_complexity': uni.compute_hjorth_complexity(data),
        'hjorth_mobility_spect': uni.compute_hjorth_mobility_spect(sr, data, normalize=False, **psd_params),
        'hjorth_complexity_spect': uni.compute_hjorth_complexity_spect(sr, data, normalize=False, **psd_params),
    }

    # === Entropy features ===
    entropy_features = {
        'app_entropy': uni.compute_app_entropy(data, emb=2, metric='chebyshev'),
        'samp_entropy': uni.compute_samp_entropy(data, emb=2, metric='chebyshev'),
        'spect_entropy': uni.compute_spect_entropy(sr, data, **psd_params),
        'svd_entropy': uni.compute_svd_entropy(data, tau=2, emb=10),
        'svd_fisher_info': uni.compute_svd_fisher_info(data, tau=2, emb=10),
        'decorr_time': uni.compute_decorr_time(sr, data),
    }

    # === Power features ===
    abspow = uni.compute_pow_freq_bands(sr, data, {'total': (0.1, 500)}, False, psd_method='multitaper')

    # Combine basic features
    df_basic = pd.DataFrame({**basic_features, **complexity_features, **entropy_features, 'abspow_': abspow})

    # === Spectral slope ===
    slope = uni.compute_spect_slope(sr, data, fmin=0.1, fmax=50, with_intercept=True, psd_method='welch')
    df_slope = pd.DataFrame(
        np.array(slope).reshape(-1, 4),
        columns=['spect_slope_intercept', 'spect_slope_slope', 'spect_slope_MSE', 'spect_slope_R2']
    )

    # === Frequency band power (normalized) ===
    pow_bands = uni.compute_pow_freq_bands(
        data=data, sfreq=sr, freq_bands=bands, normalize=True,
        ratios=None, psd_method='multitaper', log=False
    )
    df_pow = pd.DataFrame(
        np.array(pow_bands).reshape(-1, len(bands)),
        columns=[f'pow_freq_bands_{b}' for b in band_list]
    )
    # Add all band ratios
    for b1 in band_list:
        for b2 in band_list:
            if b1 != b2:
                df_pow[f'pow_freq_bands_{b1}/{b2}'] = df_pow[f'pow_freq_bands_{b1}'] / df_pow[f'pow_freq_bands_{b2}']

    # === Frequency band energy ===
    energy = uni.compute_energy_freq_bands(sr, data, freq_bands=bands)
    df_energy = pd.DataFrame(
        np.array(energy).reshape(-1, len(bands)),
        columns=[f'energy_freq_bands_{b}' for b in band_list]
    )

    # Add two-letter abbreviation ratios
    for b1 in band_list:
        for b2 in band_list:
            if b1 != b2 and f'energy_freq_bands_{b2[0]}{b1[0]}' not in df_energy.columns:
                df_energy[f'energy_freq_bands_{b1[0]}{b2[0]}'] = (
                    df_energy[f'energy_freq_bands_{b1}'] / df_energy[f'energy_freq_bands_{b2}']
                )

    # === Wavelet features ===
    wave = uni.compute_wavelet_coef_energy(data, wavelet_name='db4')
    df_wave = pd.DataFrame(
        np.array(wave).reshape(-1, 5),
        columns=[f'wavelet_coef_energy_{i}' for i in range(5)]
    )

    # === Teager-Kaiser energy ===
    kaiser = uni.compute_teager_kaiser_energy(data)
    df_kaiser = pd.DataFrame(
        np.array(kaiser).reshape(-1, 12),
        columns=[f'teager_kaiser_energy_{i}_{stat}' for i in range(6) for stat in ['mean', 'std']]
    )

    # === Combine all features ===
    df = pd.concat([df_basic, df_slope, df_energy, df_kaiser, df_pow, df_wave], axis=1)
    df.insert(0, 'subj', subj)
    df.insert(1, 'epoch_id', np.arange(len(df)))

    return df


def get_all_features_per_chan(chan, subjects):
    """
    Runs feature extraction for a *single channel* across all subjects.

    Iterates through a list of subjects, loads their raw data, and calls
    raw_chan_to_feat to extract zEEG features for the specified channel.

    Args:
        chan (str): The name of the channel to extract features for (e.g., "zeeg1").
        subjects (List[str]): List of subject IDs to process.

    Returns:
        dict: A dictionary where keys are subject IDs and values are the
              feature DataFrames returned by raw_chan_to_feat.
    """
    all_features = {}
    for subj in subjects:
        raw = mne.io.read_raw(mtl_path % subj)
        curr_feat = raw_chan_to_feat(raw, chan, subj, depth=False)
        all_features[subj] = curr_feat

    return all_features

In [ ]:
# run and create the pkl for the next model
y_all = get_depth_pred(all_subjects)
zeeg1_all = get_all_features_per_chan('zeeg1', all_subjects)
zeeg2_all = get_all_features_per_chan('zeeg2', all_subjects)
subj_data = {}
for subj in all_subjects:
    subj_data[subj] = {'zeeg1': zeeg1_all[subj], 'zeeg2': zeeg2_all[subj], 'y': y_all[subj]}

# save for later analysis
joblib.dump(subj_data, 'zeeg_training_data.pkl')


### zEEG model

In [None]:
subj_data = joblib.load('zeeg_training_data.pkl')
# combine all subjects
symmetric = True
x = pd.DataFrame()
y = np.array([])
for subj in all_subjects:
    zeeg1_subj = subj_data[subj]['zeeg1']
    zeeg2_subj = subj_data[subj]['zeeg2']
    y_subj = subj_data[subj]['y']
    zeeg1_subj.reset_index(drop=True, inplace=True)
    zeeg2_subj.reset_index(drop=True, inplace=True)
    x_subj = pd.concat([zeeg1_subj, zeeg2_subj], axis=1, ignore_index=True)
    x_subj.columns = [f'zeeg1_{col}' for col in zeeg1_subj.columns] + [f'zeeg2_{col}' for col in zeeg2_subj.columns]
    x = pd.concat([x, x_subj], ignore_index=True)
    if symmetric:
        x_sym = pd.concat([zeeg2_subj, zeeg1_subj], axis=1, ignore_index=True)
        x_sym.columns = x_subj.columns
        x = pd.concat([x, x_sym], ignore_index=True)
        y = np.concatenate((y, y_subj))
    y = np.concatenate((y, y_subj))

In [None]:
# balance data in each subject
x['pred'] = y
sampled_data_0 = pd.DataFrame()
sampled_data_1 = pd.DataFrame()
max_samples = 3000
for subj in all_subjects:
    n_EDs = x[(x['zeeg1_subj'] == subj) & (x['pred'] == 1)].shape[0]
    sample_count = min(max_samples, n_EDs)
    # Get negative samples for this subject
    sampled_data_0 = pd.concat([sampled_data_0, x[(x['zeeg1_subj'] == subj) & (x['pred'] == 0)].sample(sample_count, replace=True, random_state=8)])
    # Get positive samples for this subject
    sampled_data_1 = pd.concat([sampled_data_1, x[(x['zeeg1_subj'] == subj) & (x['pred'] == 1)].sample(sample_count, replace=True, random_state=8)])

# --- Final Assembly of Balanced Data ---
sampled_data = pd.concat([sampled_data_1, sampled_data_0], ignore_index=True)
x = sampled_data.drop(columns='pred')
y = sampled_data['pred']
sampled_data

In [10]:
# choose features
meta_data = ['subj', 'epoch_id', 'chan_name', 'epoch']
x_feat = x[x.columns[~x.columns.str.contains('|'.join(meta_data))]]

# Initialize containers
metrics = {'accuracy': [], 'precision': [], 'sensitivity': [], 'specificity': [], 'f1': [], 'ROCAUC': [], 'PRAUC': []}
# Store all predictions for aggregate plots
all_y_true = []
all_y_prob = []

kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=8)
for train_index, test_index in kf.split(x_feat, y):
    model = xgb.XGBClassifier() 
    x_train_fold, x_test_fold = x_feat.iloc[train_index], x_feat.iloc[test_index]
    y_train_fold, y_test_fold = y[train_index], y[test_index]
    
    model.fit(x_train_fold, y_train_fold)
    y_prob = model.predict_proba(x_test_fold)[:, 1]  # probabilities for class 1
    y_pred = (y_prob > 0.5).astype(int)  # thresholding at 0.5
    y_true = y_test_fold
    all_y_true.extend(y_true)
    all_y_prob.extend(y_prob)

    # save metrics
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics['accuracy'].append(accuracy_score(y_true, y_pred))
    metrics['precision'].append(precision_score(y_true, y_pred))
    metrics['sensitivity'].append(recall_score(y_true, y_pred))
    metrics['f1'].append(f1_score(y_true, y_pred))
    metrics['specificity'].append(tn / (tn + fp))
    metrics['ROCAUC'].append(roc_auc_score(y_true, y_prob))
    metrics['PRAUC'].append(average_precision_score(y_true, y_prob))

# Create results table
results = pd.DataFrame(metrics)
results.loc['mean'] = results.mean()
results

Fold 1
Fold 2
Fold 3
Fold 4
Fold 5
Positive class ratio: 0.5


Unnamed: 0,accuracy,precision,sensitivity,specificity,f1,ROCAUC,PRAUC
0,0.691008,0.693682,0.684105,0.697911,0.68886,0.765648,0.772923
1,0.693594,0.697521,0.683653,0.703534,0.690517,0.769405,0.775377
2,0.691862,0.697729,0.677026,0.706697,0.687221,0.765583,0.771551
3,0.693092,0.697901,0.680942,0.705241,0.689317,0.767562,0.77614
4,0.687494,0.692928,0.673411,0.701576,0.68303,0.761794,0.767891
mean,0.69141,0.695952,0.679827,0.702992,0.687789,0.765999,0.772776


In [ ]:
# plot ROC and PR curves
fpr, tpr, _ = roc_curve(all_y_true, all_y_prob)
precision, recall, _ = precision_recall_curve(all_y_true, all_y_prob)
roc_auc = auc(fpr, tpr)
pr_auc = auc(recall, precision)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(recall, precision, label=f'PR curve (AUC = {pr_auc:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.show()

### Application on non invasive data

In [ ]:
subjects_HC = ['HC1', 'HC2', 'HC3']  # example healthy control ids
subjects_EPI = ['EPI1', 'EPI2', 'EPI3']  # example epilepsy patient ids
# assume features are ready using the get_zeeg_features function
hc = joblib.load("v1_HC.pkl")
epi = joblib.load("v1_EPI.pkl")

In [ ]:
def extract_rate_per_bin(model, channel_1, channel_2, threshold, bin_length, sec_limit):
    """
    Computes EDs-per-minute (SPM) for HC and EPI groups from feature data.

    This function iterates through subjects, bins their data, runs the model,
    applies a refractory period, and calculates the SPM for each bin.

    Args:
        model: The trained XGBoost model object.
        channel_1 (str): Name of the first zeeg channel key (e.g., 'E227').
        channel_2 (str): Name of the second zeeg channel key (e.g., 'E254').
        threshold (float): Probability threshold for ED detection (0 to 1).
        bin_length_min (int): Duration of each analysis bin in minutes.
        refractory_sec (int): Refractory period in seconds after a detection.

    Returns:
        pd.DataFrame: A DataFrame with ["subject", "EDs_per_min", "group"]
                      for each bin of each subject.
    """

    # Sampling assumptions: 4 Hz -> 240 samples per minute
    sampling_hz = 4
    samples_per_min = sampling_hz * 60  # 240
    bin_stride = bin_length * samples_per_min
    refractory = int(sec_limit * sampling_hz)  # skip window after a detection

    # Optional caches for downstream analysis, if needed later
    proba_by_bin = {}  # (group, subj, part) -> np.ndarray of P(class=1)
    tr_by_bin = {}     # (group, subj, part) -> np.ndarray of 0/1 after threshold + refractory

    rows = []  # rows for the output DataFrame

    def process_group(group_label, subjects, store):
        """
        Run inference per subject in a group, split into bins, threshold with refractory,
        and accumulate EDs-per-minute for each bin into rows.
        """
        for subj in subjects:
            # Build feature matrix for subject by concatenating zEEG channels
            zeeg1 = store[subj][channel_1]
            zeeg2 = store[subj][channel_2]
            curr_feat = pd.concat([zeeg1, zeeg2], axis=1, ignore_index=True)
            curr_feat.columns = [f"zeeg1_{c}" for c in zeeg1.columns] + [f"zeeg2_{c}" for c in zeeg2.columns]

            # Enforce exact feature order known to the trained model
            feat_names = model.get_booster().feature_names
            X_all = curr_feat[feat_names]

            part = 0
            n = len(X_all)
            for start in range(0, n, bin_stride):
                stop = start + bin_stride
                X = X_all.iloc[start:stop, :]

                # Skip very short tail bins (< 50% of target size) to reduce noise
                if len(X) < (bin_stride // 2):
                    continue

                key = (group_label, subj, part)

                # Predict class-1 probabilities for the current bin
                proba = model.predict_proba(X)[:, 1]
                proba_by_bin[key] = proba

                # Threshold with a refractory (skip window) to prevent dense re-triggers
                result = np.zeros_like(proba, dtype=int)
                i = 0
                while i < len(proba):
                    if proba[i] >= threshold:
                        result[i] = 1
                        i += refractory
                    else:
                        i += 1
                tr_by_bin[key] = result

                # EDs per minute = detections per bin divided by bin duration in minutes
                spm = result.sum() / (len(result) / samples_per_min)

                rows.append({
                    "subject": f"{subj}_{part}",
                    "EDs_per_min": spm,
                    "group": group_label
                })

                part += 1

    # Process both groups using shared logic
    process_group("HC", subjects_HC, hc)
    process_group("EPI", subjects_EPI, epi)

    df = pd.DataFrame(rows, columns=["subject", "EDs_per_min", "group"])
    return df

In [ ]:
threshold = 0.7    # probability threshold for ED detection             
model_im = joblib.load('zeeg_model.pkl') # loaded model
bin_length = 20 # separate into 20 min bins
sec_limit = 2   # refractory period in seconds

zeeg1_id = 'E227'
zeeg2_id = 'E254'
df = extract_rate_per_bin(model_im, zeeg1_id, zeeg2_id, threshold, bin_length, sec_limit)
# save as csv
df.to_csv('zeeg_eds_per_min.csv', index=False)