In [1]:
import numpy as np
import xarray as xr
import scipy.signal as sg
import pandas as pd
import scipy.stats as st
import matplotlib.pyplot as plt
from matplotlib import patches
import npc_lims
from sklearn.metrics import roc_curve, roc_auc_score
from statsmodels.stats.multitest import fdrcorrection
from npc_sessions import DynamicRoutingSession
from dynamic_routing_analysis import spike_utils
import os

%load_ext autoreload
%autoreload 2
%matplotlib widget

In [None]:
#lick modulation
#vis lick modulation = difference between CR and FR in vis blocks (sound1 stim)
#aud lick modulation = difference between CR and FR in aud blocks (vis1 stim)

def compute_lick_modulation(trials, units, session_info, save_path):

    lick_modulation={
        'unit_id':[],
        'session_id':[],
        'project':[],
        # 'structure':[],
    }

    lick_modulation['lick_modulation_index'] = []
    lick_modulation['lick_modulation_zscore'] = []
    lick_modulation['lick_modulation_p_value'] = []
    lick_modulation['lick_modulation_sign'] = []
    lick_modulation['lick_modulation_roc_auc'] = []

    #make data array first
    time_before = 0.5
    time_after = 0.5
    binsize = 0.025
    trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, time_before, time_after, binsize)
                                                                              
    if "Templeton" in session_info.project:
        lick_trials = trials.query('is_response==True')
        non_lick_trials = trials.query('is_response==False')
        baseline_trials = trials

    elif "DynamicRouting" in session_info.project:
        lick_trials = trials.query('(stim_name=="vis1" and context_name=="aud" and is_response==True) or \
                                (stim_name=="sound1" and context_name=="vis" and is_response==True)')
        non_lick_trials = trials.query('(stim_name=="vis1" and context_name=="aud" and is_response==False) or \
                                        (stim_name=="sound1" and context_name=="vis" and is_response==False)')
        baseline_trials = trials.query('(stim_name=="vis1" and context_name=="aud") or \
                                        (stim_name=="sound1" and context_name=="vis")')
    else:
        print('incompatible project: ',session_info.project,'; skipping')
        return

    #for each unit
    for uu,unit in units.iterrows():
        if 'Templeton' in session_info.project:
            continue

        lick_modulation['unit_id'].append(unit['unit_id'])
        lick_modulation['session_id'].append(str(unit['session_id']))
        lick_modulation['project'].append(str(session_info.project))
        # lick_modulation['structure'].append(unit['structure'])

        #lick modulation
        lick_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0.2,0.5),trials=lick_trials.index).mean(dim='time',skipna=True)
        non_lick_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0.2,0.5),trials=non_lick_trials.index).mean(dim='time',skipna=True)
        
        baseline_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(-0.5,-0.2),trials=baseline_trials.index).mean(dim='time',skipna=True)
        
        # lick_diff = lick_frs_by_trial - non_lick_frs_by_trial

        lick_frs_by_trial_zscore = (lick_frs_by_trial.mean(skipna=True)-non_lick_frs_by_trial.mean(skipna=True))/baseline_frs_by_trial.std(skipna=True)
        lick_modulation['lick_modulation_zscore'].append(lick_frs_by_trial_zscore.mean(skipna=True).values)

        lick_modulation_index=(lick_frs_by_trial.mean(skipna=True)-non_lick_frs_by_trial.mean(skipna=True))/(lick_frs_by_trial.mean(skipna=True)+non_lick_frs_by_trial.mean(skipna=True))
        lick_modulation['lick_modulation_index'].append(lick_modulation_index.values)
        
        pval = st.mannwhitneyu(lick_frs_by_trial.values, non_lick_frs_by_trial.values,nan_policy='omit')[1]
        # pval = st.ranksums(lick_frs_by_trial.values, non_lick_frs_by_trial.values,nan_policy='omit')[1]
        lick_modulation['lick_modulation_p_value'].append(pval)

        stim_mod_sign=np.sign(lick_frs_by_trial.mean(skipna=True).values-non_lick_frs_by_trial.mean(skipna=True).values)
        lick_modulation['lick_modulation_sign'].append(stim_mod_sign)

        #ROC AUC
        binary_label = np.concatenate([np.ones(lick_frs_by_trial.size),np.zeros(non_lick_frs_by_trial.size)])
        binary_score = np.concatenate([lick_frs_by_trial.values,non_lick_frs_by_trial.values])
        lick_roc_auc = roc_auc_score(binary_label, binary_score)
        lick_modulation['lick_modulation_roc_auc'].append(lick_roc_auc)


    lick_modulation_df=pd.DataFrame(lick_modulation)
    lick_modulation_df.to_pickle(os.path.join(save_path,session_info.id+'_lick_modulation.pkl'))


In [None]:
# context modulation of stimulus responses metric


def compute_stim_context_modulation(trials, units, session_info, save_path):

    stim_context_modulation = {
        'unit_id':[],
        'session_id':[],
        'project':[],
        'baseline_context_modulation_index':[],
        'baseline_context_modulation_p_value':[],
        'baseline_context_modulation_zscore':[],
        'baseline_context_modulation_sign':[],
        'baseline_context_roc_auc':[],
        'vis_discrim_roc_auc':[],
        'aud_discrim_roc_auc':[],
        'target_discrim_roc_auc':[],
        'nontarget_discrim_roc_auc':[],
        'vis_vs_aud':[],
        'cr_vs_fa_early_roc_auc':[],
        'hit_vs_cr_early_roc_auc':[],
        'hit_vs_fa_early_roc_auc':[],
        'cr_vs_fa_mid_roc_auc':[],
        'hit_vs_cr_mid_roc_auc':[],
        'hit_vs_fa_mid_roc_auc':[],
        'cr_vs_fa_late_roc_auc':[],
        'hit_vs_cr_late_roc_auc':[],
        'hit_vs_fa_late_roc_auc':[],
    }
    for ss in trials['stim_name'].unique():
        stim_context_modulation[ss+'_context_modulation_index'] = []
        stim_context_modulation[ss+'_context_modulation_zscore'] = []
        stim_context_modulation[ss+'_context_modulation_sign'] = []
        stim_context_modulation[ss+'_context_modulation_p_value'] = []
        stim_context_modulation[ss+'_context_modulation_roc_auc'] = []
        stim_context_modulation[ss+'_evoked_context_modulation_index'] = []
        stim_context_modulation[ss+'_evoked_context_modulation_zscore'] = []
        stim_context_modulation[ss+'_evoked_context_modulation_sign'] = []
        stim_context_modulation[ss+'_evoked_context_modulation_p_value'] = []
        stim_context_modulation[ss+'_stimulus_modulation_index'] = []
        stim_context_modulation[ss+'_stimulus_modulation_zscore'] = []
        stim_context_modulation[ss+'_stimulus_modulation_p_value'] = []
        stim_context_modulation[ss+'_stimulus_modulation_sign'] = []
        stim_context_modulation[ss+'_stimulus_modulation_roc_auc'] = []
        stim_context_modulation[ss+'_stimulus_late_modulation_index'] = []
        stim_context_modulation[ss+'_stimulus_late_modulation_zscore'] = []
        stim_context_modulation[ss+'_stimulus_late_modulation_p_value'] = []
        stim_context_modulation[ss+'_stimulus_late_modulation_sign'] = []
        stim_context_modulation[ss+'_stimulus_late_modulation_roc_auc'] = []
        stim_context_modulation[ss+'_stim_latency'] = []

    contexts=trials['context_name'].unique()

    if 'Templeton' in session_info.project:
        contexts = ['aud','vis']

        start_time=trials['start_time'].iloc[0]
        fake_context=np.full(len(trials), fill_value='nan')
        fake_block_nums=np.full(len(trials), fill_value=np.nan)

        if np.random.choice(contexts,1)=='vis':
            block_contexts=['vis','aud','vis','aud','vis','aud']
        else:
            block_contexts=['aud','vis','aud','vis','aud','vis']

        trials['true_block_index']=trials['block_index']
        trials['true_context_name']=trials['context_name']

        for block in range(0,6):
            block_start_time=start_time+block*10*60
            block_end_time=start_time+(block+1)*10*60
            block_trials=trials.query('start_time>=@block_start_time').index
            fake_context[block_trials]=block_contexts[block]
            fake_block_nums[block_trials]=block
        
        trials['context_name']=fake_context
        trials['block_index']=fake_block_nums
        trials['is_vis_context']=trials['context_name']=='vis'
        trials['is_aud_context']=trials['context_name']=='aud'

    #make data array first
    time_before = 0.1
    time_after = 0.3
    binsize = 0.025
    trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, time_before, time_after, binsize)

    #for each unit
    for uu,unit in units.iterrows():

        stim_context_modulation['unit_id'].append(unit['unit_id'])
        stim_context_modulation['session_id'].append(str(unit['session_id']))
        stim_context_modulation['project'].append(str(session_info.project))

        #find baseline frs across all trials
        baseline_frs = trial_da.sel(unit_id=unit['unit_id'],time=slice(-0.1,0)).mean(dim='time')

        vis_baseline_frs = baseline_frs.sel(trials=trials.query('context_name=="vis"').index)
        aud_baseline_frs = baseline_frs.sel(trials=trials.query('context_name=="aud"').index)

        pval = st.mannwhitneyu(vis_baseline_frs.values, aud_baseline_frs.values,nan_policy='omit')[1]
        # pval = st.ranksums(vis_baseline_frs.values, aud_baseline_frs.values,nan_policy='omit')[1]
        stim_context_modulation['baseline_context_modulation_p_value'].append(pval)

        vis_baseline_frs = vis_baseline_frs.mean(skipna=True).values
        aud_baseline_frs = aud_baseline_frs.mean(skipna=True).values

        baseline_modulation_zscore=(vis_baseline_frs-aud_baseline_frs)/baseline_frs.std(skipna=True)
        stim_context_modulation['baseline_context_modulation_zscore'].append(baseline_modulation_zscore.values)

        baseline_modulation_index=(vis_baseline_frs-aud_baseline_frs)/(vis_baseline_frs+aud_baseline_frs)
        stim_context_modulation['baseline_context_modulation_index'].append(baseline_modulation_index)

        baseline_mod_sign=np.sign(np.mean(vis_baseline_frs-aud_baseline_frs))
        stim_context_modulation['baseline_context_modulation_sign'].append(baseline_mod_sign)

        #auc for baseline frs
        binary_label=trials['context_name']=='vis'
        baseline_context_auc=roc_auc_score(binary_label,baseline_frs.values)
        stim_context_modulation['baseline_context_roc_auc'].append(baseline_context_auc)

        #cross stimulus discrimination
        #vis1 vs. vis2
        vis1_trials = trials.query('stim_name=="vis1"')
        vis2_trials = trials.query('stim_name=="vis2"')
        vis1_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=vis1_trials.index).mean(dim='time',skipna=True)
        vis2_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=vis2_trials.index).mean(dim='time',skipna=True)
        vis1_and_vis2_frs=np.concatenate([vis1_frs_by_trial.values,vis2_frs_by_trial.values])
        binary_label=np.concatenate([np.ones(len(vis1_frs_by_trial)),np.zeros(len(vis2_frs_by_trial))])
        vis_discrim_auc=roc_auc_score(binary_label,vis1_and_vis2_frs)
        stim_context_modulation['vis_discrim_roc_auc'].append(vis_discrim_auc)

        #aud1 vs. aud2
        aud1_trials = trials.query('stim_name=="sound1"')
        aud2_trials = trials.query('stim_name=="sound2"')
        aud1_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=aud1_trials.index).mean(dim='time',skipna=True)
        aud2_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=aud2_trials.index).mean(dim='time',skipna=True)
        aud1_and_aud2_frs=np.concatenate([aud1_frs_by_trial.values,aud2_frs_by_trial.values])
        binary_label=np.concatenate([np.ones(len(aud1_frs_by_trial)),np.zeros(len(aud2_frs_by_trial))])
        aud_discrim_auc=roc_auc_score(binary_label,aud1_and_aud2_frs)
        stim_context_modulation['aud_discrim_roc_auc'].append(aud_discrim_auc)

        #targets: vis1 vs sound1
        vis1_vs_aud1_frs=np.concatenate([vis1_frs_by_trial.values,aud1_frs_by_trial.values])
        binary_label=np.concatenate([np.ones(len(vis1_frs_by_trial)),np.zeros(len(aud1_frs_by_trial))])
        target_discrim_auc=roc_auc_score(binary_label,vis1_vs_aud1_frs)
        stim_context_modulation['target_discrim_roc_auc'].append(target_discrim_auc)

        #nontargets: vis2 vs sound2
        vis2_vs_aud2_frs=np.concatenate([vis2_frs_by_trial.values,aud2_frs_by_trial.values])
        binary_label=np.concatenate([np.ones(len(vis2_frs_by_trial)),np.zeros(len(aud2_frs_by_trial))])
        nontarget_discrim_auc=roc_auc_score(binary_label,vis2_vs_aud2_frs)
        stim_context_modulation['nontarget_discrim_roc_auc'].append(nontarget_discrim_auc)

        #vis vs. aud
        vis_and_aud_frs=np.concatenate([vis1_frs_by_trial.values,vis2_frs_by_trial.values,
                                        aud1_frs_by_trial.values,aud2_frs_by_trial.values])
        binary_label=np.concatenate([np.ones(len(vis1_frs_by_trial)+len(vis2_frs_by_trial)),
                                    np.zeros(len(aud1_frs_by_trial)+len(aud2_frs_by_trial))])
        vis_vs_aud_auc=roc_auc_score(binary_label,vis_and_aud_frs)
        stim_context_modulation['vis_vs_aud'].append(vis_vs_aud_auc)

        #HIT/CR/FA - currently only makes sense for DR experiments
        behav_time_windows_start=[0,0.1,0.2]
        behav_time_windows_end=[0.1,0.2,0.3]
        behav_time_window_labels=['early','mid','late']
        if 'DynamicRouting' in session_info.project:
            cr_trials=trials.query('is_response==False and is_correct==True and is_target==True')
            fa_trials=trials.query('is_response==True and is_correct==False and is_target==True')
            hit_trials=trials.query('is_response==True and is_correct==True and is_target==True')

            
            for tw,time_window in enumerate(behav_time_window_labels):
                cr_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(behav_time_windows_start[tw],behav_time_windows_end[tw]),trials=cr_trials.index).mean(dim='time',skipna=True)
                fa_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(behav_time_windows_start[tw],behav_time_windows_end[tw]),trials=fa_trials.index).mean(dim='time',skipna=True)
                hit_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(behav_time_windows_start[tw],behav_time_windows_end[tw]),trials=hit_trials.index).mean(dim='time',skipna=True)

                #cr vs. fa
                cr_and_fa_frs=np.concatenate([cr_frs_by_trial.values,fa_frs_by_trial.values])
                binary_label=np.concatenate([np.ones(len(cr_frs_by_trial)),np.zeros(len(fa_frs_by_trial))])
                cr_vs_fa_auc=roc_auc_score(binary_label,cr_and_fa_frs)
                stim_context_modulation['cr_vs_fa_'+time_window+'_roc_auc'].append(cr_vs_fa_auc)

                #hit vs. cr
                hit_and_cr_frs=np.concatenate([hit_frs_by_trial.values,cr_frs_by_trial.values])
                binary_label=np.concatenate([np.ones(len(hit_frs_by_trial)),np.zeros(len(cr_frs_by_trial))])
                hit_vs_cr_auc=roc_auc_score(binary_label,hit_and_cr_frs)
                stim_context_modulation['hit_vs_cr_'+time_window+'_roc_auc'].append(hit_vs_cr_auc)

                #hit vs. fa
                hit_and_fa_frs=np.concatenate([hit_frs_by_trial.values,fa_frs_by_trial.values])
                binary_label=np.concatenate([np.ones(len(hit_frs_by_trial)),np.zeros(len(fa_frs_by_trial))])
                hit_vs_fa_auc=roc_auc_score(binary_label,hit_and_fa_frs)
                stim_context_modulation['hit_vs_fa_'+time_window+'_roc_auc'].append(hit_vs_fa_auc)
        else:
            for tw,time_window in enumerate(behav_time_window_labels):
                stim_context_modulation['cr_vs_fa_'+time_window+'_roc_auc'].append(np.nan)
                stim_context_modulation['hit_vs_cr_'+time_window+'_roc_auc'].append(np.nan)
                stim_context_modulation['hit_vs_fa_'+time_window+'_roc_auc'].append(np.nan)

        
        #loop through stimuli
        for ss in trials['stim_name'].unique():
            if ss=='catch':
                same_context=contexts[0]
                other_context=contexts[1]
            elif 'sound' in ss:
                same_context='aud'
                other_context='vis'
            elif 'vis' in ss:
                same_context='vis'
                other_context='aud'
            # else:
            #     same_context=contexts[contexts==ss[:-1]][0]
            #     other_context=contexts[contexts!=ss[:-1]][0]

            #stimulus modulation
            stim_trials = trials.query('stim_name==@ss')
            stim_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=stim_trials.index).mean(dim='time',skipna=True)
            stim_baseline_frs_by_trial = baseline_frs.sel(trials=stim_trials.index)
            stim_frs_by_trial_zscore = (stim_frs_by_trial-stim_baseline_frs_by_trial.mean(skipna=True))/stim_baseline_frs_by_trial.std(skipna=True)
            stim_context_modulation[ss+'_stimulus_modulation_zscore'].append(stim_frs_by_trial_zscore.mean(skipna=True).values)
            stimulus_modulation_index=(stim_frs_by_trial-stim_baseline_frs_by_trial).mean(skipna=True)/(stim_frs_by_trial+stim_baseline_frs_by_trial).mean(skipna=True)
            stim_context_modulation[ss+'_stimulus_modulation_index'].append(stimulus_modulation_index.values)
            # pval = st.ks_2samp(stim_frs_by_trial.values, stim_baseline_frs_by_trial.values)[1]
            pval = st.wilcoxon(stim_frs_by_trial.values, stim_baseline_frs_by_trial.values,nan_policy='omit',zero_method='zsplit')[1]
            stim_context_modulation[ss+'_stimulus_modulation_p_value'].append(pval)
            stim_mod_sign=np.sign(np.mean(stim_frs_by_trial.values-stim_baseline_frs_by_trial.values))
            stim_context_modulation[ss+'_stimulus_modulation_sign'].append(stim_mod_sign)
            #auc for stimulus frs
            stim_and_baseline_frs=np.concatenate([stim_frs_by_trial.values,stim_baseline_frs_by_trial.values])
            binary_label=np.concatenate([np.ones(len(stim_frs_by_trial)),np.zeros(len(stim_baseline_frs_by_trial))])
            stim_context_auc=roc_auc_score(binary_label,stim_and_baseline_frs)
            stim_context_modulation[ss+'_stimulus_modulation_roc_auc'].append(stim_context_auc)

            #stimulus late modulation
            stim_late_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0.1,0.2),trials=stim_trials.index).mean(dim='time',skipna=True)
            stim_late_frs_by_trial_zscore = (stim_late_frs_by_trial-stim_baseline_frs_by_trial.mean(skipna=True))/stim_baseline_frs_by_trial.std(skipna=True)
            stim_context_modulation[ss+'_stimulus_late_modulation_zscore'].append(stim_late_frs_by_trial_zscore.mean(skipna=True).values)
            stimulus_late_modulation_index=(stim_late_frs_by_trial-stim_baseline_frs_by_trial).mean(skipna=True)/(stim_late_frs_by_trial+stim_baseline_frs_by_trial).mean(skipna=True)
            stim_context_modulation[ss+'_stimulus_late_modulation_index'].append(stimulus_late_modulation_index.values)
            # pval = st.ks_2samp(stim_late_frs_by_trial.values, stim_baseline_frs_by_trial.values)[1]
            pval = st.wilcoxon(stim_late_frs_by_trial.values, stim_baseline_frs_by_trial.values,nan_policy='omit',zero_method='zsplit')[1]
            stim_context_modulation[ss+'_stimulus_late_modulation_p_value'].append(pval)
            stim_late_mod_sign=np.sign(np.mean(stim_late_frs_by_trial.values-stim_baseline_frs_by_trial.values))
            stim_context_modulation[ss+'_stimulus_late_modulation_sign'].append(stim_late_mod_sign)
            #auc for stimulus late frs
            stim_late_and_baseline_frs=np.concatenate([stim_late_frs_by_trial.values,stim_baseline_frs_by_trial.values])
            binary_label=np.concatenate([np.ones(len(stim_late_frs_by_trial)),np.zeros(len(stim_baseline_frs_by_trial))])
            stim_late_context_auc=roc_auc_score(binary_label,stim_late_and_baseline_frs)
            stim_context_modulation[ss+'_stimulus_late_modulation_roc_auc'].append(stim_late_context_auc)

            #latency
            stim_latency = np.abs(trial_da).sel(unit_id=unit['unit_id'],time=slice(0,0.3),trials=stim_trials.index).mean(dim='trials',skipna=True).idxmax(dim='time').values
            stim_context_modulation[ss+'_stim_latency'].append(stim_latency)

            #find stim trials in same vs. other context
            same_context_trials = trials.query('context_name==@same_context and stim_name==@ss')
            other_context_trials = trials.query('context_name==@other_context and stim_name==@ss')

            same_context_baseline_frs = baseline_frs.sel(trials=same_context_trials.index)
            other_context_baseline_frs = baseline_frs.sel(trials=other_context_trials.index)

            #find raw frs during stim (first 100ms)
            same_context_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=same_context_trials.index).mean(dim='time',skipna=True)
            other_context_frs_by_trial = trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=other_context_trials.index).mean(dim='time',skipna=True)

            pval = st.mannwhitneyu(same_context_frs_by_trial.values, other_context_frs_by_trial.values,nan_policy='omit')[1]
            # pval = st.ranksums(same_context_frs_by_trial.values, other_context_frs_by_trial.values,nan_policy='omit')[1]
            stim_context_modulation[ss+'_context_modulation_p_value'].append(pval)

            same_context_frs = same_context_frs_by_trial.mean(skipna=True).values
            other_context_frs = other_context_frs_by_trial.mean(skipna=True).values

            # context_modulation_zscore=((same_context_frs-other_context_frs)-stim_baseline_frs_by_trial.mean(skipna=True))/stim_baseline_frs_by_trial.std(skipna=True)
            context_modulation_zscore=((same_context_frs-other_context_frs))/stim_baseline_frs_by_trial.std(skipna=True)
            stim_context_modulation[ss+'_context_modulation_zscore'].append(context_modulation_zscore.values)

            # stim context modulation sign
            context_mod_sign=np.sign(np.mean(same_context_frs-other_context_frs))
            stim_context_modulation[ss+'_context_modulation_sign'].append(context_mod_sign)

            # stim context modulation auc
            binary_label=np.concatenate([np.ones(len(same_context_frs_by_trial.values)),np.zeros(len(other_context_frs_by_trial.values))])
            stim_context_auc=roc_auc_score(binary_label,np.concatenate([same_context_frs_by_trial.values,other_context_frs_by_trial.values]))
            stim_context_modulation[ss+'_context_modulation_roc_auc'].append(stim_context_auc)

            #find evoked frs during stim (first 100ms)
            same_context_evoked_frs_by_trial = (trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=same_context_trials.index).mean(dim=['time'],skipna=True)-same_context_baseline_frs)
            other_context_evoked_frs_by_trial = (trial_da.sel(unit_id=unit['unit_id'],time=slice(0,0.1),trials=other_context_trials.index).mean(dim=['time'],skipna=True)-other_context_baseline_frs)

            pval = st.mannwhitneyu(same_context_evoked_frs_by_trial.values, other_context_evoked_frs_by_trial.values,nan_policy='omit')[1]
            # pval = st.ranksums(same_context_evoked_frs_by_trial.values, other_context_evoked_frs_by_trial.values,nan_policy='omit')[1]
            stim_context_modulation[ss+'_evoked_context_modulation_p_value'].append(pval)

            same_context_evoked_frs = same_context_evoked_frs_by_trial.mean(skipna=True).values
            other_context_evoked_frs = other_context_evoked_frs_by_trial.mean(skipna=True).values

            # context_modulation_evoked_zscore=((same_context_evoked_frs-other_context_evoked_frs)-stim_baseline_frs_by_trial.mean(skipna=True))/stim_baseline_frs_by_trial.std(skipna=True)
            context_modulation_evoked_zscore=((same_context_evoked_frs-other_context_evoked_frs))/stim_baseline_frs_by_trial.std(skipna=True)
            stim_context_modulation[ss+'_evoked_context_modulation_zscore'].append(context_modulation_evoked_zscore.values)

            # evoked stim context modulation sign
            context_mod_evoked_sign=np.sign(np.mean(same_context_evoked_frs-other_context_evoked_frs))
            stim_context_modulation[ss+'_evoked_context_modulation_sign'].append(context_mod_evoked_sign)
            
            #negative numbers can make index behave weirdly, so subtract the minimum from both
            if same_context_evoked_frs<0 or other_context_evoked_frs<0:
                same_context_evoked_frs = same_context_evoked_frs - np.min([same_context_evoked_frs,other_context_evoked_frs])
                other_context_evoked_frs = other_context_evoked_frs - np.min([same_context_evoked_frs,other_context_evoked_frs])

            #compute metrics
            raw_fr_metric=(same_context_frs-other_context_frs)/(same_context_frs+other_context_frs)
            stim_context_modulation[ss+'_context_modulation_index'].append(raw_fr_metric)

            evoked_fr_metric=(same_context_evoked_frs-other_context_evoked_frs)/(same_context_evoked_frs+other_context_evoked_frs)
            stim_context_modulation[ss+'_evoked_context_modulation_index'].append(evoked_fr_metric)


    unit_metric_merge=units.reset_index(drop=True).merge(pd.DataFrame(stim_context_modulation),on=['unit_id','session_id'])
    unit_metric_merge=unit_metric_merge.drop(columns=['spike_times'])
    unit_metric_merge.to_pickle(os.path.join(save_path,session_info.id+'_stim_context_modulation.pkl'))
    # return(stim_context_modulation)


In [None]:
# unit_metric_merge=units.merge(pd.DataFrame(stim_context_modulation),on=['unit_id','session_id'])
# unit_metric_merge=unit_metric_merge.drop(columns=['spike_times'])
# unit_metric_merge.to_pickle(os.path.join(save_path,session_info.id+'_stim_context_modulation.pkl'))
# units.reset_index(drop=True)
# units.reset_index(drop=True).merge(pd.DataFrame(stim_context_modulation),on=['unit_id','session_id'])

In [None]:
# units['session_id'].iloc[0]
# units.columns.values

In [None]:
# pd.DataFrame(stim_context_modulation)['session_id'].iloc[0]

In [2]:
ephys_sessions=tuple(s for s in npc_lims.get_session_info(is_ephys=True, is_uploaded=True, is_annotated=True,))
                    #  if s.project=='DynamicRouting' or s.project=='TempletonPilotEphys')

In [None]:
# len(ephys_sessions)
# session_info.id
# ephys_sessions
session=ephys_sessions[10]

npc_lims.get_cache_path('units',session)
npc_lims.get_nwb_path(session)

In [15]:
# from pynwb import NWBHDF5IO, NWBFile
# with NWBHDF5IO(npc_lims.get_nwb_path(session), "r") as io:
#     read_nwbfile = io.read()

In [None]:
except_dict={}

for session_info in ephys_sessions[:]:
    try:
        trials=[]
        units=[]
        # session_info=npc_lims.get_session_info(session)
        save_path=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics"
        lick_save_path=os.path.join(save_path,'lick_modulation')
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        
        try:
            trials=pd.read_parquet(
                        npc_lims.get_cache_path('trials',session.id,version='0.0.214')
                    )
            units=pd.read_parquet(
                        npc_lims.get_cache_path('units',session.id,version='0.0.214')
                    )
        except:
            print(session_info.id,'failed to load trials and/or units')

            continue

        spike_utils.compute_stim_context_modulation(trials, units, session_info, save_path)

        spike_utils.compute_lick_modulation(trials, units, session_info, lick_save_path)

        print(session_info.id,'done')

    except Exception as e:
        print(session_info.id,'failed')
        except_dict[session_info.id]=e


In [None]:
session

In [None]:
x=pd.read_pickle(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics\676909_2023-12-11_0_stim_context_modulation.pkl")
x

In [None]:
y=pd.read_pickle(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics\620263_2022-07-26_0_stim_context_modulation.pkl")
y

In [None]:
for col in y.columns:
    if col not in x.columns:
        print(col+' not in new data')

In [None]:
# load and concat all the context-mod dataframes
loadpath = r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics"
all_files = os.listdir(loadpath)
all_files = [f for f in all_files if f.endswith('.pkl')]
for ff in all_files:
    if ff==all_files[0]:
        all_data=pd.read_pickle(os.path.join(loadpath,ff))
    else:
        all_data=pd.concat([all_data,pd.read_pickle(os.path.join(loadpath,ff))],axis=0)

all_data.to_pickle(os.path.join(loadpath,'combined','all_stim_context_modulation_new.pkl'))

In [None]:
# load and concat all the lick-mod dataframes
loadpath = r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics\lick_modulation"
all_lick_files = os.listdir(loadpath)
all_lick_files = [f for f in all_lick_files if f.endswith('.pkl')]
for ff in all_lick_files:
    if ff==all_lick_files[0]:
        all_lick_data=pd.read_pickle(os.path.join(loadpath,ff))
    else:
        all_lick_data=pd.concat([all_lick_data,pd.read_pickle(os.path.join(loadpath,ff))],axis=0)

all_lick_data.to_pickle(os.path.join(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\new_annotations\single unit metrics\combined",'all_lick_modulation.pkl'))

In [None]:
except_dict

In [None]:
# for uu,unit in units.iterrows():
#     print(uu,unit['unit_id'])

In [None]:
ephys_sessions[10]

In [None]:
len(all_data)

In [None]:
# redo_list=['676909_2023-12-12','670248_2023-08-01','670180_2023-07-27','670180_2023-07-26','670181_2023-07-18','668759_2023-07-12',
#            '668759_2023-07-11','662983_2023-05-16','649944_2023-02-28','649944_2023-02-27','646318_2023-01-18','646318_2023-01-17',
#            '644547_2022-12-06','644547_2022-12-05','636397_2022-09-27','636397_2022-09-26','628801_2022-09-19','620264_2022-08-02',
#            '620263_2022-07-27','620263_2022-07-26']
# redo_list=['676909_2023-12-12','674562_2023-10-04','670248_2023-08-01','670180_2023-07-27','670180_2023-07-26','670181_2023-07-18','668759_2023-07-12',
#            '668759_2023-07-11','662983_2023-05-16','649944_2023-02-28','649944_2023-02-27','646318_2023-01-18','646318_2023-01-17',
#            '644547_2022-12-06','644547_2022-12-05','636397_2022-09-27','636397_2022-09-26','628801_2022-09-19','626791_2022-08-15',
#            '626791_2022-08-16','626791_2022-08-17','620264_2022-08-02','620263_2022-07-27','620263_2022-07-26']

In [None]:
ephys_sessions[0]

In [None]:
session = ephys_sessions[0]
session_info=npc_lims.get_session_info(session)
trials=pd.read_parquet(
            npc_lims.get_cache_path('trials',session.id,version='any')
        )
units=pd.read_parquet(
            npc_lims.get_cache_path('units',session.id,version='any')
        )
performance=pd.read_parquet(
            npc_lims.get_cache_path('performance',session.id,version='any')
        )

# #make data array first
# time_before = 1.0
# time_after = 1.0
# binsize = 0.025
# trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, time_before, time_after, binsize)

In [None]:
# performance

In [None]:
trials.columns.values

In [None]:
# fig,ax=plt.subplots()
# ax.plot(trials['is_vis_target'],alpha=0.5)
# ax.plot(trials['is_target'],alpha=0.5)
# ax.plot(trials['is_vis_context'],alpha=0.5)

In [None]:
# #cr trials
# cr_trials=trials.query('is_response==False and is_correct==True and is_target==True')

# #hit trials
# hit_trials=trials.query('is_response==True and is_correct==True and is_target==True')

# #fa trials
# fa_trials=trials.query('is_response==True and is_correct==False and is_target==True')

In [None]:
# save_path=r"D:\\"
# units_no_spikes=units.drop(columns=['spike_times'])
# units_no_spikes.to_pickle(os.path.join(save_path,session_info.id+'_units_no_spikes.pkl'))

In [None]:
units['structure'].unique()

In [None]:
units.query('structure=="ORBl" and firing_rate>=5')['unit_id'].iloc[0]

In [None]:
sel_unit=units.query('structure=="ACAd" and firing_rate>=5')['unit_id'].iloc[10]
unit_da=trial_da.sel(unit_id=sel_unit,time=slice(-0.2,0)).mean(dim='time')

# sel_unit=units.query('structure=="ACAd"')['unit_id'].values
# unit_da=trial_da.sel(unit_id=sel_unit,time=slice(-1.0,0)).mean(dim=['time','unit_id'])

vis_context_trials=trials.query('context_name=="vis"')
aud_context_trials=trials.query('context_name=="aud"')

# sel_unit=units.query('structure=="ACAd" and firing_rate>=5')['unit_id'].iloc[10]
# vis1_trials=trials.query('stim_name=="vis1"')
# unit_da=trial_da.sel(unit_id=sel_unit,time=slice(0,0.1),trials=vis1_trials.index).mean(dim='time')

# vis_context_trials=trials.query('context_name=="vis" and stim_name=="vis1"')
# aud_context_trials=trials.query('context_name=="aud" and stim_name=="vis1"')

vis_frs=unit_da.sel(trials=vis_context_trials.index).values
aud_frs=unit_da.sel(trials=aud_context_trials.index).values

fig,ax=plt.subplots()
ax.hist(vis_frs,bins=np.arange(0,120,5),alpha=0.5)
ax.hist(aud_frs,bins=np.arange(0,120,5),alpha=0.5)

In [None]:
from sklearn.metrics import roc_curve, roc_auc_score

binary_label=trials['context_name']=='vis'
# binary_label=trials.query('stim_name=="vis1"')['context_name']=='vis'
all_frs=unit_da.values

fpr, tpr, thresholds = roc_curve(binary_label,all_frs)
roc_auc=roc_auc_score(binary_label,all_frs)

In [None]:
thresholds

In [None]:
fig,ax=plt.subplots()
ax.plot(fpr,tpr)
ax.set_xlabel('false positive rate')
ax.set_ylabel('true positive rate')
ax.set_title('ROC curve, AUC = '+str(roc_auc))

In [None]:
from sklearn.metrics import RocCurveDisplay

display=RocCurveDisplay(fpr=fpr,tpr=tpr)
display.plot()

In [None]:
sel_unit=units.query('structure=="ACAd" and firing_rate>=5')['unit_id'].iloc[10]
vis1_trials=trials.query('stim_name=="vis1"')
baseline_da=trial_da.sel(unit_id=sel_unit,time=slice(-0.1,0),trials=vis1_trials.index).mean(dim='time')
stim_da=trial_da.sel(unit_id=sel_unit,time=slice(0,0.1),trials=vis1_trials.index).mean(dim='time')

binary_label=np.concatenate([np.ones(len(stim_da)),np.zeros(len(baseline_da))])
all_frs=np.concatenate([stim_da.values,baseline_da.values])

fpr, tpr, thresholds = roc_curve(binary_label,all_frs)
roc_auc=roc_auc_score(binary_label,all_frs)

fig,ax=plt.subplots(2,1)
ax[0].hist(baseline_da.values,bins=np.arange(0,120,5),alpha=0.5)
ax[0].hist(stim_da.values,bins=np.arange(0,120,5),alpha=0.5)

ax[1].plot(fpr,tpr)
ax[1].set_xlabel('false positive rate')
ax[1].set_ylabel('true positive rate')
ax[1].set_title('ROC curve, AUC = '+str(roc_auc))

fig.tight_layout()



In [None]:
# units['session_id']
x=pd.read_pickle(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics\lick_modulation\664851_2023-11-16_0_lick_modulation.pkl")

In [None]:
x

In [None]:
fig,ax=plt.subplots(1,1)
ax.hist(np.abs(0.5-x['baseline_context_roc_auc'])+0.5)

In [None]:
fig,ax=plt.subplots(1,1)
ax.plot(x['baseline_context_modulation_zscore'],x['baseline_context_roc_auc'],'.')

In [None]:
trials

In [None]:
ephys_sessions[5]

In [None]:
# # load and concat all the context-mod dataframes
# loadpath = r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics"
# all_files = os.listdir(loadpath)
# all_files = [f for f in all_files if f.endswith('.pkl')]
# for ff in all_files:
#     if ff==all_files[0]:
#         all_data=pd.read_pickle(os.path.join(loadpath,ff))
#     else:
#         all_data=pd.concat([all_data,pd.read_pickle(os.path.join(loadpath,ff))],axis=0)

# all_data.to_pickle(os.path.join(loadpath,'combined','all_stim_context_modulation_new.pkl'))

In [None]:
# # load and concat all the lick-mod dataframes
# loadpath = r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics\lick_modulation"
# all_lick_files = os.listdir(loadpath)
# all_lick_files = [f for f in all_lick_files if f.endswith('.pkl')]
# for ff in all_lick_files:
#     if ff==all_lick_files[0]:
#         all_lick_data=pd.read_pickle(os.path.join(loadpath,ff))
#     else:
#         all_lick_data=pd.concat([all_lick_data,pd.read_pickle(os.path.join(loadpath,ff))],axis=0)

# all_lick_data.to_pickle(os.path.join(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics\combined",'all_lick_modulation.pkl'))

In [None]:
all_data.columns.values

In [None]:
sel_units=all_data.query('presence_ratio>=0.99 and \
                            isi_violations_ratio<=0.1 and \
                            amplitude_cutoff<=0.1 and \
                            project.str.contains("DynamicRouting")')

adj_pvals=pd.DataFrame({
    'unit_id':sel_units['unit_id'],
    'vis1':fdrcorrection(sel_units['vis1_stimulus_modulation_p_value'])[1],
    'vis2':fdrcorrection(sel_units['vis2_stimulus_modulation_p_value'])[1],
    'sound1':fdrcorrection(sel_units['sound1_stimulus_modulation_p_value'])[1],
    'sound2':fdrcorrection(sel_units['sound2_stimulus_modulation_p_value'])[1],
})

#stimulus modulation across all units
#each stim only
vis1_stim_resp=adj_pvals.query('vis1<0.05 and vis2>=0.05 and sound1>=0.05 and sound2>=0.05')
vis2_stim_resp=adj_pvals.query('vis2<0.05 and vis1>=0.05 and sound1>=0.05 and sound2>=0.05')
sound1_stim_resp=adj_pvals.query('sound1<0.05 and sound2>=0.05 and vis1>=0.05 and vis2>=0.05')
sound2_stim_resp=adj_pvals.query('sound2<0.05 and sound1>=0.05 and vis1>=0.05 and vis2>=0.05')

#both vis
both_vis_stim_resp=adj_pvals.query('vis1<0.05 and vis2<0.05 and sound1>=0.05 and sound2>=0.05')
#both aud
both_sound_stim_resp=adj_pvals.query('sound1<0.05 and sound2<0.05 and vis1>=0.05 and vis2>=0.05')

#at least one vis and one aud
mixed_stim_resp=adj_pvals.query('((vis1<0.05 or vis2<0.05) and (sound1<0.05 and sound2<0.05))')

#none
no_stim_resp=adj_pvals.query('vis1>=0.05 and vis2>=0.05 and sound1>=0.05 and sound2>=0.05')

In [None]:
labels=['vis1 only','vis2 only','both vis',
        'sound1 only','sound2 only','both sound',
        'mixed','none']
sizes=[len(vis1_stim_resp),len(vis2_stim_resp),len(both_vis_stim_resp),
        len(sound1_stim_resp),len(sound2_stim_resp),len(both_sound_stim_resp),
        len(mixed_stim_resp),len(no_stim_resp)]

fig,ax=plt.subplots()
ax.pie(sizes,labels=labels,autopct='%1.1f%%')
ax.set_title('n = '+str(len(sel_units))+' units')

fig.tight_layout()

In [None]:
# sel_area='VISpm'

area_number_responsive_to_stim={
        'area':[],
        'vis1':[],
        'vis2':[],
        'sound1':[],
        'sound2':[],
        'both_vis':[],
        'both_sound':[],
        'mixed':[],
        'none':[],
}

for sel_area in all_data['structure'].unique():

        sel_units=all_data.query('presence_ratio>=0.99 and \
                                isi_violations_ratio<=0.1 and \
                                amplitude_cutoff<=0.1 and \
                                project.str.contains("DynamicRouting") and \
                                structure.str.contains(@sel_area)')

        adj_pvals=pd.DataFrame({
        'unit_id':sel_units['unit_id'],
        'vis1':fdrcorrection(sel_units['vis1_stimulus_modulation_p_value'])[1],
        'vis2':fdrcorrection(sel_units['vis2_stimulus_modulation_p_value'])[1],
        'sound1':fdrcorrection(sel_units['sound1_stimulus_modulation_p_value'])[1],
        'sound2':fdrcorrection(sel_units['sound2_stimulus_modulation_p_value'])[1],
        })

        #stimulus modulation across all units
        #each stim only
        vis1_stim_resp=adj_pvals.query('vis1<0.05 and vis2>=0.05 and sound1>=0.05 and sound2>=0.05')
        vis2_stim_resp=adj_pvals.query('vis2<0.05 and vis1>=0.05 and sound1>=0.05 and sound2>=0.05')
        sound1_stim_resp=adj_pvals.query('sound1<0.05 and sound2>=0.05 and vis1>=0.05 and vis2>=0.05')
        sound2_stim_resp=adj_pvals.query('sound2<0.05 and sound1>=0.05 and vis1>=0.05 and vis2>=0.05')

        #both vis
        both_vis_stim_resp=adj_pvals.query('vis1<0.05 and vis2<0.05 and sound1>=0.05 and sound2>=0.05')
        #both aud
        both_sound_stim_resp=adj_pvals.query('sound1<0.05 and sound2<0.05 and vis1>=0.05 and vis2>=0.05')

        #at least one vis and one aud
        mixed_stim_resp=adj_pvals.query('((vis1<0.05 or vis2<0.05) and (sound1<0.05 and sound2<0.05))')

        #none
        no_stim_resp=adj_pvals.query('vis1>=0.05 and vis2>=0.05 and sound1>=0.05 and sound2>=0.05')

        area_number_responsive_to_stim['area'].append(sel_area)
        area_number_responsive_to_stim['vis1'].append(len(vis1_stim_resp))
        area_number_responsive_to_stim['vis2'].append(len(vis2_stim_resp))
        area_number_responsive_to_stim['sound1'].append(len(sound1_stim_resp))
        area_number_responsive_to_stim['sound2'].append(len(sound2_stim_resp))
        area_number_responsive_to_stim['both_vis'].append(len(both_vis_stim_resp))
        area_number_responsive_to_stim['both_sound'].append(len(both_sound_stim_resp))
        area_number_responsive_to_stim['mixed'].append(len(mixed_stim_resp))
        area_number_responsive_to_stim['none'].append(len(no_stim_resp))

        labels=['vis1 only','vis2 only','both vis',
                'sound1 only','sound2 only','both sound',
                'mixed','none']
        
        sizes=[len(vis1_stim_resp),len(vis2_stim_resp),len(both_vis_stim_resp),
                len(sound1_stim_resp),len(sound2_stim_resp),len(both_sound_stim_resp),
                len(mixed_stim_resp),len(no_stim_resp)]
        
        if np.sum(sizes)>0:
                fig,ax=plt.subplots()
                ax.pie(sizes,labels=labels,autopct='%1.1f%%')
                ax.set_title('area='+sel_area+'; n='+str(len(sel_units))+' units')

                fig.tight_layout()

                fig.savefig(
                        os.path.join(r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics\plots\stimulus responsiveness",sel_area+"_DR.png"),
                        dpi=300, facecolor='w', edgecolor='w',
                        orientation='portrait', format='png',
                        transparent=True, bbox_inches='tight', pad_inches=0.1,
                        metadata=None)

                plt.close()

area_number_responsive_to_stim=pd.DataFrame(area_number_responsive_to_stim)

In [None]:
area_number_responsive_to_stim.query('(vis1+vis2+sound1+sound2+both_vis+both_sound+mixed+none)>20')

In [None]:
area_fraction_responsive_to_stim=area_number_responsive_to_stim.copy()
total_n={
    'area':[],
    'total_n':[],
}
for rr,row in area_fraction_responsive_to_stim.iterrows():
    total_n['area'].append(row['area'])
    total_n['total_n'].append(row[1:].sum())
    if row[1:].sum()>0:
        area_fraction_responsive_to_stim.iloc[rr,1:]=row[1:]/row[1:].sum()

area_fraction_responsive_to_stim=pd.merge(area_fraction_responsive_to_stim,pd.DataFrame(total_n),on='area')

In [None]:
total_n

In [None]:
area_fraction_responsive_to_stim

In [None]:
area_fraction_responsive_to_stim.to_csv(
    r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\single unit metrics\combined\area_fraction_responsive_to_stim_new.csv",
)

In [None]:
fdrcorrection(all_data['vis1_stimulus_modulation_p_value'])

In [None]:
xbins=np.arange(-1,1.1,0.1)
fig,ax=plt.subplots(2,1)

dr_good_unit_ids=all_data.query('presence_ratio>=0.99 and \
                                isi_violations_ratio<=0.1 and \
                                amplitude_cutoff<=0.1 and \
                                project.str.contains("DynamicRouting")')['unit_id'].values

templ_good_unit_ids=all_data.query('presence_ratio>=0.99 and \
                                   isi_violations_ratio<=0.1 and \
                                   amplitude_cutoff<=0.1 and \
                                   project.str.contains("Templeton")')['unit_id'].values

ax[0].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @dr_good_unit_ids')['vis1_evoked_context_modulation_index'],bins=xbins,alpha=0.5,label='evoked')
ax[0].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @dr_good_unit_ids')['vis1_context_modulation_index'],bins=xbins,alpha=0.5,label='raw')
ax[0].legend()
ax[0].set_title('Dynamic Routing units')

ax[1].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @templ_good_unit_ids')['vis1_evoked_context_modulation_index'],bins=xbins,alpha=0.5,label='evoked')
ax[1].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @templ_good_unit_ids')['vis1_context_modulation_index'],bins=xbins,alpha=0.5,label='raw')
ax[1].legend()
ax[1].set_title('Templeton units')

fig.tight_layout()

In [None]:
xbins=np.arange(-1,1.1,0.1)

fig,ax=plt.subplots(2,1)

dr_good_unit_ids=all_data.query('presence_ratio>=0.99 and \
                                isi_violations_ratio<=0.1 and \
                                amplitude_cutoff<=0.1 and \
                                project.str.contains("DynamicRouting")')['unit_id'].values

templ_good_unit_ids=all_data.query('presence_ratio>=0.99 and \
                                   isi_violations_ratio<=0.1 and \
                                   amplitude_cutoff<=0.1 and \
                                   project.str.contains("Templeton")')['unit_id'].values

ax[0].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and vis1_evoked_context_modulation_p_value<0.01 and unit_id in @dr_good_unit_ids')['vis1_evoked_context_modulation_index'],bins=xbins,alpha=0.5,label='evoked')
ax[0].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and vis1_evoked_context_modulation_p_value<0.01 and unit_id in @dr_good_unit_ids')['vis1_context_modulation_index'],bins=xbins,alpha=0.5,label='raw')
ax[0].legend()
ax[0].set_title('Dynamic Routing units')

ax[1].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and vis1_evoked_context_modulation_p_value<0.01 and unit_id in @templ_good_unit_ids')['vis1_evoked_context_modulation_index'],bins=xbins,alpha=0.5,label='evoked')
ax[1].hist(all_data.query('vis1_stimulus_modulation_p_value<0.01 and vis1_evoked_context_modulation_p_value<0.01 and unit_id in @templ_good_unit_ids')['vis1_context_modulation_index'],bins=xbins,alpha=0.5,label='raw')
ax[1].legend()
ax[1].set_title('Templeton units')

fig.tight_layout()

In [None]:
#boxplot of distribution across areas - DR

all_areas=all_data['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=all_data.query('structure==@aa and \
                                    presence_ratio>=0.99 and \
                                    isi_violations_ratio<=0.1 and \
                                    amplitude_cutoff<=0.1 and \
                                    project.str.contains("DynamicRouting")')['unit_id']
                            #    peak_to_valley>0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=all_data.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_index'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]
        #plot distribution
        ax.boxplot(context_mod_values,positions=[counter],showfliers=False)

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('context modulation index')
    ax.set_title('stim_name: '+ss)


In [None]:
#number of units above some threshold of context modulation (both directions?)

# threshold value
threshold=0.5

all_areas=all_data['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=all_data.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1 and \
                                  project.str.contains("DynamicRouting")')['unit_id']
                               #peak_to_valley<0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=all_data.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_index'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]

        pos_fraction=np.sum(context_mod_values>=threshold)/len(context_mod_values)
        neg_fraction=np.sum(context_mod_values<=-threshold)/len(context_mod_values)

        #plot distribution
        # ax.boxplot(context_mod_values,positions=[counter],showfliers=False)
        ax.bar(counter,pos_fraction,color='tab:green')
        ax.bar(counter,-neg_fraction,color='tab:blue')

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('fraction context mod above +/-'+str(threshold))
    ax.set_ylim([-1.05,1.05])
    ax.set_title('stim_name: '+ss)


In [None]:
#number of units significantly modulated by context

# threshold value
sig_threshold=0.01

all_areas=all_data['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(15,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=all_data.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1 and \
                                    project.str.contains("Templeton")')['unit_id']
                            #    peak_to_valley>0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=all_data.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_p_value'].values
        # context_mod_values=stim_context_modulation_df.query('unit_id in @area_units')[ss+'_stimulus_modulation_p_value'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]

        sig_fraction=np.sum(context_mod_values<threshold)/len(context_mod_values)

        #plot distribution
        # ax.boxplot(context_mod_values,positions=[counter],showfliers=False)
        ax.bar(counter,sig_fraction,color='tab:green')
        # ax.bar(counter,-neg_fraction,color='tab:blue')

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('fraction significantly modulated by context')
    ax.set_ylim([0,1.05])
    ax.set_title('stim_name: '+ss)


In [None]:
#number of units significantly modulated by context

# threshold value
sig_threshold=0.01

all_areas=all_data['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=all_data.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1 and \
                                    project.str.contains("DynamicRouting")')['unit_id']
                            #    peak_to_valley>0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        # context_mod_values=all_data.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_p_value'].values
        stim_mod_values=all_data.query('unit_id in @area_units')[ss+'_stimulus_modulation_p_value'].values
        stim_mod_values=stim_mod_values[~np.isnan(stim_mod_values)]

        sig_fraction=np.sum(stim_mod_values<threshold)/len(stim_mod_values)

        #plot distribution
        # ax.boxplot(context_mod_values,positions=[counter],showfliers=False)
        ax.bar(counter,sig_fraction,color='tab:green')
        # ax.bar(counter,-neg_fraction,color='tab:blue')

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('fraction significantly modulated by stimulus')
    ax.set_ylim([0,1.05])
    ax.set_title('stim_name: '+ss)


In [None]:
###only include areas with at least 10 units in each of 3 recordings

In [None]:
fig,ax=plt.subplots()

good_unit_ids=units.query('presence_ratio>=0.99 and \
                            isi_violations_ratio<=0.1 and \
                            amplitude_cutoff<=0.1')['unit_id'].values

ax.hist(stim_context_modulation_df.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @good_unit_ids')['vis1_evoked_context_modulation_index'],bins=20,alpha=0.5,label='evoked')
ax.hist(stim_context_modulation_df.query('vis1_stimulus_modulation_p_value<0.01 and unit_id in @good_unit_ids')['vis1_context_modulation_index'],bins=20,alpha=0.5,label='raw')
ax.legend()

In [None]:
#plot actual distributions across areas

In [None]:
units['peak_to_valley']

In [None]:
#boxplot of distribution across areas

all_areas=units['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=units.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1')['unit_id']# and \
                            #    peak_to_valley>0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=stim_context_modulation_df.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_index'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]
        #plot distribution
        ax.boxplot(context_mod_values,positions=[counter],showfliers=False)

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('context modulation index')
    ax.set_title('stim_name: '+ss)


In [None]:
context_mod_values.values

In [None]:
#number of units above some threshold of context modulation (both directions?)

# threshold value
threshold=0.5

all_areas=units['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=units.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1 and \
                               peak_to_valley<0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=stim_context_modulation_df.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_index'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]

        pos_fraction=np.sum(context_mod_values>=threshold)/len(context_mod_values)
        neg_fraction=np.sum(context_mod_values<=-threshold)/len(context_mod_values)

        #plot distribution
        # ax.boxplot(context_mod_values,positions=[counter],showfliers=False)
        ax.bar(counter,pos_fraction,color='tab:green')
        ax.bar(counter,-neg_fraction,color='tab:blue')

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('fraction context mod above +/-'+str(threshold))
    ax.set_ylim([-1.05,1.05])
    ax.set_title('stim_name: '+ss)


In [None]:
# units['structure'].value_counts()

In [None]:
#number of units significantly modulated by context

# threshold value
sig_threshold=0.01

all_areas=units['structure'].unique()

#loop through stimuli
for ss in trials['stim_name'].unique():
    fig,ax=plt.subplots(figsize=(10,4))
    ax.axhline(0,color='black',linestyle='--')
    #loop through unique areas
    for counter,aa in enumerate(all_areas):
        
        #get unit ids in this area
        area_units=units.query('structure==@aa and \
                               presence_ratio>=0.99 and \
                               isi_violations_ratio<=0.1 and \
                               amplitude_cutoff<=0.1')['unit_id']# and \
                            #    peak_to_valley>0.0004')['unit_id']
        #get context modulation values for these units
        stim_p_val_str=ss+'_stimulus_modulation_p_value'
        context_mod_values=stim_context_modulation_df.query('unit_id in @area_units and '+stim_p_val_str+'<0.01')[ss+'_evoked_context_modulation_p_value'].values
        # context_mod_values=stim_context_modulation_df.query('unit_id in @area_units')[ss+'_stimulus_modulation_p_value'].values
        context_mod_values=context_mod_values[~np.isnan(context_mod_values)]

        sig_fraction=np.sum(context_mod_values<threshold)/len(context_mod_values)

        #plot distribution
        # ax.boxplot(context_mod_values,positions=[counter],showfliers=False)
        ax.bar(counter,sig_fraction,color='tab:green')
        # ax.bar(counter,-neg_fraction,color='tab:blue')

    ax.set_xticks(range(len(all_areas)))
    ax.set_xticklabels(all_areas,rotation=90)
    ax.set_ylabel('fraction significantly modulated by context')
    ax.set_ylim([0,1.05])
    ax.set_title('stim_name: '+ss)
