In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
from scipy import signal
import matplotlib.pyplot as plt
import xarray as xr
from pyedflib import highlevel
import yasa
import tqdm
import glob
import ghibtools as gt
import mne
from params import *

## PARAMS

In [3]:
input_file = glob.glob(f'../data/{patient}/*.edf')[0]

In [4]:
save_presentation = False

In [5]:
save_da = True

## TOOLS

In [6]:
def eeg_mono_to_bipol(da, dérivations):
    da_bipol = []
    for bipol in dérivations : 
        pole1, pole2 = bipol.split('-')[0] , bipol.split('-')[1]
        if pole1 in ['EOGDt','EOGG']:
            chan1 = pole1
            chan2 = f'EEG {pole2}'
        else:
            chan1 = f'EEG {pole1}'
            chan2 = f'EEG {pole2}'
        sig1 = da.loc[chan1,:]
        sig2 = da.loc[chan2,:]
        bipol_sig = sig1 - sig2
        da_bipol.append(bipol_sig)
    da_bipolaire = xr.concat(da_bipol, dim = 'chan')
    da_bipolaire = da_bipolaire.assign_coords({'chan':dérivations})
    return da_bipolaire

In [7]:
def to_notch_da(da, fs):
    chans = list(da.coords['chan'].values)
    da_notched = da.copy()
    for channel in chans:
        sig = da.sel(chan = channel).values
        sig_notched = gt.notch(sig, fs=fs)
        if channel == 'ECG':
            sig_notched = -sig_notched
        da_notched.loc[channel, :] = sig_notched
    return da_notched

In [8]:
def da_to_mne_object(da, srate):
    ch_names = list(da.coords['chan'].values)
    sfreq = srate
    info = mne.create_info(ch_names, sfreq, ch_types='misc', verbose=None)
    raw = mne.io.RawArray(data = da.values, info=info, first_samp=0, copy='auto', verbose=None)
    return raw

## LOAD RAW

In [None]:
signals, signal_headers, header = highlevel.read_edf(input_file)

In [None]:
srate = signal_headers[0]['sample_rate']
chans = [chan_dict['label'] for chan_dict in signal_headers]
eeg_chans = [ chan for chan in chans if 'EEG' in chan]
eeg_chans_clean = [ chan.split(' ')[1] for chan in chans if 'EEG' in chan]
eog_chans = [ chan for chan in chans if 'EOG' in chan] 
physio_chans = [ chan for chan in chans if not 'EEG' in chan] 
unit = signal_headers[0]['dimension']
time = np.arange(0 , signals.shape[1] / srate , 1 / srate)
dérivations = ['Fp2-C4' , 'C4-T4', 'T4-O2' , 'Fz-Cz' , 'Cz-Pz' , 'Fp1-C3', 'C3-T3', 'T3-O1', 'EOGDt-A1', 'EOGG-A2']
print(f'Srate : {srate}')
print(f'Total duration : {int(time[-1])} seconds')
print(f'Nb of eeg electrodes : {len(eeg_chans)}')
print(f'Nb physios electrodes : {len(physio_chans)}')

## PUT RAW IN DA

In [None]:
da = xr.DataArray(data = signals, dims = ['chan','time'] , coords = {'chan':chans , 'time':time})

## DISPLAY MONOPOLAR

* raw signals

In [None]:
plt.figure()
da.loc[eeg_chans, 10:10.5].plot.line(x='time', hue = 'chan', size = 10)
if save_presentation:
    plt.savefig(f'../presentation/raw_signals_{patient}')
plt.show()

* psd of raw signals

In [None]:
plt.figure()
for chan in eeg_chans :
    f, Pxx = gt.spectre(da.loc[chan,:].values, srate=srate, wsize = 50)
    plt.plot(f, Pxx)
if save_presentation:
    plt.savefig(f'../presentation/psd_raw_mono_{patient}')
plt.show()

## MONO TO BIPOLAR 

* signals

In [None]:
da_bipol = eeg_mono_to_bipol(da, dérivations = dérivations)

In [None]:
plt.figure()
da_bipol.loc[:,10:11].plot.line(x='time', hue = 'chan', size = 10)
if save_presentation:
    plt.savefig(f'../presentation/raw_bipol_signals_{patient}')
plt.show()

* psd

In [None]:
plt.figure()
for dérivation in dérivations:
    f, Pxx = gt.spectre(da_bipol.sel(chan = dérivation).values, srate=srate, wsize = 50)
    plt.plot(f, Pxx, label = dérivation)
    plt.legend()
if save_presentation:
    plt.savefig(f'../presentation/psd_raw_bipol_{patient}')
plt.show()

In [None]:
plt.figure()
for dérivation in dérivations:
    f, Pxx = gt.spectre(da_bipol.sel(chan = dérivation).values, srate=srate, wsize = 50)
    plt.plot(f[:1000], Pxx[:1000], label = dérivation)
    plt.legend()
plt.show()

* join bipol eeg to physios

In [None]:
da_physios = da.loc[physio_chans,:]

In [None]:
da_all = xr.concat([da_bipol , da_physios], dim = 'chan')

In [None]:
da_all.loc[:,20:25].plot.line(x='time', row = 'chan')

In [None]:
da_all.loc['DEBIT',20:40].plot()

In [None]:
da_all_notched = to_notch_da(da_all, fs=srate)

In [None]:
plt.figure()
for dérivation in dérivations:
    f, Pxx = gt.spectre(da_all_notched.sel(chan = dérivation).values, srate=srate, wsize = 50)
    plt.plot(f, Pxx, label = dérivation)
    plt.legend()
if save_presentation:
    plt.savefig(f'../presentation/psd_post_notch_{patient}')
plt.show()

In [None]:
raw = da_to_mne_object(da_all_notched, srate=srate)

In [None]:
eeg_names = [ chan for chan in dérivations if not chan in ['EOGDt-A1','EOGG-A2']]
eeg_names

## **YASA**

* read human marking

In [None]:
txt_hypno_path = glob.glob(f'../data/{patient}/*AhypnoEXP.txt')[0]

In [None]:
hypno = np.loadtxt(txt_hypno_path, dtype = str)

In [None]:
hypno = hypno[:,2]
pd.Series(hypno).value_counts()

* yasa sleep staging by giving "preferentially a central electrode" = C4-T4, eog ("preferentially, the left LOC channel referenced either to the mastoid (e.g. E1-M2) or Fpz.") = Left EOG reref to A1 , emg = menton ("Preferentially a chin electrode.")

In [None]:
sls = yasa.SleepStaging(raw , eeg_name = 'C4-T4' , eog_name = 'EOGG-A2', emg_name='Menton')

In [None]:
y_pred = sls.predict()
y_pred

In [None]:
# What are the predicted probabilities of each sleep stage at each epoch?
predict_probas = sls.predict_proba()
predict_probas

* plot yasa sleep staging

In [None]:
# Plot the predicted probabilities
plt.figure()
sls.plot_predict_proba();
if save_presentation:
    plt.savefig('../presentation/yasa_sleep_staging')
plt.show()

* display human vs yasa sleep staging

In [None]:
n_epochs_ia = pd.Series(y_pred).size
ia_stages = pd.Series(y_pred).value_counts().reindex(index = stages_labels) # IA staging
n_epochs_human = pd.Series(hypno).size
human_stages = pd.Series(hypno).value_counts().rename({'REM':'R'}).reindex(index = stages_labels)  # human staging
df_staging_compare = pd.concat([ia_stages, human_stages] , axis = 1).rename(columns = {0:'ia',1:'human'}).T
df_staging_compare.insert(0,  'patient' , patient)
df_staging_compare = df_staging_compare.reset_index().set_index(['index','patient'])
if save_da:
    df_staging_compare.to_excel(f'../df_analyse/df_staging_compare_{patient}.xlsx')
df_staging_compare

* make a 30 sec epoched da multi_indexed via epoch num & epoch sleep stage label

In [None]:
y_pred_series = pd.Series(y_pred)
y_pred_series.value_counts()

In [None]:
y_pred_series

In [None]:
epochs = np.arange(0, y_pred_series.size, 1)
da_epoched = gt.init_da({'chan':da_all_notched.coords['chan'].values, 'epoch':epochs, 'time':np.arange(0, 30, 1 / srate)})

In [None]:
da_epoched.shape

In [None]:
for epoch in epochs:
    start = (epoch) * 30  
    stop = start + 30
    epoch_slice = da_all_notched.loc[:,start:stop].values[:,:-1]
    da_epoched.loc[:, epoch, :] = epoch_slice

In [None]:
da_epoched

In [None]:
da_epoched.loc['Fp2-C4',:,:].plot()

In [None]:
midx = [epochs, y_pred_series.values]
midx_ready = pd.MultiIndex.from_arrays(midx, names=('epochs', 'stages'))
da_epoched_midx = da_epoched.assign_coords(coords = {'epoch':midx_ready})

In [None]:
da_epoched_midx

In [None]:
def stack_stages(da_midx, stage, srate = srate):
    da = da_midx.sel(stages = stage)
    epochs = list(da.coords['epochs'].values)
    to_concat = []
    for epoch in epochs:
        da_epoch = da.sel(epochs = epoch).reset_coords(drop=True)
        to_concat.append(da_epoch)
    da_concat = xr.concat(to_concat, dim = 'time')
    time_vector = np.arange(0,da_concat.shape[1]/srate, 1/srate)
    da_return = da_concat.assign_coords({'time':time_vector}) 
    return da_return

In [None]:
da_stacked_stage = stack_stages(da_midx = da_epoched_midx, stage=stage_to_study)

In [None]:
da_stacked_stage

* stack the 30 sec epoched da multi_indexed to get one signal by chan and by stage (concat staged epochs)

In [None]:
def stack_all_stages(da_midx, stages=stages_labels, srate=srate):
    concat = []
    for stage in stages:
        da_stage = stack_stages(da_midx=da_midx, stage=stage, srate = srate)
        concat.append(da_stage)
    da_all_stages = xr.concat(concat, dim = 'stage')
    da_all_stages = da_all_stages.assign_coords({'stage':stages})
    return da_all_stages

In [None]:
da_all_stages = stack_all_stages(da_midx=da_epoched_midx)

In [None]:
da_all_stages

* duration by stage

In [None]:
def get_duration_by_stade(da_all=da_all_stages, stages=stages_labels):
    for stage in stages: 
        da_sel = da_all.sel(stage=stage).dropna(dim='time')
        print( stage ,int( da_sel.coords['time'].values[-1]))

In [None]:
get_duration_by_stade()

In [None]:
da_all_stages.loc['N1','Fp2-C4',0:100].plot.line(x='time', hue = 'stage')

In [None]:
da_all_stages.loc['N1','ECG',:].plot()

In [None]:
da_all_stages.loc['N1',['EOGG','EOGG-A2'],30:50].plot.line(x='time', hue='chan')

In [None]:
if save_da:
    da_all_stages.to_netcdf(f'../dataarray/da_staged_{patient}.nc')

In [None]:
print(f'Srate : {srate}')
print(f'Total duration : {int(time[-1])} seconds')
print(f'Nb of eeg electrodes : {len(eeg_chans)}')
print(f'Nb physios electrodes : {len(physio_chans)}')
print(f"YASA trouve {n_epochs_ia} époques de 30 secs, hypnogramme human made en note {n_epochs_human} soit {n_epochs_human - n_epochs_ia} de plus soit {(n_epochs_human - n_epochs_ia) * 30} secondes de plus")
get_duration_by_stade()