In [1]:
import os
import sys
sys.path.append('/root/capsule/code/beh_ephys_analysis')
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import json
from harp.clock import decode_harp_clock, align_timestamps_to_anchor_points
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import re
from utils.beh_functions import parseSessionID, session_dirs
from utils.plot_utils import shiftedColorMap, template_reorder
from open_ephys.analysis import Session##
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.postprocessing as spost
import spikeinterface.widgets as sw
from aind_dynamic_foraging_basic_analysis.licks.lick_analysis import load_nwb
from aind_ephys_utils import align
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
from aind_dynamic_foraging_data_utils.nwb_utils import load_nwb_from_filename
from spikeinterface.core.sorting_tools import random_spikes_selection
import pickle
import datetime
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
from tqdm import tqdm

In [4]:
session = 'behavior_751004_2024-12-23_14-20-03'
data_type = 'curated'
target = 'soma'

In [5]:
# opto info
session_dir = session_dirs(session)
opto_df = pd.read_csv(os.path.join(session_dir[f'opto_dir_{data_type}'], f'{session}_opto_session_{target}.csv'), index_col=0)
with open(os.path.join(session_dir[f'opto_dir_{data_type}'], f'{session}_opto_info_{target}.json')) as f:
    opto_info = json.load(f)
powers = opto_info['powers']
sites = opto_info['sites']
num_pulses = opto_info['num_pulses']
duration = opto_info['durations']
pulse_offset = np.array([1/freq for freq in opto_info['freqs']])
total_duration = (pulse_offset) * num_pulses
session_rec = Session(session_dir['session_dir'])
recording = session_rec.recordnodes[0].recordings[0]
timestamps = recording.continuous[0].timestamps
laser_onset_samples = np.searchsorted(timestamps, opto_df['time'].values)
opto_df['laser_onset_samples'] = laser_onset_samples

In [6]:
# recording info
sorting = si.load_extractor(session_dir[f'curated_dir_{data_type}'])
max_spikes_per_unit_spontaneous = 500
spike_vector = sorting.to_spike_vector()
unit_ids = sorting.unit_ids
num_units = len(sorting.unit_ids)
print(f"Total {len(sorting.unit_ids)} units")
max_spikes_per_unit_spontaneous = 500

Total 177 units


In [68]:
# Spike indices
spike_indices = spike_vector["sample_index"]

response_window = opto_info['resp_win']+0.005 # to take care of potential late responses
response_window_samples = int(response_window * sorting.sampling_frequency)

num_cases = len(opto_info['powers']) * len(opto_info['sites']) * len(opto_info['durations'])

unit_index_offset = num_units
all_unit_ids = [f"{u} spont_pre" for u in unit_ids]
all_unit_ids += [f"{u} spont_post" for u in unit_ids]

all_spikes_in_responses = []
spike_indices_removed = []
conditions = ['powers', 'sites', 'durations', 'pre_post']

for power_ind, curr_power in enumerate(opto_info['powers']):
    for site_ind, curr_site in enumerate(opto_info['sites']):  
        for curr_pre_post_ind, curr_pre_post in enumerate(opto_info['pre_post']):                                                           
            for duration_ind, curr_duration in enumerate(opto_info['durations']):
                spikes_in_response = []
                onset_offset_indices = []
                if len(opto_df.query('site == @curr_site and power == @curr_power and duration == @curr_duration and pre_post == @curr_pre_post')) > 0:
                    for freq_ind, curr_freq in enumerate(opto_info['freqs']):
                        for num_pulse_ind, curr_num_pulses in enumerate(opto_info['num_pulses']):
                            onset_samples = opto_df.query(
                                'site == @curr_site and power == @curr_power and duration == @curr_duration and freq == @curr_freq and num_pulses == @curr_num_pulses and pre_post == @curr_pre_post'
                                )['laser_onset_samples'].values
                            if len(onset_samples) == 0:
                                continue
                            pulse_offset = 1/curr_freq * sorting.sampling_frequency
                            pulse_offset_samples = int(pulse_offset * sorting.sampling_frequency)
                            for pulse_ind in range(curr_num_pulses):
                                for onset_sample in onset_samples:
                                    # response window
                                    onset_response = onset_sample + pulse_ind * pulse_offset_samples
                                    onset_offset_indices.append(onset_response)
                                    offset_response = onset_response + response_window_samples
                                    onset_offset_indices.append(offset_response)
                            start_stop_indices = np.searchsorted(spike_indices, np.array(onset_offset_indices))
                            for i, (start, stop) in enumerate(zip(start_stop_indices[::2], start_stop_indices[1::2])):
                                sv = spike_vector[start:stop]
                                if len(sv) > 0:
                                    spike_indices_removed.append(np.arange(start, stop))
                                sv_copy = sv.copy()
                                sv_copy["unit_index"] = sv_copy["unit_index"] + unit_index_offset + num_units
                                spikes_in_response.append(sv_copy)
                                # num_cases += 1
                    
                    spikes_in_response = np.concatenate(spikes_in_response)
                    print('# of responding units: ', len(np.unique(spikes_in_response['unit_index'])))
                    unit_index_offset += num_units
                    all_spikes_in_responses.append(spikes_in_response)
                    new_unit_ids = [f"{u} emission_location:{curr_site}, power:{curr_power}, duration:{curr_duration}, freq:{curr_freq}, pre_post:{curr_pre_post}" for u in unit_ids]
                    print(f"emission_location:{curr_site}, power:{curr_power}, duration:{curr_duration}, freq:{curr_freq}, pre_post:{curr_pre_post}")
                    all_unit_ids += new_unit_ids
                            



            
all_spikes_in_responses = np.concatenate(all_spikes_in_responses)
spike_indices_removed = np.concatenate(spike_indices_removed)
all_spikes_not_in_responses = np.delete(spike_vector, spike_indices_removed)
all_spikes_not_in_responses_pre = all_spikes_not_in_responses[all_spikes_not_in_responses["sample_index"] < (opto_df.query('pre_post == "pre"')['laser_onset_samples'].max() + 5 * 60 * sorting.sampling_frequency)]
all_spikes_not_in_responses_post = all_spikes_not_in_responses[all_spikes_not_in_responses["sample_index"] > (opto_df.query('pre_post == "post"')['laser_onset_samples'].min() - 5 * 60 * sorting.sampling_frequency)]
# select random spontaneous spikes for each unit divided by first and second half of the session
sorting_no_responses_pre = si.NumpySorting(
    all_spikes_not_in_responses_pre, 
    unit_ids=[f'{unit_id}_pre' for unit_id in unit_ids], 
    sampling_frequency=sorting.sampling_frequency
)
random_spike_indices = random_spikes_selection(sorting_no_responses_pre, method="uniform",
                                               max_spikes_per_unit=max_spikes_per_unit_spontaneous)
selected_spikes_no_responses_pre = all_spikes_not_in_responses_pre[random_spike_indices]



sorting_no_responses_post = si.NumpySorting(
    all_spikes_not_in_responses_post, 
    unit_ids=[f'{unit_id}_post' for unit_id in unit_ids], 
    sampling_frequency=sorting.sampling_frequency
)
random_spike_indices = random_spikes_selection(sorting_no_responses_post, method="uniform",
                                               max_spikes_per_unit=max_spikes_per_unit_spontaneous)
selected_spikes_no_responses_post = all_spikes_not_in_responses_post[random_spike_indices]



all_spikes = np.concatenate([selected_spikes_no_responses_pre, selected_spikes_no_responses_post, all_spikes_in_responses])

# sort by segment+index
all_spikes = all_spikes[
    np.lexsort((all_spikes["sample_index"], all_spikes["segment_index"]))
]

sorting_all = si.NumpySorting(
    all_spikes, 
    unit_ids=all_unit_ids, 
    sampling_frequency=sorting.sampling_frequency
)

# of responding units:  66
emission_location:surface_LC, power:10, duration:4, freq:5, pre_post:post
# of responding units:  88
emission_location:surface_LC, power:20, duration:4, freq:5, pre_post:post
# of responding units:  84
emission_location:surface_LC, power:30, duration:4, freq:5, pre_post:post
# of responding units:  116
emission_location:surface_LC, power:30, duration:4, freq:5, pre_post:pre
# of responding units:  84
emission_location:surface_LC, power:40, duration:4, freq:5, pre_post:post
# of responding units:  90
emission_location:surface_LC, power:50, duration:4, freq:5, pre_post:post


In [69]:
print("original:", len(spike_vector))
print("sampled:", len(sorting_all.to_spike_vector()))

original: 6634213
sampled: 130468


In [70]:
def load_and_preprocess_recording(session_folder, stream_name):
    ephys_path = os.path.dirname(session_folder)
    compressed_folder = os.path.join(ephys_path, 'ecephys_compressed')
    recording_zarr = [os.path.join(compressed_folder, f) for f in os.listdir(compressed_folder) if stream_name in f][0]
    recording = si.read_zarr(recording_zarr)
    # preprocess
    recording_processed = spre.phase_shift(recording)
    recording_processed = spre.highpass_filter(recording_processed)    
    recording_processed = spre.common_reference(recording_processed)
    return recording_processed

In [71]:
# filter good channels
recording_processed = load_and_preprocess_recording(session_dir['session_dir'], 'ProbeA')
we = si.load_sorting_analyzer_or_waveforms(session_dir[f'postprocessed_dir_{data_type}'])
good_channel_ids = recording_processed.channel_ids[
    np.in1d(recording_processed.channel_ids, we.channel_ids)
]

recording_processed_good = recording_processed.select_channels(good_channel_ids)
print(f"Num good channels: {recording_processed_good.get_num_channels()}")

Num good channels: 379


In [73]:
num_cases = len(sorting_all.unit_ids) // num_units - 2
num_cases

6

In [75]:
sparsity_mask_all = np.tile(we.sparsity.mask, (num_cases + 2, 1))
sparsity_all = si.ChannelSparsity(
    sparsity_mask_all,
    unit_ids=sorting_all.unit_ids,
    channel_ids=recording_processed_good.channel_ids
)
si.set_global_job_kwargs(n_jobs=-1, progress_bar=True)

In [76]:
si.set_global_job_kwargs(n_jobs=-1, progress_bar=True)

In [77]:
# create analyzer
analyzer_all = si.create_sorting_analyzer(
    # sorting_all.select_units(ROI_unit_ids),
    sorting_all,
    recording_processed_good,
    sparsity=sparsity_all
)

In [78]:
min_spikes_per_unit = 5
keep_unit_ids = []
count_spikes = sorting_all.count_num_spikes_per_unit()
for unit_id, count in count_spikes.items():
    if count >= min_spikes_per_unit:
        keep_unit_ids.append(unit_id)
print(keep_unit_ids)

analyzer = analyzer_all.select_units(keep_unit_ids)
print(f"Number of units with at least {min_spikes_per_unit} spikes: {len(analyzer.unit_ids)}")

['0 spont_pre', '1 spont_pre', '2 spont_pre', '3 spont_pre', '4 spont_pre', '5 spont_pre', '6 spont_pre', '7 spont_pre', '8 spont_pre', '9 spont_pre', '10 spont_pre', '15 spont_pre', '16 spont_pre', '17 spont_pre', '18 spont_pre', '19 spont_pre', '21 spont_pre', '22 spont_pre', '26 spont_pre', '27 spont_pre', '28 spont_pre', '29 spont_pre', '30 spont_pre', '31 spont_pre', '32 spont_pre', '33 spont_pre', '34 spont_pre', '35 spont_pre', '36 spont_pre', '37 spont_pre', '38 spont_pre', '39 spont_pre', '41 spont_pre', '42 spont_pre', '43 spont_pre', '44 spont_pre', '45 spont_pre', '46 spont_pre', '47 spont_pre', '48 spont_pre', '50 spont_pre', '51 spont_pre', '52 spont_pre', '56 spont_pre', '57 spont_pre', '58 spont_pre', '59 spont_pre', '60 spont_pre', '61 spont_pre', '62 spont_pre', '63 spont_pre', '64 spont_pre', '65 spont_pre', '66 spont_pre', '67 spont_pre', '68 spont_pre', '69 spont_pre', '70 spont_pre', '71 spont_pre', '72 spont_pre', '73 spont_pre', '74 spont_pre', '75 spont_pre', '

In [79]:
_ = analyzer.compute("random_spikes", method="all")
_ = analyzer.compute(["waveforms", "templates"])
waveform_zarr_folder = f'{session_dir[f"ephys_processed_dir_{data_type}"]}/opto_waveforms.zarr'
analyzer_saved_zarr = analyzer.save_as(format='zarr', folder = f'{session_dir[f"ephys_processed_dir_{data_type}"]}/opto_waveforms.zarr')

compute_waveforms:   0%|          | 0/6698 [00:00<?, ?it/s]

ValueError: Folder already exists /root/capsule/scratch/751004/behavior_751004_2024-12-23_14-20-03/ephys/curated/processed/opto_waveforms.zarr

In [80]:
analyzer_saved_zarr = analyzer.save_as(format='zarr', folder = f'{session_dir[f"ephys_processed_dir_{data_type}"]}/opto_waveforms.zarr')

In [7]:
load_sorting_analyzer = True
if load_sorting_analyzer:
    # Load
    # analyzer_loaded = si.load_sorting_analyzer(f'{results_folder}/analyzer_saved')
    waveform_zarr_folder = f'{session_dir[f"ephys_processed_dir_{data_type}"]}/opto_waveforms.zarr'
    analyzer = si.load_sorting_analyzer(waveform_zarr_folder)
    print(analyzer)

SortingAnalyzer: 379 channels - 451 units - 1 segments - zarr - sparse - has recording
Loaded 3 extensions: random_spikes, templates, waveforms


In [9]:
conditions = ['powers', 'sites', 'durations', 'pre_post']

In [10]:
columns = conditions + ['unit_id', 'template', 'peak_channel', 'peak_waveform',  "correlation_spont", 'euclidean_spont', 'pre_post']
waveform_metrics = pd.DataFrame(columns=columns)

template_ext = analyzer.get_extension("templates")
extreme_channel_indices = si.get_template_extremum_channel(analyzer, mode = "at_index", outputs = "index")
extreme_channels = si.get_template_extremum_channel(analyzer) 

for ind_id, unit_id in enumerate(analyzer.unit_ids):
    print(unit_id)
    unit_id_name, case = unit_id.split(" ", 1)
    # print(unit_id_name, case)
    unit_template = template_ext.get_unit_template(unit_id)       
    peak_waveform = unit_template[:,extreme_channel_indices[unit_id]]
    # plt.figure()
    # plt.plot(list(unit_waveform[:,extreme_channel_indices["14 spont"]]))
    # plt.show()

    new_row = dict(unit_id=unit_id_name, template=unit_template, peak_channel=extreme_channels[unit_id], peak_waveform=peak_waveform)    
    if "spont" in case:                
        new_row["spont"] = 1
        if "pre" in case:
            new_row["pre_post"] = "pre"
        else:
            new_row["pre_post"] = "post"
        new_row["power"] = np.nan
        new_row["site"] = np.nan
    else:
        split_strs = case.split(",")
        site_str, power_str, duration_str, duration_str, pre_post_stre = split_strs

        site = site_str.split(":")[1]
        power = power_str.split(":")[1]
        duration = duration_str.split(":")[1]
        # site = site_str.split(":")[1]
        new_row["spont"] = 0
        new_row["site"] = site
        new_row["power"] = power
        # new_row["site"] = power
    waveform_metrics = pd.concat([waveform_metrics, pd.DataFrame([new_row])])


0 spont_pre


InvalidIndexError: Reindexing only valid with uniquely valued Index objects

In [43]:
# calculate correlation and Euclidean distance, normalized by power
waveform_metrics['unit_id'].unique()
sparsity = analyzer.sparsity
all_channels = sparsity.channel_ids
# loop through all rows in the dataframe
for index, row in waveform_metrics.iterrows():
    # get the template and peak waveform for the current row
    if row['spont'] == 1:
        continue
    else:
        template = row['template'][0]
        peak_waveform = row['peak_waveform']
        peak_channel = row['peak_channel']
        unit_id = row['unit_id']

        spont_template = waveform_metrics.query(f"spont == 1 and unit_id == '{unit_id}'")['template'].values[0]
        spont_peak_waveform = waveform_metrics.query(f"spont == 1 and unit_id == '{unit_id}'")['peak_waveform'].values[0]
        spont_peak_channel = waveform_metrics.query(f"spont == 1 and unit_id == '{unit_id}'")['peak_channel'].values[0]
        peak_ind = np.argmin(np.min(spont_template, 0))
        peak_channel_new = all_channels[peak_ind]

         

        if spont_peak_channel != peak_channel:
            print(row['unit_id'], peak_channel, 'vs spout', spont_peak_channel)
            print(f'recomputated {peak_channel_new}')
        # print(row['unit_id'], spont_peak_channel)
        
        # # calculate the correlation of the peak waveform with all waveforms in the template
        # correlation = np.corrcoef(template.reshape(-1), peak_waveform.reshape(-1))[0, 1]
        
        # # calculate the Euclidean distance of the peak waveform from all waveforms in the template
        # euclidean_distance = np.linalg.norm(template.reshape(-1) - peak_waveform.reshape(-1))
        
        # # store the results in the dataframe
        # waveform_metrics.at[index, 'correlation_spont'] = correlation
        # waveform_metrics.at[index, 'euclidean_spont'] = euclidean_distance

3 CH18 vs spout CH16
recomputated CH16
21 CH62 vs spout CH61
recomputated CH61
31 CH76 vs spout CH74
recomputated CH74
37 CH90 vs spout CH88
recomputated CH88
41 CH91 vs spout CH89
recomputated CH89
91 CH310 vs spout CH308
recomputated CH308
98 CH325 vs spout CH323
recomputated CH323
100 CH328 vs spout CH326
recomputated CH326
118 CH353 vs spout CH351
recomputated CH351
119 CH357 vs spout CH353
recomputated CH353
121 CH359 vs spout CH357
recomputated CH357
127 CH371 vs spout CH369
recomputated CH369
129 CH375 vs spout CH371
recomputated CH371
145 CH59 vs spout CH61
recomputated CH61
146 CH65 vs spout CH63
recomputated CH63
194 CH69 vs spout CH67
recomputated CH67
195 CH117 vs spout CH115
recomputated CH115
196 CH93 vs spout CH98
recomputated CH98
3 CH18 vs spout CH16
recomputated CH16
7 CH38 vs spout CH36
recomputated CH36
17 CH57 vs spout CH56
recomputated CH56
19 CH51 vs spout CH59
recomputated CH59
21 CH62 vs spout CH61
recomputated CH61
26 CH69 vs spout CH65
recomputated CH65
28 CH

In [None]:
waveform_metrics.to_csv(f"{session_dir[f'ephys_processed_dir_{data_type}']}/{session_id}_opto_waveform_metrics.csv", index=False)
print(f"Saved waveform metrics to {session_dir[f'ephys_processed_dir_{data_type}']}/{session_id}_opto_waveform_metrics.csv")

In [39]:
test = '3'
waveform_metrics.query(f"unit_id == '{test}'")


Unnamed: 0,unit_id,spont,power,site,template,peak_channel,peak_waveform
0,3,1,,,"[[0.0, 0.0, 0.0, 0.0, 0.04094940051436424, -0....",CH16,"[3.6071085929870605, 3.5766892433166504, 5.173..."
0,3,0,10.0,surface_LC,"[[0.0, 0.0, 0.0, 0.0, -5.637272357940674, -4.9...",CH18,"[-13.587952613830566, -5.451136112213135, -0.4..."
0,3,0,20.0,surface_LC,"[[0.0, 0.0, 0.0, 0.0, -5.191874980926514, -9.6...",CH18,"[0.7800002098083496, -4.899374961853027, -3.31..."
0,3,0,30.0,surface_LC,"[[0.0, 0.0, 0.0, 0.0, 0.9832974672317505, -6.6...",CH16,"[3.970531702041626, 2.0661702156066895, 3.9331..."
0,3,0,40.0,surface_LC,"[[0.0, 0.0, 0.0, 0.0, 6.090882301330566, -11.2...",CH18,"[8.293235778808594, 2.6497058868408203, -3.131..."
0,3,0,50.0,surface_LC,"[[0.0, 0.0, 0.0, 0.0, -1.249772071838379, -2.3...",CH18,"[3.5099997520446777, 8.37613582611084, -3.2440..."
