In [56]:
import os
import sys
sys.path.append('/root/capsule/code/beh_ephys_analysis')
import pandas as pd
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patches as patches
import json
from harp.clock import decode_harp_clock, align_timestamps_to_anchor_points
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
import re
from utils.beh_functions import parseSessionID, session_dirs
from utils.plot_utils import shiftedColorMap, template_reorder
from open_ephys.analysis import Session##
import spikeinterface as si
import spikeinterface.extractors as se
import spikeinterface.postprocessing as spost
import spikeinterface.widgets as sw
from aind_dynamic_foraging_basic_analysis.licks.lick_analysis import load_nwb
from aind_ephys_utils import align
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import colormaps
from aind_dynamic_foraging_data_utils.nwb_utils import load_nwb_from_filename
from spikeinterface.core.sorting_tools import random_spikes_selection
import pickle
import datetime
import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.postprocessing as spost
from tqdm import tqdm

In [2]:
session = 'behavior_751004_2024-12-23_14-20-03'
data_type = 'curated'
target = 'soma'

In [3]:
# opto info
session_dir = session_dirs(session)
opto_df = pd.read_csv(os.path.join(session_dir[f'opto_dir_{data_type}'], f'{session}_opto_session_{target}.csv'), index_col=0)
with open(os.path.join(session_dir[f'opto_dir_{data_type}'], f'{session}_opto_info_{target}.json')) as f:
    opto_info = json.load(f)
powers = opto_info['powers']
sites = opto_info['sites']
num_pulses = opto_info['num_pulses']
duration = opto_info['durations']
pulse_offset = np.array([1/freq for freq in opto_info['freqs']])
total_duration = (pulse_offset) * num_pulses
session_rec = Session(session_dir['session_dir'])
recording = session_rec.recordnodes[0].recordings[0]
timestamps = recording.continuous[0].timestamps
laser_onset_samples = np.searchsorted(timestamps, opto_df['time'].values)
opto_df['laser_onset_samples'] = laser_onset_samples

In [4]:
# recording info
sorting = si.load_extractor(session_dir[f'curated_dir_{data_type}'])
max_spikes_per_unit_spontaneous = 500
spike_vector = sorting.to_spike_vector()
unit_ids = sorting.unit_ids
num_units = len(sorting.unit_ids)
print(f"Total {len(sorting.unit_ids)} units")
max_spikes_per_unit_spontaneous = 500

Total 177 units


In [61]:
# Spike indices
spike_indices = spike_vector["sample_index"]

response_window = opto_info['resp_win']+0.005 # to take care of potential late responses
response_window_samples = int(response_window * sorting.sampling_frequency)

num_cases = len(opto_info['powers']) * len(opto_info['sites']) * len(opto_info['durations'])

unit_index_offset = num_units
all_unit_ids = [f"{u} spont" for u in unit_ids]

all_spikes_in_responses = []
spike_indices_removed = []

for power_ind, curr_power in enumerate(opto_info['powers']):
    for site_ind, curr_site in enumerate(opto_info['sites']):                                                             
        for duration_ind, curr_duration in enumerate(opto_info['durations']):
            spikes_in_response = []
            onset_offset_indices = []
            for freq_ind, curr_freq in enumerate(opto_info['freqs']):
                for num_pulse_ind, curr_num_pulses in enumerate(opto_info['num_pulses']):
                    onset_samples = opto_df.query('site == @curr_site and power == @curr_power and duration == @curr_duration and freq == @curr_freq and num_pulses == @curr_num_pulses')['laser_onset_samples'].values
                    pulse_offset = 1/curr_freq * sorting.sampling_frequency
                    pulse_offset_samples = int(pulse_offset * sorting.sampling_frequency)
                    for pulse_ind in range(curr_num_pulses):
                        for onset_sample in onset_samples:
                            # response window
                            onset_response = onset_sample + pulse_ind * pulse_offset_samples
                            onset_offset_indices.append(onset_response)
                            offset_response = onset_response + response_window_samples
                            onset_offset_indices.append(offset_response)
                start_stop_indices = np.searchsorted(spike_indices, np.array(onset_offset_indices))
                for i, (start, stop) in enumerate(zip(start_stop_indices[::2], start_stop_indices[1::2])):
                    sv = spike_vector[start:stop]
                    if len(sv) > 0:
                        spike_indices_removed.append(np.arange(start, stop))
                    sv_copy = sv.copy()
                    sv_copy["unit_index"] = sv_copy["unit_index"] + unit_index_offset
                    spikes_in_response.append(sv_copy)
                    # num_cases += 1
                    
            spikes_in_response = np.concatenate(spikes_in_response)
            print('# of responding units: ', len(np.unique(spikes_in_response['unit_index'])))
            unit_index_offset += num_units
            all_spikes_in_responses.append(spikes_in_response)
            new_unit_ids = [f"{u} emission_location:{curr_site}, power:{curr_power}, duration:{curr_duration}, freq:{curr_freq}" for u in unit_ids]
            print(f"emission_location:{curr_site}, power:{curr_power} , duration: {curr_duration}, freq: {curr_freq} #spikes: {len(spikes_in_response)}")
            all_unit_ids += new_unit_ids
                            



            
all_spikes_in_responses = np.concatenate(all_spikes_in_responses)
spike_indices_removed = np.concatenate(spike_indices_removed)

# select random spontaneous spikes
all_spikes_not_in_responses = np.delete(spike_vector, spike_indices_removed)
sorting_no_responses = si.NumpySorting(
    all_spikes_not_in_responses, 
    unit_ids=unit_ids, 
    sampling_frequency=sorting.sampling_frequency
)
random_spike_indices = random_spikes_selection(sorting_no_responses, method="uniform",
                                               max_spikes_per_unit=max_spikes_per_unit_spontaneous)
selected_spikes_no_responses = all_spikes_not_in_responses[random_spike_indices]

all_spikes = np.concatenate([selected_spikes_no_responses, all_spikes_in_responses])

# sort by segment+index
all_spikes = all_spikes[
    np.lexsort((all_spikes["sample_index"], all_spikes["segment_index"]))
]

sorting_all = si.NumpySorting(
    all_spikes, 
    unit_ids=all_unit_ids, 
    sampling_frequency=sorting.sampling_frequency
)

# of responding units:  66
emission_location:surface_LC, power:10 , duration: 4, freq: 5 #spikes: 488
# of responding units:  88
emission_location:surface_LC, power:20 , duration: 4, freq: 5 #spikes: 926
# of responding units:  131
emission_location:surface_LC, power:30 , duration: 4, freq: 5 #spikes: 2772
# of responding units:  84
emission_location:surface_LC, power:40 , duration: 4, freq: 5 #spikes: 961
# of responding units:  90
emission_location:surface_LC, power:50 , duration: 4, freq: 5 #spikes: 976


In [62]:
print("original:", len(spike_vector))
print("sampled:", len(sorting_all.to_spike_vector()))

original: 6634213
sampled: 82055


In [63]:
def load_and_preprocess_recording(session_folder, stream_name):
    ephys_path = os.path.dirname(session_folder)
    compressed_folder = os.path.join(ephys_path, 'ecephys_compressed')
    recording_zarr = [os.path.join(compressed_folder, f) for f in os.listdir(compressed_folder) if stream_name in f][0]
    recording = si.read_zarr(recording_zarr)
    # preprocess
    recording_processed = spre.phase_shift(recording)
    recording_processed = spre.highpass_filter(recording_processed)    
    recording_processed = spre.common_reference(recording_processed)
    return recording_processed

In [32]:
recording_processed = load_and_preprocess_recording(session_dir['session_dir'], 'ProbeA')

In [66]:
we = si.load_sorting_analyzer_or_waveforms(session_dir[f'postprocessed_dir_{data_type}'])

In [67]:
# filter good channels
good_channel_ids = recording_processed.channel_ids[
    np.in1d(recording_processed.channel_ids, we.channel_ids)
]

recording_processed_good = recording_processed.select_channels(good_channel_ids)
print(f"Num good channels: {recording_processed_good.get_num_channels()}")

Num good channels: 379


In [68]:
num_cases = len(sorting_all.unit_ids) // num_units - 1

In [70]:
sparsity_mask_all = np.tile(we.sparsity.mask, (num_cases + 1, 1))
sparsity_all = si.ChannelSparsity(
    sparsity_mask_all,
    unit_ids=sorting_all.unit_ids,
    channel_ids=recording_processed_good.channel_ids
)
si.set_global_job_kwargs(n_jobs=-1, progress_bar=True)

In [71]:
si.set_global_job_kwargs(n_jobs=-1, progress_bar=True)

In [72]:
# create analyzer
analyzer_all = si.create_sorting_analyzer(
    # sorting_all.select_units(ROI_unit_ids),
    sorting_all,
    recording_processed_good,
    sparsity=sparsity_all
)

In [73]:
min_spikes_per_unit = 5
keep_unit_ids = []
count_spikes = sorting_all.count_num_spikes_per_unit()
for unit_id, count in count_spikes.items():
    if count >= min_spikes_per_unit:
        keep_unit_ids.append(unit_id)
print(keep_unit_ids)

analyzer = analyzer_all.select_units(keep_unit_ids)
print(f"Number of units with at least {min_spikes_per_unit} spikes: {len(analyzer.unit_ids)}")

['0 spont', '1 spont', '2 spont', '3 spont', '4 spont', '5 spont', '6 spont', '7 spont', '8 spont', '9 spont', '10 spont', '15 spont', '16 spont', '17 spont', '18 spont', '19 spont', '21 spont', '22 spont', '26 spont', '27 spont', '28 spont', '29 spont', '30 spont', '31 spont', '32 spont', '33 spont', '34 spont', '35 spont', '36 spont', '37 spont', '38 spont', '39 spont', '41 spont', '42 spont', '43 spont', '44 spont', '45 spont', '46 spont', '47 spont', '48 spont', '50 spont', '51 spont', '52 spont', '56 spont', '57 spont', '58 spont', '59 spont', '60 spont', '61 spont', '62 spont', '63 spont', '64 spont', '65 spont', '66 spont', '67 spont', '68 spont', '69 spont', '70 spont', '71 spont', '72 spont', '73 spont', '74 spont', '75 spont', '76 spont', '77 spont', '78 spont', '79 spont', '80 spont', '81 spont', '82 spont', '83 spont', '84 spont', '85 spont', '86 spont', '87 spont', '88 spont', '89 spont', '90 spont', '91 spont', '92 spont', '93 spont', '94 spont', '95 spont', '96 spont', '

In [74]:
_ = analyzer.compute("random_spikes", method="all")
_ = analyzer.compute(["waveforms", "templates"])

compute_waveforms:   0%|          | 0/6698 [00:00<?, ?it/s]

In [76]:
analyzer_saved_zarr = analyzer.save_as(format='zarr', folder = f'{session_dir[f"ephys_processed_dir_{data_type}"]}/opto_waveforms.zarr')

# load_sorting_analyzer = False
# if load_sorting_analyzer:
#     # Load
#     # analyzer_loaded = si.load_sorting_analyzer(f'{results_folder}/analyzer_saved')
#     analyzer_loaded_zarr = si.load_sorting_analyzer(f'{results_folder}/analyzer_saved.zarr')
#     print(analyzer_loaded_zarr)

In [60]:
columns=["unit_id", "spont", "power", "site", "template", "peak_channel", "peak_waveform"]
waveform_metrics = pd.DataFrame(columns=columns)

template_ext = analyzer.get_extension("templates")
extreme_channel_indices = si.get_template_extremum_channel(analyzer, mode = "at_index", outputs = "index")
extreme_channels = si.get_template_extremum_channel(analyzer) 

for ind_id, unit_id in enumerate(analyzer.unit_ids):
    print(unit_id)
    unit_id_name, case = unit_id.split(" ", 1)
    # print(unit_id_name, case)
    unit_template = template_ext.get_unit_template(unit_id)       
    peak_waveform = unit_template[:,extreme_channel_indices[unit_id]]
    # plt.figure()
    # plt.plot(list(unit_waveform[:,extreme_channel_indices["14 spont"]]))
    # plt.show()

    new_row = dict(unit_id=unit_id_name, template=unit_template, peak_channel=extreme_channels[unit_id], peak_waveform=peak_waveform)    
    if "spont" in case:                
        new_row["spont"] = 1
        new_row["power"] = np.nan
        new_row["site"] = np.nan
    else:
        split_strs = case.split(",")
        site_str, power_str, duration_str, freq_str = split_strs

        site = site_str.split(":")[1]
        power = power_str.split(":")[1]
        duration = duration_str.split(":")[1]
        # site = site_str.split(":")[1]
        new_row["spont"] = 0
        new_row["site"] = site
        new_row["power"] = power
        # new_row["site"] = power
    waveform_metrics = pd.concat([waveform_metrics, pd.DataFrame([new_row])])


0 spont
1 spont
2 spont
0 spont
196 emission_location:surface_LC, power:50, duration:4, freq:5


ValueError: too many values to unpack (expected 3)

In [None]:
# calculate correlation and Euclidean distance, normalized by power


In [None]:


waveform_metrics.to_csv(f"{session_dir[f'ephys_processed_dir_{data_type}']}/{session_id}_opto_waveform_metrics.csv", index=False)
print(f"Saved waveform metrics to {session_dir[f'ephys_processed_dir_{data_type}']}/{session_id}_opto_waveform_metrics.csv")