In [2]:
import numpy as np
import xarray as xr
import scipy.signal as sg
import pandas as pd
import scipy.stats as st
import matplotlib.pyplot as plt
from matplotlib import patches
import npc_lims
from npc_sessions import DynamicRoutingSession, get_sessions
from dynamic_routing_analysis import spike_utils
import os

%load_ext autoreload
%autoreload 2
%matplotlib widget


In [None]:
#find all DR sessions with ephys
DR_ephys_sessions = tuple(s for s in npc_lims.get_session_info(is_ephys=True) 
                          if s.is_uploaded and s.is_annotated and 
                          s.project=='DynamicRouting')

In [6]:
# session = DynamicRoutingSession(DR_ephys_sessions[2].id)
session = DynamicRoutingSession('668755_2023-08-31')
# DR_ephys_sessions[2].id

In [14]:
session.trials[:].query('trial_index_in_block<5')

Unnamed: 0_level_0,start_time,stop_time,quiescent_start_time,quiescent_stop_time,stim_start_time,stim_stop_time,opto_start_time,opto_stop_time,response_window_start_time,response_window_stop_time,...,is_aud_target,is_vis_target,is_nontarget,is_aud_nontarget,is_vis_nontarget,is_vis_context,is_aud_context,is_context_switch,is_repeat,is_opto
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
0,2460.26425,2465.78546,2460.26425,2461.812369,2461.812369,2462.312779,,,2461.86548,2462.78304,...,False,True,False,False,False,True,False,False,False,False
1,2466.23591,2471.77383,2466.23591,2467.784039,2467.784039,2468.284459,,,2467.8372,2468.7547,...,False,True,False,False,False,True,False,False,False,False
2,2472.15755,2477.69545,2472.15755,2473.706659,2473.706659,2474.207069,,,2473.75881,2474.67621,...,False,True,False,False,False,True,False,False,False,False
3,2477.7289,2483.25015,2477.7289,2479.279011,2479.279011,2479.779471,,,2479.33018,2480.24755,...,False,True,False,False,False,True,False,False,False,False
4,2485.91899,2491.45698,2485.91899,2487.466841,2487.466841,2487.967241,,,2487.52038,2488.43776,...,False,True,False,False,False,True,False,False,False,False
86,3060.61747,3066.15553,3060.61747,3062.14267,3062.14267,3062.64267,,,3062.21884,3063.13621,...,True,False,False,False,False,False,True,True,False,False
87,3067.12287,3072.66095,3067.12287,3068.64852,3068.64852,3069.14852,,,3068.72428,3069.6417,...,True,False,False,False,False,False,True,False,False,False
88,3073.77853,3079.33312,3073.77853,3075.32057,3075.32057,3075.82057,,,3075.39647,3076.31387,...,True,False,False,False,False,False,True,False,False,False
89,3081.3683,3086.90626,3081.3683,3082.89353,3082.89353,3083.39353,,,3082.96955,3083.88691,...,True,False,False,False,False,False,True,False,False,False
90,3090.49246,3096.04724,3090.49246,3092.03518,3092.03518,3092.53518,,,3092.11053,3093.02799,...,True,False,False,False,False,False,True,False,False,False


In [None]:
#make trial aligned 3d spike tensor
time_before = 0.5
time_after = 1.0
binsize = 0.001
trial_da = spike_utils.make_neuron_time_trials_tensor(session.units, session.trials, time_before, time_after, binsize)

In [None]:
#compute FR differences between vis & aud context baseline
vis_baseline = trial_da.sel(time=slice(-0.15,0.0),trials=session.trials[:].query('is_vis_context').index.values).mean(['time','trials'])
aud_baseline = trial_da.sel(time=slice(-0.15,0.0),trials=session.trials[:].query('is_aud_context').index.values).mean(['time','trials'])

#compute FR differences between vis & aud stimuli (target only, or all stimuli)
vis_stim_fr = trial_da.sel(time=slice(0.0,0.15),trials=session.trials[:].query('is_vis_target').index.values).mean(['time','trials'])
aud_stim_fr = trial_da.sel(time=slice(0.0,0.15),trials=session.trials[:].query('is_aud_target').index.values).mean(['time','trials'])


In [None]:
fig,ax=plt.subplots(1,1)
ax.hist(vis_baseline-aud_baseline,bins=np.arange(-10,10,0.5))
ax.set_xlabel('baseline FR difference (vis-aud)')

In [None]:
fig,ax=plt.subplots(1,1)
ax.hist(vis_stim_fr-aud_stim_fr,bins=np.arange(-10,10,0.5))
ax.set_xlabel('stim FR difference (vis-aud)')

In [None]:
#compute FR differences between vis & aud context baseline
vis_baseline = trial_da.sel(time=slice(-0.15,0.0),trials=session.trials[:].query('is_vis_context').index.values).mean(['time','trials'])
aud_baseline = trial_da.sel(time=slice(-0.15,0.0),trials=session.trials[:].query('is_aud_context').index.values).mean(['time','trials'])

#compute FR differences between vis & aud stimuli (target only, or all stimuli)
vis_stim_fr = trial_da.sel(time=slice(0.0,0.15),trials=session.trials[:].query('is_vis_target').index.values).mean(['time','trials'])
aud_stim_fr = trial_da.sel(time=slice(0.0,0.15),trials=session.trials[:].query('is_aud_target').index.values).mean(['time','trials'])

# r,p=st.pearsonr((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline),
#                 (vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr))

fig,ax=plt.subplots(1,1)
ax.axvline(0,color='k',linestyle='--')
ax.axhline(0,color='k',linestyle='--')

ax.plot((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline),
        (vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr),'k.',alpha=0.2)
# ax.plot(vis_baseline-aud_baseline,vis_stim_fr-aud_stim_fr,'k.',alpha=0.2)
ax.set_xlabel('baseline FR difference (vis-aud)')
ax.set_ylabel('stim FR difference (vis-aud)')

ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
# ax.set_xlim([-7,7])
# ax.set_ylim([-20,20])
# ax.set_title('r={:.2f}, p={:.4e}'.format(r,p))
ax.set_title('all units')

In [None]:
n_pos_pos=np.mean(((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline)>0)&
                 ((vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr)>0))
n_pos_neg=np.mean(((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline)>0)&
                 ((vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr)<0))
n_neg_pos=np.mean(((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline)<0)&
                 ((vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr)>0))
n_neg_neg=np.mean(((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline)<0)&
                 ((vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr)<0))

[n_pos_pos,n_pos_neg,n_neg_pos,n_neg_neg]

In [None]:
fig,ax=plt.subplots(1,1)
ax.axvline(0,color='k',linestyle='--')
ax.axhline(0,color='k',linestyle='--')

ax.plot((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline),
        (vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr),'k.',alpha=0.2)

ax.text(0.5,0.5,'{:.2f}%'.format(n_pos_pos*100),color='r')
ax.text(0.5,-0.5,'{:.2f}%'.format(n_pos_neg*100),color='r')
ax.text(-0.5,0.5,'{:.2f}%'.format(n_neg_pos*100),color='r')
ax.text(-0.5,-0.5,'{:.2f}%'.format(n_neg_neg*100),color='r')

ax.set_xlabel('baseline FR difference (vis-aud)')
ax.set_ylabel('stim FR difference (vis-aud)')

ax.set_xlim([-1,1])
ax.set_ylim([-1,1])

ax.set_title('all units')

In [None]:
(vis_baseline.values-aud_baseline.values)/(vis_baseline.values+aud_baseline.values)

In [None]:
session.units[:]['structure'].unique()

In [None]:
sel_area='ACAd'

#compute FR differences between vis & aud context baseline
vis_baseline = trial_da.sel(time=slice(-0.15,0.0),
                            trials=session.trials[:].query('is_vis_context').index.values,
                            unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])
aud_baseline = trial_da.sel(time=slice(-0.15,0.0),
                            trials=session.trials[:].query('is_aud_context').index.values,
                            unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])

#compute FR differences between vis & aud stimuli (target only, or all stimuli)
vis_stim_fr = trial_da.sel(time=slice(0.0,0.15),
                           trials=session.trials[:].query('is_vis_target').index.values,
                           unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])
aud_stim_fr = trial_da.sel(time=slice(0.0,0.15),
                           trials=session.trials[:].query('is_aud_target').index.values,
                           unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])


# r,p=st.pearsonr(vis_baseline-aud_baseline,vis_stim_fr-aud_stim_fr)

fig,ax=plt.subplots(1,1)
ax.axvline(0,color='k',linestyle='--')
ax.axhline(0,color='k',linestyle='--')

# ax.plot(vis_baseline-aud_baseline,vis_stim_fr-aud_stim_fr,'k.',alpha=0.3)
ax.plot((vis_baseline-aud_baseline)/(vis_baseline+aud_baseline),
        (vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr),'k.',alpha=0.2)
ax.set_xlabel('baseline FR difference (vis-aud)')
ax.set_ylabel('stim FR difference (vis-aud)')

ax.set_xlim([-1,1])
ax.set_ylim([-1,1])
# ax.set_xlim([-7,7])
# ax.set_ylim([-20,20])
# ax.set_title(sel_area+' r={:.2f}, p={:.4e}'.format(r,p))
ax.set_title(sel_area+' units')

In [None]:
# #plot firing rates for all units
# fig,ax=plt.subplots(1,1)
# ax.plot(vis_baseline,vis_stim_fr,'k.',alpha=0.2)
# ax.plot(aud_baseline,aud_stim_fr,'r.',alpha=0.2)
# ax.set_xlabel('baseline FR')
# ax.set_ylabel('stim FR')
# ax.set_xlim([0,20])
# ax.set_ylim([0,20])
# ax.set_title(sel_area+' units')

#function to calculate vis vs. aud firing rate index
def compute_vis_aud_fr_index(trial_da,session,sel_area):
    #compute FR differences between vis & aud context baseline
    vis_baseline = trial_da.sel(time=slice(-0.15,0.0),
                                trials=session.trials[:].query('is_vis_context').index.values,
                                unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])
    aud_baseline = trial_da.sel(time=slice(-0.15,0.0),
                                trials=session.trials[:].query('is_aud_context').index.values,
                                unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])

    #compute FR differences between vis & aud stimuli (target only, or all stimuli)
    vis_stim_fr = trial_da.sel(time=slice(0.0,0.15),
                               trials=session.trials[:].query('is_vis_target').index.values,
                               unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])
    aud_stim_fr = trial_da.sel(time=slice(0.0,0.15),
                               trials=session.trials[:].query('is_aud_target').index.values,
                               unit_id=session.units[:].query('structure==@sel_area')['unit_id'].values).mean(['time','trials'])

    vis_aud_fr_index = (vis_stim_fr-aud_stim_fr)/(vis_stim_fr+aud_stim_fr) - (vis_baseline-aud_baseline)/(vis_baseline+aud_baseline)
    return vis_aud_fr_index


In [None]:
session.units[:].query('structure==@sel_area')['unit_id'].values

In [None]:
session.trials[:].columns

In [None]:
#significance test for FR difference between vis & aud context baseline

vis_baseline_by_trial = trial_da.sel(time=slice(-0.5,0.0),
                            trials=session.trials[:].query('is_vis_context').index.values).mean(['time'])
aud_baseline_by_trial = trial_da.sel(time=slice(-0.5,0.0),
                            trials=session.trials[:].query('is_aud_context').index.values).mean(['time'])

In [None]:
p_aud_vs_vis=np.full(len(session.units[:]),np.nan)
for uu,unit in enumerate(session.units[:]['unit_id'].values):
    _,p_aud_vs_vis[uu] = st.ranksums(aud_baseline_by_trial.sel(unit_id=unit),vis_baseline_by_trial.sel(unit_id=unit))


In [None]:
# session.units[:]['unit_id'].values
# aud_baseline_by_trial.sel(unit_id=unit)
# unit
# trial_da
# p_aud_vs_vis

fraction_context_modulated_units={}

for aa in session.units[:]['structure'].unique():

    p_aud_vs_vis=np.full(len(session.units[:].query('structure==@aa')),np.nan)
    for uu,unit in enumerate(session.units[:].query('structure==@aa')['unit_id'].values):
        _,p_aud_vs_vis[uu] = st.ranksums(aud_baseline_by_trial.sel(unit_id=unit),vis_baseline_by_trial.sel(unit_id=unit))
    
    fraction_context_modulated_units[aa]=np.sum(p_aud_vs_vis<0.01)/len(p_aud_vs_vis)



In [None]:
fraction_context_modulated_units

In [None]:
# structure_probe=spike_utils.get_structure_probe(session)

fraction_context_modulated_units={}

for aa in structure_probe['structure_probe'].unique():

    p_aud_vs_vis=np.full(len(structure_probe.query('structure_probe==@aa')),np.nan)
    for uu,unit in enumerate(structure_probe.query('structure_probe==@aa')['unit_id'].values):
        _,p_aud_vs_vis[uu] = st.ranksums(aud_baseline_by_trial.sel(unit_id=unit),vis_baseline_by_trial.sel(unit_id=unit))
    
    fraction_context_modulated_units[aa]=np.sum(p_aud_vs_vis<0.01)/len(p_aud_vs_vis)

In [None]:
fraction_context_modulated_units

In [None]:
session.units[:].columns

In [None]:
unique_areas

In [None]:
session.units[:]['group_name'][0]

In [None]:
unique_areas=session.units[:]['structure'].unique()

structure_probe=np.full(len(session.units[:]),'',dtype=object)

for aa in unique_areas:
    unique_probes=session.units[:].query('structure==@aa')['group_name'].unique()

    if len(unique_probes)>1:
        for up in unique_probes:
            unit_idx=session.units[:].query('structure==@aa and group_name==@up').index.values
            structure_probe[unit_idx]=aa+'_'+up
    elif len(unique_probes)==1:
        unit_idx=session.units[:].query('structure==@aa').index.values
        structure_probe[unit_idx]=aa
    else:
        print('no units in '+aa)

structure_probe=pd.DataFrame({
    'structure_probe':structure_probe,
    'unit_id':session.units[:]['unit_id']},index=session.units[:].index.values)
structure_probe



In [None]:
session=DynamicRoutingSession('674562_2023-10-04')

In [None]:
# structure_probe=pd.DataFrame({'structure_probe':structure_probe},index=session.units[:].index.values)
# structure_probe
session_probes=session.sorted_channel_indices.keys()

fig,ax=plt.subplots(1,1)
for pr,probe in enumerate(session_probes):
    # ax.vlines(session.sorted_channel_indices[probe],pr,pr+1)
    ax.vlines(list(set(range(0,384))-set(session.sorted_channel_indices[probe])),pr,pr+1)
    # ax.plot(session.sorted_channel_indices[probe],np.ones(len(session.sorted_channel_indices[probe]),)*pr,'.-')

ax.set_yticks(np.arange(len(session_probes))+0.5)
ax.set_yticklabels(session_probes)
ax.set_title(session.id)
ax.axvline(0,color='k',linestyle='--')
ax.axvline(383,color='k',linestyle='--')

In [None]:
manual_missing_channels={
    'A':['AP337' 'AP339' 'AP340' 'AP341' 'AP342' 'AP343' 'AP344' 'AP361' 'AP362'
        'AP363' 'AP364' 'AP365' 'AP366' 'AP367' 'AP368' 'AP369' 'AP370' 'AP371'
        'AP372' 'AP373' 'AP374' 'AP375' 'AP376' 'AP377' 'AP378' 'AP379' 'AP380'] ,
    
}

In [None]:
missing_channels={}

for pr,probe in enumerate(session_probes):
    missing_channels[probe]=list(set(range(0,384))-set(session.sorted_channel_indices[probe]))

missing_channels

In [None]:
np.arange(len(session_probes))+0.5

In [None]:
#find all DR sessions with ephys
DR_ephys_sessions = tuple(s for s in npc_lims.get_session_info(is_ephys=True) 
                          if s.is_uploaded and s.is_annotated and 
                          s.project=='DynamicRouting')

In [None]:
# npc_lims.get_session_info(is_ephys=True) 
session.id

In [None]:
all_missing_channels=[]
except_list={}
for DR_session in DR_ephys_sessions[:]:
    try:
        session = DynamicRoutingSession(DR_session.id)
        session_probes=session.sorted_channel_indices.keys()
        fig,ax=plt.subplots(1,1)
        for pr,probe in enumerate(session_probes):
            all_missing_channels.append(list(set(range(0,384))-set(session.sorted_channel_indices[probe])))
            ax.vlines(list(set(range(0,384))-set(session.sorted_channel_indices[probe])),pr,pr+1)
        
        ax.set_yticks(np.arange(len(session_probes))+0.5)
        ax.set_yticklabels(session_probes)
        ax.axvline(0,color='k',linestyle='--')
        ax.axvline(383,color='k',linestyle='--')
        ax.set_title(session.id)
    except Exception as e:
        except_list[session.id]=e
    

In [None]:
all_missing_channels_stack=np.hstack(all_missing_channels)
fig,ax=plt.subplots(1,1)
hist,bin_edges=np.histogram(all_missing_channels_stack,bins=np.arange(0,384,1))
ax.bar(bin_edges[:-1],hist/len(all_missing_channels),width=1)
ax.set_xlabel('channel number')
ax.set_ylabel('how often channel was missing')

In [None]:
all_missing_channels_stack=np.hstack(all_missing_channels)
fig,ax=plt.subplots(1,1)
ax.hist(all_missing_channels_stack,bins=np.arange(0,384,1))
ax.set_xlabel('channel number')
ax.set_ylabel('number of insertions with missing channel')

In [None]:
hist,bin_edges

In [None]:
len(all_missing_channels)

In [None]:
set(range(0,384))-set(session.sorted_channel_indices[probe])

In [None]:
np.unique(structure_probe)

In [None]:
session.performance[:]

In [None]:
session.units[:].query('structure=="MOs"')['group_name'].unique()[:]

In [None]:
session.trials[:].columns

In [None]:
session.trials[:]['reward_time'].iloc[0]

In [None]:
 #make train, test splits based on block number

block_number=session.trials[:]['block_index'].values

train=[]
test=[]
block_numbers=np.unique(block_number)
for bb in block_numbers:
    not_block_inds=np.where(block_number!=bb)[0]
    train.append(not_block_inds)
    block_inds=np.where(block_number==bb)[0]
    test.append(block_inds)

In [None]:
train[0]

In [None]:
test[0]

In [None]:
session.trials[:].columns

In [None]:
# np.arange(0, 0.1, 0.1)

In [None]:
# bin_size=0.1
# timebin_da=spike_utils.make_neuron_timebins_matrix(session.units[:], session.trials[:], bin_size)

In [None]:
# timebin_da

In [None]:
session = DynamicRoutingSession('668755_2023-08-30')

In [None]:
trials=session.trials[:]

In [None]:
bin_size=0.1
timebins_table,bins=spike_utils.make_timebins_table(trials, bin_size)

In [None]:
pd.DataFrame.from_dict(timebins_table)