In [1]:
%load_ext autoreload
%autoreload 2

# IMPORT

In [2]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import pandas as pd
import jobtools
from configuration import *
from params import *
from sigma_coupling import sigma_coupling_job
from bibliotheque import init_da

# PARAMS

In [3]:
sigma_coupling_chan = sigma_coupling_figs_params['sigma_coupling_chan']
transition_ratio = sigma_coupling_figs_params['sigma_coupling_params']['transition_ratio']
fig_global_cycle_type = sigma_coupling_figs_params['fig_global_cycle_type']
save_article = sigma_coupling_figs_params['save_article']
save_article

True

In [4]:

if save_article:
    # save_folder = article_folder 
    save_folder = base_folder / 'results' / 'sigma_coupling_figures'
    fig_format = '.tif'
    dpis = 300
else:
    save_folder = base_folder / 'results' / 'sigma_coupling_figures'
    fig_format = '.tif'
    dpis = 300

# PREPARE VARIABLES

In [5]:
concat_phase_freqs = []
concat_Ns = []

for run_key in run_keys:
    phase_freq_run_key = sigma_coupling_job.get(run_key)['sigma_coupling']
    df_n_run_key = pd.Series(data = phase_freq_run_key.attrs['data_n_cycle_averaged'],
                             index = phase_freq_run_key.coords['cycle_type'].values).to_frame().T
    df_n_run_key.insert(0, 'subject', run_key)
    # print(df_n_run_key)
    concat_Ns.append(df_n_run_key)
    concat_phase_freqs.append(phase_freq_run_key)


mean_phase_freqs = xr.concat(concat_phase_freqs, dim = 'subject').assign_coords({'subject':run_keys})
points = mean_phase_freqs.coords['point'].values
freqs = mean_phase_freqs.coords['freq'].values
Ns = pd.concat(concat_Ns).set_index('subject')


cycle_types = mean_phase_freqs.coords['cycle_type'].values

In [6]:
cycle_types

array(['all', 'spindled', 'unspindled', 'N2', 'N3', 'diff'], dtype=object)

# 1st FIGURE - One by subject : 6 subplots phase-freq for the 6 types of cycles (all, spindled, unspindled, N2,N3, diff)

In [9]:
# for subject in run_keys:
for subject in run_keys[:2]:
    subject_title = 'P{}'.format(subject.split('S')[1])

    nrows = len(channels_events_select)
    fig, axs = plt.subplots(nrows=nrows, ncols = cycle_types.size, sharex = True, sharey = True, constrained_layout = True, figsize = (20,20))
    fig.suptitle(f'{subject_title}', fontsize = 20)

    for row, chan in enumerate(channels_events_select):
        for col, cycle_type in enumerate(cycle_types):

            ax = axs[row, col]
            N = Ns.loc[subject,cycle_type]
            data = mean_phase_freqs.loc[subject, cycle_type,chan,:,:].data.T
            im = ax.pcolormesh(points, freqs, data)
            ax.axvline(x = transition_ratio, color = 'r')

            ax.set_title(f'{chan} - {cycle_type} - N = {N}')
            if col == 0:
                ax.set_ylabel('Freq [Hz]')
            ax.set_xlabel('Phase')
            ax.set_xticks([0, 0, transition_ratio, 1])
            ax.set_xticklabels([0, 0, 'inspi-expi', '360°'], rotation=45, fontsize=10)

    # fig.savefig(save_folder / f'{subject_title}_phase_freq{fig_format}', bbox_inches = 'tight', dpi = dpis)
    plt.close()

# 2nd FIGURE - 20 subplots for the 20 subjects, with phase-freq corresponding phase-freq maps of the chosen cycle type

In [None]:
subject_array = np.array(run_keys).reshape(4,5)
nrows = subject_array.shape[0]
ncols = subject_array.shape[1]

for chan in channels_events_select:
    fig, axs = plt.subplots(nrows = nrows, ncols = ncols, figsize = (20,10), sharex = True, sharey = True, constrained_layout = True)
    fig.suptitle(f'{chan}', fontsize = 20, y = 1.05)
    for r in range(nrows):
        for c in range(ncols):
            ax = axs[r,c]
            subject = subject_array[r,c]
            data = mean_phase_freqs.loc[subject, fig_global_cycle_type,chan,:,:].data.T

            im = ax.pcolormesh(points, freqs, data)
            ax.axvline(x = transition_ratio, color = 'r')
            subject_title = 'P{}'.format(subject.split('S')[1])
            ax.set_title(subject_title)
            if c == 0:
                ax.set_ylabel('Freq [Hz]')
            if r == nrows-1:
                ax.set_xlabel('Phase')
                ax.set_xticks([0, 0, transition_ratio, 1])
                ax.set_xticklabels([0, 0, 'inspi-expi', '360°'], rotation=45, fontsize=10)

            plt.colorbar(im, ax = ax, label = 'Power in µV**2')

    fig.savefig(save_folder / f'mean_phase_freq_subjects_detailed_{chan}{fig_format}', bbox_inches = 'tight', dpi = dpis)
    plt.close()

# 3th FIGURE - Mean across subjects of the phase-freq matrix of set cycles types and channel

In [14]:
def zscore(da):
    return (da - da.mean()) / da.std()

In [None]:
mean_phase_freqs_zscored = init_da({'subject':run_keys, 'chan':channels_events_select, 'point':points, 'freq':freqs})

In [None]:
for subject in run_keys:
    for chan in channels_events_select:
        mean_phase_freqs_zscored.loc[subject, chan , : ,: ] = zscore(mean_phase_freqs.sel(cycle_type = fig_global_cycle_type, subject = subject, chan = chan))

In [None]:
nrows = 3
ncols = 4

fig, axs = plt.subplots(nrows = nrows, ncols = ncols, figsize = (22,10), constrained_layout = False, sharex = False, sharey = False)

axs[2,3].remove()

delta = 0.1
vmin = mean_phase_freqs_zscored.quantile(delta)
vmax = mean_phase_freqs_zscored.quantile(1 - delta)

for ax, chan in zip(axs.flat, channels_events_select):

    data = mean_phase_freqs_zscored.sel(chan = chan).mean('subject').data.T
    im = ax.pcolormesh(points, freqs, data, vmin=vmin, vmax = vmax)
    ax.axvline(x = transition_ratio, color = 'r')

    ax.set_ylabel('Freq [Hz]')
    # ax.set_xlabel('Respiration phase')
    ax.set_xticks([ 0, transition_ratio, 1])
    ax.set_xticklabels([ 0, 'inspi-expi', '360°'],fontsize=10)
    ax.set_title(chan)

ax_x_start = 1.02
ax_x_width = 0.01
ax_y_start = 0
ax_y_height = 1
cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
clb = fig.colorbar(im, cax=cbar_ax)
clb.ax.set_title('Normalized power [AU]',fontsize=10)

fig.savefig(save_folder / f'mean_phase_freq_across_subjects{fig_format}', bbox_inches = 'tight', dpi = dpis)
plt.close()

# 4th FIGURE - Mean across subjects of the phase-freq matrix , plot multiple cycle types and stages

In [10]:
chan_test = 'Fz'

In [15]:
cycle_types

array(['all', 'spindled', 'unspindled', 'N2', 'N3', 'diff'], dtype=object)

In [11]:
mean_phase_freqs_zscored_c_type = init_da({'subject':run_keys, 'cycle_type':cycle_types, 'chan':channels_events_select, 'point':points, 'freq':freqs})

In [None]:
for subject in run_keys:
    for chan in channels_events_select:
        for c_type in cycle_types:
            mean_phase_freqs_zscored_c_type.loc[subject, c_type, chan , : ,: ] = zscore(mean_phase_freqs.sel(cycle_type = c_type, subject = subject, chan = chan))
            
mean_phase_freqs_zscored_c_type_mean_sub = mean_phase_freqs_zscored_c_type.mean('subject')

In [None]:
c_types_fig = {'N2':['N2','spindled_N2','spindled_N2','diff_N2'],
               'N3':['N3','spindled_N3','spindled_N3','diff_N3']
              }

nrows = 2
ncols = 4

for chan in channels_events_select:
    
    fig, axs = plt.subplots(nrows = nrows, ncols = ncols, figsize = (22,10), constrained_layout = False, sharex = False, sharey = False)

    # delta = 0.1
    # vmin = mean_phase_freqs_zscored_c_type.quantile(delta)
    # vmax = mean_phase_freqs_zscored_c_type.quantile(1 - delta)
    
    for r, stage in enumerate(['N2','N3']):
        for c, c_type in enumerate(c_types_fig[stage]):
            ax = axs[r,c]
            
            data = mean_phase_freqs_zscored_c_type_mean_sub.loc[c_type, chan, :,:].values
            # im = ax.pcolormesh(points, freqs, data, vmin=vmin, vmax = vmax)
            im = ax.pcolormesh(points, freqs, data)
            ax.axvline(x = transition_ratio, color = 'r')

            ax.set_ylabel('Freq [Hz]')
            ax.set_xlabel('Respiration phase')
            ax.set_xticks([ 0, transition_ratio, 1])
            ax.set_xticklabels([ 0, 'inspi-expi', '360°'],fontsize=10)
            ax.set_title(f'{stage} - {c_type}')

    # ax_x_start = 1.02
    # ax_x_width = 0.01
    # ax_y_start = 0
    # ax_y_height = 1
    # cbar_ax = fig.add_axes([ax_x_start, ax_y_start, ax_x_width, ax_y_height])
    # clb = fig.colorbar(im, cax=cbar_ax)
    # clb.ax.set_title('Normalized power [AU]',fontsize=10)

    # fig.savefig(save_folder / f'mean_phase_freq_across_subjects{fig_format}', bbox_inches = 'tight', dpi = dpis)
    plt.show()