In [11]:
# %reset
# #%matplotlib qt

In [12]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io
import scipy.stats
import mne
from mne.time_frequency import tfr_morlet
from mne.stats import permutation_cluster_1samp_test
import gc
import os
import copy
import pickle
from os.path import exists
import mne
import numpy as np
from mne import create_info
#from auxiliary import AuxFuncs
from IPython.utils import io

#gc.collect()

In [13]:
epochs_dir = 'C:\AnatArzData\YKM_data\epochs_and_evoked_allSubs'
prepro_name = "referenced"
import_type = "5Electorodes_plainEEGLAB"
trial_exclution_str = "_excOulierTrials-3.5" #"_excOulierTrials-2.5"

## Import epochs

In [14]:
class AuxFuncs:
    def __init__(self, import_path):
        # import epochs data and meta-data
        with open(import_path, "rb") as file:
            [allEvents_df, allEpochs_condIdDict, configu] = pickle.load(file)

        self.allEpochs_condIdDict = allEpochs_condIdDict
        self.allEvents_df = allEvents_df
        self.config = configu
        self.info, self.montage = self.get_subject_info()

    def get_subject_info(self, example_subject="32"):
        subject_setfile_wake_n = f'{self.config["set_files_dir"]}\\s_{example_subject}_wake_night_referenced.set'
        output_file_path = (
            f"{self.config['outputs_dir_path'] }/epochs_Wn_s{example_subject}_file"
        )

        if exists(output_file_path):
            with open(output_file_path, "rb") as config_dictionary_file:
                epochs_Wn_example_sub = pickle.load(config_dictionary_file)
                # print(epochs_Wn_example_sub)
        else:
            epochs_Wn_example_sub = mne.io.read_epochs_eeglab(
                subject_setfile_wake_n,
                events=None,
                event_id=None,
                eog=(),
                verbose=None,
                uint16_codec=None,
            )
            with open(output_file_path, "wb") as epochs_Wn_s_example_file:
                pickle.dump(epochs_Wn_example_sub, epochs_Wn_s_example_file)

        montage = mne.channels.make_standard_montage("GSN-HydroCel-128")
        epochs_Wn_example_sub_piked = epochs_Wn_example_sub.pick_channels(
            self.config["ch_names"]
        )
        epochs_Wn_example_sub_monatged = epochs_Wn_example_sub_piked.set_montage(
            montage
        )
        epochs_info = epochs_Wn_example_sub_monatged.info
        return epochs_info, epochs_Wn_example_sub_monatged

    def getEpochsPerCond(self, cond_df, y_ax, dataset, outputType="dict"):
        df_minTrials = copy.deepcopy(cond_df)
        df_minTrials = df_minTrials[
            (df_minTrials.SamplesCount > 0)
        ]  # discard cond with 0 enough samples
        keys = (str(key) for key in df_minTrials.Cond_id)
        epochs_allSamples = {str_key: dataset[str_key] for str_key in keys}
        if outputType == "array":
            epochs_allSamples_arr = np.zeros(
                (len(self.config["electrodes"]), len(y_ax), 0)
            )
            for epoch_key in epochs_allSamples:
                epochs_allSamples_arr = np.concatenate(
                    (epochs_allSamples_arr, epochs_allSamples[epoch_key]), axis=2
                )
            return df_minTrials, epochs_allSamples_arr
        return df_minTrials, epochs_allSamples

    # output: [#epochs, #elect, #times]
    def getEpochsPerConstraint(self, constraints):
        curr_df = self.allEvents_df.copy(deep=True)
        # apply constraints
        for key in constraints:
            curr_df = curr_df[(curr_df[key] == constraints[key])]

        curr_df = curr_df[
            (curr_df.SamplesCount > 0)
        ]  # discard cond with 0 enough samples
        epochsPerCond = {}
        for key in curr_df.Cond_id:
            epochsPerCond[str(key)] = self.allEpochs_condIdDict[str(key)]

        # load the metadata first and them the data. More time efficient
        epochs_name_in_const = []
        epochs_trial_num_per_cond = []
        for cond in epochsPerCond:
            for curr_cond_trial in range(epochsPerCond[cond].shape[2]):
                if len(epochs_name_in_const) == 0:
                    epochs_name_in_const = [cond]
                    epochs_trial_num_per_cond = [curr_cond_trial]
                else:
                    epochs_name_in_const = np.append(epochs_name_in_const, cond)
                    epochs_trial_num_per_cond = np.append(
                        epochs_trial_num_per_cond, curr_cond_trial
                    )

        epochs_in_const = np.zeros(
            (
                len(epochs_name_in_const),
                len(self.config["electrodes"]),
                len(self.config["times"]),
            )
        )
        trial = 0
        for cond in epochsPerCond:
            for curr_cond_trial in range(epochsPerCond[cond].shape[2]):
                epochs_in_const[trial, :, :] = epochsPerCond[cond][
                    :, :, curr_cond_trial
                ]
                trial += 1

        return epochs_in_const, epochs_name_in_const, epochs_trial_num_per_cond

    # output: [#conds, #elect, #times]
    def getEvokedPerCondAndElectd(
        self,
        constraints,
        df,
        dataset,
        y_ax,
        outputType="array",
        tmin=-0.1,
        baseline=(None, 0),
    ):
        curr_df = copy.deepcopy(df)
        # apply constraints
        for key in constraints:
            curr_df = curr_df[curr_df[key] == constraints[key]]

        conds_df, epochsPerCond = self.getEpochsPerCond(curr_df, y_ax, dataset)
        evoked_perCond_andElectd = np.zeros(
            (len(epochsPerCond), np.size(self.config["electrodes"]), np.size(y_ax))
        )

        for cond_i, cond in enumerate(epochsPerCond):
            evoked_perCond_andElectd[cond_i] = np.squeeze(
                np.nanmean(epochsPerCond[cond], axis=2)
            )

        if outputType == "array":
            return conds_df, evoked_perCond_andElectd
        if outputType == "mne":
            mne_epochs = mne.EpochsArray(
                evoked_perCond_andElectd, self.info, tmin=tmin, baseline=baseline
            )
            return conds_df, mne_epochs

    def create_output_dir(self, dir_name):
        fig_output_dir_path = f"{self.config['outputs_dir_path']}/{dir_name}"
        if not exists(fig_output_dir_path):
            mkdir(fig_output_dir_path)
        return fig_output_dir_path


In [15]:
import_path = f'{epochs_dir}\\{import_type}{trial_exclution_str}.pkl'
aux = AuxFuncs(import_path)

allEpochs_perCond = aux.allEpochs_condIdDict
allEvents_df = aux.allEvents_df
c = aux.config
times = c['times']
time0_i = c['time0_i']

fig_output_dir = f"{c['outputs_dir_path']}/timefreq_clusterPerm"
if not os.path.exists(fig_output_dir):
    os.mkdir(fig_output_dir)

In [16]:
def applyDesign(ax,title=''):
    ax.get_figure().patch.set_facecolor('#f5f1ecff')
    ax.set_facecolor('silver')
    ax.set_title(title)
    ax.legend(loc = 'upper right',prop={'size': 10})
    ax.axvline(x=0,color='gray', linestyle='--',label ="_nolegend_")
    ax.axhline(y=0, color='gray', linestyle='-',label ="_nolegend_")
    ax.set_ylabel('magnitude')
    ax.tick_params(axis='both', which='major', labelsize=14)
    ax.set_xlabel('msec')

## Time-freq functions

In [17]:
cond_1 = {'TOA_cond':'Rand','Vigilance':'N3'}
cond_2 = {'TOA_cond':'Fixed','Vigilance':'N3'}
contrast_name = "allSubs_RandvsFixed_N3"

# Parameters:
# ==================
# zscore baseline
# Tail is 0, so the statistic is thresholded on both sides of the distribution.
baseline_period = (None, 0) ## For timefreq analysis
p_value = 0.05 # default 0.05 # for clusters
decim = 1# default 2 for testing.. For reals- 1
n_permutations=1000 # default 1k for testing. For reals - 10K
min_freq = 4
freqs = np.arange(min_freq,80, 1)  # define frequencies of interest
n_cycles = freqs / min_freq # different number of cycle per frequency

tail = 0 # 0 = two-tailed test (for clusters (as we want positive and negative clusters))
p_value_pixels = 0.05

test_per_voxel='ttest'
is_tfr_performAverageBaseline=True

In [18]:
a= np.zeros((len(c['electrodes']), len(freqs),len(times)))

In [19]:
def plot_clusters_map(ax1,T_obs,ch_idx,tfr_epochs,t_thresh,T_obs_plot):
    vmax = np.max(np.abs(T_obs))
    vmin = -vmax
    ax1.imshow(T_obs[ch_idx], cmap=plt.cm.gray,
               extent=[times[0], times[-1], freqs[0], freqs[-1]],
               aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
    ax1.imshow(T_obs_plot[ch_idx], cmap=plt.cm.RdBu_r,
               extent=[times[0], times[-1], freqs[0], freqs[-1]],
               aspect='auto', origin='lower', vmin=vmin, vmax=vmax)
    #ax1.set_colorbar()
    ax1.set_xlabel('Time (ms)')
    ax1.set_ylabel('Frequency (Hz)')
    ax1.set_title(f'Induced power ({tfr_epochs.ch_names[ch_idx]})\nThreshold:{t_thresh}\ncluster p_val={p_value},n_permutation={n_permutations}\ncontrast:{cond_1};{cond_2}')
def plot_elecds_erps(ax2,mean_epochs_time_diff,ch_idx,tfr_epochs):
    ax2.plot(times,mean_epochs_time_diff.T,color='blue')
    mean_electrods = np.nanmean(mean_epochs_time_diff,axis=0)
    ax2.plot(times,mean_electrods,color='yellow',label='mean')
    ax2.plot(times,mean_epochs_time_diff[ch_idx,:],label=tfr_epochs.ch_names[ch_idx],color='red')
    ax2.legend()
    ax2.set_xlim(times[0],times[-1])
    applyDesign(ax2,'ERP difference')

## Run many contrasts at once

In [20]:
def stat_fun_wilcox(X):
    result = scipy.stats.wilcoxon(X)
    return result.statistic

def getClustersPerConditions_5timesBaseline(c,cond_1,cond_2, contrast_name,cluster_range_index,is_tfr_performAverageBaseline=True,baseline_period=(None, 0),test_per_voxel='ttest'):
    if test_per_voxel=='ttest':
        test_per_voxel = mne.stats.ttest_1samp_no_p
    else:
        test_per_voxel = stat_fun_wilcox

    epochs_time_diff = []
    epochs_power_diff = []
    for sub in c['subs']:
        currContr_conds1 = cond_1.copy()
        currContr_conds1['Subject'] = sub
        __, cont_epochs1 = aux.getEvokedPerCondAndElectd(currContr_conds1,allEvents_df,allEpochs_perCond,times, outputType='mne')

        if cond_2 == "baseline":
            baseline_epochs_data = copy.deepcopy(cont_epochs1.get_data())

            baseline_orig_data = baseline_epochs_data[:,:,0:time0_i]
            baseline_epochs_data[:,:,time0_i:time0_i*2] = baseline_orig_data
            baseline_epochs_data[:,:,time0_i*2:time0_i*3] = baseline_orig_data
            baseline_epochs_data[:,:,time0_i*3:time0_i*4] = baseline_orig_data
            baseline_epochs_data[:,:,time0_i*4:time0_i*5] = baseline_orig_data
            baseline_epochs_data[:,:,time0_i*5:] = baseline_orig_data[:,:,:13]

            with io.capture_output() as captured: # suppress output
                cont_epochs2 = mne.EpochsArray(baseline_epochs_data, aux.info,tmin=-0.1)
        else:
            currContr_conds2 = cond_2.copy()
            currContr_conds2['Subject'] = sub
            __, cont_epochs2 = aux.getEvokedPerCondAndElectd(currContr_conds2,allEvents_df,allEpochs_perCond,times, outputType='mne')

        if is_tfr_performAverageBaseline:
            with io.capture_output() as captured: # suppress output
                tfr_epochs1 = tfr_morlet(cont_epochs1, freqs, n_cycles=n_cycles, decim=decim, average=False, return_itc=False)
                tfr_epochs2 = tfr_morlet(cont_epochs2, freqs, n_cycles=n_cycles, decim=decim, average=False, return_itc=False)

            both_tfr_data = np.concatenate((tfr_epochs1.data,tfr_epochs2.data),axis=0)
            both_tfr_data_tfr = copy.deepcopy(tfr_epochs1)
            both_tfr_data_tfr.data = both_tfr_data

            #Correction is applied to all epoch and channel together in the following way: 1.Calculate the mean signal of the baseline period. 2.Subtract this mean from the entire epoch.
            with io.capture_output() as captured: # suppress output
                both_tfr_data_tfr.apply_baseline(mode='zscore', baseline=baseline_period)

            num_of_cond1_trials = tfr_epochs1.data.shape[0]
            epochs_power1 = np.mean(both_tfr_data_tfr.data[:num_of_cond1_trials,:,:,:],axis=(0,1)) # elec, freqs, time
            epochs_power2 = np.mean(both_tfr_data_tfr.data[num_of_cond1_trials:,:,:,:],axis=(0,1)) # elec, freqs, time

            epochs_time_diff.append(np.mean(cont_epochs1,axis=0) - np.mean(cont_epochs2,axis=0))
        else: # apply baseline seperately
            tfr_epochs1 = tfr_morlet(cont_epochs1, freqs, n_cycles=n_cycles, decim=decim, average=True, return_itc=False)
            tfr_epochs2 = tfr_morlet(cont_epochs2, freqs, n_cycles=n_cycles, decim=decim, average=True, return_itc=False)
            tfr_epochs1.apply_baseline(mode='zscore', baseline=baseline_period)
            tfr_epochs2.apply_baseline(mode='zscore', baseline=baseline_period)
            epochs_power1 = tfr_epochs1.data # elec, freqs, time
            epochs_power2 = tfr_epochs2.data # elec, freqs, time

            epochs_time_diff.append(np.mean(cont_epochs1.get_data(),axis=0) - np.mean(cont_epochs2.get_data(),axis=0))

        means_diff = epochs_power1-epochs_power2
        epochs_power_diff.append(means_diff)

    # (n_epochs, n_channels, n_freqs, n_times)
    epochs_power_diff_arr = np.zeros((len(c['subs']),len(c['electrodes']),len(freqs),len(times)))
    for s,subject in enumerate(c['subs']):
        epochs_power_diff_arr[s,:,:,:] = epochs_power_diff[s]

    #### Define adjacency for statistics
    tfr_epochs = tfr_epochs1
    sensor_adjacency, ch_names = mne.channels.find_ch_adjacency(tfr_epochs.info,ch_type=None)
    use_idx = [ch_names.index(ch_name) for ch_name in tfr_epochs.ch_names]
    sensor_adjacency = sensor_adjacency[use_idx][:, use_idx]
    assert sensor_adjacency.shape == (len(tfr_epochs.ch_names), len(tfr_epochs.ch_names))
    assert epochs_power_diff_arr.shape == (len(c['subs']), len(tfr_epochs.ch_names), len(tfr_epochs.freqs), len(tfr_epochs.times))
    adjacency = mne.stats.combine_adjacency(sensor_adjacency, len(tfr_epochs.freqs), len(tfr_epochs.times))
    assert adjacency.shape[0] == adjacency.shape[1] == len(tfr_epochs.ch_names) * len(tfr_epochs.freqs) * len(tfr_epochs.times)

    ### run cluster permutation
    degrees_of_freedom = len(c['subs']) - 1
    t_thresh = scipy.stats.t.ppf(1 - p_value_pixels / 2, df=degrees_of_freedom)
    T_obs, clusters, cluster_p_values, H0 = permutation_cluster_1samp_test(epochs_power_diff_arr[...,cluster_range_index[0]:cluster_range_index[1]], n_permutations=n_permutations, threshold=t_thresh, tail=tail,  out_type='mask', verbose=True, stat_fun=test_per_voxel)

    ############# plot
    T_obs_plot = np.nan * np.ones_like(T_obs)
    for c, clust_p_val in zip(clusters, cluster_p_values):
        if clust_p_val <= p_value:
            T_obs_plot[c] = T_obs[c]
    mean_epochs_time_diff = np.nanmean(epochs_time_diff,axis=0)

    padded_t_obs = np.zeros((len(c['electrodes']), len(freqs),len(times)))
    padded_t_obs_plot = np.zeros((len(c['electrodes']), len(freqs),len(times)))
    padded_t_obs[:,:,cluster_range_index[0]:cluster_range_index[1]]=T_obs
    padded_t_obs_plot[:,:,cluster_range_index[0]:cluster_range_index[1]]=T_obs_plot

    for ch_idx, elecd in enumerate(c['electrodes']):
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
        fig.subplots_adjust(left=0.12,bottom= 0.08,right= 0.96, top=0.85, wspace=0.2, hspace=0.3)
        plot_clusters_map(ax1,padded_t_obs,ch_idx,tfr_epochs,t_thresh,padded_t_obs_plot)
        plot_elecds_erps(ax2,mean_epochs_time_diff,ch_idx,tfr_epochs)
        plt.ioff()
        plt.savefig(f'{fig_output_dir}/timeFreq_clusters_{contrast_name}_{tfr_epochs.ch_names[ch_idx]}.png',bbox_inches='tight')
       
       
contrasts_5timesBase = {}    
contrasts_5timesBase["allSubs_RandvsFixed_Wn"] = {'cond_1':{'TOA_cond':'Rand','Vigilance':'Wn'},'cond_2':{'TOA_cond':'Fixed','Vigilance':'Wn'}}
for contrast in contrasts_5timesBase:
    cond_1 = contrasts_5timesBase[contrast]['cond_1']
    cond_2 = contrasts_5timesBase[contrast]['cond_2']
    print(f'Now computeing {cond_1} and {cond_2}')

    contrast_name = f"{contrast.title()}_-{test_per_voxel}"
    getClustersPerConditions_5timesBaseline(c,cond_1, cond_2, contrast_name,cluster_range_index=[time0_i, time0_i*2],is_tfr_performAverageBaseline=is_tfr_performAverageBaseline, test_per_voxel=test_per_voxel) 

Now computeing {'TOA_cond': 'Rand', 'Vigilance': 'Wn'} and {'TOA_cond': 'Fixed', 'Vigilance': 'Wn'}
Not setting metadata
33 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Not setting metadata
9 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Not setting metadata
36 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Not setting metadata
9 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Not setting metadata
33 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Applying baseline correction (mode: mean)
0 projection items activated
Not setting metadata
9 matching events found
Setting baseline interval to [-0.1, 0.0] sec
Ap

100%|██████████| Permuting : 999/999 [00:06<00:00,  162.83it/s]


IndexError: only integers, slices (`:`), ellipsis (`...`), numpy.newaxis (`None`) and integer or boolean arrays are valid indices

In [None]:
np.zeros((len(c['electrodes']), len(freqs),len(times)))