In [3]:
import sys
import os
sys.path.append('/root/capsule/code/beh_ephys_analysis')
from harp.clock import decode_harp_clock, align_timestamps_to_anchor_points
from open_ephys.analysis import Session
import datetime
from aind_ephys_rig_qc.temporal_alignment import search_harp_line
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib.colors import LinearSegmentedColormap
import pandas as pd
from pynwb import NWBFile, TimeSeries, NWBHDF5IO
from scipy.io import loadmat
from scipy.stats import zscore
import ast
from utils.plot_utils import combine_pdf_big

from open_ephys.analysis import Session
from pathlib import Path
import glob

import json
import seaborn as sns
from PyPDF2 import PdfMerger
from sklearn.linear_model import LinearRegression
import statsmodels.api as sm
import re
from aind_dynamic_foraging_basic_analysis.plot.plot_foraging_session import plot_foraging_session
from aind_dynamic_foraging_data_utils.nwb_utils import load_nwb_from_filename
from hdmf_zarr.nwb import NWBZarrIO
from utils.beh_functions import session_dirs, parseSessionID, load_model_dv, makeSessionDF, get_session_tbl, get_unit_tbl, get_history_from_nwb
from utils.ephys_functions import*
from utils.opto_utils import opto_metrics, load_opto_sig
import pandas as pd
import pickle
import scipy.stats as stats
from joblib import Parallel, delayed
from multiprocessing import Pool
from functools import partial
import time
import spikeinterface as si
import shutil 
import seaborn as sns
import math
import seaborn as sns
from sklearn.decomposition import PCA
from scipy.stats import zscore
from joblib import Parallel, delayed
%matplotlib inline

In [4]:
# Make combined session-unit table
dfs = [pd.read_csv('/root/capsule/code/data_management/session_assets.csv'),
        pd.read_csv('/root/capsule/code/data_management/hopkins_session_assets.csv')]
df = pd.concat(dfs)
exclude = ['ecephys_717120_2024-03-06_12-23-53', 'ecephys_713854_2024-03-08_14-54-25', 'ecephys_713854_2024-03-08_16-20-33', 'behavior_754897_2025-03-15_11-32-18']
# session_ids, behs = zip(*[
#     (session, beh)
#     for session, beh in zip(session_ids, behs)
#     if isinstance(session, str) and session not in exclude
# ])
# exclude sessions
df = df[~df['session_id'].isin(exclude)]
# remove those are not strings
df = df[df['session_id'].apply(lambda x: isinstance(x, str))]
# session_ids = list(session_ids)
# behs = list(behs)

In [5]:
def process_session(session, beh, rec_side, probe, target='soma'):
    session_dir = session_dirs(session)

    # --- skip missing or invalid sessions ---
    if 'ZS' in session:
        if (not os.path.exists(session_dir['nwb_dir_raw'])) or (get_unit_tbl(session, 'curated') is None):
            print(f'Skipping {session} due to no neuron data')
            return None
    if session_dir['curated_dir_curated'] is None:
        return None

    print(f'Processing {session}')
    data_type = 'curated'

    # --- load data ---
    unit_tbl = get_unit_tbl(session, data_type)
    if unit_tbl is None or len(unit_tbl) == 0:
        return None

    opto_metrics_session = opto_metrics(session, data_type=data_type)
    session_df = get_session_tbl(session)
    session_opto_sig = load_opto_sig(session, data_type=data_type)

    # --- basic derived columns ---
    if 'p_max' not in unit_tbl.columns:
        p_max = unit_tbl['p_max_x'].tolist()
        p_mean = unit_tbl['p_mean_x'].tolist()
        lat_max_p = unit_tbl['lat_max_p_x'].tolist()
        eu = unit_tbl['euc_max_p_x'].tolist()
        corr = unit_tbl['corr_max_p_x'].tolist()
        peaks = unit_tbl['peak_x'].values
        amp = unit_tbl['amp_x'].values
    else:
        p_max = unit_tbl['p_max'].tolist()
        p_mean = unit_tbl['p_mean'].tolist()
        lat_max_p = unit_tbl['lat_max_p'].tolist()
        eu = unit_tbl['euc_max_p'].tolist()
        corr = unit_tbl['corr_max_p'].tolist()
        peaks = unit_tbl['peak'].values
        amp = unit_tbl['amp'].values

    if 'x_ccf' in unit_tbl.columns:
        x_ccf = unit_tbl['x_ccf'].tolist()
        y_ccf = unit_tbl['y_ccf'].tolist()
        z_ccf = unit_tbl['z_ccf'].tolist()
    else:
        x_ccf = [np.nan]*len(unit_tbl)
        y_ccf = [np.nan]*len(unit_tbl)
        z_ccf = [np.nan]*len(unit_tbl)

    # --- waveform-related ---
    if 'peak_wf_opt' in unit_tbl.columns:
        wf_opt = [wf_opt_unit if isinstance(wf_opt_unit, np.ndarray) else wf_unit
                  for wf_opt_unit, wf_unit in zip(unit_tbl['peak_wf_opt'], unit_tbl['peak_wf'])]
        wf_opt_aligned = [wf_opt_unit if isinstance(wf_opt_unit, np.ndarray) else wf_unit
                          for wf_opt_unit, wf_unit in zip(unit_tbl['peak_wf_opt_aligned'], unit_tbl['peak_wf_aligned'])]
        wf_opt_2d = [wf_opt_unit if isinstance(wf_opt_unit, np.ndarray) else wf_unit
                     for wf_opt_unit, wf_unit in zip(unit_tbl['mat_wf_opt'], unit_tbl['wf_2d'])]
    else:
        wf_opt = unit_tbl['peak_wf'].tolist()
        wf_opt_aligned = unit_tbl['peak_wf_aligned'].tolist()
        wf_opt_2d = unit_tbl['wf_2d'].tolist()

    amp_opt = [
        np.max(wf_opt_curr) - np.min(wf_opt_curr) if isinstance(wf_opt_curr, np.ndarray) else curr_amp_unit
        for wf_opt_curr, curr_amp_unit in zip(wf_opt, amp)
    ]
    if 'amplitude_opt' in unit_tbl.columns:
        peak_opt = [
            curr_peak_opt if not np.isnan(curr_peak_opt) else curr_peak
            for curr_peak_opt, curr_peak in zip(unit_tbl['amplitude_opt'].values, peaks)
        ]
    else:
        peak_opt = list(peaks)

    if 'peak_waveform_raw_aligned' in unit_tbl.columns:
        wf_raw = unit_tbl['peak_waveform_raw_fake_aligned'].tolist()
        wf_2d_raw = unit_tbl['mat_wf_raw_fake'].tolist()
        peak_raw = [
            curr_peak_raw - curr_wf[0] if curr_peak_raw is not None and not np.isnan(curr_peak_raw)
            else None
            for curr_peak_raw, curr_wf in zip(unit_tbl['peak_raw_fake'], wf_raw)
        ]
        amp_raw = unit_tbl['amplitude_raw_fake'].tolist()
    else:
        wf_raw = [None]*len(unit_tbl)
        wf_2d_raw = [None]*len(unit_tbl)
        peak_raw = [None]*len(unit_tbl)
        amp_raw = [None]*len(unit_tbl)

    # --- waveform-independent scalar values ---
    isi_v = unit_tbl['isi_violations_ratio'].tolist()
    snr = unit_tbl['snr'].tolist()
    y_loc = unit_tbl['y_loc'].tolist()
    fr = unit_tbl['firing_rate'].tolist()
    decoder = unit_tbl['decoder_label'].tolist()
    tag_loc = unit_tbl['tagged_loc'].tolist() if 'tagged_loc' in unit_tbl.columns else [np.nan]*len(unit_tbl)
    top = unit_tbl['LC_range_top'].tolist() if 'LC_range_top' in unit_tbl.columns else [np.nan]*len(unit_tbl)
    bottom = unit_tbl['LC_range_bottom'].tolist() if 'LC_range_bottom' in unit_tbl.columns else [np.nan]*len(unit_tbl)

    # --- opto per-unit results ---
    resp_p_all_conditions, resp_lat_all_conditions = [], []
    mean_p_all_conditions, eu_all_conditions = [], []
    corr_all_conditions, sig_counts_all_conditions = [], []
    all_sig_counts = []
    trial_count = []

    for unit_id in unit_tbl['unit_id'].values:
        unit_opto = opto_metrics_session.load_unit(unit_id)
        unit_opto_sig = session_opto_sig.load_unit(unit_id) if session_opto_sig is not None else None
        unit_drift = load_drift(session, unit_id, data_type=data_type)
        if unit_opto is None:
            continue

        curr_p_resp_all = unit_opto['resp_p_bl'].values
        curr_lat_resp_all = unit_opto['resp_lat'].values
        curr_p_mean_all = unit_opto['mean_p'].values
        curr_eu_all = unit_opto['euclidean_norm'].values
        curr_corr_all = unit_opto['correlation'].values

        unit_opto['sig_num'] = np.full(len(unit_opto), np.nan)
        if unit_opto_sig is not None:
            if not session_dir['aniID'].startswith('ZS'):
                for cond_ind, row in unit_opto.iterrows():
                    filt = (unit_opto_sig['power'] == row['powers']) & (unit_opto_sig['site'] == row['sites'])
                    if len(unit_opto_sig['pre_post'].unique()) > 1:
                        filt &= (unit_opto_sig['pre_post'] == row['stim_times'])
                    curr_sig_rows = unit_opto_sig[filt]
                    if len(curr_sig_rows) >= 1:
                        unit_opto.loc[cond_ind, 'sig_num'] = curr_sig_rows['p_sig_count'].values[0]
            else:
                unit_opto['sig_num'] = unit_opto_sig['p_sig_count'].values[0]

        curr_sig_num_all = unit_opto['sig_num'].values
        curr_max_count = np.nan if unit_opto_sig is None else unit_opto_sig['p_sig_count'].max()

        # trial length
        if session_df is not None:
            go_cue_times = session_df['goCue_start_time']
            if unit_drift is not None:
                if unit_drift['ephys_cut'][0] is not None:
                    go_cue_times = go_cue_times[go_cue_times >= unit_drift['ephys_cut'][0]]
                if unit_drift['ephys_cut'][1] is not None:
                    go_cue_times = go_cue_times[go_cue_times <= unit_drift['ephys_cut'][1]]
            curr_trial_count = len(go_cue_times)
        else:
            curr_trial_count = 0

        resp_p_all_conditions.append(curr_p_resp_all)
        resp_lat_all_conditions.append(curr_lat_resp_all)
        mean_p_all_conditions.append(curr_p_mean_all)
        eu_all_conditions.append(curr_eu_all)
        corr_all_conditions.append(curr_corr_all)
        sig_counts_all_conditions.append(curr_sig_num_all)
        all_sig_counts.append(curr_max_count)
        trial_count.append(curr_trial_count)

    # --- final dictionary ---
    return {
        'session': session,
        'unit': unit_tbl['unit_id'].tolist(),
        'qc_pass': unit_tbl['default_qc'].tolist(),
        'opto_tagged': unit_tbl['tagged_loc'].tolist(),


        
        'in_df': beh,
        'trial_count': trial_count,
        'p_max': p_max,
        'p_mean': p_mean,
        'sig_counts': all_sig_counts,
        'lat_max_p': lat_max_p,
        'isi_violations': isi_v,
        'snr': snr,
        'eu': eu,
        'corr': corr,
        'amp': amp_opt,
        'amp_raw': amp_raw,
        'peak': peak_opt,
        'peak_raw': peak_raw,
        'wf': wf_opt,
        'wf_raw': wf_raw,
        'wf_aligned': wf_opt_aligned,
        'wf_2d': wf_opt_2d,
        'wf_2d_raw': wf_2d_raw,
        'probe': probe,
        'y_loc': y_loc,
        'rec_side': rec_side,
        'top': top,
        'bottom': bottom,
        'tag_loc': tag_loc,
        'fr': fr,
        'decoder': decoder,
        'all_p_max': resp_p_all_conditions,
        'all_p_mean': mean_p_all_conditions,
        'all_lat_max_p': resp_lat_all_conditions,
        'all_corr': corr_all_conditions,
        'all_eu': eu_all_conditions,
        'all_sig_counts': sig_counts_all_conditions,
        'x_ccf': x_ccf,
        'y_ccf': y_ccf,
        'z_ccf': z_ccf
    }


In [6]:
target = 'soma'
def safe_process(session, beh, rec_side, probe):
    # try:
    return process_session(session, beh, rec_side, probe, target=target)
    # except Exception as e:
    #     print(f'Error processing {session}: {e}')
    #     return None

# for index, row in df.iterrows():
#     result = safe_process(row['session_id'], row['behavior'], row['side'], row['probe'])
#     if result is not None:
#         results.append(result)
results = Parallel(n_jobs=8)(
    delayed(safe_process)(row['session_id'], row['behavior'], row['side'], row['probe'])
    for _, row in df.iterrows()
)


Processing ecephys_713854_2024-03-05_13-01-09
Processing ecephys_713854_2024-03-08_15-43-01
Processing ecephys_713854_2024-03-05_12-01-40
Processing ecephys_713854_2024-03-05_13-31-20
Processing ecephys_717120_2024-03-07_12-12-02
Processing ecephys_717120_2024-03-06_12-54-27
Processing ecephys_684930_2023-09-27_10-04-04
Processing ecephys_684930_2023-09-28_11-45-27
Processing ecephys_687697_2023-09-15_11-30-06
Processing ecephys_684930_2023-09-28_12-44-15
Processing ecephys_687697_2023-09-15_12-36-06
Processing ecephys_691893_2023-10-05_12-46-57
Processing ecephys_691893_2023-10-06_13-48-18
Processing behavior_716325_2024-05-31_10-31-14




Processing behavior_717121_2024-06-15_10-00-58
Processing behavior_751004_2024-12-20_13-26-11
Processing behavior_751004_2024-12-19_11-50-37
Processing behavior_751004_2024-12-21_13-28-28
Processing behavior_751004_2024-12-23_14-20-03
Processing behavior_751004_2024-12-22_13-09-17
Processing behavior_751769_2025-01-16_11-32-05
Processing behavior_751769_2025-01-17_11-37-39
Processing behavior_751769_2025-01-18_10-15-25
Processing behavior_758017_2025-02-04_11-57-38
Processing behavior_758017_2025-02-05_11-42-34
Processing behavior_758017_2025-02-06_11-26-14
Processing behavior_758017_2025-02-07_14-11-08
Processing behavior_751766_2025-02-11_11-53-38
Processing behavior_751766_2025-02-13_11-31-21
Processing behavior_751766_2025-02-14_11-37-11
Processing behavior_751181_2025-02-25_12-12-35
Processing behavior_751181_2025-02-26_11-51-19
Processing behavior_751181_2025-02-27_11-24-47
Processing behavior_754897_2025-03-11_12-07-41
Processing behavior_754897_2025-03-12_12-23-15
Processing be

In [9]:
# remove all None results
results = [res for res in results if res is not None]

In [10]:
# sort by the sequence of session_ids in df
session_order = {session: i for i, session in enumerate(df['session_id'].tolist())}
results.sort(key=lambda x: session_order[x['session']])

In [13]:
results_df = [pd.DataFrame(res) for res in results]

In [14]:
combined_tagged_units = pd.concat(results_df, ignore_index=True)

  combined_tagged_units = pd.concat(results_df, ignore_index=True)


In [16]:
# save dataframe in combined folder
with open(os.path.join('/root/capsule/scratch/combined/combine_unit_tbl', 'combined_unit_tbl.pkl'), 'wb') as f:
    pickle.dump(combined_tagged_units, f)