# SWB ROI + TFR Info
01/30/24

from bp re-referenced data

In [59]:
import numpy as np
import mne
from glob import glob
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import seaborn as sns
from scipy.stats import zscore, linregress, ttest_ind, ttest_rel, ttest_1samp
import pandas as pd
from mne.preprocessing.bads import _find_outliers
import os 
import joblib
import re
import datetime
import scipy
import datetime


In [15]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
import sys
sys.path.append('/sc/arion/projects/guLab/Alie/SWB/ephys_analysis/LFPAnalysis/')

In [17]:
from LFPAnalysis import lfp_preprocess_utils, sync_utils, analysis_utils, nlx_utils

In [60]:
# Specify root directory for un-archived data and results 
base_dir = '/sc/arion/projects/guLab/Alie/SWB/'
anat_dir = f'{base_dir}ephys_analysis/recon_labels/'
neural_dir = f'{base_dir}ephys_analysis/data/'
behav_dir = f'{base_dir}swb_behav_models/data/behavior_preprocessed/'

subj_list = pd.read_excel(f'{base_dir}ephys_analysis/subj_info/SWB_elec_info_minerva.xlsx', usecols=[0])
subj_ids = subj_list.PatientID.to_list()

bdi_list = pd.read_csv(f'{base_dir}ephys_analysis/subj_info/bdi_bai_info_01312024.csv', usecols=['bdi'])
subj_bdis = bdi_list.bdi.to_list()

date = datetime.date.today().strftime('%m%d%Y')
print(date)

01312024


# Get YBA Info for all bp ref anodes

In [19]:
#load all bp data ch names for each subj + extract anode name 

subj_bp_ch_names = {}

for subj_id in subj_ids:
    bp_data = mne.io.read_raw_fif(f'{neural_dir}{subj_id}/bp_ref_ieeg.fif',preload=False)
    anode_list = [x.split('-')[0] for x in bp_data.ch_names] 
    subj_bp_ch_names[subj_id] = {'anode-cathode':bp_data.ch_names,
                                'anode':anode_list}
    
    

Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS002/bp_ref_ieeg.fif...
    Range : 0 ... 1083499 =      0.000 ...  2166.998 secs
Ready.
Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS003/bp_ref_ieeg.fif...
    Range : 0 ... 1250999 =      0.000 ...  2501.998 secs
Ready.
Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS004/bp_ref_ieeg.fif...
    Range : 0 ... 1155311 =      0.000 ...  2310.622 secs
Ready.
Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS009/bp_ref_ieeg.fif...
    Range : 0 ... 1077655 =      0.000 ...  2155.310 secs
Ready.
Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS011/bp_ref_ieeg.fif...
    Range : 0 ... 1578124 =      0.000 ...  3156.248 secs
Ready.
Opening raw data file /sc/arion/projects/guLab/Alie/SWB/ephys_analysis/data/MS015/bp_ref_ieeg.fif...
    Range : 0 ... 1108437 =      0.000 ...  2216.874 secs
Ready.
Open

In [20]:
subj_bp_ch_names

{'MS002': {'anode-cathode': ['lacas1-lacas2',
   'lacas2-lacas3',
   'lacas3-lacas4',
   'lacas4-lacas5',
   'lacas5-lacas6',
   'lacas6-lacas7',
   'lacas7-lacas8',
   'lagit1-lagit2',
   'lagit2-lagit3',
   'lagit3-lagit4',
   'laims1-laims2',
   'laims2-laims3',
   'laims3-laims4',
   'laims4-laims5',
   'laims5-laims6',
   'laims11-laims12',
   'laims12-laims13',
   'lhplt1-lhplt2',
   'lhplt2-lhplt3',
   'lhplt3-lhplt4',
   'lhplt9-lhplt10',
   'lhplt10-lhplt11',
   'lhplt11-lhplt12',
   'lloif1-lloif2',
   'lloif2-lloif3',
   'lloif7-lloif8',
   'lloif8-lloif9',
   'lmoif11-lmoif12',
   'lmoif12-lmoif13',
   'lpips1-lpips2',
   'lpips10-lpips11',
   'lsif6-lsif7',
   'lsif7-lsif8',
   'racas2-racas3',
   'racas7-racas8',
   'racas8-racas9',
   'racas12-racas13',
   'ragit1-ragit2',
   'ragit2-ragit3',
   'ragit7-ragit8',
   'ragit8-ragit9',
   'ragit9-ragit10',
   'raims1-raims2',
   'raims2-raims3',
   'raims3-raims4',
   'raims11-raims12',
   'raims12-raims13',
   'raims13-raim

In [25]:
subj_id

'MS024'

In [28]:
### updated ROI function 

# There are some things that MNE is not that good at, or simply does not do. Let's write our own code for these. 
def select_rois_picks(elec_data, chan_name, YBA_ROI_labels, manual_col='ManualExamination'):
    
    """
    Grab specific roi for the channel you are looking at 
    """
    YBA_ROI_labels['Long.name'] = YBA_ROI_labels['Long.name'].str.lower().str.replace(" ", "")
    
    roi = np.nan
    NMM_label = elec_data[elec_data.label==chan_name].NMM.str.lower().str.strip()
    BN246_label = elec_data[elec_data.label==chan_name].BN246.str.lower().str.strip()

    # Account for individual differences in labelling: 
    YBA_label = elec_data[elec_data.label==chan_name].YBA_1.str.lower().str.replace(" ", "")
    manual_label = elec_data[elec_data.label==chan_name][manual_col].str.lower().str.replace(" ", "")

    # Only NMM assigns entorhinal cortex 
    if NMM_label.str.contains('entorhinal').iloc[0]:
        roi = 'EC'

    # First priority: Use YBA labels if there is no manual label
    if pd.isna(manual_label).iloc[0]:
        try:
            roi = YBA_ROI_labels[YBA_ROI_labels['Long.name']==YBA_label.values[0]].Custom.values[0]
        except IndexError:
            # This is probably white matter or out of brain, but not manually labelled as such
            roi = np.nan
    else:
        # Now look at the manual labels: 
        if YBA_label.str.contains('unknown').iloc[0]:
            # prioritize thalamus labels! Which are not present in YBA for some reason
            if (manual_label.str.contains('thalamus').iloc[0]):
                roi = 'THAL'
            else:
                try:
                    roi = YBA_ROI_labels[YBA_ROI_labels['Long.name']==manual_label.values[0]].Custom.values[0]
                except IndexError: 
                    # This is probably white matter or out of brain, and manually labelled as such
                    roi = np.nan

    # Next  use BN246 labels if still unlabeled
    if pd.isna(roi):
        # Just use the dumb BN246 label from LeGui, stripping out the hemisphere which we don't care too much about at the moment
        if (BN246_label.str.contains('hipp').iloc[0]):
            roi = 'HPC'
        elif (BN246_label.str.contains('amyg').iloc[0]):
            roi = 'AMY'
        elif (BN246_label.str.contains('ins').iloc[0]):
            roi = 'INS'
        elif (BN246_label.str.contains('ifg').iloc[0]):
            roi = 'IFG'
        elif (BN246_label.str.contains('org').iloc[0]):
            roi = 'OFC' 
        elif (BN246_label.str.contains('mfg').iloc[0]):
            roi = 'dlPFC'
        elif (BN246_label.str.contains('sfg').iloc[0]):
            roi = 'dmPFC'

    if pd.isna(roi):
        # Just use the dumb NMM label from LeGui, stripping out the hemisphere which we don't care too much about at the moment
        if (NMM_label.str.contains('hippocampus').iloc[0]):
            roi = 'HPC'
        if (NMM_label.str.contains('amygdala').iloc[0]):
            roi = 'AMY'
        if (NMM_label.str.contains('acgc').iloc[0]):
            roi = 'ACC'
        if (NMM_label.str.contains('mcgc').iloc[0]):
            roi = 'MCC'
        if (NMM_label.str.contains('ofc').iloc[0]):
            roi = 'OFC'
        if (NMM_label.str.contains('mfg').iloc[0]):
            roi = 'dlPFC'
        if (NMM_label.str.contains('sfg').iloc[0]):
            roi = 'dmPFC'  

    if pd.isna(roi):
        # This is mostly temporal gyrus
        roi = 'Unknown'

    return roi


In [29]:
anodes_anat_dict = {}

# Load the YBA ROI labels, custom assigned by Salman: 
file_path = '/sc/arion/projects/guLab/Alie/SWB/ephys_analysis/LFPAnalysis/LFPAnalysis/YBA_ROI_labelled.xlsx'
YBA_ROI_labels = pd.read_excel(file_path)


for subj_id in subj_ids:
    anat_file = glob(f'{anat_dir}{subj_id}_labels.csv')[0]
    elec_locs = pd.read_csv(anat_file)
    elec_locs = elec_locs[elec_locs.columns.drop(list(elec_locs.filter(regex='Unnamed')))]
    locs_names = list(elec_locs.label.str.lower())
    elec_locs.label = elec_locs.label.str.lower()
    
    subj_anodes = subj_bp_ch_names[subj_id]['anode']
    
    yba_anode_list = []
    roi_anode_list = []
    
    for ch in subj_anodes:
        if pd.isnull(elec_locs.loc[elec_locs.label==ch,'ManualExamination']).values.any():
            yba_anode_list.append(elec_locs.YBA_1[elec_locs.label==ch]) 
        else:
            yba_anode_list.append(elec_locs.ManualExamination[elec_locs.label==ch]) 
            
        ch_roi = select_rois_picks(elec_locs,ch,YBA_ROI_labels) #can't use unless LFP Analysis is installed as a package
        roi_anode_list.append(ch_roi)
    
    anodes_anat_dict[subj_id] = {'anode':subj_anodes,'yba_labels':yba_anode_list,'roi_labels':roi_anode_list}
    
    

In [None]:
# ch_roi = select_rois_picks(elec_data, chan_name, manual_col='collapsed_manual')

In [30]:
anodes_anat_dict

{'MS002': {'anode': ['lacas1',
   'lacas2',
   'lacas3',
   'lacas4',
   'lacas5',
   'lacas6',
   'lacas7',
   'lagit1',
   'lagit2',
   'lagit3',
   'laims1',
   'laims2',
   'laims3',
   'laims4',
   'laims5',
   'laims11',
   'laims12',
   'lhplt1',
   'lhplt2',
   'lhplt3',
   'lhplt9',
   'lhplt10',
   'lhplt11',
   'lloif1',
   'lloif2',
   'lloif7',
   'lloif8',
   'lmoif11',
   'lmoif12',
   'lpips1',
   'lpips10',
   'lsif6',
   'lsif7',
   'racas2',
   'racas7',
   'racas8',
   'racas12',
   'ragit1',
   'ragit2',
   'ragit7',
   'ragit8',
   'ragit9',
   'raims1',
   'raims2',
   'raims3',
   'raims11',
   'raims12',
   'raims13',
   'rhplt1',
   'rhplt2',
   'rhplt8',
   'rhplt9',
   'rmoif1',
   'rmoif2',
   'rmoif6',
   'rmoif7',
   'rmoif8',
   'rmoif9'],
  'yba_labels': [0    Left cingulate gyrus D
   Name: YBA_1, dtype: object,
   4    Left cingulate gyrus E
   Name: YBA_1, dtype: object,
   5    Left cingulate gyrus F
   Name: YBA_1, dtype: object,
   6    Left cingu

In [45]:
my_rois = ['OFC','vmPFC','vlPFC','dmPFC','dlPFC','AINS','PINS','ACC','MCC','PCC','AMY','HPC','PHG']
anat_df_cols = ['subj_id','bdi']+my_rois

anat_df = pd.DataFrame(columns = anat_df_cols)

In [58]:
for subj_id in subj_ids:
    
    subj_counts = {}
    subj_rois = anodes_anat_dict[subj_id]['roi_labels']
    
    for roi in my_rois:
        subj_counts[roi] = [subj_rois.count(roi)]
    
    subj_counts['subj_id'] = subj_id
    subj_counts['bdi']     = subj_bdis[subj_ids.index(subj_id)]
    anat_df = pd.concat([anat_df,pd.DataFrame(subj_counts)])
        
anat_df       

Unnamed: 0,subj_id,bdi,OFC,vmPFC,vlPFC,dmPFC,dlPFC,AINS,PINS,ACC,MCC,PCC,AMY,HPC,PHG
0,MS002,13,6,0,5,7,4,8,1,6,0,0,5,5,0
0,MS003,32,3,1,1,4,2,6,1,6,0,0,4,0,0
0,MS004,11,0,0,2,1,2,0,0,0,0,0,0,1,0
0,MS009,41,1,0,0,0,4,1,2,4,0,0,0,0,0
0,MS011,14,0,0,2,0,0,0,1,2,1,0,0,0,0
0,MS015,19,0,0,1,2,3,0,0,1,2,0,1,2,0
0,MS016,14,3,0,1,3,2,0,4,2,0,0,2,7,0
0,MS017,8,0,2,0,0,2,2,0,3,0,0,1,2,0
0,MS019,7,4,2,3,3,1,0,0,3,3,0,0,0,0
0,MS020,16,0,0,0,3,0,0,7,2,0,0,0,1,0


In [61]:
anat_df.to_csv(f'{base_dir}ephys_analysis/subj_info/anat_coverage_info_{date}.csv')

# TFRs by Condition

In [None]:
#rois = ['hippocampus', 'amygdala', 'insula', 'cingulate' ,'frontal']
#region = 'frontal orbital'
region = 'frontal orbital'

# band definitions for y-axis
#yticks =  [70, 100, 125, 150, 175] 
yticks =  [4, 30, 60, 120, 180] 

# task condition to contrast 
#conditions = ["(all_epochs == 1)"]
conditions = ["(acquired == 1)" and "(reward == 1)",
             "(acquired == 0)"]


# Define your conditions as lambda functions
conditions = [
    lambda acquired, reward: acquired == 1 and reward == 1,
    lambda acquired: acquired == 0
]

cond_name = 'acquired'


for subj_id in subj_ids:
    #save_path = f'{base_dir}/work/qasims01/MemoryBanditData/EMU/Subjects/{subj_id}'
    # Get electrode df 
    # electrode files could either be csv or excel
    elec_files = glob(f'{anat_dir}/*.csv') + glob(f'{anat_dir}/*.xlsx')
    # There should really only be one 
    elec_file = elec_files[0]
    elec_data = lfp_preprocess_utils.load_elec(elec_file)

    anode_list = [x.split('-')[0] for x in epochs_all_evs[event].ch_names]
    elec_df = elec_data[elec_data.label.str.lower().isin(anode_list)]
    elec_df['label'] = epochs_all_evs[event].ch_names

    picks = analysis_utils.select_picks_rois(elec_df, region)
    print(picks)
    #picks = ['lmoif1-lmoif8'] #- plot just a single electrode 
    
    for event in analysis_evs:
        fig, ax = plt.subplots(1, 2, figsize=(10,3))
        for ix, cond in enumerate(conditions):
            # Set the times for 
            times = power_epochs[event].times
            plot_data = np.nanmean(np.nanmean(power_epochs[event][cond].copy().pick_channels(picks).data, axis=0), axis=0)

            im = ax[ix].imshow(plot_data,
                      extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation=None,
                      aspect='auto', origin='lower', cmap='RdBu_r', vmin = -np.nanmax(np.abs(plot_data)), vmax = np.nanmax(np.abs(plot_data)))
            ax[ix].set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond}_{event}')
            fig.colorbar(im, ax=ax[ix])

In [None]:
# Let's load the power epochs for specific conditions: 

analysis_evs = ['feedback_start']

# band definitions for y-axis
# yticks = [4, 8, 13, 30, 60, 120]

# task condition to contrast 
conditions = ['(rpe>0)',
             '(rpe<0)']

cond_name = 'rpe'

# conditions = ['(rpe>0) & (hits==1)',
#              '(rpe>0) & (hits==0)']

# cond_name = 'SME'

tfr_group_data = {f'{a}': {f'{b}': [] for b in conditions} for a in rois}

power_epochs = {f'{a}': {f'{b}': np.nan for b in analysis_evs} for a in subj_ids}

for subj_id in subj_ids: 
    if subj_id in ['MS024', 'MS034', 'MS038']:
        continue
    for event in analysis_evs:
        filepath = f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/scratch/TFR'
        power_epochs[subj_id][event] = mne.time_frequency.read_tfrs(f'{filepath}/{event}-tfr.h5')[0]
                
        # replace IED metadata with behavioral metadata
        if 'SME' in cond_name:
            epochs_to_analyze = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True).trials_gamble.values - 1

            power_epochs[subj_id][event] = power_epochs[subj_id][event][epochs_to_analyze.astype(int)]

            power_epochs[subj_id][event].metadata = combined_df[(combined_df.participant==subj_id) & (combined_df.condition=='Day1')].dropna(subset=['trials_gamble']).sort_values(by='trials_gamble').reset_index(drop=True)
        else:
            power_epochs[subj_id][event].metadata = learn_df[(learn_df.participant==subj_id)]
            
        # Get electrode df 
        elec_df = pd.read_csv(f'{base_dir}/projects/guLab/Salman/EphysAnalyses/{subj_id}/Day1_reref_elec_df')
        
        
        progress_bar = tqdm(elec_df.label, ascii=True, desc='Aggregating TFRs')

        for elec_ix, chan in enumerate(progress_bar):
            region = elec_df.salman_region.iloc[elec_ix]
            if region in rois:
                for ix, cond in enumerate(conditions):
                    # Set the times for 
                    times = power_epochs[subj_id][event].times
                    if 'SME' in cond_name:
                        plot_data = np.nanmean(np.nanmean(power_epochs[subj_id][event][cond].pick_channels([chan]).data, axis=0), axis=0)
                    else:
                        plot_data = np.nanmean(np.nanmean(power_epochs[subj_id][event][cond].pick_channels([chan]).data, axis=0), axis=0)
                    tfr_group_data[region][cond].append(plot_data)                    
                    

In [None]:
# rois = ['orbital', 'cingulate', 'amygdala', 'insular', 'hippocampus','superior frontal', 'middle frontal', 'inferior frontal']
region = 'orbital'

# band definitions for y-axis
#freqs = np.logspace(*np.log10([4,200]),num=40)
yticks = np.arange(4,200,step=10)

#cond_names = [ev(0,-,+),cr,rpe(0,-,+),profit(0,-,+),total_cpe(-,+),decision_cpe(-,+),regret(-,0),relief(0,+)]
# task condition to contrast 
conditions = ["(total_cpe<0)",
             "(total_cpe>0)"]


cond_name = 'total_cpe'

event = 'DecisionOnset'
for subj_id in subj_ids:
    
    # Get electrode df 
    elec_file = anat_path + subj_id + '_labels.csv'

    elec_data = lfp_preprocess_utils.load_elec(elec_file)

    #elec_data = elec_data.dropna(how='all')
    anode_list = [x.split('-')[0] for x in epochs_all_subjs_all_evs[subj_id][event].ch_names] #having problem here because anode names does not match elec_df!!
    #new_anode_list = lfp_preprocess_utils.match_elec_names(anode_list,elec_data.label)
    elec_df = elec_data[elec_data.label.str.lower().isin(anode_list)]
    #elec_df = elec_data[elec_data.label.str.lower().isin(new_anode_list[0])]

    
    
    elec_df['label'] = epochs_all_subjs_all_evs[subj_id][event].ch_names

    picks = analysis_utils.select_picks_rois(elec_df, region)
    #print(picks)
    #print(subj_id,len(picks))
    
    for event in analysis_evs:
        fig, ax = plt.subplots(1, 2, figsize=(25, 9), dpi=300)
        for ix, cond in enumerate(conditions):
            plot_title = ['Regret','Relief']
            region_name = region[0].upper() + region[1:]
            # Set the times for 
            times = power_epochs[subj_id][event].times
            plot_data = np.nanmean(np.nanmean(power_epochs[subj_id][event][cond].copy().pick_channels(picks).data, axis=0), axis=0)

            im = ax[ix].imshow(plot_data,
                      extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation='Bicubic',
                      aspect='auto', origin='lower', cmap='RdBu_r',vmin = -np.nanmax(np.abs(plot_data)), vmax = np.nanmax(np.abs(plot_data)))
            ax[ix].set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency',title=f'{region_name} {plot_title[ix]} Encoding')
            fig.colorbar(im, ax=ax[ix])
    plt.savefig(f'{base_dir}figs/final_total_cpe_TFRs/{subj_id}_{region}_{cond_name}.png', dpi='figure', format='png', metadata=None,
        bbox_inches=None, pad_inches=0.1,
        facecolor='auto', edgecolor='auto',
        backend=None)   


In [None]:
subj_id = 'MS002'
region = 'orbital'
cond_name = 'total_cpe'

conditions = ["(total_cpe<0)",
             "(total_cpe>0)"]

save_path = f'{base_dir}figs/{cond_name}_clusters/{subj_id}_{region}_{cond_name}'

# Get electrode df 
elec_file = anat_path + subj_id + '_labels.csv'

elec_data = lfp_preprocess_utils.load_elec(elec_file)

anode_list = [x.split('-')[0] for x in epochs_all_subjs_all_evs[subj_id][event].ch_names]
elec_df = elec_data[elec_data.label.str.lower().isin(anode_list)]
elec_df['label'] = epochs_all_subjs_all_evs[subj_id][event].ch_names
    
picks = analysis_utils.select_picks_rois(elec_df, region)

for event in analysis_evs:
    
    # Average the data in each condition across channels 
    X = np.nanmean(power_epochs[subj_id][event].copy().pick_channels(picks).data, axis=1)
    
    F_obs, clusters, cluster_p_values, H0 = \
    mne.stats.permutation_cluster_1samp_test(X, n_permutations=500, out_type='mask', verbose=True)
    
    if any(cluster_p_values<=0.05):
#     print(region)
        # Create new stats image with only significant clusters
        fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

        times =  power_epochs[subj_id][event][conditions[1]].times


        # Average the data in each condition across epochs for plotting
        evoked_power = np.nanmean(X, axis=0)
    #     evoked_power_2 = np.nanmean(X[1], axis=0)
    #     evoked_power_contrast = evoked_power_1 - evoked_power_2
        signs = np.sign(evoked_power)

        F_obs_plot = np.nan * np.ones_like(F_obs)
        for c, p_val in zip(clusters, cluster_p_values):
            if p_val <= 0.05:
                F_obs_plot[c] = F_obs[c] * signs[c]

        ax.imshow(F_obs,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',
                  aspect='auto', origin='lower', cmap='Greys_r')
        max_F = np.nanmax(abs(F_obs_plot))
        ax.imshow(F_obs_plot,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]],
                  aspect='auto', origin='lower', cmap='RdBu_r',
                  vmin=-max_F, vmax=max_F)

        ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond_name}_{event}')

            # # ax.set_title(f'Induced power ({ch_name})')
    
    

In [None]:
subj_id = 'MS002'
region = 'frontal'

conditions = ["(total_cpe<0)",
             "(total_cpe>0)"]
cond_name = 'total_cpe'


save_path = f'{base_dir}figs/region_two_sammple_clust/{subj_id}_{region}_{cond_name}'

elec_file = anat_path + subj_id + '_labels.csv'

elec_data = lfp_preprocess_utils.load_elec(elec_file)

anode_list = [x.split('-')[0] for x in epochs_all_subjs_all_evs[subj_id][event].ch_names]
elec_df = elec_data[elec_data.label.str.lower().isin(anode_list)]
elec_df['label'] = epochs_all_subjs_all_evs[subj_id][event].ch_names
    
picks = analysis_utils.select_picks_rois(elec_df, region)



for event in analysis_evs:
    
    # Average the data in each condition across channels 
    X = [np.nanmean(power_epochs[subj_id][event][conditions[0]].copy().pick_channels(picks).data, axis=1), 
         np.nanmean(power_epochs[subj_id][event][conditions[1]].copy().pick_channels(picks).data, axis=1)]
    
    F_obs, clusters, cluster_p_values, H0 = \
    mne.stats.permutation_cluster_test(X, n_permutations=500, out_type='mask', verbose=True)
    
    if any(cluster_p_values<=0.05):
        # Create new stats image with only significant clusters
        fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

        times =  power_epochs[subj_id][event][conditions[0]].times


        # Average the data in each condition across epochs for plotting
        evoked_power_1 = np.nanmean(X[0], axis=0)
        evoked_power_2 = np.nanmean(X[1], axis=0)
        evoked_power_contrast = evoked_power_1 - evoked_power_2
        signs = np.sign(evoked_power_contrast)

        F_obs_plot = np.nan * np.ones_like(F_obs)
        for c, p_val in zip(clusters, cluster_p_values):
            if p_val <= 0.05:
                F_obs_plot[c] = F_obs[c] * signs[c]

        ax.imshow(F_obs,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',
                  aspect='auto', origin='lower', cmap='gray')
        max_F = np.nanmax(abs(F_obs_plot))
        ax.imshow(F_obs_plot,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]],
                  aspect='auto', origin='lower', cmap='RdBu_r',
                  vmin=-max_F, vmax=max_F)

        ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond_name}_{event}')

        # # ax.set_title(f'Induced power ({ch_name})')
    
    

In [None]:
###### single electrode cluster permutation

subj_id = 'MS017'
cond_name = 'total_cpe'
elec_list = ms017_insula_elecs
event = 'DecisionOnset'

conditions = ["(total_cpe<0)",
             "(total_cpe>0)"]

#save_path = f'{base_dir}figs/{cond_name}_clusters/{subj_id}_{region}_{cond_name}'

In [None]:
for elec in elec_list:
    
    for c in conditions:
        cond = c
        # Average the data in each condition across channels 
        X = np.nanmean(power_epochs[subj_id][event][cond].copy().pick_channels([elec]).data, axis=1)

        F_obs, clusters, cluster_p_values, H0 = \
        mne.stats.permutation_cluster_1samp_test(X, n_permutations=500, out_type='mask', verbose=True)
        if any(cluster_p_values<=0.05):
        #     print(region)
            # Create new stats image with only significant clusters
            fig, ax = plt.subplots(1, 1, figsize=(6, 6), dpi=300)

            times =  power_epochs[subj_id][event][cond].times

            # Average the data in each condition across epochs for plotting
            evoked_power = np.nanmean(X, axis=0)
        #     evoked_power_2 = np.nanmean(X[1], axis=0)
        #     evoked_power_contrast = evoked_power_1 - evoked_power_2
            signs = np.sign(evoked_power)

            F_obs_plot = np.nan * np.ones_like(F_obs)
            for c, p_val in zip(clusters, cluster_p_values):
                if p_val <= 0.05:
                    F_obs_plot[c] = F_obs[c] * signs[c]

            ax.imshow(F_obs,
                        extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',
                        aspect='auto', origin='lower', cmap='gray')
            max_F = np.nanmax(abs(F_obs_plot))
            ax.imshow(F_obs_plot,
                        extent=[times[0], times[-1], freqs[0], freqs[-1]],
                        aspect='auto', origin='lower', cmap='RdBu_r',
                        vmin=-max_F, vmax=max_F)

            ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{elec}_{cond}_{event}')

            plt.savefig(f'{base_dir}figs/one_sample_clust/{subj_id}_{elec}_{cond}.png', dpi='figure', format='png', metadata=None,
            bbox_inches=None, pad_inches=0.1,
            facecolor='auto', edgecolor='auto',
            backend=None)  

    


In [None]:
conditions = ["(total_cpe>0)",
                "(total_cpe<0)"] #relief
cond_name = 'total_cpe'
yticks = np.arange(4,200,step=10)
region = 'inferior frontal'





In [None]:
subj_ids = ['MS002','MS003','MS016','MS017','MS019','MS020','MS022','MS027']

for subj_id in subj_ids:

    save_path = f'{base_dir}figs/region_two_sammple_clust/{subj_id}_{region}_{cond_name}'

    elec_file = anat_path + subj_id + '_labels.csv'

    elec_data = lfp_preprocess_utils.load_elec(elec_file)

    anode_list = [x.split('-')[0] for x in epochs_all_subjs_all_evs[subj_id][event].ch_names]
    elec_df = elec_data[elec_data.label.str.lower().isin(anode_list)]
    elec_df['label'] = epochs_all_subjs_all_evs[subj_id][event].ch_names
        
    picks = analysis_utils.select_picks_rois(elec_df, region)


    for event in analysis_evs:
        
        # Average the data in each condition across channels 
        X = [np.nanmean(power_epochs[subj_id][event][conditions[0]].copy().pick_channels(picks).data, axis=1), 
            np.nanmean(power_epochs[subj_id][event][conditions[1]].copy().pick_channels(picks).data, axis=1)]
        
        F_obs, clusters, cluster_p_values, H0 = \
        mne.stats.permutation_cluster_test(X, n_permutations=500, out_type='mask', verbose=True)
        
        if any(cluster_p_values<=0.05):
            # Create new stats image with only significant clusters
            fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

            times =  power_epochs[subj_id][event][conditions[0]].times


            # Average the data in each condition across epochs for plotting
            evoked_power_1 = np.nanmean(X[0], axis=0)
            evoked_power_2 = np.nanmean(X[1], axis=0)
            evoked_power_contrast = evoked_power_1 - evoked_power_2
            signs = np.sign(evoked_power_contrast)

            F_obs_plot = np.nan * np.ones_like(F_obs)
            for c, p_val in zip(clusters, cluster_p_values):
                if p_val <= 0.05:
                    F_obs_plot[c] = F_obs[c] * signs[c]

            ax.imshow(F_obs,
                    extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',
                    aspect='auto', origin='lower', cmap='gray')
            max_F = np.nanmax(abs(F_obs_plot))
            ax.imshow(F_obs_plot,
                    extent=[times[0], times[-1], freqs[0], freqs[-1]],
                    aspect='auto', origin='lower', cmap='RdBu_r',
                    vmin=-max_F, vmax=max_F)

            ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond_name}_{event}')
            
            plt.savefig('f{save_path}.png', dpi='figure', format='png', metadata=None,
                        bbox_inches=None, pad_inches=0.1,facecolor='auto', edgecolor='auto',backend=None)  

            # # ax.set_title(f'Induced power ({ch_name})')
        
            
        

In [None]:
region = 'lpfc'

elec_df = pd.read_csv('/Users/christinamaher/Desktop/ieeg_data/MS009/Anat/MS009_labels.csv')

elecs_to_pick = elec_df.loc[elec_df['bin'] == region, 'NMMlabel'].str.lower() + '-'
elecs_to_pick = elecs_to_pick.tolist()

picks = []
for e in elecs_to_pick:
    picks_temp = list(filter(lambda s: e in s,  power_epochs[0].info['ch_names']))
    picks.append(picks_temp)

picks = [item for sublist in picks for item in sublist]
picks = [picks]
picks

In [None]:
subj_id = 'MS009'
region = 'lpfc'
analysis_evs = ['choice_ts']
freqs = np.logspace(*np.log10([2, 200]), num=30)
yticks = [4, 30, 60, 120]



save_path = '/Users/christinamaher/Desktop/christinamaher/ieeg_data/MS009/Ephys'
# Get electrode df 

elec_df = pd.read_csv('/Users/christinamaher/Desktop/ieeg_data/MS009/Anat/MS009_labels.csv')

elecs_to_pick = elec_df.loc[elec_df['bin'] == region, 'NMMlabel'].str.lower() + '-'
elecs_to_pick = elecs_to_pick.tolist()

picks = []
for e in elecs_to_pick:
    picks_temp = list(filter(lambda s: e in s,  power_epochs[0].info['ch_names']))
    picks.append(picks_temp)

picks = [item for sublist in picks for item in sublist][2]
picks = [picks]
    
conditions = ["(condition == 'no_hint')",
             "(condition == 'hint')"]

cond_name = 'Context'


for event in analysis_evs:
    
    # Average the data in each condition across channels 
    X = [np.nanmean(power_epochs[0][conditions[0]].copy().pick_channels(picks).data, axis=1), 
         np.nanmean(power_epochs[0][conditions[1]].copy().pick_channels(picks).data, axis=1)]
    
    F_obs, clusters, cluster_p_values, H0 = \
    mne.stats.permutation_cluster_test(X, n_permutations=500, out_type='mask', verbose=True)
    
    if any(cluster_p_values<=0.05):
        # Create new stats image with only significant clusters
        fig, ax = plt.subplots(1, 1, figsize=(6, 4), dpi=300)

        times =  power_epochs[0][conditions[0]].times


        # Average the data in each condition across epochs for plotting
        evoked_power_1 = np.nanmean(X[0], axis=0)
        evoked_power_2 = np.nanmean(X[1], axis=0)
        evoked_power_contrast = evoked_power_1 - evoked_power_2
        signs = np.sign(evoked_power_contrast)

        F_obs_plot = np.nan * np.ones_like(F_obs)
        for c, p_val in zip(clusters, cluster_p_values):
            if p_val <= 0.05:
                F_obs_plot[c] = F_obs[c] * signs[c]

        ax.imshow(F_obs,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',
                  aspect='auto', origin='lower', cmap='gray')
        max_F = np.nanmax(abs(F_obs_plot))
        ax.imshow(F_obs_plot,
                  extent=[times[0], times[-1], freqs[0], freqs[-1]],
                  aspect='auto', origin='lower', cmap='RdBu_r',
                  vmin=-max_F, vmax=max_F)

        ax.set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond_name}_{event}')
        
        # # ax.set_title(f'Induced power ({ch_name})')

In [None]:
region = 'hippocampus'
# band definitions for y-axis
yticks = [4, 8, 13, 30, 60, 120]

conditions = ["(condition == 'hint')",
             "(condition == 'no_hint')"]

cond_name = 'Context'

elecs_to_pick = elec_df.loc[elec_df['bin'] == region, 'NMMlabel'].str.lower() + '-'
elecs_to_pick = elecs_to_pick.tolist()

picks = []
for e in elecs_to_pick:
    picks_temp = list(filter(lambda s: e in s,  power_epochs['choice_ts'].info['ch_names']))[0]
    picks.append(picks_temp)

    
for event in analysis_evs:
    fig, ax = plt.subplots(1, 2, figsize=(20, 6), dpi=300)
    for ix, cond in enumerate(conditions):
        times = power_epochs[event].times
        plot_data = np.nanmean(np.nanmean(power_epochs[event][cond].copy().pick_channels(picks).data, axis=0), axis=0)

        im = ax[ix].imshow(plot_data,
                    extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation='Bicubic',
                    aspect='auto', origin='lower', cmap='RdBu_r', vmin = -np.nanmax(np.abs(plot_data)), vmax = np.nanmax(np.abs(plot_data)))
        ax[ix].set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{region}_{cond}')
        fig.colorbar(im, ax=ax[ix])

In [None]:
region = 'lpfc'
# band definitions for y-axis
yticks = [4, 30, 60, 120]

conditions = ["(condition == 'no_hint')",
             "(condition == 'hint')"]

cond_name = 'Context'

elecs_to_pick = elec_df.loc[elec_df['bin'] == region, 'NMMlabel'].str.lower() + '-'
elecs_to_pick = elecs_to_pick.tolist()

picks = []
for e in elecs_to_pick:
    picks_temp = list(filter(lambda s: e in s,  power_epochs[0].info['ch_names']))
    picks.append(picks_temp)

picks = [item for sublist in picks for item in sublist][2]
picks = [picks]
    
for event in analysis_evs:
    fig, ax = plt.subplots(1, 3, figsize=(15, 4), dpi=300)
    for ix, cond in enumerate(conditions):
        times = power_epochs[0].times
        plot_data = np.nanmean(np.nanmean(power_epochs[0][cond].copy().pick_channels(picks).data, axis=0), axis=0)

        im = ax[ix].imshow(plot_data,
                    extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation='Bicubic',
                    aspect='auto', origin='lower', cmap='RdBu_r', vmin = -0.6, vmax = 0.6)
        ax[ix].set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{region}_{cond}')
        fig.colorbar(im, ax=ax[ix])
    
    ax[2].imshow(F_obs,extent=[times[0], times[-1], freqs[0], freqs[-1]], interpolation = 'Bicubic',aspect='auto', origin='lower', cmap='gray')
    max_F = np.nanmax(abs(F_obs_plot))
    ax[2].imshow(F_obs_plot,extent=[times[0], times[-1], freqs[0], freqs[-1]],aspect='auto', origin='lower', cmap='RdBu_r',vmin=-max_F, vmax=max_F)

    ax[2].set(yticks=yticks, xlabel='Time (s)', ylabel='Frequency', title=f'{subj_id}_{region}_{cond_name}_{event}')
fig.tight_layout()