In [2]:
from mnelab.io.writers import write_edf
import mne
import pandas as pd
import numpy as np
import antropy as ant
import scipy.stats as sp_stats
from sklearn.model_selection import StratifiedKFold
from lightgbm import LGBMClassifier
import joblib
import mne_features
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from imblearn.under_sampling import RandomUnderSampler
from sklearn.metrics import confusion_matrix

In [1]:
# same function as mne but support nan values (const physical calues)
def write_edf_nan(fname, raw):
    """Export raw to EDF/BDF file (requires pyEDFlib)."""
    import pyedflib
    from pathlib import Path

    suffixes = Path(fname).suffixes
    ext = "".join(suffixes[-1:])
    if ext == ".edf":
        filetype = pyedflib.FILETYPE_EDFPLUS
        dmin, dmax = -32768, 32767
    elif ext == ".bdf":
        filetype = pyedflib.FILETYPE_BDFPLUS
        dmin, dmax = -8388608, 8388607
    data = raw.get_data() * 1e6  # convert to microvolts
    fs = raw.info["sfreq"]
    nchan = raw.info["nchan"]
    ch_names = raw.info["ch_names"]
    if raw.info["meas_date"] is not None:
        meas_date = raw.info["meas_date"]
    else:
        meas_date = None
    prefilter = (f"{raw.info['highpass']}Hz - "
                 f"{raw.info['lowpass']}")
    pmin, pmax = data.min(axis=1), data.max(axis=1)
    f = pyedflib.EdfWriter(fname, nchan, filetype)
    channel_info = []
    data_list = []
    for i in range(nchan):
        channel_info.append(dict(label=ch_names[i],
                                 dimension="uV",
                                 sample_rate=fs,
                                 physical_min=-5000,
                                 physical_max=5000,
                                 digital_min=dmin,
                                 digital_max=dmax,
                                 transducer="",
                                 prefilter=prefilter))
        data_list.append(data[i])
    f.setTechnician("Exported by MNELAB")
    f.setSignalHeaders(channel_info)
    if raw.info["meas_date"] is not None:
        f.setStartdatetime(meas_date)
    # note that currently, only blocks of whole seconds can be written
    f.writeSamples(data_list)
    for annot in raw.annotations:
        f.writeAnnotation(annot["onset"], annot["duration"], annot["description"])

In [3]:
# Load the data
subj= '701'
depth = sorted(['RAH1', 'LAH1', 'RA1', 'LA1', 'LEC1', 'REC1', 'RPHG1', 'LPHG1', 'RMH1', 'LMH1', 'RAH2', 'LAH2', 'RA2', 'LA2', 'LEC2', 'REC2', 'RPHG2', 'LPHG2', 'RMH2', 'LMH2'])
scalp = ['C3', 'C4', 'Cz', 'Pz', 'EOG1', 'EOG2']


In [ ]:
# for one file
raw = mne.io.read_raw(r'D:\Bonn\012fn1\012fn1-2012-07-04-evening.edf', preload=True)

In [ ]:
# for split files
raw1 = mne.io.read_raw(r'D:\Bonn\012fn1\012fn1-2012-07-04-evening.edf', preload=True)
raw2 = mne.io.read_raw(r'D:\Bonn\012fn1\012fn1-2012-07-05-morning.edf', preload=True)
raw = mne.concatenate_raws([raw1, raw2])
raw.info.ch_names

In [None]:
# rename as the convention
new_names = {}
for ch in raw.info.ch_names:
    prefix, suffix = ch.split(' ')
    suffix = suffix.replace('PHC', 'PHG')
    if prefix == 'SEEG':
        new_names[ch] = suffix[-2] + suffix[:-2] + suffix[-1]
    else:
        new_names[ch] = suffix
new_names

In [None]:
raw.rename_channels(new_names)
write_edf(r'D:\Bonn\012fn1\P701_overnightData.edf', raw)
raw.save(r'D:\Bonn\012fn1\P701_overnightData.fif', overwrite=True)

In [None]:
# filters and resample
curr_scalp = [x for x in scalp if x in raw.ch_names]
curr_depth = [x for x in depth if x in raw.ch_names]
raw.pick(curr_depth + curr_scalp)
raw.load_data()
raw.filter(l_freq=0.1, h_freq=500, picks=depth, phase='zero-double')
# TODO: 50 or 60?
raw.notch_filter((60, 120, 180, 240), method='spectrum_fit', phase='zero-double')
raw.filter(l_freq=0.1, h_freq=40, picks=scalp, phase='zero-double')
raw.resample(1000)
write_edf(rf'D:\Bonn\012fn1\P{subj}_mtl_filtered.edf', raw)
raw.save(rf'D:\Bonn\012fn1\P{subj}_mtl_filtered.fif', overwrite=True)

In [4]:
# Cleaning
raw = mne.io.read_raw_fif(rf'D:\Bonn\012fn1\P{subj}_mtl_filtered.fif')
raw.set_channel_types({x: 'eog' for x in ['EOG1', 'EOG2']})
raw.set_channel_types({x: 'seeg' for x in depth if x in raw.ch_names})
raw.plot(duration=60 * 5, scalings=dict(eeg=5e-4, eog=4e-4, seeg=5e-4))

Opening raw data file D:\Bonn\012fn1\P701_mtl_filtered.fif...


  raw = mne.io.read_raw_fif(rf'D:\Bonn\012fn1\P{subj}_mtl_filtered.fif')


Isotrak not found
    Range : 0 ... 21459999 =      0.000 ... 21459.999 secs
Ready.
Opening raw data file D:\Bonn\012fn1\P701_mtl_filtered-1.fif...
Isotrak not found
    Range : 21460000 ... 42919999 =  21460.000 ... 42919.999 secs
Ready.
Opening raw data file D:\Bonn\012fn1\P701_mtl_filtered-2.fif...
Isotrak not found
    Range : 42920000 ... 52973999 =  42920.000 ... 52973.999 secs
Ready.
Using qt as 2D backend.


<mne_qt_browser._pg_figure.MNEQtBrowser(0x2d076ccaa20) at 0x000002D00246F7C0>

In [None]:
# save the fif with annotations
raw.drop_channels(raw.info['bads'], on_missing='ignore')
raw.save(f'D:\\Bonn\\P{subj}_mtl_clean.fif', overwrite=True)

In [8]:
# save the edf with annotations
clean_raw = mne.io.read_raw_fif(f'D:\\Bonn\\012fn1\\P{subj}_mtl_clean.fif') # in case of re-run (memory issues)
nan_clean = clean_raw.get_data(reject_by_annotation='NaN')
final = mne.io.RawArray(nan_clean, clean_raw.info)
write_edf_nan(f'D:\\Bonn\\P{subj}_mtl_clean.edf', final)

Opening raw data file D:\Bonn\012fn1\P701_mtl_clean.fif...


  clean_raw = mne.io.read_raw_fif(f'D:\\Bonn\\012fn1\\P{subj}_mtl_clean.fif') # in case of re-run (memory issues)


Isotrak not found
    Range : 0 ... 24385999 =      0.000 ... 24385.999 secs
Ready.
Opening raw data file D:\Bonn\012fn1\P701_mtl_clean-1.fif...
Isotrak not found
    Range : 24386000 ... 48771999 =  24386.000 ... 48771.999 secs
Ready.
Opening raw data file D:\Bonn\012fn1\P701_mtl_clean-2.fif...
Isotrak not found
    Range : 48772000 ... 52973999 =  48772.000 ... 52973.999 secs
Ready.
Setting 1485569 of 52974000 (2.80%) samples to NaN, retaining 51488431 (97.20%) samples.
Creating RawArray with float64 data, n_channels=22, n_times=52974000
    Range : 0 ... 52973999 =      0.000 ... 52973.999 secs
Ready.


In [9]:
# extract features from the epochs of a single channel
def extract_epochs_features(epochs, subj):
    mobility, complexity = ant.hjorth_params(epochs, axis=1)
    feat = {
        'subj': np.full(len(epochs), subj),
        'epoch_id': np.arange(len(epochs)),
        'kurtosis': sp_stats.kurtosis(epochs, axis=1),
        'hjorth_mobility': mobility,
        'hjorth_complexity': complexity,
        'ptp_amp': np.ptp(epochs, axis=1),
        'samp_entropy': np.apply_along_axis(ant.sample_entropy, axis=1, arr=epochs)
    }

    kaiser = mne_features.univariate.compute_teager_kaiser_energy(np.array(epochs))

    # Reshape the list into a 2D array with 12 columns (each row will have 12 values)
    reshaped_list = np.array(kaiser).reshape(-1, 12)

    # Create the DataFrame
    X_new = pd.DataFrame(reshaped_list)
    # rename columns
    X_new.columns = [
        f'teager_kaiser_energy_{i}_mean' if j % 2 == 0 else f'teager_kaiser_energy_{i}_std'
        for i in range(6) for j in range(2)
    ]

    # Convert to dataframe
    feat = pd.DataFrame(feat)
    feat = pd.concat([feat, X_new], axis=1)

    return feat

# get features and labels of a single subject (all channels)
def get_subj_data(raw, chan):
    window_size = 250  # ms
    epochs = []
    chan_raw = raw.copy().pick([chan]).get_data().flatten()
    # normalize chan
    chan_norm = (chan_raw - np.nanmean(chan_raw)) / np.nanstd(chan_raw)
    # run on all 250ms epochs excluding the last 1s
    for i in range(0, len(chan_norm) - 4 * window_size, window_size):
        if not np.isnan(chan_norm[i: i + window_size]).any():
            epochs.append(chan_norm[i: i + window_size])

    # add epoch-level features
    curr_feat = extract_epochs_features(epochs, subj, raw.info['sfreq'])
    # add channel-level features
    chan_feat = {
        'chan_name': chan,
        'chan_ptp': np.ptp(chan_norm),
        'chan_kurt': sp_stats.kurtosis(chan_norm),
    }
    
    for feat in chan_feat.keys():
        curr_feat[feat] = chan_feat[feat]

    # save the epochs as column for debugging/visualization
    curr_feat['epoch'] = epochs

    return curr_feat

In [10]:
from mne_features.univariate import get_univariate_funcs, compute_pow_freq_bands

def extract_all_epochs_features(epochs, subj, sr):
    feat = {
        'subj': np.full(len(epochs), subj),
        'epoch_id': np.arange(len(epochs)),
    }

    selected_funcs = get_univariate_funcs(sr)
    selected_funcs.pop('spect_edge_freq', None)
    bands_dict = {'theta': (4, 8), 'alpha': (8, 12), 'sigma': (12, 16), 'beta': (16, 30), 'gamma': (30, 100), 'fast': (100, 300)}
    params = {'pow_freq_bands__freq_bands': bands_dict, 'pow_freq_bands__ratios': 'all', 'pow_freq_bands__psd_method': 'multitaper',
              'energy_freq_bands__freq_bands': bands_dict}
    X_new = extract_features(np.array(epochs)[:, np.newaxis, :], sr, selected_funcs, funcs_params=params, return_as_df=True)
    X_new['abspow'] = compute_pow_freq_bands(sr, np.array(epochs), {'total': (0.1, 500)}, False, psd_method='multitaper')
    # rename columns
    names = []
    for name in X_new.columns:
        if type(name) is tuple:
            if name[1] == 'ch0':
                names.append(name[0])
            else:
                names.append(name[0] + '_' + name[1].replace('ch0_', ''))
        else:
            names.append(name)

    X_new.columns = names

    # add ratios between bands
    X_new['energy_freq_bands_ab'] = X_new['energy_freq_bands_alpha'] / X_new['energy_freq_bands_beta']
    X_new['energy_freq_bands_ag'] = X_new['energy_freq_bands_alpha'] / X_new['energy_freq_bands_gamma']
    X_new['energy_freq_bands_as'] = X_new['energy_freq_bands_alpha'] / X_new['energy_freq_bands_sigma']
    X_new['energy_freq_bands_af'] = X_new['energy_freq_bands_alpha'] / X_new['energy_freq_bands_fast']
    X_new['energy_freq_bands_at'] = X_new['energy_freq_bands_alpha'] / X_new['energy_freq_bands_theta']
    X_new['energy_freq_bands_bt'] = X_new['energy_freq_bands_beta'] / X_new['energy_freq_bands_theta']
    X_new['energy_freq_bands_bs'] = X_new['energy_freq_bands_beta'] / X_new['energy_freq_bands_sigma']
    X_new['energy_freq_bands_bg'] = X_new['energy_freq_bands_beta'] / X_new['energy_freq_bands_gamma']
    X_new['energy_freq_bands_bf'] = X_new['energy_freq_bands_beta'] / X_new['energy_freq_bands_fast']
    X_new['energy_freq_bands_st'] = X_new['energy_freq_bands_sigma'] / X_new['energy_freq_bands_theta']
    X_new['energy_freq_bands_sg'] = X_new['energy_freq_bands_sigma'] / X_new['energy_freq_bands_gamma']
    X_new['energy_freq_bands_sf'] = X_new['energy_freq_bands_sigma'] / X_new['energy_freq_bands_fast']
    X_new['energy_freq_bands_gt'] = X_new['energy_freq_bands_gamma'] / X_new['energy_freq_bands_theta']
    X_new['energy_freq_bands_gf'] = X_new['energy_freq_bands_gamma'] / X_new['energy_freq_bands_fast']
    X_new['energy_freq_bands_ft'] = X_new['energy_freq_bands_fast'] / X_new['energy_freq_bands_theta']

    # Convert to dataframe
    feat = pd.DataFrame(feat)
    feat = pd.concat([feat, X_new], axis=1)

    return feat

def get_subj_data_zeeg(raw, chan):
    window_size = 250  # ms
    epochs = []
    chan_raw = raw.copy().pick([chan]).get_data().flatten()
    # normalize chan
    chan_norm = (chan_raw - np.nanmean(chan_raw)) / np.nanstd(chan_raw)
    # run on all 250ms epochs excluding the last 1s
    for i in range(0, len(chan_norm) - 4 * window_size, window_size):
        if not np.isnan(chan_norm[i: i + window_size]).any():
            epochs.append(chan_norm[i: i + window_size])

    # add epoch-level features
    curr_feat = extract_all_epochs_features(epochs, subj, raw.info['sfreq'])
    # add channel-level features
    chan_feat = {
        'chan_name': chan,
        'chan_ptp': np.ptp(chan_norm),
        'chan_kurt': sp_stats.kurtosis(chan_norm),
    }
    
    for feat in chan_feat.keys():
        curr_feat[feat] = chan_feat[feat]

    # save the epochs as column for debugging/visualization
    curr_feat['epoch'] = epochs

    return curr_feat

In [11]:
# run depth model for creating labels
depth_model, features = joblib.load(r'C:\repos\depth_ieds\lgbm_full_f15_s25_b_V5.pkl').values()

In [None]:
# extract depth features
chan = 'RAH1'
temp = final.copy().pick(chan).crop(60*60*5 + 60*20, 60*60*5 + 60*25)
subj_data = get_subj_data(temp, chan)
# save to pkl
# subj_data.to_pickle(rf'D:\Bonn\P{subj}_depth_feat.pkl')

subj_data

In [41]:
prob = depth_model.predict_proba(subj_data[features])[:, 1]
y_depth = (prob > 0.8).astype(int)

In [18]:
raw= clean_raw.copy().crop(60*60*5 + 60*20, 60*60*5 + 60*25)

In [21]:
curr_chans = [chan for chan in raw.ch_names if chan in depth]
# get only one deepest channel from each location
min_indexes = {}
for item in curr_chans:
    prefix = item[:-1]
    index = int(item[-1])
    if prefix not in min_indexes or index < int(min_indexes[prefix][-1][-1]):
        min_indexes[prefix] = item
y_curr = None
for chan in min_indexes.values():
    curr_feat = get_subj_data(raw, chan)
    predictions = depth_model.predict_proba(curr_feat[features])
    print(sum((predictions[:, 1] >= 0.8).astype(int)), chan)
    if y_curr is None:
        y_curr = (predictions[:, 1] >= 0.8).astype(int)
    else:
        y_curr += (predictions[:, 1] >= 0.8).astype(int)

# at least 2 channels should be above threshold?
y_curr[y_curr == 1] = 0
y_curr[y_curr > 1] = 1
y_depth = y_curr
y_depth

27 RAH1
40 LAH1
94 RA1
30 LA1
10 LEC1
124 REC1
26 LPHG1
23 RMH1
25 LMH1


array([0, 0, 0, ..., 0, 0, 0])

In [22]:
# total spikes
print(sum(y_depth))
# ratio
print(y_depth.sum() / len(y_depth))

85
0.07101086048454469


In [24]:
# extract scalp features
eog1 = get_subj_data_zeeg(raw, 'EOG1')
eog2 = get_subj_data_zeeg(raw, 'EOG2')
# combine and rename columns
subj_feat = pd.concat([eog1, eog2], axis=1, ignore_index=True) 
subj_feat.columns = [f'eog1_{col}' for col in eog1.columns] + [f'eog2_{col}' for col in eog2.columns]
# save to pkl
# subj_feat.to_pickle(rf'D:\Bonn\P{subj}_zeeg_feat.pkl')
subj_feat

Unnamed: 0,eog1_subj,eog1_epoch_id,eog1_app_entropy,eog1_decorr_time,eog1_energy_freq_bands_theta,eog1_energy_freq_bands_alpha,eog1_energy_freq_bands_sigma,eog1_energy_freq_bands_beta,eog1_energy_freq_bands_gamma,eog1_energy_freq_bands_fast,...,eog2_energy_freq_bands_st,eog2_energy_freq_bands_sg,eog2_energy_freq_bands_sf,eog2_energy_freq_bands_gt,eog2_energy_freq_bands_gf,eog2_energy_freq_bands_ft,eog2_chan_name,eog2_chan_ptp,eog2_chan_kurt,eog2_epoch
0,701,0,0.172145,-1.000,0.054619,0.034944,0.030562,0.090199,0.092207,0.000079,...,3.366357,0.614865,493.295773,5.474957,802.283626,0.006824,EOG2,10.359066,3.673379,"[0.2910601865218791, 0.27085540444514555, 0.24..."
1,701,1,0.173184,-1.000,0.046688,0.046029,0.047431,0.132119,0.072265,0.000273,...,0.695229,0.724523,370.600312,0.959568,511.509633,0.001876,EOG2,10.359066,3.673379,"[0.42728626838803074, 0.4326197366924741, 0.43..."
2,701,2,0.177867,-1.000,0.032664,0.026965,0.015006,0.096482,0.060310,0.000723,...,0.233752,0.288732,23.563429,0.809581,81.610027,0.009920,EOG2,10.359066,3.673379,"[-0.9316964981397581, -0.9154853175538413, -0...."
3,701,3,0.121215,-1.000,0.040300,0.053474,0.059047,0.072183,0.069587,0.000545,...,13.485761,0.625284,66.848866,21.567415,106.909595,0.201735,EOG2,10.359066,3.673379,"[-0.4146620600095086, -0.43864427108781995, -0..."
4,701,4,0.079268,-1.000,0.159191,0.051980,0.010390,0.062352,0.094975,0.000335,...,0.092975,0.211318,58.348508,0.439977,276.117657,0.001593,EOG2,10.359066,3.673379,"[-0.4385535960016858, -0.4189118817956746, -0...."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,701,1192,0.081332,-1.000,0.144600,0.056448,0.013222,0.037078,0.017838,0.000049,...,0.093380,0.767438,387.699334,0.121677,505.186288,0.000241,EOG2,10.359066,3.673379,"[0.21704725781916898, 0.22380713474712743, 0.2..."
1193,701,1193,0.033331,-1.000,0.029814,0.013494,0.015413,0.069475,0.029345,0.000052,...,0.264601,0.315946,543.724512,0.837489,1720.940141,0.000487,EOG2,10.359066,3.673379,"[0.9136661454550974, 0.9162567031119532, 0.921..."
1194,701,1194,0.065752,-1.000,0.035587,0.026600,0.021840,0.067818,0.045333,0.000124,...,0.404642,0.648315,470.665064,0.624144,725.981701,0.000860,EOG2,10.359066,3.673379,"[-0.8245014135779527, -0.8289155812807444, -0...."
1195,701,1195,0.093725,0.048,0.030497,0.030434,0.026980,0.073301,0.047034,0.000111,...,0.356107,0.641986,618.843285,0.554696,963.951115,0.000575,EOG2,10.359066,3.673379,"[-1.6877949656409865, -1.6818209465981702, -1...."


In [25]:
meta_data = ['subj', 'epoch_id', 'chan_name', 'epoch']
feat_to_choose_depth = ['teager_kaiser_energy_1_std', 'teager_kaiser_energy_2_std', 'chan_ptp', 'ptp_amp', 'hjorth_mobility', 'hjorth_complexity', 'teager_kaiser_energy_3_std', 'chan_kurt', 'teager_kaiser_energy_5_mean', 'teager_kaiser_energy_0_std', 'kurtosis', 'teager_kaiser_energy_0_mean', 'teager_kaiser_energy_1_mean', 'teager_kaiser_energy_5_std', 'samp_entropy']
feat_to_choose2 = ['chan', 'bands_gf', 'bands_gamma', 'bands_fast', 'bands_beta/gamma', 'bands_theta/fast', 'ptp', 'skew', 'teager', 'bands_sf', 'bands_bf', 'bands_bg', 'rms', 'katz', 'kurt', 'slope', 'mobility', 'hurst', 'wavelet']
feat_to_choose3 = ['chan', 'bands_gf', 'bands_gamma', 'energy_freq_bands_fast', 'bands_beta/gamma', 'bands_theta/fast', 'ptp', 'skew', 'teager', 'bands_sf', 'bands_bf', 'rms', 'kurt', 'slope', 'mobility', 'eog1_wavelet_coef_energy_0']
remove_3 = ['gamma/beta', 'fast/theta', 'fast/gamma']
remove_4 = ['gamma/sigma', 'fast/alpha', 'coef_1', '2_mean']
# get only columns that contain the string in feat_to_choose
x_feat = subj_feat[subj_feat.columns[subj_feat.columns.str.contains('|'.join(feat_to_choose2))]]
# now remove the metadata and some unwanted features
clean_feat = x_feat[x_feat.columns[~x_feat.columns.str.contains('|'.join(meta_data+remove_3))]]
clean_feat

Unnamed: 0,eog1_energy_freq_bands_gamma,eog1_energy_freq_bands_fast,eog1_hjorth_mobility,eog1_hjorth_mobility_spect,eog1_hurst_exp,eog1_katz_fd,eog1_kurtosis,eog1_pow_freq_bands_gamma,eog1_pow_freq_bands_fast,eog1_pow_freq_bands_theta/fast,...,eog2_wavelet_coef_energy_1,eog2_wavelet_coef_energy_2,eog2_wavelet_coef_energy_3,eog2_wavelet_coef_energy_4,eog2_energy_freq_bands_bg,eog2_energy_freq_bands_bf,eog2_energy_freq_bands_sf,eog2_energy_freq_bands_gf,eog2_chan_ptp,eog2_chan_kurt
0,0.092207,0.000079,0.072972,1.724620,0.752238,1.242145,1.728290,0.017627,0.000323,1544.998765,...,0.001134,0.048922,0.629345,1.829672,0.694779,557.409740,493.295773,802.283626,10.359066,3.673379
1,0.072265,0.000273,0.111870,2.107009,0.618507,1.455004,3.690966,0.041697,0.001513,314.222095,...,0.000157,0.016922,0.146240,2.472885,2.287288,1169.969886,370.600312,511.509633,10.359066,3.673379
2,0.060310,0.000723,0.241754,1.348111,0.835812,1.202516,2.373995,0.023218,0.000578,829.726495,...,0.000729,0.039661,0.212950,0.856875,1.336773,109.094089,23.563429,81.610027,10.359066,3.673379
3,0.069587,0.000545,0.147197,1.220273,0.871311,1.220477,1.959333,0.008369,0.000225,2373.250111,...,0.001720,0.042194,0.204139,1.383887,0.794851,84.977251,66.848866,106.909595,10.359066,3.673379
4,0.094975,0.000335,0.119188,3.409393,0.819717,1.303538,1.565871,0.005502,0.000176,3013.886212,...,0.001212,0.031702,0.528629,0.481083,0.609592,168.319031,58.348508,276.117657,10.359066,3.673379
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,0.017838,0.000049,0.160558,1.170072,1.252324,1.255017,2.791343,0.004847,0.000779,674.446268,...,0.000161,0.004853,0.106114,1.726922,2.168267,1095.378631,387.699334,505.186288,10.359066,3.673379
1193,0.029345,0.000052,0.165910,2.281380,0.989556,1.081192,1.772199,0.006619,0.001344,390.683242,...,0.000052,0.007928,0.155167,0.547002,2.103573,3620.123816,543.724512,1720.940141,10.359066,3.673379
1194,0.045333,0.000124,0.027994,1.724609,0.775838,1.194464,1.403741,0.003874,0.000345,1539.859461,...,0.000088,0.004882,0.123664,1.340190,1.950196,1415.806745,470.665064,725.981701,10.359066,3.673379
1195,0.047034,0.000111,0.100966,1.415740,0.810222,1.170506,2.907787,0.014566,0.001176,434.453827,...,0.000136,0.006415,0.164723,1.440781,1.232480,1188.050163,618.843285,963.951115,10.359066,3.673379


In [26]:
# unbalanced model
metrics = {'accuracy': [], 'precision': [], 'sensitivity': [], 'specificity': [],'f1': [], 'ROCAUC': [], 'PRAUC': []}
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=8)
i = 1
x, y = clean_feat, y_depth
for train_index, test_index in kf.split(x, y):
    print(f'Fold {i}')
    i += 1
    z_model = LGBMClassifier()
    x_train_fold, x_test_fold = x.iloc[train_index], x.iloc[test_index]
    y_train_fold, y_test_fold = y[train_index], y[test_index]
    z_model.fit(x_train_fold, y_train_fold)
    y_pred = z_model.predict(x_test_fold)
    y_true = y_test_fold
    # save scores in dict
    metrics['accuracy'].append(accuracy_score(y_true, y_pred))
    metrics['precision'].append(precision_score(y_true, y_pred))
    metrics['sensitivity'].append(recall_score(y_true, y_pred))
    metrics['f1'].append(f1_score(y_true, y_pred))
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics['specificity'].append(tn / (tn + fp))
    metrics['ROCAUC'].append(roc_auc_score(y_true, y_pred))
    metrics['PRAUC'].append(average_precision_score(y_true, y_pred))

# print results as df
results = pd.DataFrame(metrics)
# add mean row
results.loc['mean'] = results.mean()
print(sum(y)/len(y))
results

Fold 1
[LightGBM] [Info] Number of positive: 68, number of negative: 889
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001462 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 23460
[LightGBM] [Info] Number of data points in the train set: 957, number of used features: 92
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.071055 -> initscore=-2.570590
[LightGBM] [Info] Start training from score -2.570590
Fold 2
[LightGBM] [Info] Number of positive: 68, number of negative: 889
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.001084 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 23460
[LightGBM] [Info] Number of data points in the train set: 957, number of used features: 92
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.071055 -> initscore=-2.570590
[LightGBM] [Info] Start training from score -2.570590
Fold 3
[Li

Unnamed: 0,accuracy,precision,sensitivity,specificity,f1,ROCAUC,PRAUC
0,0.925,0.0,0.0,0.995516,0.0,0.497758,0.070833
1,0.933333,0.666667,0.117647,0.995516,0.2,0.556581,0.140931
2,0.937238,1.0,0.117647,1.0,0.210526,0.558824,0.180409
3,0.945607,0.833333,0.294118,0.995495,0.434783,0.644807,0.295307
4,0.937238,0.625,0.294118,0.986486,0.4,0.640302,0.234033
mean,0.935683,0.625,0.164706,0.994603,0.249062,0.579654,0.184303


In [27]:
# undersample
rus = RandomUnderSampler(random_state=8)
x, y = rus.fit_resample(clean_feat, y_depth)
len(y)

170

In [28]:
metrics = {'accuracy': [], 'precision': [], 'sensitivity': [], 'specificity': [],'f1': [], 'ROCAUC': [], 'PRAUC': []}
kf = StratifiedKFold(n_splits=5, shuffle=True, random_state=8)
i = 1
for train_index, test_index in kf.split(x, y):
    print(f'Fold {i}')
    i += 1
    z_model = LGBMClassifier()
    x_train_fold, x_test_fold = x.iloc[train_index], x.iloc[test_index]
    y_train_fold, y_test_fold = y[train_index], y[test_index]
    z_model.fit(x_train_fold, y_train_fold)
    y_pred = z_model.predict(x_test_fold)
    y_true = y_test_fold
    # save scores in dict
    metrics['accuracy'].append(accuracy_score(y_true, y_pred))
    metrics['precision'].append(precision_score(y_true, y_pred))
    metrics['sensitivity'].append(recall_score(y_true, y_pred))
    metrics['f1'].append(f1_score(y_true, y_pred))
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    metrics['specificity'].append(tn / (tn + fp))
    metrics['ROCAUC'].append(roc_auc_score(y_true, y_pred))
    metrics['PRAUC'].append(average_precision_score(y_true, y_pred))

# print results as df
results = pd.DataFrame(metrics)
# add mean row
results.loc['mean'] = results.mean()
print(sum(y)/len(y))
results

Fold 1
[LightGBM] [Info] Number of positive: 68, number of negative: 68
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000197 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 4324
[LightGBM] [Info] Number of data points in the train set: 136, number of used features: 92
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Fold 2
[LightGBM] [Info] Number of positive: 68, number of negative: 68
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000185 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 4324
[LightGBM] [Info] Number of data points in the train set: 136, number of used features: 92
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Fold 3
[LightGBM] [Info] Number of positive: 68, number of negative: 68
[LightGBM] [Info] Auto-choosing col-wise multi-threa

Unnamed: 0,accuracy,precision,sensitivity,specificity,f1,ROCAUC,PRAUC
0,0.735294,0.75,0.705882,0.764706,0.727273,0.735294,0.676471
1,0.735294,0.785714,0.647059,0.823529,0.709677,0.735294,0.684874
2,0.617647,0.666667,0.470588,0.764706,0.551724,0.617647,0.578431
3,0.764706,0.736842,0.823529,0.705882,0.777778,0.764706,0.695046
4,0.558824,0.555556,0.588235,0.529412,0.571429,0.558824,0.53268
mean,0.682353,0.698956,0.647059,0.717647,0.667576,0.682353,0.6335


In [29]:
# visualize the models
def map_nan_index(edf):
    sr = 1000
    window_size = 250
    raw = mne.io.read_raw(edf)
    raw_data = raw.pick([raw.ch_names[0]])
    if raw_data.info['sfreq'] != sr:
        raw_data.resample(sr)
    raw_data = raw_data.get_data(reject_by_annotation='NaN')[0]
    map = []

    for j, i in enumerate(range(0, len(raw_data), window_size)):
        curr_block = raw_data[i: i + window_size]
        if i + window_size < len(raw_data):
            if not np.isnan(curr_block).any():
                map.append(j)
    return map

# a new model includes all the data
z_model = LGBMClassifier().fit(x, y)
y_proba = z_model.predict_proba(clean_feat).T
y_scalp = [p > 0.8 for p in y_proba[1]]
print(sum(y_scalp), sum(y_depth))
index_map = map_nan_index(r'D:\Bonn\012fn1\P701_mtl_clean.edf')
scalp_indexes = np.where(np.array(y_scalp) == True)[0]
scalp_onsets = [index_map[int(x)] / 4 for x in scalp_indexes]
depth_indexes = np.where(y_depth == 1)[0]
depth_onsets = [index_map[int(x)] / 4 for x in depth_indexes]
both = [x for x in scalp_onsets if x in depth_onsets]
depth_without_both = [x for x in depth_onsets if x not in both]
scalp_without_both = [x for x in scalp_onsets if x not in both]
# all annot
annot = mne.Annotations(scalp_without_both, [0.25] * len(scalp_without_both), ['scalp'] * len(scalp_without_both)).append(depth_without_both, [0.25] * len(depth_without_both), ['depth'] * len(depth_without_both)).append(both, [0.25] * len(both), ['both'] * len(both))
raw.set_annotations(annot)

raw.plot(duration=30, scalings='auto')

[LightGBM] [Info] Number of positive: 85, number of negative: 85
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.000220 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 5338
[LightGBM] [Info] Number of data points in the train set: 170, number of used features: 92
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
259 85
Extracting EDF parameters from D:\Bonn\012fn1\P701_mtl_clean.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Using qt as 2D backend.


<mne_qt_browser._pg_figure.MNEQtBrowser(0x1941c16eb90) at 0x000001964D04C3C0>

In [None]:
v2_model = joblib.load('validation_models_v2\\lgbm_s13_f98_b_sym.pkl')