In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
from scipy import signal
import matplotlib.pyplot as plt
import xarray as xr
import ghibtools as gh
import tqdm
from params import *
from deform_tools import deform_to_cycle_template

%matplotlib inline

In [3]:
save = False

In [4]:
ratio = 0.3

In [5]:
frontal_spindles_only = False

In [6]:
def tf_cycle_stretch_expi(da, rsp_features, nb_point_by_cycle=1000, expi_ratio = ratio):
    # da = 3d da ('raw',freqs * time)
    da_stretch_cycle = None
    for mode in ['raw','normal']:
        clipped_times, times_to_cycles, cycles, cycle_points, deformed_data = deform_to_cycle_template(data = da.loc[mode,:,:].values.T,
                                                                                                       times = da.coords['time'].values , 
                                                                                                       cycle_times=rsp_features[['start_time','transition_time']].values, 
                                                                                                       nb_point_by_cycle=nb_point_by_cycle,
                                                                                                       inspi_ratio = expi_ratio)
        deformed = deformed_data.T    

        for cycle in cycles:
            data_of_the_cycle = deformed[:,cycle*nb_point_by_cycle:(cycle+1)*nb_point_by_cycle]
            if da_stretch_cycle is None:
                da_stretch_cycle = gh.init_da({'normalisation':['raw','normal'],'cycle' : cycles, 'freqs': da.coords['freqs'].values , 'point':np.arange(0,nb_point_by_cycle,1)})
            da_stretch_cycle.loc[mode, cycle, : , :] = data_of_the_cycle
    new_rsp_features = rsp_features[rsp_features.index.isin(list(cycles))]
    return da_stretch_cycle, new_rsp_features

In [7]:
def inspi_to_expi_rsp_features(patient):
    rsp_features = pd.read_excel(f'../df_analyse/resp_features_new_{patient}.xlsx', index_col = [0])
    spindles = pd.read_excel(f'../df_analyse/spindles_{patient}.xlsx', index_col = [0])
    spindles = spindles[spindles['stage'] == 'N2']
    if frontal_spindles_only:
        spindles = spindles[spindles['Channel'].isin(['Fp1-C3','Fp2-C4'])] # ONLY FRONTAL SPINDLES !!
    
    rows = []
    for i in rsp_features.index:
        if i != rsp_features.index[-1]:
            start = rsp_features.loc[i,'inspi_time']
            transition = rsp_features.loc[i, 'expi_time']
            stop = rsp_features.loc[i+1, 'inspi_time']

            start_idx = rsp_features.loc[i,'inspi_index']
            transition_idx = rsp_features.loc[i, 'expi_index']
            stop_idx = rsp_features.loc[i+1, 'inspi_index']

            patient = rsp_features.loc[i,'patient']

            expi_duration = stop - transition
            inspi_duration = transition - start
            cycle_duration = stop - start
            ratio_transition = inspi_duration / cycle_duration
            
            n_spindles_in_cycle = spindles[(spindles['Peak'] >= start) & (spindles['Peak'] < stop)].shape[0]
            if n_spindles_in_cycle == 0:
                encoding = 0
            else:
                encoding = 1
            
            if cycle_duration < 20:
                rows.append([patient, start_idx, transition_idx, stop_idx, start , transition , stop, expi_duration, inspi_duration,  cycle_duration, ratio_transition, encoding, n_spindles_in_cycle])
    df_expi_rsp_features = pd.DataFrame(rows, columns = ['patient','start_idx','transition_idx','stop_idx','start_time','transition_time','stop_time','expi_duration','inspi_duration', 'cycle_duration','ratio_transition','spindled','n_spindles'])
    return df_expi_rsp_features

In [8]:
def get_midx_stretched_da(patient, df_all_expi_features):
    da = xr.load_dataarray(f'../dataarray/da_tf_frontal_{patient}.nc')
    rsp_features_patient = df_all_expi_features[df_all_expi_features['patient'] == patient]
    print(patient, rsp_features_patient['spindled'].value_counts())
    da_stretch_cycle, new_rsp_features = tf_cycle_stretch_expi(da = da , rsp_features = rsp_features_patient)
    da_stretch_midx = gh.midx_da(da = da_stretch_cycle , dim = 'cycle', midx_labels = ('c','spindling','n_spindles'), midx_coords = [new_rsp_features.index, list(new_rsp_features.loc[:,'spindled']), list(new_rsp_features.loc[:,'n_spindles'])])
    mean_cycle_da = da_stretch_midx.groupby('spindling').mean()
    all_cycles = mean_cycle_da.mean('spindling')
    unspindled = mean_cycle_da.sel(spindling=0).drop('spindling')
    spindled = mean_cycle_da.sel(spindling=1).drop('spindling')
    diff = spindled - unspindled
    da_return = xr.concat([all_cycles, spindled, unspindled, diff], dim = 'spindle_mode').assign_coords({'spindle_mode':['all','spindled','unspindled','diff']})
    return da_return

In [None]:
concat = []
for patient in patients:
    concat.append(inspi_to_expi_rsp_features(patient))
df_all_expi_features = pd.concat(concat)

In [None]:
df_all_expi_features

In [None]:
df_all_expi_features.describe()

In [None]:
df_all_expi_features['n_spindles'].value_counts()

In [None]:
concat= []
for patient in patients:
    concat.append(get_midx_stretched_da(patient, df_all_expi_features))
da_all = xr.concat(concat, dim = 'participant').assign_coords({'participant':patients})

In [None]:
da_all = xr.concat(concat, dim = 'participant').assign_coords({'participant':patients})

In [None]:
da_all

In [None]:
da_all.mean('participant').sel(normalisation = 'normal').loc['diff',11:16,:].mean('freqs').plot.line(x='point')

In [None]:
da = da_all.sel(normalisation = 'normal')

In [None]:
xvline = ratio * 1000
inspi_label_pos = xvline / 2
expi_label_pos = 1000 - ((1000-xvline)/2)

In [None]:
if frontal_spindles_only:
    save_title_append = '_frontal_spindle_only'
else:
    save_title_append = None

In [None]:
for mode in da.coords['spindle_mode'].values:
    fig, axs = plt.subplots(ncols = 5, nrows = 2, constrained_layout = True, figsize = (20,8))
    fig.suptitle(f'Mean normalized respiration cycle Phase-Frequency Plot : cycles = {mode} (inspi start)', fontsize = 20, y = 1.05)
    for row, sublists in enumerate([ patients[:5] , patients[5:] ]): 
        for col, patient in enumerate(sublists):
            ax = axs[row, col]
            min_freq = 11
            max_freq = 17
            data = da.loc[patient, mode , min_freq:max_freq,:].values
            ax.pcolormesh(da.coords['point'], da.coords['freqs'].loc[min_freq:max_freq], data)
            ax.set_title(patient)
            ax.set_ylabel('Freq [Hz]')
            ax.set_xlabel('Respiration Cycle Phase')
            ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
            ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
            ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
    if save:
        plt.savefig(f'../presentation_4/stretch_tf_by_patient_{mode}_inspi_start{save_title_append}', bbox_inches = 'tight')
    plt.show()

In [None]:
fig, axs = plt.subplots(ncols = 4, constrained_layout = True, figsize = (20,5))
fig.suptitle('Mean Patient * respiration cycles normalized Phase-Frequency Plot (inspi start)', fontsize = 20, y = 1.05)
min_freq = 11
max_freq = 16

for col, mode in enumerate(da.coords['spindle_mode'].values):
    ax = axs[col]
    data = da.loc[:, mode , min_freq:max_freq,:].mean('participant').values
    ax.pcolormesh( da.coords['point'] , da.coords['freqs'].loc[min_freq:max_freq] , data)
    ax.set_title(mode)
    ax.set_ylabel('Freq [Hz]')
    ax.set_xlabel('Respiration Cycle Phase')
    ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
    ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
    ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
if save:
    plt.savefig(f'../presentation_4/stretch_tf_mean_patient_modes_inspi_start{save_title_append}', bbox_inches = 'tight')
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout = True, figsize = (10,7))
min_freq = 12
max_freq = 13
data = da.loc[:, 'diff' , min_freq:max_freq,:].mean(['participant','freqs']).values
ax.plot(da.coords['point'], data)
ax.set_title('Mean Patient * Spindled respiration cycles normalized Sigma Power')
ax.set_ylabel('Sigma Power')
ax.set_xlabel('Respiration Cycle Phase')
ax.vlines(x = xvline, ymin = min(data), ymax=max(data), color = 'r')
ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
if save:
    plt.savefig(f'../presentation_4/stretch_sigma_power_mean_patient_inspi_start{save_title_append}', bbox_inches = 'tight')
plt.show()