# Decode context from spikes or facemap

1 - either use all annotated & uploaded ephys sessions as input or provide a list of session_ids

2 - set a savepath and filename for the output - one .pkl file per session

3 - set parameters - descriptions below

4 - run decoding!

In [1]:
import sys
# sys.path.append(r"C:\Users\shailaja.akella\Dropbox (Personal)\DR\dynamic_routing_analysis_ethan\src")

import npc_lims
from dynamic_routing_analysis import decoding_utils, path_utils
from npc_sessions import DynamicRoutingSession
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import os
import pandas as pd
import upath

%load_ext autoreload
%autoreload 2

In [None]:
#1A get all uploaded & annotated ephys sessions

ephys_sessions = tuple(s for s in npc_lims.get_session_info(is_ephys=True, is_uploaded=True, 
                                                            is_annotated=True, project='DynamicRouting', issues = []))

In [None]:
#1B alternatively, provide a list of session ids:
# session_id_list=['733891_2024-09-19','712815_2024-05-22','708016_2024-05-01','664851_2023-11-14','702136_2024-03-05','686176_2023-12-05']
# session_id_list=['703333_2024-04-08']
session_id_list=['667252_2023-09-25'] #only 4 blocks - test error handling
session_list=[]
for ss in session_id_list:
    session_list.append(npc_lims.get_session_info(ss))
ephys_sessions=tuple(session_list)
ephys_sessions

In [None]:
#2 set savepath and filename
savepath=upath.UPath(r"\\allen\programs\mindscope\workgroups\templeton\TTOC\decoding results\test_logger")
filename='test_logger.pkl'

# filename='2024_10_28'
# savepath = path_utils.DECODING_ROOT_PATH / 'decoding_test_2024_10_28'

except_list={}

#3 set parameters
#linear shift decoding currently just takes the average firing rate over all bins defined here
spikes_binsize=0.2 #bin size in seconds
spikes_time_before=0.2 #time before the stimulus per trial
spikes_time_after=0.01 #time after the stimulus per trial

# #not used for linear shift decoding, were used in a previous iteration of decoding analysis
# decoder_binsize=0.2
# decoder_time_before=0.2
# decoder_time_after=0.1


params = {
    'n_units': ['all'], #number of units to sample for each area (list)
    'n_repeats': 25,  # number of times to repeat decoding with different randomly sampled units
    'input_data_type': 'spikes',  # spikes or facemap or LP
    'vid_angle_facemotion': 'face', # behavior, face, eye
    'vid_angle_LP': 'behavior',
    'central_section': '4_blocks_plus',
    'predict': 'context', # 'context' or 'vis_appropriate_response'
    # for linear shift decoding, how many trials to use for the shift. '4_blocks_plus' is best
    'exclude_cue_trials': False,  # option to totally exclude autorewarded trials
    'n_unit_threshold': 20,  # minimum number of units to include an area in the analysis
    'keep_n_SVDs': 500,  # number of SVD components to keep for facemap data
    'LP_parts_to_keep': ['ear_base_l', 'eye_bottom_l', 'jaw', 'nose_tip', 'whisker_pad_l_side'],
    'spikes_binsize': spikes_binsize,
    'spikes_time_before': spikes_time_before,
    'spikes_time_after': spikes_time_after,
    # 'decoder_binsize':decoder_binsize,
    # 'decoder_time_before':decoder_time_before,
    # 'decoder_time_after':decoder_time_after,
    'savepath': savepath,
    'filename': filename,
    'use_structure_probe': True,  # if True, append probe name to area name when multiple probes in the same area
    'crossval': '5_fold_constant',  # '5_fold', '5_fold_constant', or 'blockwise' - blockwise untested with linear shift
    'labels_as_index': True,  # convert labels (context names) to index [0,1]
    'decoder_type': 'LogisticRegression',  # 'linearSVC' or 'LDA' or 'RandomForest' or 'LogisticRegression'
    'only_use_all_units': True, #if True, do not run decoding with different areas, only with all areas -- for debugging
}


for ephys_session in ephys_sessions[:1]:
    # if os.path.exists(savepath + '/' + ephys_session.id[:17] + '_' + filename + '.pkl'): 
    #     print(ephys_session.id[:17] + ' completed, skipping...')    
    #     continue
    # try:
        session = DynamicRoutingSession(ephys_session.id)
        print(session.id+' loaded')
        if 'structure' in session.electrodes[:].columns:
            session_info=ephys_session
            session_id=str(session_info.id)
            trials=pd.read_parquet(
                npc_lims.get_cache_path('trials',session_id,'any')
            )
            units=pd.read_parquet(
                npc_lims.get_cache_path('units',session_id,'any')
            )
            decoding_utils.decode_context_with_linear_shift(session=None,params=params,trials=trials,units=units,session_info=session_info)

            #find path of decoder result
            file_path= savepath / (ephys_session.id[:17] + '_' + filename)

            decoding_results=decoding_utils.concat_decoder_results(file_path,savepath=savepath,return_table=True,single_session=True)

            #find n_units to loop through for next step
            if decoding_results is not None:
                n_units=[]
                for col in decoding_results.filter(like='true_accuracy_').columns.values:
                    if len(col.split('_'))==3:
                        temp_n_units=col.split('_')[2]
                        try:
                            n_units.append(int(temp_n_units))
                        except:
                            n_units.append(temp_n_units)
                    else:
                        n_units.append(None)

                for nu in n_units:
                    decoding_utils.concat_trialwise_decoder_results(file_path,savepath=savepath,return_table=False,n_units=nu,single_session=True)

        else:
            print('no structure column found in electrodes table, moving to next recording')
        session=[]
    # except Exception as e:
        # except_list[session.id]=repr(e)


In [None]:
# savepath=upath.UPath(r"\\allen\programs\mindscope\workgroups\templeton\TTOC\decoding results\predict_appropriate_response")
# decoding_utils.concat_decoder_results(file_path,savepath=savepath,return_table=True,single_session=True)
# decoding_utils.concat_trialwise_decoder_results(file_path,savepath=savepath,return_table=False,n_units=nu,single_session=True)

In [None]:
# savepath=r"\\allen\programs\mindscope\workgroups\templeton\TTOC\decoding results\test"
# filename='decoding_results_test'

# decoding_utils.concat_decoder_summary_tables(savepath)

In [3]:
import upath

path=[]
del path
# path=path_utils.DECODING_ROOT_PATH

# path=path_utils.DECODING_ROOT_PATH / 'n_units_test_2024-11-06T00:18:30.855494'
# path=path_utils.DECODING_ROOT_PATH / 'n_units_test_medium_unit_criteria_2024-11-07T00:47:13.551561'
# path=path_utils.DECODING_ROOT_PATH / 'full_test_LDA_medcrit_2024-11-09T00:33:11.111162'
# path=path_utils.DECODING_ROOT_PATH / 'full_test_logreg_medcrit_2024-11-11T18:39:59.601162' / 're_run'
# path=path_utils.DECODING_ROOT_PATH / 'full_logreg_medcrit_2024-11-26T16:45:38.919765'###old###
# path=path_utils.DECODING_ROOT_PATH / 'full_logreg_medcrit_2024-11-26T23:54:35.702811'
# path=path_utils.DECODING_ROOT_PATH / 'full_logreg_medcrit_2024-11-26T23:54:35.702811' / 'summary_re_run_0'
# path=path_utils.DECODING_ROOT_PATH / 'full_logreg_medcrit_2024-11-26T23:54:35.702811' / 'summary_re_run_1'
path=path_utils.DECODING_ROOT_PATH / 'logreg_many_nunits_0_2024-12-10-0'

# filename='decoding_results_test_2024_10_28.pkl'
all_paths = []
all_filenames = []
csvs = []
# all_paths_0 = []
# all_filenames_0 = []
# csvs_0 = []
for file in path.iterdir():
    # if file.is_file():
    all_paths.append(file)
    print(file)
    all_filenames.append(file.name)
    if 'results.csv' in str(file):
        csvs.append(file)

# for file in path_0.iterdir():
#     # if file.is_file():
#     all_paths_0.append(file)
#     print(file)
#     all_filenames_0.append(file.name)
#     if 'results.csv' in str(file):
#         csvs_0.append(file)


s3://aind-scratch-data/dynamic-routing/ethan/decoding-results/logreg_many_nunits_0_2024-12-10-0/667252_2023-09-25_2024-12-10-0.json


In [None]:
# len(all_paths)/21
len(csvs)

# path.name.split('T17')[1]

In [None]:
mouseid='686176'
date='2023-12-07'
session_paths=[]
for file in path.iterdir():
    if mouseid in str(file) and date in str(file):
        session_paths.append(file)


In [None]:
session_paths[0]

In [None]:
#load decoding results from pickle
import pickle
results=pickle.loads(upath.UPath(session_paths[0]).read_bytes())

In [None]:
sel_area='SNr'
test_array=results['686176_2023-12-07']['results'][sel_area]['no_shift']['all'][0]['predict_proba'][:,1]
fig,ax=plt.subplots(1,1)
ax.plot(test_array)
ax.set_xlabel('trial')
ax.set_ylabel('predict_proba')
ax.set_title(sel_area)

In [None]:
all_confidence=pd.read_pickle(r"D:\decoding_results_from_CO\logreg_2024-11-27_re_concat_1\decoder_confidence_all_trials_all_units.pkl")

In [None]:
np.mean(np.hstack(all_confidence['predict_proba'].values)<0)

In [None]:
all_confidence

In [None]:
# savepath = path
# file_path = savepath / f"703333_2024-04-08_.pkl"

# decoding_results=decoding_utils.concat_decoder_results(file_path,savepath=savepath,return_table=True,single_session=True)

# #find n_units to loop through for next step
# n_units=[]
# for col in decoding_results.filter(like='true_accuracy_').columns.values:
#     if len(col.split('_'))==3:
#         temp_n_units=col.split('_')[2]
#         try:
#             n_units.append(int(temp_n_units))
#         except:
#             n_units.append(temp_n_units)
#     else:
#         n_units.append(None)

# decoding_results=[]

# for nu in n_units:
#     decoding_utils.concat_trialwise_decoder_results(file_path,savepath=savepath,return_table=False,n_units=nu,single_session=True)


In [None]:
filepath=path / '703333_2024-04-08_linear_shift_decoding_results.csv'

decoding_results=pd.read_csv(filepath)

decoding_results

In [None]:
#compare different n units



In [None]:
# import pickle
# #copy files if do not exist
# for temp_path in all_paths_0:
#     if 'T17' in temp_path.name:
#         temp_name = temp_path.name.split('T17')[0] + '2024-10-29T17' + path.name.split('T17')[1]
#     else:
#         temp_name = temp_path.name
#     if temp_name not in all_filenames:
#         # print(path.name)
#         print(temp_name)
#         if '.csv' in temp_name:
#             result=pd.read_csv(temp_path)
#             result.to_csv(path / temp_name, index=False)
#         elif '.pkl' in temp_name:
#             result=pickle.loads(upath.UPath(temp_path).read_bytes())
#             (path / temp_name).write_bytes(pickle.dumps(result, protocol=pickle.HIGHEST_PROTOCOL))
        

In [None]:
upath.UPath

In [None]:
file_path= upath.UPath('s3://aind-scratch-data/dynamic-routing/ethan/decoding-results/full_test_1_2024-10-29T17:46:03.914748/626791_2022-08-16_2024-10-29T17:46:03.914748.pkl')

In [None]:
file_path.is_file()

In [None]:
# savepath=r'D:\decoding_results_from_CO\n_units_test_2024-11-06'
# savepath=r'D:\decoding_results_from_CO\n_units_test_medium_unit_criteria_2024-11-07'
# savepath=r'D:\decoding_results_from_CO\lda_test_2024-11-11'
# savepath=r'D:\decoding_results_from_CO\logreg_test_2024-11-13'
savepath=r'D:\decoding_results_from_CO\logreg_2024-11-27_re_concat_1'
decoding_utils.concat_decoder_summary_tables(path,savepath)

In [None]:
test=pd.read_pickle(r"D:\decoding_results_from_CO\logreg_2024-11-27_re_concat_1\decoder_confidence_all_trials_all_units.pkl")

In [None]:
test['probe'].unique()

In [None]:
test['trial_index']