In [1]:
import numpy as np
import scipy.stats as stats
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
from dynamic_routing_analysis import spike_utils
import npc_lims
import os

import matplotlib
import matplotlib.font_manager as fm

matplotlib.rcParams['font.size'] = 8
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42
fm.FontProperties().set_family('arial')

%load_ext autoreload
%autoreload 2
# %matplotlib inline
%matplotlib widget

In [None]:
# 728917_2024-10-04 - VGAT mouse, ORB inactivation
# 741611_2024-09-11, 741611_2024-09-12 - PV-Cre mouse, PL inactivation
# 748098_2024-10-28 - PV-Cre mouse, PL inactivation
# 762124_2024-12-02 - Chrimson in ORB 

In [2]:
results_path=r's3://aind-scratch-data/dynamic-routing/decoding/results/v265_0_consolidated.parquet'

session_table_path=r"\\allen\programs\mindscope\workgroups\dynamicrouting\Ethan\CO decoding results\session_table_v0.265.csv"
session_table=pl.read_csv(session_table_path).to_pandas()

In [3]:
session_table.query("session_id=='728917_2024-10-04'")

Unnamed: 0,session_id,is_production,project,date,time,subject_id,subject_age_days,subject_sex,subject_genotype,implant,...,n_hits,n_contingent_rewards,n_responses,n_trials,is_first_block_aud,is_first_block_vis,is_engaged,is_good_behavior,is_bad_behavior,is_stage_5_passed
197,728917_2024-10-04,False,DynamicRouting,2024-10-04,11:50:09,728917,216,F,VGAT-ChR2-YFP/wt,,...,0.0;0.0;0.0;0.0;0.0;0.0,0.0;0.0;0.0;0.0;0.0;0.0,0.0;0.0;0.0;0.0;0.0;0.0,92.0;91.0;94.0;93.0;90.0;91.0,False,True,False,False,False,True


In [3]:
session_id='728917_2024-10-04'
trials=pd.read_parquet(
    npc_lims.get_cache_path('trials',session_id)
)
units=pd.read_parquet(
    npc_lims.get_cache_path('units',session_id)
)
units=units.sort_values('peak_channel')

In [5]:
trials.columns

Index(['start_time', 'stop_time', 'quiescent_start_time',
       'quiescent_stop_time', 'stim_start_time', 'stim_stop_time',
       'opto_start_time', 'opto_stop_time', 'response_window_start_time',
       'response_window_stop_time', 'reward_time',
       'post_response_window_start_time', 'post_response_window_stop_time',
       'stim_name', 'block_index', 'rewarded_modality', 'trial_index',
       'trial_index_in_block', 'opto_wavelength', 'opto_location_bregma_x',
       'opto_location_bregma_y', 'opto_duration', 'opto_label', 'opto_power',
       'opto_stim_name', 'repeat_index', 'is_response', 'is_correct',
       'is_incorrect', 'is_hit', 'is_false_alarm', 'is_correct_reject',
       'is_miss', 'is_go', 'is_nogo', 'is_rewarded', 'is_noncontingent_reward',
       'is_contingent_reward', 'is_reward_scheduled', 'is_instruction',
       'is_aud_stim', 'is_vis_stim', 'is_catch', 'is_target', 'is_aud_target',
       'is_vis_target', 'is_nontarget', 'is_aud_nontarget', 'is_vis_nontarge

In [7]:
trials['opto_power'].unique()

array([nan, 5. , 1.5])

In [8]:
trials['is_opto'].mean()

0.49909255898366606

In [11]:
trials['opto_label'].unique()

array(['', 'probe_5mW', 'probe_1mW', 'lateral_5mW', 'posterior_5mW',
       'lateral_1mW', 'posterior_1mW'], dtype=object)

In [15]:
units['electrode_group_name'].unique()

array(['probeB'], dtype=object)

In [20]:
np.nanmean((trials['opto_start_time']-trials['stim_start_time']).values)


-0.026153908064850434

In [None]:
# units.sort_values('peak_channel')['peak_channel']

id
0        1
1        1
311      1
15       1
2        3
      ... 
316    358
651    358
531    358
630    371
309    371
Name: peak_channel, Length: 654, dtype: int64

In [4]:
spikes_time_before=0.5
spikes_time_after=1.0
spikes_binsize=0.1
trial_da = spike_utils.make_neuron_time_trials_tensor(units, trials, spikes_time_before, spikes_time_after, spikes_binsize)

In [29]:
# trials[['opto_power','opto_label']].value_counts()
trials['opto_label'].unique()

array(['', 'probe_5mW', 'probe_1mW', 'lateral_5mW', 'posterior_5mW',
       'lateral_1mW', 'posterior_1mW'], dtype=object)

In [63]:
#calculate metrics for each unit

unit_opto_metrics={
    'unit_id':[],
    'peak_channel':[],
    'opto_power':[],
    'opto_cond':[],
    'baseline':[],
    'vis1':[],
    'vis2':[],
    'sound1':[],
    'sound2':[],
    'catch':[],
}

opto_labels=trials['opto_label'].unique()

baseline_time_start=-0.5
baseline_time_stop=-0.1
baseline_dur=baseline_time_stop-baseline_time_start
stim_time_start=0.1
stim_time_stop=0.5
stim_dur=stim_time_stop-stim_time_start

#in no opto, and each opto condition combination
#baseline firing rate, vis1 response rate, vis2 response rate, sound1 response rate, sound2 response rate

for uu,unit in units.iterrows():
    unit_id=unit['unit_id']

    #no opto trials
    opto_power=0
    opto_cond='no_opto'
    no_opto_trials=trials.query('~is_opto')
    baseline=trial_da.sel(
        unit_id=unit_id,trials=no_opto_trials['trial_index'].values
        ).sel(time=slice(baseline_time_start,baseline_time_stop)).mean()/baseline_dur
    
    vis1_no_opto_trials=trials.query('~is_opto & (stim_name=="vis1")')
    vis1=trial_da.sel(
        unit_id=unit_id,trials=vis1_no_opto_trials['trial_index'].values
        ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
    vis2_no_opto_trials=trials.query('~is_opto & (stim_name=="vis2")')
    vis2=trial_da.sel(
        unit_id=unit_id,trials=vis2_no_opto_trials['trial_index'].values
        ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
    sound1_no_opto_trials=trials.query('~is_opto & (stim_name=="sound1")')
    sound1=trial_da.sel(
        unit_id=unit_id,trials=sound1_no_opto_trials['trial_index'].values
        ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
    sound2_no_opto_trials=trials.query('~is_opto & (stim_name=="sound2")')
    sound2=trial_da.sel(
        unit_id=unit_id,trials=sound2_no_opto_trials['trial_index'].values
        ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
    catch=trial_da.sel(
        unit_id=unit_id,trials=no_opto_trials.query('stim_name=="catch"')['trial_index'].values
        ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur

    
    unit_opto_metrics['unit_id'].append(unit_id)
    unit_opto_metrics['peak_channel'].append(unit['peak_channel'])
    unit_opto_metrics['opto_power'].append(opto_power)
    unit_opto_metrics['opto_cond'].append(opto_cond)
    unit_opto_metrics['baseline'].append(baseline.values)
    unit_opto_metrics['vis1'].append(vis1.values)
    unit_opto_metrics['vis2'].append(vis2.values)
    unit_opto_metrics['sound1'].append(sound1.values)
    unit_opto_metrics['sound2'].append(sound2.values)
    unit_opto_metrics['catch'].append(catch.values)

    #loop through opto conditions
    for opto_cond in opto_labels:
        if opto_cond=='':
            continue
        opto_power=trials.query(f'opto_label=="{opto_cond}"')['opto_power'].iloc[0]
        opto_trials=trials.query(f'is_opto & (opto_label=="{opto_cond}")')
        
        #baseline firing rate, vis1 response rate, vis2 response rate, sound1 response rate, sound2 response rate
        baseline=trial_da.sel(
            unit_id=unit_id,trials=opto_trials['trial_index'].values
            ).sel(time=slice(baseline_time_start,baseline_time_stop)).mean()/baseline_dur
        vis1_opto_trials=trials.query(f'is_opto & (stim_name=="vis1") & (opto_label=="{opto_cond}")')
        vis1=trial_da.sel(
            unit_id=unit_id,trials=vis1_opto_trials['trial_index'].values
            ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
        vis2_opto_trials=trials.query(f'is_opto & (stim_name=="vis2") & (opto_label=="{opto_cond}")')
        vis2=trial_da.sel(
            unit_id=unit_id,trials=vis2_opto_trials['trial_index'].values
            ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
        sound1_opto_trials=trials.query(f'is_opto & (stim_name=="sound1") & (opto_label=="{opto_cond}")')
        sound1=trial_da.sel(
            unit_id=unit_id,trials=sound1_opto_trials['trial_index'].values
            ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
        sound2_opto_trials=trials.query(f'is_opto & (stim_name=="sound2") & (opto_label=="{opto_cond}")')
        sound2=trial_da.sel(
            unit_id=unit_id,trials=sound2_opto_trials['trial_index'].values
            ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
        catch=trial_da.sel(
            unit_id=unit_id,trials=opto_trials.query('stim_name=="catch"')['trial_index'].values
            ).sel(time=slice(stim_time_start,stim_time_stop)).mean()/stim_dur
        
        unit_opto_metrics['unit_id'].append(unit_id)
        unit_opto_metrics['peak_channel'].append(unit['peak_channel'])
        unit_opto_metrics['opto_power'].append(opto_power)
        unit_opto_metrics['opto_cond'].append(opto_cond)
        unit_opto_metrics['baseline'].append(baseline.values)
        unit_opto_metrics['vis1'].append(vis1.values)
        unit_opto_metrics['vis2'].append(vis2.values)
        unit_opto_metrics['sound1'].append(sound1.values)
        unit_opto_metrics['sound2'].append(sound2.values)
        unit_opto_metrics['catch'].append(catch.values)

unit_opto_metrics_df=pd.DataFrame(unit_opto_metrics)


In [65]:
unit_opto_metrics_df

Unnamed: 0,unit_id,peak_channel,opto_power,opto_cond,baseline,vis1,vis2,sound1,sound2,catch
0,728917_2024-10-04_B-0,1,0.0,no_opto,0.20380434782608695,0.6076388888888888,0.5208333333333334,0.24671052631578946,0.6578947368421052,0.0
1,728917_2024-10-04_B-0,1,5.0,probe_5mW,0.49019607843137253,0.0,0.6944444444444444,0.625,0.0,0.5681818181818181
2,728917_2024-10-04_B-0,1,1.5,probe_1mW,0.27777777777777773,0.0,2.7777777777777777,0.0,0.0,0.0
3,728917_2024-10-04_B-0,1,5.0,lateral_5mW,0.0,0.0,1.5625,0.0,0.0,0.0
4,728917_2024-10-04_B-0,1,5.0,posterior_5mW,0.4464285714285714,0.78125,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...
4573,728917_2024-10-04_B-320,371,1.5,probe_1mW,1.111111111111111,1.25,0.0,0.0,0.0,11.11111111111111
4574,728917_2024-10-04_B-320,371,5.0,lateral_5mW,0.4261363636363636,0.8928571428571428,0.0,0.0,0.0,0.0
4575,728917_2024-10-04_B-320,371,5.0,posterior_5mW,0.29761904761904756,0.0,0.0,0.0,0.0,0.0
4576,728917_2024-10-04_B-320,371,1.5,lateral_1mW,2.734375,2.7777777777777777,3.75,1.25,4.166666666666667,70.625


In [8]:
# sel_unit_to_plot='728917_2024-10-04_B-320'
savepath=r"C:\Users\ethan.mcbride\OneDrive - Allen Institute\quick figures\2025-05-01-opto_perturb_plots\psths"

sel_units_to_plot=units.query('isi_violations_ratio<=0.5 and amplitude_cutoff<=0.1 and presence_ratio>=0.7 and \
                                activity_drift<=0.2 and decoder_label!="noise" and firing_rate>=2.0')['unit_id'].values

opto_labels=trials['opto_label'].unique()


for unit_id in sel_units_to_plot:
    # unit=unit_opto_metrics_df.query(f'unit_id=="{unit_id}"')
    unit_PSTHs={
        'unit_id':[],
        'opto_cond':[],
        'stim_name':[],
        'time':[],
        'PSTH':[]
    }
    for opto_cond in opto_labels:
        for stim_name in ['vis1','vis2','sound1','sound2','catch']:
            if opto_cond=='' or opto_cond=='no_opto':
                opto_cond='no_opto'
                trials_opto=trials.query('~is_opto and (stim_name==@stim_name)')
            else:
                trials_opto=trials.query(f'opto_label=="{opto_cond}" and (stim_name==@stim_name)')

            # unit_id=sel_unit_to_plot
            unit_PSTH=trial_da.sel(unit_id=unit_id).sel(trials=trials_opto['trial_index'].values).mean(dim='trials')

            unit_PSTHs['unit_id'].append(unit_id)
            unit_PSTHs['opto_cond'].append(opto_cond)
            unit_PSTHs['stim_name'].append(stim_name)
            unit_PSTHs['time'].append(unit_PSTH.time.values)
            unit_PSTHs['PSTH'].append(unit_PSTH.values)

    unit_PSTHs_df=pd.DataFrame(unit_PSTHs)

    fig,ax=plt.subplots(2,3,figsize=(10,6),sharex=True,sharey=True)
    ax=ax.flatten()

    for ss,stim_name in enumerate(['vis1','vis2','sound1','sound2','catch']):

        ax[ss].axvline(0,color='k',linestyle='--',linewidth=0.5)
        ax[ss].axvline(0.5,color='k',linestyle='--',linewidth=0.5)

        for opto_cond in opto_labels:
            if opto_cond=='':
                opto_cond='no_opto'
                color='k'
            elif opto_cond=='no_opto':
                color='k'
            elif opto_cond=='probe_1mW':
                color='lightskyblue'
            elif opto_cond=='probe_5mW':
                color='royalblue'
            else:
                continue
            unit_PSTH=unit_PSTHs_df.query(f'opto_cond=="{opto_cond}" and stim_name==@stim_name')['PSTH'].values[0]
            unit_PSTH_time=unit_PSTHs_df.query(f'opto_cond=="{opto_cond}" and stim_name==@stim_name')['time'].values[0]
            ax[ss].plot(unit_PSTH_time,unit_PSTH,color=color,label=opto_cond)

        # ax.legend()
        ax[ss].set_xlabel('Time relative to stimulus (s)')
        ax[ss].set_ylabel('Firing rate (Hz)')
        ax[ss].set_title(stim_name)

    fig.suptitle(f"{unit_id}; {units.query('unit_id==@unit_id')['structure'].values[0]}; ch {units.query('unit_id==@unit_id')['peak_channel'].values[0]}")

    fig.tight_layout()

    fig.savefig(os.path.join(savepath,f"{unit_id}.png"),dpi=300)
    plt.close(fig)

    # break

In [87]:
units.query('isi_violations_ratio<=0.5 and amplitude_cutoff<=0.1 and presence_ratio>=0.7 and \
            activity_drift<=0.2 and decoder_label!="noise" and firing_rate>=1.0').shape

(257, 61)

In [None]:
units.query(f'unit_id=="{unit_id}"')['structure']
units.query(f'unit_id=="{unit_id}"')['peak_channel']

Unnamed: 0_level_0,activity_drift,amplitude,amplitude_cutoff,amplitude_cv_median,amplitude_cv_range,amplitude_median,ccf_ap,ccf_dv,ccf_ml,channels,...,unit_id,velocity_above,velocity_below,spike_times,obs_intervals,electrodes,session_idx,date,subject_id,session_id
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
309,0.017,85.264935,0.002052,,,46.8,,,,"[358, 359, 360, 361, 362, 364, 365, 366, 367, ...",...,728917_2024-10-04_B-320,,,"[28.819634362063137, 29.085600060700422, 39.85...","[[24.2649842728541, 3688.0]]","[358, 359, 360, 361, 362, 364, 365, 366, 367, ...",0,2024-10-04,728917,728917_2024-10-04


In [None]:

#loop over channel bins

#find units in each channel bin

#take response fraction - sel_opto_cond / no_opto for baseline and each stimulus

In [None]:
#plot baseline & stimulus evoked activty with and without opto across the whole probe?



UndefinedVariableError: name 'opto_cond' is not defined