In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
import pandas as pd
from scipy import signal
import matplotlib.pyplot as plt
import xarray as xr
import ghibtools as gh
import tqdm
from params import *
from deform_tools import deform_to_cycle_template

%matplotlib inline

In [3]:
save = False
# save = True

# save_type = 'png'
# save_type = 'eps'
save_type = 'tiff'

In [4]:
ratio = 0.3

In [5]:
frontal_spindles_only = False

In [6]:
participants = patients

In [7]:
def tf_cycle_stretch_expi(da, rsp_features, nb_point_by_cycle=1000, expi_ratio =ratio):
    # da = 3d da ('raw',freqs * time)
    da_stretch_cycle = None

    clipped_times, times_to_cycles, cycles, cycle_points, deformed_data = deform_to_cycle_template(data = da.values.T,
                                                                                                   times = da.coords['time'].values , 
                                                                                                   cycle_times=rsp_features[['start_time','transition_time']].values, 
                                                                                                   nb_point_by_cycle=nb_point_by_cycle,
                                                                                                   inspi_ratio = expi_ratio)
    deformed = deformed_data.T    

    for cycle in cycles:
        data_of_the_cycle = deformed[:,cycle*nb_point_by_cycle:(cycle+1)*nb_point_by_cycle]
        if da_stretch_cycle is None:
            da_stretch_cycle = gh.init_da({'cycle' : cycles, 'freqs': da.coords['freqs'].values , 'point':np.arange(0,nb_point_by_cycle,1)})
        da_stretch_cycle.loc[cycle, : , :] = data_of_the_cycle
        
    da_return = xr.concat([da_stretch_cycle , (da_stretch_cycle - np.mean(da_stretch_cycle.values)) / np.std(da_stretch_cycle.values)], dim = 'normalisation')
    da_return = da_return.assign_coords({'normalisation':['raw','normal']})
    new_rsp_features = rsp_features[rsp_features.index.isin(list(cycles))]
    return da_return, new_rsp_features

In [8]:
def inspi_to_expi_rsp_features(participant, rsp_features):
    spindles = pd.read_excel(f'../df_analyse/spindles_{participant}.xlsx', index_col = [0])
    if frontal_spindles_only:
        spindles = spindles[spindles['Channel'].isin(['Fp1-C3','Fp2-C4'])] # ONLY FRONTAL SPINDLES !!
    
    rows = []
    for i in rsp_features.index:
        if not i+1 in rsp_features.index: # tricks for cycles removed at detection cycle step
            continue
        if i != rsp_features.index[-1]:
            start = rsp_features.loc[i,'inspi_time']
            transition = rsp_features.loc[i, 'expi_time']
            stop = rsp_features.loc[i+1, 'inspi_time']

            start_idx = rsp_features.loc[i,'inspi_index']
            transition_idx = rsp_features.loc[i, 'expi_index']
            stop_idx = rsp_features.loc[i+1, 'inspi_index']

            participant = rsp_features.loc[i,'participant']

            expi_duration = stop - transition
            inspi_duration = transition - start
            cycle_duration = stop - start
            ratio_transition = inspi_duration / cycle_duration
            
            n_spindles_in_cycle = spindles[(spindles['Peak'] >= start) & (spindles['Peak'] < stop)].shape[0]
            if n_spindles_in_cycle == 0:
                encoding = 0
            else:
                encoding = 1
            
            if cycle_duration < 20:
                rows.append([participant, start_idx, transition_idx, stop_idx, start , transition , stop, expi_duration, inspi_duration,  cycle_duration, ratio_transition, encoding, n_spindles_in_cycle])
    df_expi_rsp_features = pd.DataFrame(rows, columns = ['participant','start_idx','transition_idx','stop_idx','start_time','transition_time','stop_time','expi_duration','inspi_duration', 'cycle_duration','ratio_transition','spindled','n_spindles'])
    return df_expi_rsp_features

In [19]:
xr.open_dataarray(f'../dataarray/da_tf_frontal_P1.nc')

ValueError: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'h5netcdf']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html

In [9]:
def get_midx_stretched_da(participant, df_all_expi_features):
    da = xr.load_dataarray(f'../dataarray/da_tf_frontal_{participant}.nc')
    rsp_features_participant = df_all_expi_features[df_all_expi_features['participant'] == participant]
    da_stretch_cycle, new_rsp_features = tf_cycle_stretch_expi(da = da , rsp_features = rsp_features_participant)
    da_stretch_midx = gh.midx_da(da = da_stretch_cycle , dim = 'cycle', midx_labels = ('c','spindling','n_spindles'), midx_coords = [new_rsp_features.index, list(new_rsp_features.loc[:,'spindled']), list(new_rsp_features.loc[:,'n_spindles'])])
    all_cycles = da_stretch_midx.mean('cycle')
    unspindled = da_stretch_midx.sel(spindling=0).mean('cycle').drop('spindling')
    spindled = da_stretch_midx.sel(spindling=1).mean('cycle').drop('spindling')
    diff = spindled - unspindled
    da_return = xr.concat([all_cycles, spindled, unspindled, diff], dim = 'spindle_mode').assign_coords({'spindle_mode':['all','spindled','unspindled','diff']})
    return da_return

In [10]:
concat = []
for participant in participants:
    concat.append(inspi_to_expi_rsp_features(participant,rsp_features=pd.read_excel(f'../df_analyse/resp_features_{participant}.xlsx', index_col = 0)))
df_all_expi_features = pd.concat(concat)

In [11]:
df_all_expi_features

Unnamed: 0,participant,start_idx,transition_idx,stop_idx,start_time,transition_time,stop_time,expi_duration,inspi_duration,cycle_duration,ratio_transition,spindled,n_spindles
0,P1,155,434,913,0.605469,1.695312,3.566406,1.871094,1.089844,2.960938,0.368074,0,0
1,P1,913,1202,1717,3.566406,4.695312,6.707031,2.011719,1.128906,3.140625,0.359453,0,0
2,P1,1717,1996,2520,6.707031,7.796875,9.843750,2.046875,1.089844,3.136719,0.347447,0,0
3,P1,2520,2795,3324,9.843750,10.917969,12.984375,2.066406,1.074219,3.140625,0.342040,0,0
4,P1,3324,3557,3973,12.984375,13.894531,15.519531,1.625000,0.910156,2.535156,0.359014,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
5248,P20,5431832,5432149,5432757,21218.093750,21219.332031,21221.707031,2.375000,1.238281,3.613281,0.342703,1,1
5249,P20,5432757,5433084,5433677,21221.707031,21222.984375,21225.300781,2.316406,1.277344,3.593750,0.355435,1,1
5250,P20,5433677,5434019,5434632,21225.300781,21226.636719,21229.031250,2.394531,1.335938,3.730469,0.358115,0,0
5251,P20,5434632,5434978,5435616,21229.031250,21230.382812,21232.875000,2.492188,1.351562,3.843750,0.351626,1,1


In [12]:
df_all_expi_features.describe()

Unnamed: 0,start_idx,transition_idx,stop_idx,start_time,transition_time,stop_time,expi_duration,inspi_duration,cycle_duration,ratio_transition,spindled,n_spindles
count,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0,74242.0
mean,2147813.0,2148165.0,2148795.0,8389.894461,8391.269582,8393.730006,2.460425,1.37512,3.835545,0.360606,0.286132,0.829167
std,1337920.0,1337921.0,1337949.0,5226.250236,5226.252958,5226.363279,0.527464,0.29836,0.646628,0.058646,0.451955,1.671135
min,95.0,365.0,822.0,0.371094,1.425781,3.210938,0.726562,0.503906,1.507812,0.072464,0.0,0.0
25%,1021000.0,1021348.0,1021954.0,3988.280273,3989.638672,3992.006836,2.128906,1.167969,3.410156,0.323743,0.0,0.0
50%,2056564.0,2056936.0,2057472.0,8033.453125,8034.908203,8036.998047,2.421875,1.367188,3.800781,0.358854,0.0,0.0
75%,3143971.0,3144295.0,3145013.0,12281.135742,12282.402344,12285.208008,2.726562,1.550781,4.191406,0.395051,1.0,1.0
max,5435616.0,5435929.0,5436498.0,21232.875,21234.097656,21236.320312,9.964844,9.445312,13.28125,0.838809,1.0,17.0


In [13]:
df_all_expi_features['n_spindles'].value_counts()

0     52999
1      6598
2      4440
3      3407
4      2539
5      1935
6      1180
7       632
8       309
9        89
10       51
11       36
12       13
13        9
14        4
17        1
Name: n_spindles, dtype: int64

In [14]:
concat= []
for participant in participants:
    print(participant)
    concat.append(get_midx_stretched_da(participant, df_all_expi_features))
da_all = xr.concat(concat, dim = 'participant').assign_coords({'participant':participants})

P1


ValueError: found the following matches with the input file in xarray's IO backends: ['netcdf4', 'h5netcdf']. But their dependencies may not be installed, see:
https://docs.xarray.dev/en/stable/user-guide/io.html 
https://docs.xarray.dev/en/stable/getting-started-guide/installing.html

In [None]:
da_all = xr.concat(concat, dim = 'participant').assign_coords({'participant':participants})

In [None]:
da_all

In [None]:
da_all.mean('participant').sel(normalisation = 'normal').loc['diff',11:16,:].mean('freqs').plot.line(x='point')

In [None]:
da = da_all.sel(normalisation = 'normal')

In [None]:
xvline = ratio * 1000
inspi_label_pos = xvline / 2
expi_label_pos = 1000 - ((1000-xvline)/2)

In [None]:
if frontal_spindles_only:
    save_title_append = '_frontal_spindle_only'
else:
    save_title_append = None

In [None]:
for mode in da.coords['spindle_mode'].values:
    fig, axs = plt.subplots(ncols = 5, nrows = 2, constrained_layout = True, figsize = (20,8))
    fig.suptitle(f'Mean normalized respiration cycle Phase-Frequency Plot : cycles = {mode} (inspi start)', fontsize = 20, y = 1.05)
    for row, sublists in enumerate([ participants[:5] , participants[5:] ]): 
        for col, participant in enumerate(sublists):
            ax = axs[row, col]
            min_freq = 11
            max_freq = 17
            data = da.loc[participant, mode , min_freq:max_freq,:].values
            ax.pcolormesh(da.coords['point'], da.coords['freqs'].loc[min_freq:max_freq], data)
            ax.set_title(participant)
            ax.set_ylabel('Freq [Hz]')
            ax.set_xlabel('Respiration Cycle Phase')
            ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
            ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
            ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
    if save:
        plt.savefig(f'../presentation_4/stretch_tf_by_participant_{mode}_inspi_start{save_title_append}', bbox_inches = 'tight')
    plt.show()

In [None]:
fig, axs = plt.subplots(ncols = 3, nrows = 3, constrained_layout = True, figsize = (16,10))
# fig.suptitle(f'Mean normalized respiration cycle Phase-Frequency Plot : cycles = {mode} (inspi start)', fontsize = 20, y = 1.05)
for row, sublists in enumerate([ participants[:3] , participants[3:6], participants[6:] ]): 
    for col, participant in enumerate(sublists):
        ax = axs[row, col]
        min_freq = 10
        max_freq = 16
        data = da.loc[participant, 'diff' , min_freq:max_freq,:].values
        m = ax.pcolormesh(da.coords['point'], da.coords['freqs'].loc[min_freq:max_freq], data)
        ax.set_title(participant)
        if participant in ['P1','P4','P7']:
            ax.set_ylabel('Freq [Hz]')
        if participant in ['P7','P9','P10']:
            ax.set_xlabel('Respiration Cycle Phase')
        ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
        # ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
        # ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
        ax.set_xticks([0, 0, 300, 500, 1000])
        ax.set_xticklabels([0, 0, 'transition i-e', 'Pi', '2*Pi'], rotation=0, fontsize=10)
        # ax.colorbar(m)
fig.colorbar(m, ax=axs.ravel().tolist(), label = 'Normalized Power [AU]')
# fig.delaxes(axs[1,4])

if save:
    if save_type == 'eps':
        plt.savefig(f'../article/stretch_tf_by_participant_{mode}.eps', format='eps', bbox_inches = 'tight', dpi = 300)
    elif save_type == 'png': 
        plt.savefig(f'../article/stretch_tf_by_participant_{mode}', bbox_inches = 'tight', dpi = 300)
    elif save_type == 'tiff':
        plt.savefig(f'../article/stretch_tf_by_participant_{mode}.tif', format='tif', bbox_inches = 'tight', dpi = 300)
plt.show()

In [None]:
# participant_to_subjects = {'P1':'S1','P2':'S2','P3':'S3','P4':'S4','P5':'S5','P6':'S6','P7':'S7','P8':'S8','P9':'S9','P10':'S10'}

# fig, axs = plt.subplots(ncols = 3, nrows = 3, constrained_layout = True, figsize = (16,10))
# # fig.suptitle(f'Mean normalized respiration cycle Phase-Frequency Plot : cycles = {mode} (inspi start)', fontsize = 20, y = 1.05)
# for row, sublists in enumerate([ participants[:3] , participants[3:6], participants[6:] ]): 
#     for col, participant in enumerate(sublists):
#         ax = axs[row, col]
#         min_freq = 10
#         max_freq = 16
#         data = da.loc[participant, 'diff' , min_freq:max_freq,:].values
#         m = ax.pcolormesh(da.coords['point'], da.coords['freqs'].loc[min_freq:max_freq], data)
#         ax.set_title(participant_to_subjects[participant])
#         if participant in ['P1','P4','P7']:
#             ax.set_ylabel('Freq [Hz]')
#         if participant in ['P7','P9','P10']:
#             ax.set_xlabel('Respiration Cycle Phase')
#         ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
#         # ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
#         # ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
#         ax.set_xticks([0, 0, 300, 500, 1000])
#         ax.set_xticklabels([0, 0, 'transition i-e', 'Pi', '2*Pi'], rotation=0, fontsize=10)
#         # ax.colorbar(m)
# fig.colorbar(m, ax=axs.ravel().tolist(), label = 'Normalized Power [AU]')
# # fig.delaxes(axs[1,4])

# if save:
#     if save_type == 'eps':
#         plt.savefig(f'../article/stretch_tf_by_participant_{mode}.eps', format='eps', bbox_inches = 'tight', dpi = 300)
#     elif save_type == 'png': 
#         plt.savefig(f'../article/stretch_tf_by_participant_{mode}', bbox_inches = 'tight', dpi = 300)
#     elif save_type == 'tiff':
#         plt.savefig(f'../article/stretch_tf_by_participant_{mode}_Subject.tif', format='tif', bbox_inches = 'tight', dpi = 300)
# plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout = True, figsize = (10,5))
min_freq = 10
max_freq = 16

data = da.loc[:, 'diff' , min_freq:max_freq,:].mean('participant').values
m = ax.pcolormesh( da.coords['point'] , da.coords['freqs'].loc[min_freq:max_freq] , data)
ax.set_ylabel('Freq [Hz]')
ax.set_xlabel('Respiration Cycle Phase')
ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
ax.set_xticks([0, 0, 300, 500, 1000])
ax.set_xticklabels([0, 0, 'transition inspi-expi', 'Pi', '2*Pi'], rotation=0, fontsize=10)
fig.colorbar(m, label = 'Normalized Power [AU]')
if save:
    if save_type == 'png':
        plt.savefig(f'../article/stretch_tf_mean_participant', bbox_inches = 'tight', dpi = 300)
    elif save_type == 'eps':
        plt.savefig(f'../article/stretch_tf_mean_participant.eps', format = 'eps', bbox_inches = 'tight', dpi = 300)
    elif save_type == 'tiff':
        plt.savefig(f'../article/stretch_tf_mean_participant.tif', format = 'tif', bbox_inches = 'tight', dpi = 300)
plt.show()

In [None]:
fig, axs = plt.subplots(ncols = 4, constrained_layout = True, figsize = (20,5))
fig.suptitle('Mean participant * respiration cycles normalized Phase-Frequency Plot (inspi start)', fontsize = 20, y = 1.05)
min_freq = 11
max_freq = 16

for col, mode in enumerate(da.coords['spindle_mode'].values):
    ax = axs[col]
    data = da.loc[:, mode , min_freq:max_freq,:].mean('participant').values
    ax.pcolormesh( da.coords['point'] , da.coords['freqs'].loc[min_freq:max_freq] , data)
    ax.set_title(mode)
    ax.set_ylabel('Freq [Hz]')
    ax.set_xlabel('Respiration Cycle Phase')
    ax.vlines(x = xvline, ymin = min_freq, ymax=max_freq, color = 'r')
    ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
    ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
if save:
    plt.savefig(f'../presentation_5/stretch_tf_mean_participant_modes_inspi_start', bbox_inches = 'tight')
plt.show()

In [None]:
fig, ax = plt.subplots(constrained_layout = True, figsize = (10,7))
min_freq = 12
max_freq = 13
data = da.loc[:, 'diff' , min_freq:max_freq,:].mean(['participant','freqs']).values
ax.plot(da.coords['point'], data)
ax.set_title('Mean participant * Spindled respiration cycles normalized Sigma Power')
ax.set_ylabel('Sigma Power')
ax.set_xlabel('Respiration Cycle Phase')
ax.vlines(x = xvline, ymin = min(data), ymax=max(data), color = 'r')
ax.set_xticks([inspi_label_pos,xvline,expi_label_pos])
ax.set_xticklabels(['inspi','i-e transition','expi'], rotation=45, fontsize=10)
if save:
    plt.savefig(f'../presentation_4/stretch_sigma_power_mean_participant_inspi_start{save_title_append}', bbox_inches = 'tight')
plt.show()