# Decode context from spikes or facemap

1 - either use all annotated & uploaded ephys sessions as input or provide a list of session_ids

2 - set a savepath and filename for the output - one .pkl file per session

3 - set parameters - descriptions below

4 - run decoding!

In [1]:
import sys
# sys.path.append(r"C:\Users\shailaja.akella\Dropbox (Personal)\DR\dynamic_routing_analysis_ethan\src")

import npc_lims
from dynamic_routing_analysis import decoding_utils, path_utils
from npc_sessions import DynamicRoutingSession
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import os
import pandas as pd
import upath

%load_ext autoreload
%autoreload 2

In [None]:
#1A get all uploaded & annotated ephys sessions

ephys_sessions = tuple(s for s in npc_lims.get_session_info(is_ephys=True, is_uploaded=True, 
                                                            is_annotated=True, project='DynamicRouting', issues = []))

In [2]:
#1B alternatively, provide a list of session ids:
session_id_list=['733891_2024-09-19','712815_2024-05-22','708016_2024-05-01','664851_2023-11-14','702136_2024-03-05','686176_2023-12-05']
# session_id_list=['703333_2024-04-08']
# session_id_list=['668755_2023-08-30']
# session_id_list=['667252_2023-09-25'] #only 4 blocks - test error handling
session_list=[]
for ss in session_id_list:
    session_list.append(npc_lims.get_session_info(ss))
ephys_sessions=tuple(session_list)
ephys_sessions

(SessionInfo(id='733891_2024-09-19', project='DynamicRouting', is_ephys=True, is_sync=True, allen_path=WindowsUPath('//allen/programs/mindscope/workgroups/dynamicrouting/PilotEphys/Task 2 pilot/DRpilot_733891_20240919'), experiment_day=4, session_kwargs={}, notes='', issues=[]),
 SessionInfo(id='712815_2024-05-22', project='DynamicRouting', is_ephys=True, is_sync=True, allen_path=WindowsUPath('//allen/programs/mindscope/workgroups/dynamicrouting/PilotEphys/Task 2 pilot/DRpilot_712815_20240522'), experiment_day=3, session_kwargs={}, notes='', issues=[]),
 SessionInfo(id='708016_2024-05-01', project='DynamicRouting', is_ephys=True, is_sync=True, allen_path=WindowsUPath('//allen/programs/mindscope/workgroups/dynamicrouting/PilotEphys/Task 2 pilot/DRpilot_708016_20240501'), experiment_day=3, session_kwargs={}, notes='', issues=[]),
 SessionInfo(id='664851_2023-11-14', project='DynamicRouting', is_ephys=True, is_sync=True, allen_path=WindowsUPath('//allen/programs/mindscope/workgroups/dynam

In [30]:
#2 set savepath and filename
savepath=upath.UPath(r"\\allen\programs\mindscope\workgroups\templeton\TTOC\decoding results\test_redefined_metrics\presence_ratio")
filename='test_redefined_metrics_presence_ratio.pkl'

# filename='2024_10_28'
# savepath = path_utils.DECODING_ROOT_PATH / 'decoding_test_2024_10_28'

recalculated_unit_metrics = pd.read_pickle(r"D:\recalc_metrics\units_with_recalc_metrics.pkl")
query_string='sliding_rp_violation<=0.1 and amplitude_cutoff<=0.1 and presence_ratio_task>=0.99 and session_id==@session_id'


except_list={}

#3 set parameters
#linear shift decoding currently just takes the average firing rate over all bins defined here
# spikes_binsize=0.2 #bin size in seconds
# spikes_time_before=0.2 #time before the stimulus per trial
# spikes_time_after=0.01 #time after the stimulus per trial
spikes_binsize=0.1 #bin size in seconds
spikes_time_before=0.0 #time before the stimulus per trial
spikes_time_after=0.11 #time after the stimulus per trial

# #not used for linear shift decoding, were used in a previous iteration of decoding analysis
# decoder_binsize=0.2
# decoder_time_before=0.2
# decoder_time_after=0.1


params = {
    'n_units': [20,30,'all'], #number of units to sample for each area (list)
    'n_repeats': 25,  # number of times to repeat decoding with different randomly sampled units
    'input_data_type': 'spikes',  # spikes or facemap or LP
    'vid_angle_facemotion': 'face', # behavior, face, eye
    'vid_angle_LP': 'behavior',
    'central_section': '4_blocks_plus',
    'predict': 'context', # 'context' or 'vis_appropriate_response'
    # for linear shift decoding, how many trials to use for the shift. '4_blocks_plus' is best
    'exclude_cue_trials': False,  # option to totally exclude autorewarded trials
    'n_unit_threshold': 20,  # minimum number of units to include an area in the analysis
    'keep_n_SVDs': 500,  # number of SVD components to keep for facemap data
    'LP_parts_to_keep': ['ear_base_l', 'eye_bottom_l', 'jaw', 'nose_tip', 'whisker_pad_l_side'],
    'spikes_binsize': spikes_binsize,
    'spikes_time_before': spikes_time_before,
    'spikes_time_after': spikes_time_after,
    # 'decoder_binsize':decoder_binsize,
    # 'decoder_time_before':decoder_time_before,
    # 'decoder_time_after':decoder_time_after,
    'savepath': savepath,
    'filename': filename,
    'use_structure_probe': True,  # if True, append probe name to area name when multiple probes in the same area
    'crossval': '5_fold_constant',  # '5_fold', '5_fold_constant', or 'blockwise' - blockwise untested with linear shift
    'labels_as_index': True,  # convert labels (context names) to index [0,1]
    'decoder_type': 'LogisticRegression',  # 'linearSVC' or 'LDA' or 'RandomForest' or 'LogisticRegression'
    'only_use_all_units': False, #if True, do not run decoding with different areas, only with all areas -- for debugging
    'return_results': True,  # if True, return the results of the decoding analysis
    'units_query': query_string, #query string to filter units
}


for ephys_session in ephys_sessions[:]:
    # if os.path.exists(savepath + '/' + ephys_session.id[:17] + '_' + filename + '.pkl'): 
    #     print(ephys_session.id[:17] + ' completed, skipping...')    
    #     continue
    # try:
        # session = DynamicRoutingSession(ephys_session.id)
        # print(session.id+' loaded')
        # if 'structure' in session.electrodes[:].columns:
        session_info=ephys_session
        session_id=str(session_info.id)+'_0'
        trials=pd.read_parquet(
            npc_lims.get_cache_path('trials',session_id,'any')
        )
        units=pd.read_parquet(
            npc_lims.get_cache_path('units',session_id,'any')
        )

        sel_units = recalculated_unit_metrics.query(query_string)['unit_id'].tolist()
        units=units.query("unit_id in @sel_units")
        
        # results=decoding_utils.decode_stimulus_across_context(session=None,params=params,trials=trials,units=units,session_info=session_info)
        
        decoding_utils.decode_context_with_linear_shift(session=None,params=params,trials=trials,units=units,session_info=session_info)

        #find path of decoder result
        file_path= savepath / (ephys_session.id[:17] + '_' + filename)

        decoding_results=decoding_utils.concat_decoder_results(file_path,savepath=savepath,return_table=True,single_session=True)

        #find n_units to loop through for next step
        if decoding_results is not None:
            n_units=[]
            for col in decoding_results.filter(like='true_accuracy_').columns.values:
                if len(col.split('_'))==3:
                    temp_n_units=col.split('_')[2]
                    try:
                        n_units.append(int(temp_n_units))
                    except:
                        n_units.append(temp_n_units)
                else:
                    n_units.append(None)

            for nu in n_units:
                decoding_utils.concat_trialwise_decoder_results(file_path,savepath=savepath,return_table=False,n_units=nu,single_session=True)

        # else:
        #     print('no structure column found in electrodes table, moving to next recording')
        # session=[]
    # except Exception as e:
        # except_list[session.id]=repr(e)


PermissionError: [Errno 13] Permission denied: '\\\\allen\\programs\\mindscope\\workgroups\\templeton\\TTOC\\decoding results\\test_redefined_metrics\\presence_ratio\\733891_2024-09-19_test_redefined_metrics_presence_ratio.pkl'

In [29]:
session_id='733891_2024-09-19_0'
units=pd.read_parquet(
    npc_lims.get_cache_path('units',session_id,'any')
)
sel_units = recalculated_unit_metrics.query(query_string)['unit_id'].tolist()
units.query("unit_id in @sel_units")

Unnamed: 0_level_0,amplitude_cutoff,amplitude_cv_median,amplitude_cv_range,amplitude_median,drift_ptp,drift_std,drift_mad,firing_range,firing_rate,isi_violations_ratio,...,structure,location,peak_electrode,spike_times,obs_intervals,device_name,session_idx,date,subject_id,session_id
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
396,0.000217,,,95.939995,9.028496,1.762204,3.166011,9.40,3.023711,0.131366,...,ACB,ACB,14,"[238.64865587916546, 280.74446711942346, 290.7...","[[20.262055945945377, 7791.154417547969]]",20097906812,0,2024-09-19,733891,733891_2024-09-19
2914,0.000132,,,67.860000,14.939842,1.531747,1.367935,7.00,2.748071,0.062480,...,ACB,ACB,1938,"[20.607533333942683, 20.619633289694324, 20.61...","[[20.29100115813666, 7791.168083818202]]",22175718152,0,2024-09-19,733891,733891_2024-09-19
405,0.000238,0.237265,0.100399,49.140000,6.583939,1.617412,1.364893,2.60,3.609356,0.046097,...,ACB,ACB,29,"[20.2966225123565, 20.445022081944085, 20.5138...","[[20.262055945945377, 7791.154417547969]]",20097906812,0,2024-09-19,733891,733891_2024-09-19
2915,0.000025,0.162257,0.155845,107.640000,7.023911,2.139738,2.110443,3.80,5.950915,0.012113,...,ACB,ACB,1946,"[20.384467483006087, 20.515100338627803, 20.67...","[[20.29100115813666, 7791.168083818202]]",22175718152,0,2024-09-19,733891,733891_2024-09-19
2928,0.000016,0.219422,0.147396,79.560000,5.756775,0.882377,1.247640,5.20,3.600095,0.006619,...,ACB,ACB,1973,"[20.57300012689392, 22.029894799172133, 22.760...","[[20.29100115813666, 7791.168083818202]]",22175718152,0,2024-09-19,733891,733891_2024-09-19
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1626,0.000125,,,65.520000,,,,7.00,2.265869,0.000000,...,ZI,ZI,783,"[39.122146614189674, 41.17939279964196, 44.203...","[[20.2620343787872, 7791.156370107085]]",18005120322,0,2024-09-19,733891,733891_2024-09-19
1621,0.000100,0.155377,0.084708,100.619995,6.853111,0.979612,1.387127,12.20,5.190728,0.307259,...,ZI,ZI,770,"[21.43772267405094, 21.60658765955762, 21.7934...","[[20.2620343787872, 7791.156370107085]]",18005120322,0,2024-09-19,733891,733891_2024-09-19
1625,0.000047,,,79.560000,9.281208,1.365571,2.330336,19.40,7.635081,0.027961,...,ZI,ZI,778,"[25.238251503998647, 28.066323348717013, 29.27...","[[20.2620343787872, 7791.156370107085]]",18005120322,0,2024-09-19,733891,733891_2024-09-19
1627,0.000007,0.252908,0.110103,58.499996,11.944893,3.645960,3.387148,12.40,8.333321,0.017913,...,ZI,ZI,789,"[20.69866336519699, 21.305190660159393, 21.375...","[[20.2620343787872, 7791.156370107085]]",18005120322,0,2024-09-19,733891,733891_2024-09-19


In [28]:
sel_units

[]

In [25]:
# units

In [24]:
# session_id='712815_2024-05-22_0'
# recalculated_unit_metrics.query('sliding_rp_violation<=0.1 and amplitude_cutoff<=0.1 and presence_ratio_task>=0.99 and session_id==@session_id')

In [None]:
sel_area='ORBl'

vis_stim_vis_context=results['703333_2024-04-08']['results'][sel_area]['predict_vis_stim_vis_context']['all'][0]['balanced_accuracy_test']
vis_stim_aud_context=results['703333_2024-04-08']['results'][sel_area]['predict_vis_stim_aud_context']['all'][0]['balanced_accuracy_test']
aud_stim_aud_context=results['703333_2024-04-08']['results'][sel_area]['predict_aud_stim_aud_context']['all'][0]['balanced_accuracy_test']
aud_stim_vis_context=results['703333_2024-04-08']['results'][sel_area]['predict_aud_stim_vis_context']['all'][0]['balanced_accuracy_test']

fig,ax=plt.subplots(1,1)
ax.bar(np.arange(0,4),[vis_stim_vis_context,vis_stim_aud_context,aud_stim_aud_context,aud_stim_vis_context])
ax.axhline(0.5,color='black',linestyle='--',alpha=0.5)
ax.set_xticks(np.arange(0,4))
ax.set_xticklabels(['vis_stim_vis_context','vis_stim_aud_context','aud_stim_aud_context','aud_stim_vis_context'],rotation=45)
ax.set_ylabel('balanced_accuracy')
ax.set_ylim(0,1)

ax.set_title('Decoding accuracy for '+sel_area)


In [None]:
units['structure'].unique()