In [None]:
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import npc_lims
from npc_sessions import DynamicRoutingSession, get_sessions
from dynamic_routing_analysis import spike_utils


In [None]:
tracked_sessions: tuple[npc_lims.SessionInfo, ...] = npc_lims.get_session_info()
tracked_sessions[0]

In [None]:
session = DynamicRoutingSession('668755_2023-08-30')


In [None]:
session.id

In [None]:
if 'structure' in session.electrodes[:].columns:
    print(session.electrodes[:]['structure'].unique())
else:
    print('no structure column found in electrodes table')

In [None]:
session.units[:].query('default_qc')['structure'].value_counts()

In [None]:
session.units[:].columns

In [None]:
session.trials[:].columns

In [None]:
#make trial aligned 3d spike tensor
time_before = 0.5
time_after = 1.0
binsize = 0.025
trial_da = spike_utils.make_neuron_time_trials_tensor(session.units, session.trials, time_before, time_after, binsize)

In [None]:
trial_da.shape

In [None]:
#calculate aud vs. vis context differences

vis_context_fr = trial_da.sel(trials=session.trials[:].query('is_vis_context').index,
                              time=slice(-0.2,0)).mean(dim=['trials','time'])

aud_context_fr = trial_da.sel(trials=session.trials[:].query('is_aud_context').index,
                              time=slice(-0.2,0)).mean(dim=['trials','time'])

vis_vs_aud_diff = vis_context_fr - aud_context_fr


In [None]:
session.trials[:].columns

In [None]:
session.trials[:]['stim_name'].unique()

In [None]:
fig,ax=plt.subplots(1,1)
ax.hist(vis_vs_aud_diff, bins=np.arange(-10,10,0.5))
ax.set_xlabel('vis - aud context baseline FR')
ax.set_ylabel('unit count')

In [None]:
##plot example units with context differences -- subplot for each stimulus

sel_unit=np.random.permutation(np.where(vis_vs_aud_diff>5)[0])[0]
# sel_unit=session.units[:].query('structure.str.contains("ORB") and firing_rate>=5').index.values[5]
# sel_unit=session.units[:].query('structure.str.contains("AI") and firing_rate>=4').index.values[0]

fig,ax=plt.subplots(2,2)

ax=ax.flatten()

stims=['vis1','vis2','sound1','sound2']

for st,stim in enumerate(stims):

    stim_trials=session.trials[:].query('stim_name==@stim')

    vis_context_spikes=trial_da.sel(
        trials=stim_trials.query('is_vis_context').index,
        unit_id=sel_unit,).mean(dim=['trials'])

    aud_context_spikes=trial_da.sel(
        trials=stim_trials.query('is_aud_context').index,
        unit_id=sel_unit,).mean(dim=['trials'])

    ax[st].plot(vis_context_spikes.time, vis_context_spikes.values, label='vis context',color='g')
    ax[st].plot(aud_context_spikes.time, aud_context_spikes.values, label='aud context',color='b')
    ax[st].axvline(0, color='k', linestyle='--')
    ax[st].set_title(stim)
    ax[st].legend()

    fig.suptitle('unit '+session.units[:].iloc[sel_unit]['unit_id']+'; '+session.units[:].iloc[sel_unit]['structure'])

    fig.tight_layout()

In [None]:
session.units[:]['structure'].unique()

In [None]:
session.units[:].columns

In [None]:
#plot distribution of vis vs. aud context fr differences for each area
xbins=np.arange(-10,10,1)

for area in session.units[:]['structure'].unique():

    fig,ax=plt.subplots(1,1)

    area_vis_vs_aud_diff = vis_vs_aud_diff[session.units[:].query('structure==@area').index.values]

    ax.hist(area_vis_vs_aud_diff, bins=xbins)
    ax.set_xlabel('vis - aud context baseline FR')
    ax.set_ylabel('unit count')
    ax.set_title(area)

    fig.tight_layout()