In [63]:
import sys
import os
sys.path.append('/root/capsule/code/beh_ephys_analysis')
from utils.beh_functions import parseSessionID, session_dirs, get_unit_tbl, get_session_tbl
from utils.plot_utils import shiftedColorMap, template_reorder, get_gradient_colors
from utils.opto_utils import opto_metrics, get_opto_tbl
from utils.ephys_functions import cross_corr_train, auto_corr_train, load_drift
import json
import matplotlib.pyplot as plt
import pandas as pd
import pickle
from aind_ephys_utils import align
from scipy.stats import wilcoxon
%matplotlib inline

In [66]:
def cal_opto_sigs(session, data_type):
    unit_tbl = get_unit_tbl(session, data_type)
    opto_tbl = get_opto_tbl(session, data_type, loc = 'soma')
    # loop through all conditions
    powers = opto_tbl['power'].unique().tolist()
    sites = opto_tbl['site'].unique().tolist()
    pre_posts = opto_tbl['pre_post'].unique().tolist()
    freqs = opto_tbl['freq'].unique().tolist()
    opto_sigs = pd.DataFrame()
    filter = unit_tbl['default_qc'] == 1
    unit_ids_focus = unit_tbl[filter]['unit_id'].unique().tolist()
    for unit in unit_ids_focus:
        # print(f"Processing unit {unit}...")
        spike_times = unit_tbl[unit_tbl['unit_id']== unit]['spike_times'].values[0]
        opto_tbl_curr = opto_tbl.copy()
        unit_drift = load_drift(session, unit)
        pulse_num = 5
        pre_win_ratio = 0.5
        post_win = 0.05
        if unit_drift is not None:
            if unit_drift['ephys_cut'][0] is not None:
                spike_times = spike_times[spike_times >= unit_drift['ephys_cut'][0]]
                opto_tbl_curr = opto_tbl_curr[opto_tbl_curr['time'] >= unit_drift['ephys_cut'][0]]
            if unit_drift['ephys_cut'][1] is not None:
                spike_times = spike_times[spike_times <= unit_drift['ephys_cut'][1]]
                opto_tbl_curr = opto_tbl_curr[opto_tbl_curr['time'] <= unit_drift['ephys_cut'][1]]
        for power_ind, power in enumerate(powers):
            for site_ind, site in enumerate(sites):
                for freq_ind, freq in enumerate(freqs):
                    for pre_post_ind, pre_post in enumerate(pre_posts):
                        # get the trials for this condition
                        trials = opto_tbl[(opto_tbl['power'] == power) & (opto_tbl['site'] == site) & (opto_tbl['pre_post'] == pre_post) & (opto_tbl['freq'] == freq)]
                        if len(trials) == 0:
                            # print(f"No trials for power {power}, site {site}, freq {freq}, pre_post {pre_post}")
                            continue
                        # get the pulse times
                        train_times = trials['time'].values
                        p_unit_condition = []
                        for pulse_ind in range(pulse_num):
                            pulse_times = train_times + (pulse_ind * 1/freq)  # assuming freq is in Hz and pulse_ind starts from 0
                            # get the opto signal
                            pre_win_counts = align.to_events(spike_times, pulse_times, [-1/freq * pre_win_ratio, 0], return_df=True)
                            post_win_counts = align.to_events(spike_times, pulse_times, [0, post_win], return_df=True)
                            pre_win_freq = [len(pre_win_counts[pre_win_counts['event_index'] == event_ind]) / (1/freq * pre_win_ratio) for event_ind in range(len(pulse_times))]
                            post_win_freq = [len(post_win_counts[post_win_counts['event_index'] == event_ind]) / (post_win) for event_ind in range(len(pulse_times))]
                            # paired non-parametric test
                            stat, p = wilcoxon(pre_win_freq, post_win_freq)
                            p_unit_condition.append(p)
                        # store the results
                        opto_sigs = pd.concat([opto_sigs, pd.DataFrame({
                                            'unit_id': [unit],
                                            'power': [power],
                                            'site': [site],
                                            'freq': [freq],
                                            'pre_post': [pre_post],
                                            'p_unit_condition': [p_unit_condition],  # wrap list of dicts in a list to keep in one cell
                                            'p_sig_count': [sum(p < 0.05 for p in p_unit_condition)],
                                        })], ignore_index=True)
    # save the results
    opto_sigs_file = os.path.join(session_dirs(session, data_type)[f'opto_dir_{data_type}'], f'{session}_opto_sigs.pkl')
    with open(opto_sigs_file, 'wb') as f:
        pickle.dump(opto_sigs, f)
    return opto_sigs

In [65]:
cal_opto_sigs('behavior_754897_2025-03-13_11-20-42','curated')

Processing unit 4...
Processing unit 10...
Processing unit 11...
Processing unit 12...
Processing unit 15...
Processing unit 18...
Processing unit 23...
Processing unit 24...
Processing unit 25...


  z = (r_plus - mn) / se


Processing unit 26...
Processing unit 27...
Processing unit 28...
Processing unit 30...
Processing unit 32...


  z = (r_plus - mn) / se


Processing unit 33...


  z = (r_plus - mn) / se


Processing unit 34...


  z = (r_plus - mn) / se


Processing unit 35...
Processing unit 36...
Processing unit 37...
Processing unit 41...
Processing unit 42...


  z = (r_plus - mn) / se


Processing unit 44...
Processing unit 46...
Processing unit 47...
Processing unit 48...
Processing unit 49...
Processing unit 50...
Processing unit 51...


  z = (r_plus - mn) / se


Processing unit 54...
Processing unit 55...
Processing unit 56...
Processing unit 61...
Processing unit 62...
Processing unit 63...
Processing unit 64...
Processing unit 65...


  z = (r_plus - mn) / se


Processing unit 69...
Processing unit 71...


  z = (r_plus - mn) / se


Processing unit 74...
Processing unit 75...


  z = (r_plus - mn) / se


Processing unit 76...


  z = (r_plus - mn) / se


Processing unit 77...
Processing unit 79...
Processing unit 81...


  z = (r_plus - mn) / se


Processing unit 83...
Processing unit 109...
Processing unit 110...
Processing unit 112...
Processing unit 113...
Processing unit 115...
Processing unit 116...


  z = (r_plus - mn) / se


Processing unit 117...
Processing unit 118...
Processing unit 121...


  z = (r_plus - mn) / se


Processing unit 122...
Processing unit 123...
Processing unit 126...


  z = (r_plus - mn) / se


Processing unit 131...
Processing unit 133...
Processing unit 134...
Processing unit 136...
Processing unit 137...
Processing unit 138...


  z = (r_plus - mn) / se


Processing unit 139...
Processing unit 140...
Processing unit 141...


  z = (r_plus - mn) / se


Processing unit 142...


  z = (r_plus - mn) / se


Processing unit 143...
Processing unit 144...


  z = (r_plus - mn) / se


Processing unit 147...
Processing unit 151...


  z = (r_plus - mn) / se


Processing unit 153...
Processing unit 155...
Processing unit 157...
Processing unit 158...
Processing unit 159...
Processing unit 160...
Processing unit 162...
Processing unit 164...
Processing unit 166...
Processing unit 167...


  z = (r_plus - mn) / se


Processing unit 169...
Processing unit 171...


  z = (r_plus - mn) / se


Processing unit 175...
Processing unit 178...
Processing unit 179...
Processing unit 180...


  z = (r_plus - mn) / se


Processing unit 182...
Processing unit 184...


  z = (r_plus - mn) / se


Processing unit 186...


  z = (r_plus - mn) / se


Processing unit 190...


  z = (r_plus - mn) / se


Processing unit 191...


  z = (r_plus - mn) / se


Processing unit 192...


  z = (r_plus - mn) / se
  z = (r_plus - mn) / se


Processing unit 193...
Processing unit 194...


  z = (r_plus - mn) / se


Processing unit 195...
Processing unit 196...
Processing unit 197...


  z = (r_plus - mn) / se


Processing unit 198...


  z = (r_plus - mn) / se


Processing unit 199...
Processing unit 204...
Processing unit 205...


  z = (r_plus - mn) / se


Processing unit 206...


  z = (r_plus - mn) / se


Processing unit 207...


  z = (r_plus - mn) / se


Processing unit 208...
Processing unit 211...


  z = (r_plus - mn) / se


Processing unit 212...


  z = (r_plus - mn) / se


Processing unit 214...
Processing unit 217...


  z = (r_plus - mn) / se


Processing unit 218...


  z = (r_plus - mn) / se


Processing unit 223...
Processing unit 229...
Processing unit 231...
Processing unit 232...
Processing unit 236...


  z = (r_plus - mn) / se


Processing unit 237...
Processing unit 238...


  z = (r_plus - mn) / se


Processing unit 240...
Processing unit 242...
Processing unit 243...
Processing unit 244...
Processing unit 249...


  z = (r_plus - mn) / se


Processing unit 254...


  z = (r_plus - mn) / se


Processing unit 255...


  z = (r_plus - mn) / se


Processing unit 258...


  z = (r_plus - mn) / se


Processing unit 263...


  z = (r_plus - mn) / se


Processing unit 266...


  z = (r_plus - mn) / se


Processing unit 267...


  z = (r_plus - mn) / se


Processing unit 268...


  z = (r_plus - mn) / se


Processing unit 269...
Processing unit 270...


  z = (r_plus - mn) / se


Processing unit 271...
Processing unit 273...


  z = (r_plus - mn) / se


Processing unit 277...


  z = (r_plus - mn) / se


Processing unit 278...


  z = (r_plus - mn) / se


Processing unit 282...
Processing unit 283...


Unnamed: 0,unit_id,power,site,freq,pre_post,p_unit_condition,p_sig_count
0,4,50,surface_LC,5,pre,"[0.838188625026738, 0.057394761418403314, 0.50...",0
1,4,50,surface_LC,5,post,"[0.6142946646634824, 0.7815112949987133, 0.670...",0
2,4,30,surface_LC,5,pre,"[0.6417580893423203, 0.863791257068895, 0.6250...",0
3,4,30,surface_LC,5,post,"[0.10829365589900912, 0.18229232651282867, 0.6...",0
4,4,10,surface_LC,5,post,"[0.7702877264736667, 0.9269603529511417, 0.025...",1
...,...,...,...,...,...,...,...
954,283,30,surface_LC,5,pre,"[0.00017533020440085103, 9.599730148021778e-05...",5
955,283,30,surface_LC,5,post,"[0.00017614138090729764, 7.934772672716567e-05...",5
956,283,10,surface_LC,5,post,"[4.809547412854571e-05, 4.7894534681727224e-05...",5
957,283,20,surface_LC,5,post,"[3.6597037477988786e-05, 3.833012152637274e-05...",5


In [None]:
session_assets = pd.read_csv('/root/capsule/code/data_management/session_assets.csv')
session_list = session_assets['session_id']
probe_list = session_assets['probe']
probe_list = [probe for probe, session in zip(probe_list, session_list) if isinstance(session, str)]
session_list = [session for session in session_list if isinstance(session, str)]    
from joblib import Parallel, delayed
data_type = 'curated'
def process(session, data_type): 
    print(f'Starting {session}')
    session_dir = session_dirs(session)
    # if os.path.exists(os.path.join(session_dir['beh_fig_dir'], f'{session}.nwb')):
    print(session_dir[f'curated_dir_{data_type}'])
    if session_dir[f'curated_dir_{data_type}'] is not None:
        try:
            # plot_ephys_probe(session, data_type=data_type, probe=probe) 
            cal_opto_sigs(session, data_type)
            plt.close('all')
            print(f'Finished {session}')
        except:
            print(f'Error processing {session}')
            plt.close('all')
    else: 
        print(f'No curated data found for {session}') 
    # elif session\_dir['curated_dir_raw'] is not None:
    #     data_type = 'raw' 
    #     opto_tagging_df_sess = opto_plotting_session(session, data_type, target, resp_thresh=resp_thresh, lat_thresh=lat_thresh, target_unit_ids= None, plot = True, save=True)
Parallel(n_jobs=-8)(
    delayed(process)(session, data_type) 
    for session in session_list
)

In [68]:
class load_opto_sig():
    def __init__(self, session, data_type):
        self.session = session
        self.data_type = data_type
        self.opto_sigs = self.load_opto_sigs()

    def load_opto_sigs(self):
        opto_sigs_file = os.path.join(session_dirs(self.session)[f'opto_dir_{self.data_type}'], f'{self.session}_opto_sigs.pkl')
        if os.path.exists(opto_sigs_file):
            with open(opto_sigs_file, 'rb') as f:
                return pickle.load(f)
        else:
            print(f'No opto sigs found for {self.session}')
            return None

    def get_opto_sigs(self, unit):
        if self.opto_sigs is not None:
            unit_opto_sigs = self.opto_sigs[self.opto_sigs['unit_id'] == unit]
            if not unit_opto_sigs.empty:
                return unit_opto_sigs
            else:
                print(f'No opto sigs found for unit {unit} in session {self.session}')
                return None
        else:
            print(f'No opto sigs loaded for session {self.session}')
            return None

In [71]:
opto_sigs = load_opto_sig('behavior_754897_2025-03-13_11-20-42', 'curated')
curr_unit_sigs = opto_sigs.get_opto_sigs(10)
curr_unit_sigs

Unnamed: 0,unit_id,power,site,freq,pre_post,p_unit_condition,p_sig_count
7,10,50,surface_LC,5,pre,"[3.808961977046787e-05, 0.00014159510770021393...",5
8,10,50,surface_LC,5,post,"[0.008299215995280766, 0.0004087524092441004, ...",5
9,10,30,surface_LC,5,pre,"[0.0028375448887923154, 0.03582047350421736, 0...",5
10,10,30,surface_LC,5,post,"[0.0047315429964332155, 6.887215595255096e-05,...",5
11,10,10,surface_LC,5,post,"[0.806968367170738, 0.27257111883252894, 0.045...",3
12,10,20,surface_LC,5,post,"[0.05500883362926572, 0.006492857745083879, 0....",4
13,10,40,surface_LC,5,post,"[0.00010417979413217688, 0.0001438636130358712...",5


In [73]:
opto_response = opto_metrics('behavior_754897_2025-03-13_11-20-42', 'curated')
unit_opto_response = opto_response.load_unit(10)
unit_opto_response

Unnamed: 0,unit_id,resp_p,resp_p_bl,resp_lat,powers,sites,num_pulses,durations,freqs,stim_times,opto_pass,mean_p,euclidean_norm,correlation
56,10,0.4,0.305467,0.017459,20,surface_LC,5,4,5,post,True,0.175467,0.190139,0.98291
57,10,0.5,0.409505,0.01576,30,surface_LC,5,4,5,pre,True,0.302486,0.135288,0.986911
58,10,0.55,0.455467,0.016305,30,surface_LC,5,4,5,post,True,0.302486,0.151615,0.986995
59,10,0.5,0.405467,0.016863,30,surface_LC,5,4,5,post,True,0.302486,0.151615,0.986995
60,10,0.5,0.405467,0.013983,30,surface_LC,5,4,5,post,True,0.302486,0.151615,0.986995
61,10,0.45,0.355467,0.016462,30,surface_LC,5,4,5,post,True,0.302486,0.151615,0.986995
62,10,0.45,0.355467,0.013426,30,surface_LC,5,4,5,post,True,0.302486,0.151615,0.986995
63,10,0.7,0.605467,0.01681,40,surface_LC,5,4,5,post,True,0.455467,0.145236,0.986634
64,10,0.55,0.455467,0.017369,40,surface_LC,5,4,5,post,True,0.455467,0.145236,0.986634
65,10,0.45,0.355467,0.019567,40,surface_LC,5,4,5,post,True,0.455467,0.145236,0.986634
