# Setup

In [None]:
"""Import necessary packages"""

import sys, struct, math, os, time
import numpy as np
from ismember import ismember
import umap
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import zscore, ttest_rel, pearsonr
from scipy import signal, special

sys.path.append(os.path.dirname(os.getcwd()) + '/util/')
sys.path.append('./util/')
from intanutil.read_header import read_header
from intanutil.get_bytes_per_data_block import get_bytes_per_data_block
from intanutil.read_one_data_block import read_one_data_block
from intanutil.notch_filter import notch_filter
from intanutil.data_to_result import data_to_result

import spikeinterface
spikeinterface.__version__
import spikeinterface
import spikeinterface.extractors as se
import spikeinterface.toolkit as st
import spikeinterface.sorters as ss
import spikeinterface.comparison as sc
import spikeinterface.widgets as sw
import matplotlib.pyplot as plt
from spikeinterface import WaveformExtractor
import shutil
import seaborn as sns
import pandas as pd
import seaborn as sns
from scipy.io import loadmat
from spikeinterface.core.npzsortingextractor import NpzSortingExtractor
from pylab import *
ss.Kilosort3Sorter.set_kilosort3_path('/kilosort3')
import pylab
from sklearn.neighbors import LocalOutlierFactor
from pprint import pprint

from probeinterface.utils import combine_probes
from probeinterface import generate_multi_columns_probe, Probe
from probeinterface.plotting import plot_probe

import copy
from util.brpylib import NsxFile, brpylib_ver
from calculate_features import features_5

from CurationTool import shank_autocorrelogram_show, shank_correlogram_show
from spikeinterface.sorters import WaveClusSorter, IronClustSorter, Kilosort3Sorter

In [None]:
def read_data(filename):
    """Reads Intan Technologies RHD2000 data file generated by evaluation board GUI.
    
    Data are returned in a dictionary, for future extensibility.
    """

    tic = time.time()
    fid = open(filename, 'rb')
    filesize = os.path.getsize(filename)

    header = read_header(fid)

    print('Found {} amplifier channel{}.'.format(header['num_amplifier_channels'], plural(header['num_amplifier_channels'])))
    print('Found {} auxiliary input channel{}.'.format(header['num_aux_input_channels'], plural(header['num_aux_input_channels'])))
    print('Found {} supply voltage channel{}.'.format(header['num_supply_voltage_channels'], plural(header['num_supply_voltage_channels'])))
    print('Found {} board ADC channel{}.'.format(header['num_board_adc_channels'], plural(header['num_board_adc_channels'])))
    print('Found {} board digital input channel{}.'.format(header['num_board_dig_in_channels'], plural(header['num_board_dig_in_channels'])))
    print('Found {} board digital output channel{}.'.format(header['num_board_dig_out_channels'], plural(header['num_board_dig_out_channels'])))
    print('Found {} temperature sensors channel{}.'.format(header['num_temp_sensor_channels'], plural(header['num_temp_sensor_channels'])))
    print('')

    # Determine how many samples the data file contains.
    bytes_per_block = get_bytes_per_data_block(header)

    # How many data blocks remain in this file?
    data_present = False
    bytes_remaining = filesize - fid.tell()
    if bytes_remaining > 0:
        data_present = True

    if bytes_remaining % bytes_per_block != 0:
        raise Exception('Something is wrong with file size : should have a whole number of data blocks')

    num_data_blocks = int(bytes_remaining / bytes_per_block)

    num_amplifier_samples = header['num_samples_per_data_block'] * num_data_blocks
    num_aux_input_samples = int((header['num_samples_per_data_block'] / 4) * num_data_blocks)
    num_supply_voltage_samples = 1 * num_data_blocks
    num_board_adc_samples = header['num_samples_per_data_block'] * num_data_blocks
    num_board_dig_in_samples = header['num_samples_per_data_block'] * num_data_blocks
    num_board_dig_out_samples = header['num_samples_per_data_block'] * num_data_blocks

    record_time = num_amplifier_samples / header['sample_rate']

    if data_present:
        print('File contains {:0.3f} seconds of data.  Amplifiers were sampled at {:0.2f} kS/s.'.format(record_time, header['sample_rate'] / 1000))
    else:
        print('Header file contains no data.  Amplifiers were sampled at {:0.2f} kS/s.'.format(header['sample_rate'] / 1000))

    if data_present:
        # Pre-allocate memory for data.
        print('')
        print('Allocating memory for data...')

        data = {}
        if (header['version']['major'] == 1 and header['version']['minor'] >= 2) or (header['version']['major'] > 1):
            data['t_amplifier'] = np.zeros(num_amplifier_samples, dtype=np.int)
        else:
            data['t_amplifier'] = np.zeros(num_amplifier_samples, dtype=np.uint)

        data['amplifier_data'] = np.zeros([header['num_amplifier_channels'], num_amplifier_samples], dtype=np.uint)
        data['aux_input_data'] = np.zeros([header['num_aux_input_channels'], num_aux_input_samples], dtype=np.uint)
        data['supply_voltage_data'] = np.zeros([header['num_supply_voltage_channels'], num_supply_voltage_samples], dtype=np.uint)
        data['temp_sensor_data'] = np.zeros([header['num_temp_sensor_channels'], num_supply_voltage_samples], dtype=np.uint)
        data['board_adc_data'] = np.zeros([header['num_board_adc_channels'], num_board_adc_samples], dtype=np.uint)
        
        # by default, this script interprets digital events (digital inputs and outputs) as booleans
        # if unsigned int values are preferred(0 for False, 1 for True), replace the 'dtype=np.bool' argument with 'dtype=np.uint' as shown
        # the commented line below illustrates this for digital input data; the same can be done for digital out
        
        #data['board_dig_in_data'] = np.zeros([header['num_board_dig_in_channels'], num_board_dig_in_samples], dtype=np.uint)
        data['board_dig_in_data'] = np.zeros([header['num_board_dig_in_channels'], num_board_dig_in_samples], dtype=np.bool)
        data['board_dig_in_raw'] = np.zeros(num_board_dig_in_samples, dtype=np.uint)
        
        data['board_dig_out_data'] = np.zeros([header['num_board_dig_out_channels'], num_board_dig_out_samples], dtype=np.bool)
        data['board_dig_out_raw'] = np.zeros(num_board_dig_out_samples, dtype=np.uint)

        # Read sampled data from file.
        print('Reading data from file...')

        # Initialize indices used in looping
        indices = {}
        indices['amplifier'] = 0
        indices['aux_input'] = 0
        indices['supply_voltage'] = 0
        indices['board_adc'] = 0
        indices['board_dig_in'] = 0
        indices['board_dig_out'] = 0

        print_increment = 10
        percent_done = print_increment
        for i in range(num_data_blocks):
            read_one_data_block(data, header, indices, fid)

            # Increment indices
            indices['amplifier'] += header['num_samples_per_data_block']
            indices['aux_input'] += int(header['num_samples_per_data_block'] / 4)
            indices['supply_voltage'] += 1
            indices['board_adc'] += header['num_samples_per_data_block']
            indices['board_dig_in'] += header['num_samples_per_data_block']
            indices['board_dig_out'] += header['num_samples_per_data_block']            

            fraction_done = 100 * (1.0 * i / num_data_blocks)
            if fraction_done >= percent_done:
                print('{}% done...'.format(percent_done))
                percent_done = percent_done + print_increment

        # Make sure we have read exactly the right amount of data.
        bytes_remaining = filesize - fid.tell()
        if bytes_remaining != 0: raise Exception('Error: End of file not reached.')

    # Close data file.
    fid.close()

    if (data_present):
        print('Parsing data...')

        # Extract digital input channels to separate variables.
        for i in range(header['num_board_dig_in_channels']):
            data['board_dig_in_data'][i, :] = np.not_equal(np.bitwise_and(data['board_dig_in_raw'], (1 << header['board_dig_in_channels'][i]['native_order'])), 0)

        # Extract digital output channels to separate variables.
        for i in range(header['num_board_dig_out_channels']):
            data['board_dig_out_data'][i, :] = np.not_equal(np.bitwise_and(data['board_dig_out_raw'], (1 << header['board_dig_out_channels'][i]['native_order'])), 0)

        # Scale voltage levels appropriately.
        data['amplifier_data'] = np.multiply(0.195, (data['amplifier_data'].astype(np.int32) - 32768))      # units = microvolts
        data['aux_input_data'] = np.multiply(37.4e-6, data['aux_input_data'])               # units = volts
        data['supply_voltage_data'] = np.multiply(74.8e-6, data['supply_voltage_data'])     # units = volts
        if header['eval_board_mode'] == 1:
            data['board_adc_data'] = np.multiply(152.59e-6, (data['board_adc_data'].astype(np.int32) - 32768)) # units = volts
        elif header['eval_board_mode'] == 13:
            data['board_adc_data'] = np.multiply(312.5e-6, (data['board_adc_data'].astype(np.int32) - 32768)) # units = volts
        else:
            data['board_adc_data'] = np.multiply(50.354e-6, data['board_adc_data'])           # units = volts
        data['temp_sensor_data'] = np.multiply(0.01, data['temp_sensor_data'])               # units = deg C

        # Check for gaps in timestamps.
        num_gaps = np.sum(np.not_equal(data['t_amplifier'][1:]-data['t_amplifier'][:-1], 1))
        if num_gaps == 0:
            print('No missing timestamps in data.')
        else:
            print('Warning: {0} gaps in timestamp data found.  Time scale will not be uniform!'.format(num_gaps))

        # Scale time steps (units = seconds).
        data['t_amplifier'] = data['t_amplifier'] / header['sample_rate']
        data['t_aux_input'] = data['t_amplifier'][range(0, len(data['t_amplifier']), 4)]
        data['t_supply_voltage'] = data['t_amplifier'][range(0, len(data['t_amplifier']), header['num_samples_per_data_block'])]
        data['t_board_adc'] = data['t_amplifier']
        data['t_dig'] = data['t_amplifier']
        data['t_temp_sensor'] = data['t_supply_voltage']

        # If the software notch filter was selected during the recording, apply the
        # same notch filter to amplifier data here.
        if header['notch_filter_frequency'] > 0 and header['version']['major'] < 3:
            print('Applying notch filter...')

            print_increment = 10
            percent_done = print_increment
            for i in range(header['num_amplifier_channels']):
                data['amplifier_data'][i,:] = notch_filter(data['amplifier_data'][i,:], header['sample_rate'], header['notch_filter_frequency'], 10)

                fraction_done = 100 * (i / header['num_amplifier_channels'])
                if fraction_done >= percent_done:
                    print('{}% done...'.format(percent_done))
                    percent_done += print_increment
    else:
        data = [];

    # Move variables to result struct.
    result = data_to_result(header, data, data_present)

    print('Done!  Elapsed time: {0:0.1f} seconds'.format(time.time() - tic))
    return result

def plural(n):
    """Utility function to optionally pluralize words based on the value of n.
    """

    if n == 1:
        return ''
    else:
        return 's'

def sorting_day_split(sorting, date_id_all, day_length, pack_folder, sorting_save_name='firings_inlier'):
    """
    TBU
    """
    
    sampling_freq = sorting.get_sampling_frequency()
    
    fig, ax = plt.subplots(1,1,figsize=(20,40))
    sw.plot_rasters(sorting,  time_range=(0, np.sum(day_length)),ax=ax)

    colors = []
    cm = pylab.get_cmap('rainbow')
    NUM_COLORS = len(day_length)
    for i in range(NUM_COLORS):
        colors.append(cm(1. * i / NUM_COLORS))  # color will now be an RGBA tuple
    
    for i in range(len(day_length)):
        ax.axvspan(np.sum(day_length[:i])/sampling_freq, np.sum(day_length[:(i+1)])/sampling_freq, 
                   facecolor=colors[i], alpha=0.1)
    
    plt.savefig(waveform_folder+'/rasters.pdf',dpi=300)   
    plt.show()
    
    for i in range(len(day_length)):
        pack_folder_i = pack_folder + '/' + date_id_all[i] + '/'

        if os.path.exists(pack_folder_i)==False:
            os.mkdir(pack_folder_i)

        start_frame = np.sum(day_length[:i])
        end_frame = np.sum(day_length[:(i+1)])

        sub_sorting = sorting.frame_slice(start_frame, end_frame)

        keep_unit_ids = []
        for unit_id in sub_sorting.unit_ids:
            spike_train = sub_sorting.get_unit_spike_train(unit_id=unit_id)
            n = spike_train.size
            if(n>20):
                keep_unit_ids.append(unit_id)

        curated_sub_sorting = sub_sorting.select_units(unit_ids=keep_unit_ids, renamed_unit_ids=None)

        save_path = pack_folder_i + '/sorting/' + sorting_save_name + '.npz'
        if os.path.exists(pack_folder_i + '/sorting/')==False:
            os.mkdir(pack_folder_i + '/sorting/')
        NpzSortingExtractor.write_sorting(curated_sub_sorting, save_path)

def create_mesh_probe(n):
    """
    positions=np.array([[80,0],[80,80],[80,160],[80,240],[0,0],
                         [0,80],[0,160],[0,240],[0,320],[0,400],
                         [0,480],[0,560],[80,320],[80,400],[80,480],
                         [80,560],[160,560],[160,480],[160,400],[160,320],
                        [240,560],[240,480],[240,400],[240,320],[240,240],
                         [240,160],[240,80],[240,0],[160,240],[160,160],
                       [160,80],[160,0]])
                       """
    
    #distribution ch 16-48
    positions=np.array([[160,0],[80,0],[160,80],[80,80],[160,160],
                         [80,160],[160,240],[80,240],[160,320],[80,320],
                         [160,400],[80,400],[160,480],[80,480],[160,560],
                         [80,560],[240,0],[0,0],[240,80],[0,80],
                         [240,160],[0,160],[240,240],[0,240],[240,320],
                         [0,320],[240,400],[0,400],[240,480],[0,480],
                         [240,560],[0,560]])
    
    mesh_probe = Probe(ndim=2, si_units='um')
    mesh_probe.set_contacts(positions=positions, shapes='circle', shape_params={'radius': 5})

    ant = {'first_index':0}
    mesh_probe.annotate(**ant)
    channel_indices_raw = np.arange(n)
    channel_indices = [i for i in channel_indices_raw]
    mesh_probe.set_device_channel_indices(channel_indices)
    return mesh_probe

def stack_recordings(pack_folder_pre, mesh_probe, trigger_val=3.5):
    """
    TBU
    """
    
    cont_data_all = []
    
    for dirpath, dirname, filenames in os.walk(pack_folder_pre):
        for i in filenames:
            if '.rhd' in i:
                print(dirpath+'/'+i)
                raw_data = read_data(dirpath+'/'+i)
                sampling_freq = raw_data['frequency_parameters']['amplifier_sample_rate']
                cont_data_all.append(raw_data['amplifier_data'].T)

    cont_data_all = np.vstack(cont_data_all)
    recording = se.NumpyRecording(traces_list=cont_data_all, sampling_frequency=sampling_freq)
    recording.set_probe(mesh_probe, in_place=True)

    return recording, cont_data_all

In [None]:
mesh_probe = create_mesh_probe(32)
plot_probe(mesh_probe, with_channel_index=True)
plot_probe(mesh_probe, with_device_index=True)

# Conduct Spike Sorting

In [None]:
"""Set parameters"""

date_id_all = ['043023_axolotl_B_1', '050123_axolotl_B_1', '050223_axolotl_B_1', '050323_axolotl_B_1', '050423_axolotl_B_1']

recording_traces = []
session_length_concat = []
day_length = []
cont_trigger_all_all = []
save_folder_name = '_'.join(date_id_all)
data_folder_all = f'./processed_data/Ephys_concat_{save_folder_name}/'

sorting_method="mountainsort"

sorting_save_path = data_folder_all + sorting_method + '/'
pack_folder = sorting_save_path

output_folder = sorting_save_path + '/sorting'
firing_save_path = output_folder + f'/firings.npz'

freq_max = 3000
freq_min = 300
fs = 10000
default_params = {
        'detect_sign': -1,  # Use -1, 0, or 1, depending on the sign of the spikes in the recording
        'adjacency_radius': -1,  # Use -1 to include all channels in every neighborhood
        'freq_min': 300,  # Use None for no bandpass filtering
        'freq_max': 3000,
        'filter': True,
        'whiten': True,  # Whether to do channel whitening as part of preprocessing
        # 'curation': False,
        # 'num_workers': None,
        'num_workers': 9,
        'clip_size': 50,
        'detect_threshold': 5, # 5
        'detect_interval': 30,  # Minimum number of timepoints between events detected on the same channel, 30
        # 'noise_overlap_threshold': None,  # Use None for no automated curation'
    }

if sorting_method == 'waveclus':
    default_TDC_params = ss.WaveClusSorter.default_params()
    default_TDC_params['detect_threshold']=5
    pprint(default_TDC_params)
if sorting_method == 'kilosort3':
    default_TDC_params = ss.Kilosort3Sorter.default_params()
    pprint(default_TDC_params)
if sorting_method == 'ironclust':
    default_TDC_params = ss.IronClustSorter.default_params()
    pprint(default_TDC_params)    

In [None]:
"""Load Recordings"""

if os.path.exists(data_folder_all+'recordings/'):
    print('data_folder_all already exists')
    recording_concat = spikeinterface.core.base.BaseExtractor.load_from_folder(data_folder_all+'recordings/')
    session_length_concat = np.load(data_folder_all+'recordings/session_length.npy')
    day_length=np.load(data_folder_all+'recordings/day_length.npy')

else:
    for date_id in date_id_all:
        pack_folder_pre =  f'/home/jialiulab/disk1/yichun/hao_sheng/data/cyborg_axolotl/{date_id}' # Source data folder (flipped?)
        data_folder_pre = data_folder_all + date_id + '/recordings/' # Output folder
        
        # If output folder exists, just load recording object from there
        if os.path.exists(data_folder_pre):
            recording = spikeinterface.core.base.BaseExtractor.load_from_folder(data_folder_pre)
            session_length = np.load(data_folder_pre + 'session_length.npy')
        
        # Otherwise, read in .rhd file, create recording object, and save down recording object
        else:
            mesh_probe = create_mesh_probe(32)
            recording, session_length = stack_recordings(pack_folder_pre, mesh_probe, trigger_val=3)
            recording.set_probe(mesh_probe, in_place=True)
            recording = recording.save(folder=data_folder_pre)
            np.save(data_folder_pre+'session_length.npy', session_length)
        
        sampling_freq = recording.get_sampling_frequency()
        recording_trace = recording.get_traces()
        recording_traces.append(recording_trace)
        session_length_concat.append(session_length)
        print(recording_trace.shape)
        day_length.append(recording_trace.shape[0])

    recording_traces = np.vstack(recording_traces)
    session_length_concat = np.vstack(session_length_concat)

    recording_concat = se.NumpyRecording(traces_list=recording_traces, sampling_frequency=sampling_freq)
    recording_concat.set_probe(mesh_probe, in_place=True)
    recording_concat = recording_concat.save(folder = data_folder_all + 'recordings/')

    np.save(data_folder_all + 'recordings/session_length.npy', session_length_concat)
    np.save(data_folder_all + 'recordings/day_length.npy', day_length)
    
print(recording_concat)
print('Num. channels = {}'.format(len(recording_concat.get_channel_ids())))
print('Sampling frequency = {} Hz'.format(recording_concat.get_sampling_frequency()))
print('Num. timepoints seg0= {}'.format(recording_concat.get_num_segments()))

if os.path.exists(sorting_save_path)==False:
    print(sorting_save_path)
    os.mkdir(sorting_save_path)

In [None]:
recording_f = st.preprocessing.bandpass_filter(recording_concat, freq_min=freq_min, freq_max=freq_max)
recording_cmr = st.preprocessing.common_reference(recording_f, reference='global',operator='average')

In [None]:
if not os.path.exists(firing_save_path):
    print(0)
    sorting_wave_clus = ss.run_sorter(sorter_name='mountainsort4', 
                                      recording=recording_cmr, 
                                      remove_existing_folder='True', 
                                      output_folder=output_folder, 
                                      **default_params,)        
    keep_unit_ids = []
    for unit_id in sorting_wave_clus.unit_ids:
        spike_train = sorting_wave_clus.get_unit_spike_train(unit_id=unit_id)
        n = spike_train.size
        if(n>20):
            keep_unit_ids.append(unit_id)

    sorting = sorting_wave_clus.select_units(unit_ids=keep_unit_ids, renamed_unit_ids=None)
    NpzSortingExtractor.write_sorting(sorting, firing_save_path)

sorting = se.NpzSortingExtractor(firing_save_path)

# Extract and show waveforms (all days)

In [None]:
def sorting_unit_show(we, recording_cmr, sorting, pack_folder, waveform_type):
    """
    TBU
    """

    # plot_probe(mesh_probe,with_channel_index=True)
    
    waveform_folder = pack_folder + 'waveforms/'
    
    fig, axs = plt.subplots(int(np.ceil(len(sorting.unit_ids)/4)), 4, figsize=(20, 5*np.ceil(len(sorting.unit_ids)/4)))
    sw.plot_unit_templates(we, unit_ids=sorting.unit_ids, axes=axs)
    plt.savefig(waveform_folder+'/templates' + waveform_type + '.pdf',dpi=600)
        
    # fig, axs = plt.subplots(int(np.ceil(len(sorting.unit_ids)/4)), 4, figsize=(20, 5*np.ceil(len(sorting.unit_ids)/4)))
    # sw.plot_unit_probe_map(we, unit_ids=sorting.unit_ids,
    #                   axes=axs)
    # plt.savefig(waveform_folder+'/probe_map.pdf',dpi=300)
        
    extremum_channels_ids = st.get_template_extremum_channel(we, peak_sign='neg')
    
    colors=[]
    cm = get_cmap('rainbow')
    NUM_COLORS = len(sorting.unit_ids)
    for i in range(NUM_COLORS):
        colors.append(cm(1. * i / NUM_COLORS))  # color will now be an RGBA tuple

    # for i, unit_id in enumerate(sorting.unit_ids):
    #     template = we.get_waveforms(unit_id)
    #     ax = axs[int(np.floor(i/4)), int(np.mod(i,4))]
    #     ax.plot(template[:,:, extremum_channels_ids[unit_id]].T, lw=0.3,label=unit_id,color=colors[i])
    #     ax.set_title(f'template{unit_id}')
    #     break

    # plt.savefig(waveform_folder+'/extremum_waveforms_map.pdf',dpi=300)
    
    fig, axs = plt.subplots(int(np.ceil(len(sorting.unit_ids)/4)), 4, figsize=(20, 5*np.ceil(len(sorting.unit_ids)/4)))

    for i, unit_id in enumerate(sorting.unit_ids):
        if int(np.ceil(len(sorting.unit_ids)/4))>1:
            template = we.get_template(unit_id)
            ax = axs[int(np.floor(i/4)), int(np.mod(i,4))]
            ax.plot(template[:, extremum_channels_ids[unit_id]].T, lw=3,label=unit_id,color=colors[i])
            ax.set_title(f'template{unit_id}')
        else:
            template = we.get_template(unit_id)
            ax = axs[int(np.mod(i,4))]
            ax.plot(template[:, extremum_channels_ids[unit_id]].T, lw=3,label=unit_id,color=colors[i])
            ax.set_title(f'template{unit_id}')
    
    plt.savefig(waveform_folder+'/waveform' + waveform_type + '.pdf',dpi=600)

In [None]:
waveform_folder = pack_folder + 'waveforms/'

we = spikeinterface.extract_waveforms(recording_cmr, sorting, waveform_folder, 
                                      load_if_exists=True, ms_before=1, ms_after=2., 
                                      max_spikes_per_unit=1000000, n_jobs=-1, chunk_size=30000)

In [None]:
we = we[0]
we.recording.set_probe(mesh_probe, in_place=True)

In [None]:
sorting_day_split(sorting, date_id_all, day_length, pack_folder, 
                  sorting_save_name='firings')

In [None]:
fig,ax = plt.subplots(8,4,figsize=(15,10))
sw.plot_isi_distribution(sorting, window_ms=200.0, bin_ms=1.0,axes=ax)
plt.savefig(waveform_folder+'/ISI.pdf',dpi=600)

In [None]:
sorting_unit_show(we, recording_cmr, sorting, pack_folder, '')

# Curation

In [None]:
merge_unit_ids_pack = []
delete_unit_ids_pack = [3,5,9,21,22,24,25,26]

we_load_if_exists = True
waveform_show = False
input_state = 'merged'

In [None]:
curation_save_folder = pack_folder + f'/curation_result_{input_state}/'

if os.path.exists(curation_save_folder)==False:
    os.mkdir(curation_save_folder)

In [None]:
S = sorting._sorting_segments[0]
merged_sorting = sorting
remove_ids = []

for idx in range(len(merge_unit_ids_pack)):
    merge_unit_ids = merge_unit_ids_pack[idx]

    for unit_id_id, unit_id in enumerate(merge_unit_ids):
        S.spike_labels[S.spike_labels == unit_id] = merge_unit_ids[0]

    merged_sorting._sorting_segments[0] = S
    remove_ids.extend(merge_unit_ids[1:])

remove_ids+=delete_unit_ids_pack
keep_ids = merged_sorting.unit_ids[~np.isin(merged_sorting.unit_ids, remove_ids)]

merged_sorting = merged_sorting.select_units(unit_ids=keep_ids, renamed_unit_ids=None)    

waveform_save_folder = pack_folder + '/waveforms_merged'
if(we_load_if_exists == False):
    if os.path.exists(waveform_save_folder):
        shutil.rmtree(waveform_save_folder)

In [None]:
save_path = pack_folder + f'/sorting/firings_merged.npz'
NpzSortingExtractor.write_sorting(merged_sorting, save_path)

merged_sorting = se.NpzSortingExtractor(save_path)

merged_we = spikeinterface.extract_waveforms(recording_cmr, merged_sorting, waveform_save_folder, 
                                             load_if_exists=we_load_if_exists, overwrite=False,ms_before=1, ms_after=2, 
                                             max_spikes_per_unit=1000000, n_jobs=1, chunk_size=30000)

In [None]:
sorting = merged_sorting
we = merged_we

In [None]:
we=we[0]

# Analysis (All Days)

In [None]:
sorting_day_split(sorting, date_id_all, day_length, pack_folder, 
                  sorting_save_name='firings_merged')

In [None]:
fig,ax = plt.subplots(5,4,figsize=(15,12))
sw.plot_isi_distribution(sorting, window_ms=200.0, bin_ms=2.0,axes=ax)
plt.savefig(waveform_folder+'/ISI.pdf',dpi=600)

In [None]:
sorting_unit_show(we, recording_cmr, sorting, pack_folder, '')

## By Electrode

In [None]:
extremum_channels_ids = st.get_template_extremum_channel(we, peak_sign='neg')
probe_groups = np.arange(0,32)
NumShanks = 32

recording_cmr.set_property('group', probe_groups, ids=None)
slice_recording = recording_cmr.split_by()
slice_sorting = []
slice_we = []
slice_unit_ids = []

for shank_id in np.arange(NumShanks):
    shank_channel_ids = np.where(probe_groups==shank_id)[0]
    shank_unit_ids = sorting.unit_ids[np.where(np.isin(np.array(list(extremum_channels_ids.values())), shank_channel_ids))[0]]

    print(f'Electrode:{shank_id+1}, units: {shank_unit_ids}')
    slice_unit_ids.append(shank_unit_ids)

    shank_sorting = sorting.select_units(unit_ids=shank_unit_ids, renamed_unit_ids=None)
    slice_sorting.append(shank_sorting)

    shank_waveform_folder = pack_folder + f'/waveforms_electrode{shank_id+1}'
    if(we_load_if_exists == False):
        if os.path.exists(shank_waveform_folder):
            shutil.rmtree(shank_waveform_folder)
 
    shank_recording = slice_recording[shank_id]
    shank_we = spikeinterface.extract_waveforms(shank_recording, shank_sorting, shank_waveform_folder,
        load_if_exists=we_load_if_exists,
        ms_before=1, ms_after=2., max_spikes_per_unit=1000000,
        n_jobs=1, chunk_size=30000)
    
    slice_we.append(shank_we)

# Analysis (By Day)

In [None]:
"""Curation by Day"""

we_load_if_exists = True

print(sorting)
print(recording_cmr)
for day_id in range(len(date_id_all)):
    
    data_folder_day = data_folder_all + date_id_all[day_id] + '/'
    pack_folder_day = pack_folder + date_id_all[day_id] + '/'
    if os.path.exists(data_folder_day)==False:
        os.mkdir(data_folder_day)
    
    curation_save_folder_day = pack_folder_day + 'curation_result_merged/'
    if os.path.exists(curation_save_folder_day)==False:
        os.mkdir(curation_save_folder_day)
    
    # Load sorting object
    firing_save_path_day = pack_folder_day + 'sorting/firings_merged.npz'
    sorting_day = se.NpzSortingExtractor(firing_save_path_day)
    print(sorting_day)
    
    # Load recording object
    recording_save_path_day = data_folder_day + 'recordings/'
    recording_day = spikeinterface.core.base.BaseExtractor.load_from_folder(recording_save_path_day)
    recording_f_day = st.preprocessing.bandpass_filter(recording_day, freq_min=freq_min, freq_max=freq_max)
    recording_cmr_day = st.preprocessing.common_reference(recording_f_day, reference='global', operator='average')
    print(recording_cmr_day)
    
    # Save down waveform object
    waveform_save_folder_day = pack_folder_day + 'waveforms_merged/'
    if(we_load_if_exists == False):
        if os.path.exists(waveform_save_folder_day):
            shutil.rmtree(waveform_save_folder_day)
    
    we_day = spikeinterface.extract_waveforms(recording_cmr_day, sorting_day, waveform_save_folder_day, 
                                              load_if_exists=we_load_if_exists, overwrite=False, ms_before=1, ms_after=2, 
                                              max_spikes_per_unit=1000000, n_jobs=1, chunk_size=30000)


# Plot autocorrelogram

In [None]:
"""Autocorrelogram with 20ms window"""

corr_bad_units = shank_autocorrelogram_show(slice_sorting, window_ms=20.0, bin_ms=0.1, threshold_ms=1, 
                                            symmetrize=True, neuron_id_rename=False, save_path=curation_save_folder)

In [None]:
"""Autocorrelogram with 100ms window"""

corr_bad_units = shank_autocorrelogram_show(slice_sorting, window_ms=100.0, bin_ms=0.1, threshold_ms=1, 
                                            symmetrize=True, neuron_id_rename=False, save_path=curation_save_folder)

# Plot correlogram

In [None]:
sw.plot_autocorrelograms(sorting, sorting.unit_ids)
plt.savefig(waveform_folder+'/autocorrelogram.pdf',dpi=600)

In [None]:
sw.plot_crosscorrelograms(sorting, sorting.unit_ids)

In [None]:
shank_correlogram_show([sorting], window_ms=200.0, bin_ms=1.0, threshold_ms=3, symmetrize=True, neuron_id_rename=False, save_path=curation_save_folder)

# window_ms = 100, 100, 50
# bin_ms = 1.0
# threshold_ms = 3


# Displacement

In [None]:
def location_cal(info, degree=2):
    """
    TBU
    
    Parameters
    --------------------------
    info: pandas dataframe with neuron_id, template
    degree:
    
    
    Returns
    --------------------------
    np.array(location_day): an array of the x and y positions of each unit
    
    """
    
    sensor_positions = np.array([[160,0],[80,0],[160,80],[80,80],[160,160],
                                 [80,160],[160,240],[80,240],[160,320],[80,320],
                                 [160,400],[80,400],[160,480],[80,480],[160,560],
                                 [80,560],[240,0],[0,0],[240,80],[0,80],
                                 [240,160],[0,160],[240,240],[0,240],[240,320],
                                 [0,320],[240,400],[0,400],[240,480],[0,480],
                                 [240,560],[0,560]])

    sensor_channels = np.arange(32)
    location_day = []

    for neuron_id in range(len(info)):
        
        info_unit = info.iloc[neuron_id] # gets the pd data for neuron_id
        template = info_unit['template'][:, sensor_channels] # gets the template at channel ids for the shank where neuron_id is located
        
        NumChannels = template.shape[1] # number of channels in the shank
        amplitudes = np.max(template,axis=0) - np.min(template,axis=0) # the peak-to-trough of the template at each channel
        
        x = np.sum(np.array([sensor_positions[i,0]*(amplitudes[i]**degree) for i in range(NumChannels)]))
        x /= np.sum(np.array([(amplitudes[i]**degree) for i in range(NumChannels)]))
        y = np.sum(np.array([sensor_positions[i,1]*(amplitudes[i]**degree) for i in range(NumChannels)]))
        y /= np.sum(np.array([(amplitudes[i]**degree) for i in range(NumChannels)]))

        location_day.append([x,y])
        
    return np.array(location_day)

def unit_position_plot(location_day0, location_day1, day0_name, day_follow_name, with_device_index=True, 
                       degree=2, colors = ['darkblue','red'], s=[100,80], linewidth=1, save_folder='./'):
    """
    TBU
    """
    
    sensor_channels = np.arange(32)    
    ShankNum = len(sensor_channels)
    
    fig, ax = plt.subplots(1, 1, figsize=(4, 6))

    mesh_probe = create_mesh_probe(32)
        
    plot_probe(mesh_probe, with_device_index=with_device_index, ax=ax)
    location_day0 = location_day0
    location_day1 = location_day1

    for neuron_id in range(location_day0.shape[0]):
        label = f'neuron{neuron_id+1}'
        ax.scatter(location_day0[neuron_id, 0], location_day0[neuron_id, 1], marker='.', s=s[0], color=colors[0], label=label)
    
    for neuron_id in range(location_day0.shape[0]):
        label = f'neuron{neuron_id+1}'
        ax.scatter(location_day1[neuron_id, 0], location_day1[neuron_id, 1], marker='.', s=s[1], color=colors[1], label=label)
        ax.plot([location_day0[neuron_id, 0], location_day1[neuron_id, 0]], [location_day0[neuron_id, 1], location_day1[neuron_id, 1]], c=colors[1])
         
    fig.suptitle(f'{day0_name} - {day_follow_name}')
    plt.tight_layout()
    plt.savefig(save_folder+f'Displacement_{day0_name}_{day_follow_name}_degree{degree}.pdf')

def unit_position_plot_oneday(location, day_name, unit_ids, with_device_index=True, color='gray',
                              s=[100,80], linewidth=1, save_folder='./'):
    """
    TBU
    """
    
    sensor_channels = np.arange(32)    
    ShankNum = len(sensor_channels)
    
    fig, ax = plt.subplots(1, 1, figsize=(4, 6))

    mesh_probe = create_mesh_probe(32)
        
    plot_probe(mesh_probe, with_device_index=with_device_index, ax=ax)

    for neuron_id in range(len(unit_ids)):
        label = f'neuron{neuron_id+1}'
        ax.scatter(location[neuron_id, 0], location[neuron_id, 1], marker='.', s=s[0], color=color, label=label)
        ax.annotate(unit_ids[neuron_id], (location[neuron_id, 0], location[neuron_id, 1]))
    
    fig.suptitle(f'{day_name}')
    plt.tight_layout()
    plt.savefig(save_folder+f'Unit_position_{day_name}.pdf')
    
def geometry_drift_plot(location_day0, location_day1, day0_name, day1_name, degree=2, levels=5,thresh=.3,
                        colors=['darkblue','red'], s=[100,80], lim=[-15,15], save_folder='./'):
    """
    TBU
    """
    location = {}
    location['drift_x'] = location_day1[:,0] - location_day0[:,0]
    location['drift_y'] = location_day1[:,1] - location_day0[:,1]
    fig, ax = plt.subplots(figsize=(4,4))
    sns.kdeplot(data=location, x=f'drift_x', y=f'drift_y', levels=levels, ax=ax,
                thresh=thresh, color='gray',zorder=-1)
    dift_vec =[[0,location[f'drift_x'].mean()],
              [0,location[f'drift_y'].mean()]]
    ax.plot(dift_vec[0],dift_vec[1],color=colors[1],zorder=1,linewidth=2)
    ax.scatter([0],[0],color=colors[0],zorder=2,s=s[0])
    ax.scatter([location[f'drift_x'].mean()],[location[f'drift_y'].mean()],color=colors[1],zorder=3,s=s[1])
    ax.set_xlim(lim)
    ax.set_ylim(lim)
    ax.set_title(f'geometry drift: {day0_name}-{day1_name}')
    plt.tight_layout()
    
    plt.savefig(save_folder+f'geometry_drift_{day0_name}_{day1_name}_degree{degree}.pdf')

In [None]:
"""Create info file"""

days = ['043023', '050123', '050223', '050323', '050423']

day_name_all = []
day_id_all = []
unit_id_all = []
template_all = []
waveform_all = []

for day_id, day_name in enumerate(date_id_all):
    
    data_folder_day = data_folder_all + day_name + '/'
    pack_folder_day = pack_folder + day_name + '/'
    slice_curated_ids = [np.nan]*NumShanks
    shank_ids = []
    
    """
    for shank_id in range(NumShanks):
        shank_curated_waveform_folder = pack_folder + f'/waveforms_electrode{shank_id+1}'
        if os.path.exists(shank_curated_waveform_folder)==True:
            slice_curated_we = WaveformExtractor.load_from_folder(shank_curated_waveform_folder)
            slice_curated_ids[shank_id] = slice_curated_we.sorting.unit_ids
            slice_shank_ids = shank_id*np.ones((len(slice_curated_ids[shank_id]),)).astype(int)
            shank_ids.append(slice_shank_ids)

    shank_ids = np.hstack(shank_ids)
    curated_ids = np.hstack(slice_curated_ids)
    curated_ids = curated_ids[~np.isnan(curated_ids)]
    """
    
    waveform_folder_day = pack_folder_day + 'waveforms_merged/'
    we_day = WaveformExtractor.load_from_folder(waveform_folder_day)
    curated_ids_day = we_day.sorting.unit_ids
    template_day = we_day.get_all_templates(unit_ids=curated_ids_day)
    
    for idx, unit_id in enumerate(curated_ids_day):

        day_name_all.append(day_name)
        day_id_all.append(day_id)
        unit_id_all.append(unit_id)
        template_all.append(template_day[idx,:,:])
        waveform_all.append(we_day.get_waveforms(unit_id=unit_id))

info = {'day_name': day_name_all, 'day_id': day_id_all, 'unit_id': unit_id_all, 'template': template_all, 'waveform': waveform_all}
info = pd.DataFrame(info)

In [None]:
"""Create displacement plots"""

for day0_id in range(len(date_id_all)-1):
    
    day0_name = date_id_all[day0_id]
    info_day0 = info.loc[info['day_id']==day0_id]
    
    unit_ids_day0 = info_day0['unit_id'].values # Holds the unit ids for day0
    template_day0 = info_day0['template'].values # Holds the unit templates for day0
    locations_day0 = location_cal(info_day0) # Calculates the location for day0
    print(day0_name)
    
    for day_follow_id in np.arange(day0_id, len(date_id_all)):
        
        day_follow_name = date_id_all[day_follow_id]
        info_day_follow = info.loc[info['day_id']==day_follow_id]
        
        unit_ids_day_follow = info_day_follow['unit_id'].values # Holds the neuron ids for day_follow
        template_day_follow = info_day_follow['template'].values # Holds the template for day_follow
        locations_day_follow = location_cal(info_day_follow) # Calculates the location fo day_follow?
        
        stable_unit_ids = np.intersect1d(unit_ids_day0, unit_ids_day_follow) # the neuron_id that are found in both days
        
        indices_day0,_ = ismember(unit_ids_day0, stable_unit_ids) # finds stable neurons in day0 
        indices_day_follow,_ = ismember(unit_ids_day_follow, stable_unit_ids) # finds stable neurons in day_follow
        locations_day0_stable = locations_day0[indices_day0,:] # gets locations of stable neurons in day0
        locations_day_follow_stable = locations_day_follow[indices_day_follow,:] # gets locations of stable neurons in day_follow

        # Unit position plot
        unit_position_plot(locations_day0_stable, locations_day_follow_stable,
                           day0_name, day_follow_name, with_device_index=False, degree=2,
                           colors=['darkblue','red'], s=[400,160], linewidth=2, 
                           save_folder=pack_folder)
        
        # Geometry drift plot
        geometry_drift_plot(locations_day0_stable, locations_day_follow_stable, day0_name, day_follow_name, 
                            degree=2, colors=['darkblue','red'], s=[100,60], lim=[-400,200], levels=5, 
                            thresh=0.2, save_folder=pack_folder)

# Unit Location (All Days)

In [None]:
day_name_concat = []
day_id_concat = []
unit_id_concat = []
template_concat = []

curated_ids = we.sorting.unit_ids
curated_templates = we.get_all_templates(unit_ids=curated_ids)

for idx, unit_id in enumerate(curated_ids):
    day_name_concat.append('All_Days')
    day_id_concat.append('All_Days')
    unit_id_concat.append(unit_id)
    template_concat.append(curated_templates[idx,:,:])

info_concat = {'day_name': day_name_concat, 'day_id': day_id_concat, 'unit_id': unit_id_concat, 'template': template_concat}
info_concat = pd.DataFrame(info_concat)

"""Location plots for all days"""

locations_concat = location_cal(info_concat)

unit_position_plot_oneday(locations_concat, 'All_Days', unit_ids = unit_id_concat, with_device_index=False, 
                          color='gray', s=[400,160], linewidth=2, save_folder=pack_folder)

"""Location plots for individual days"""

for day_id, day_name in enumerate(date_id_all):
    
    info_day_TEMP = info.loc[info['day_id']==day_id]
    locations_day_TEMP = location_cal(info_day_TEMP)
    
    unit_position_plot_oneday(locations_day_TEMP, day_name, unit_ids = info_day_TEMP['unit_id'].values, with_device_index=False, 
                              color='gray', s=[400,160], linewidth=2, save_folder=pack_folder)

"""-------------------------"""

fig, ax = plt.subplots(1,1,figsize=(20,20))
sw.plot_rasters(we.sorting, time_range=(0, np.sum(day_length)), ax=ax)

"""-------------------------"""
we_long = spikeinterface.extract_waveforms(recording_cmr, merged_sorting, waveform_save_folder+'_long', 
                                           load_if_exists=we_load_if_exists, ms_before=2, ms_after=2, 
                                           max_spikes_per_unit=1000000, n_jobs=1, chunk_size=30000)
we_long = we_long[0]

sorting_unit_show(we_long, recording_cmr, sorting, pack_folder, '_long')

# Waveform Evolution

In [None]:
def waveform_overlay_plot(info_day_select, unit_ids_select, y_scale_factor=0.8, x_scale_factor=1, 
                          ylim=[-100,200], y_displace=5, alpha_lim=[0.5,1],save_folder='./'):
    """
    TBU
    
    Parameters
    ----------------------------------
    info_day_select: a list of pandas dataframes. Each dataframe is an info pd
    unit_ids_select: a list of lists of unit_ids. 
    y_scale_factor:
    x_scale_factor: 
    ylim:
    y_displace:
    alpha_lim: think this is used for transparency of template plotting
    save_folder:
    
    Returns
    ----------------------------------
    FF
    
    """
    
    sensor_location = np.array([[160,0],[80,0],[160,80],[80,80],[160,160],
                                 [80,160],[160,240],[80,240],[160,320],[80,320],
                                 [160,400],[80,400],[160,480],[80,480],[160,560],
                                 [80,560],[240,0],[0,0],[240,80],[0,80],
                                 [240,160],[0,160],[240,240],[0,240],[240,320],
                                 [0,320],[240,400],[0,400],[240,480],[0,480],
                                 [240,560],[0,560]])

    sensor_channels = np.arange(32)
    
    cm = pylab.get_cmap('gist_rainbow')
    NUM_COLORS = len(unit_ids_select[0])
    colors = []
    for i in range(NUM_COLORS):
        colors.append(cm(1. * i / NUM_COLORS))
        
    alpha_s = np.arange(alpha_lim[0],alpha_lim[1]+(alpha_lim[1]-alpha_lim[0])/(len(unit_ids_select)-1),
                        (alpha_lim[1]-alpha_lim[0])/(len(unit_ids_select)-1))
        
    if os.path.exists(save_folder)==False:
        os.mkdir(save_folder)
    
    fig, axes = plt.subplots(len(unit_ids_select[0]), 1, figsize=(6,10*len(unit_ids_select[0])))
    save_name = 'waveform_similarity_overlay_'
    
    # Not sure what this does???  
    for align_id in range(len(unit_ids_select)):
        info_day = info_day_select[align_id] 
        day_name = info_day['day_name'].values[0]
        save_name = save_name + '_' + day_name
    
    # Outer loop is the unit_id, inner loop is the day (so plot each unit one by one)
    for unit_id_id, unit_id in enumerate(unit_ids_select[0]): # So unit_id_id is the count of the unit_id, and unit_id is the actual unit_id
        for align_id in range(len(unit_ids_select)): # align_id tracks the day; the length of unit_ids_select is the number of days
            
            info_day = info_day_select[align_id] # info_day is the info for the given day
            unit_ids = unit_ids_select[align_id] # unit_ids are the unit_ids for the given day
            day_name = info_day['day_name'].values[0] # this gets the day_name for the first element in info_day. Since all in info_day are same day, this is fine
            ax = axes[unit_id_id] # gets a handle on the appropriate subplot axis
            plot_channel_ids = sensor_channels # Get all channel ids
            
            try:
                template = info_day.loc[info_day['unit_id']==unit_id]['template'].values[0][:, plot_channel_ids] # for this day and this unit_id, get the templates at the channel_ids at this shank
                skip = 0
            except:
                # template = np.zeros(info_day.iloc[0]['template'][:, plot_channel_ids].shape)
                skip = 1
            
            tps = np.arange(template.shape[0]) # time points
            
            for idx in range(32):
                """
                x = (tps-template.shape[0]/3) * x_scale_factor + sensor_location[idx][0] # The x-position
                y = y_scale_factor * (template[:,idx] + sensor_location[idx][1] - y_displace*align_id) # The y-position
                ax.plot(x, y,linewidth=3,c=colors[unit_id_id], alpha=alpha_s[align_id]) # Plot this thing
                """
                if (skip==0):
                    x = tps*x_scale_factor + sensor_location[idx][0] # The x-position
                    y = y_scale_factor * (template[:,idx] - y_displace*align_id) + sensor_location[idx][1] # The y-position
                    ax.plot(x, y,linewidth=3,c=colors[unit_id_id], alpha=alpha_s[align_id]) # Plot this thing
        
        ax.set_title(f'unit{unit_id}')
        ax.set_ylim(ylim)
        ax.axis('equal')
        ax.axis('off')
        
    plt.savefig(save_folder+save_name+'.pdf')

## Waveform Evolution Per Electrode

In [None]:
info_day_select = []
unit_ids_select = []
        
for day_name in date_id_all:
    
    info_day = info.loc[info['day_name']==day_name]
    unit_ids = curated_ids
    
    info_day_select.append(info_day)
    unit_ids_select.append(unit_ids)

waveform_overlay_plot(info_day_select, unit_ids_select, ylim=[-500,200], y_scale_factor=0.35, 
                      x_scale_factor=1, y_displace=20, alpha_lim=[0.3,1], save_folder=pack_folder)

In [None]:
for unit_id in curated_ids:
    print('unit_id:', unit_id)
    print(info.loc[info['unit_id']==unit_id]['day_name'].values)

## Waveform Evolution Extremum

In [None]:
for day_id, day_name in enumerate(date_id_all):
    
    data_folder_day = data_folder_all + day_name + '/'
    pack_folder_day = pack_folder + day_name + '/'
    slice_curated_ids = [np.nan]*NumShanks
    shank_ids = []
    
    waveform_folder_day = pack_folder_day + 'waveforms_merged/'
    we_day = WaveformExtractor.load_from_folder(waveform_folder_day)
    curated_ids_day = we_day.sorting.unit_ids
    template_day = we_day.get_all_templates(unit_ids=curated_ids_day)
    
    sorting_unit_show(we_day, we_day.recording, we_day.sorting, pack_folder, f'_byday_{day_name}')

# Waveform Similarity

In [None]:
def waveform_similarity_cal(choosen_day, template_day0, template_day_follow):
    
    r_values = []
    p_values = []
    for idx, choosen_day_id in enumerate(choosen_day):
        
        template_x = np.reshape(template_day0[choosen_day_id[0]].T,(-1,))
        template_y = np.reshape(template_day_follow[choosen_day_id[1]].T,(-1,))
        
        r_value, p_value = pearsonr(template_x, template_y)
        
        r_values.append(r_value)
        p_values.append(p_value)

    return np.array(r_values), np.array(p_values)

In [None]:
r_values_within_all = []
r_values_across_all = []

# day0 is the first day of the pair
for day0_id, day0_name in enumerate(date_id_all[:-1]):
    info_day0 = info.loc[info['day_name']==day0_name]
    unit_ids_day0 = info_day0['unit_id'].values
    template_day0 = info_day0['template'].values

    # day1 is the second day of the pair
    for day1_name in date_id_all[(day0_id+1):]:
        info_day1 = info.loc[info['day_name']==day1_name]
        unit_ids_day1 = info_day1['unit_id'].values
        template_day1 = info_day1['template'].values
                
        print(type(template_day1))
        print(template_day1.shape)
        
        within_choosen = []
        across_choosen = []
        for unit_day0_count, unit_day0 in enumerate(unit_ids_day0):
            for unit_day1_count, unit_day1 in enumerate(unit_ids_day1):
                if unit_day0 == unit_day1:
                    within_choosen.append((unit_day0_count, unit_day1_count))
                else:
                    across_choosen.append((unit_day0_count, unit_day1_count))
        
        print(within_choosen)
        print(across_choosen)
        r_values_within, _ = waveform_similarity_cal(within_choosen, template_day0, template_day1)
        r_values_across, _ = waveform_similarity_cal(across_choosen, template_day0, template_day1)
        
        fig, ax = plt.subplots(figsize=(4,4))
        maxbin = 1
        minbin = -0.5
        binsize = 0.02
        bin_counts_day_within, bin_edges = np.histogram(r_values_within, 
                                                        bins=np.arange(minbin,maxbin+binsize,binsize), 
                                                        density=True)
        bin_counts_day_across, bin_edges = np.histogram(r_values_across, 
                                                        bins=np.arange(minbin,maxbin+binsize,binsize), 
                                                        density=True)

        line1 = ax.bar(bin_edges[:-1]+binsize/2,bin_counts_day_within*binsize,width=binsize,color='red', alpha=0.7)
        line2 = ax.bar(bin_edges[:-1]+binsize/2,bin_counts_day_across*binsize,width=binsize,color='gray', alpha=0.7)
        ax.legend([line1, line2],['Within units', 'Across units'])
        ax.set_title(f'{day0_name}-{day1_name}')
        ax.set_xlabel('Waveform similarity')
        ax.set_ylabel('Probability')
        
        r_values_within_all.append(r_values_within)
        r_values_across_all.append(r_values_across)

r_values_within_all = np.hstack(r_values_within_all)
r_values_across_all = np.hstack(r_values_across_all)

fig, ax = plt.subplots(figsize=(4,4))
maxbin = 1
minbin = -0.5
binsize = 0.02
bin_counts_day_within, bin_edges = np.histogram(r_values_within_all, bins=np.arange(minbin,maxbin+binsize,binsize), density=True)
bin_counts_day_across, bin_edges = np.histogram(r_values_across_all, bins=np.arange(minbin,maxbin+binsize,binsize), density=True)

line1 = ax.bar(bin_edges[:-1]+binsize/2,bin_counts_day_within*binsize,width=binsize,color='red', alpha=0.7)
line2 = ax.bar(bin_edges[:-1]+binsize/2,bin_counts_day_across*binsize,width=binsize,color='gray', alpha=0.7)
ax.legend([line1, line2],['Within units', 'Across units'])
ax.set_title('All days')
ax.set_xlabel('Waveform similarity')
ax.set_ylabel('Probability')

# UMAP

In [None]:
def dimensionality_reduction_cal(info_day_select, method='umap',  n_components=2, svd_solver='auto', nb_points = 20000, 
                                 umap_params={'n_neighbors': 20, 'random_state': 2, 'min_dist': 0.1, 'metric': 'euclidean'}, points_nb=10000):
    """
    TBU
    
    Parameters
    ------------------------------
    info_day_select:
    method:
    n_components:
    svd_solver:
    nb_points:
    umap_params:
    points_nb:
    
    Returns
    ------------------------------
    """
    
    sensor_location = np.array([[160,0],[80,0],[160,80],[80,80],[160,160],
                                [80,160],[160,240],[80,240],[160,320],[80,320],
                                [160,400],[80,400],[160,480],[80,480],[160,560],
                                [80,560],[240,0],[0,0],[240,80],[0,80],
                                [240,160],[0,160],[240,240],[0,240],[240,320],
                                [0,320],[240,400],[0,400],[240,480],[0,480],
                                [240,560],[0,560]])
    
    sensor_channels = np.arange(32)
    
    waveforms = []
    for day_id in range(len(info_day_select)):
        info_day = info_day_select[day_id]
        for unit_id in range(len(info_day)):
            waveform = info_day.iloc[unit_id]['waveform'][:points_nb,:,sensor_channels]
            waveforms.append(np.reshape(waveform, (waveform.shape[0], -1), order='F'))
            
    waveforms = np.vstack(waveforms)
    templates = np.mean(waveforms, axis=0)
    
    waveforms = []
    unit_ids = []
    day_ids = []
    shank_ids = []
    for day_id in range(len(info_day_select)):
        info_day = info_day_select[day_id]
        for unit_id_id, unit_id in enumerate(info_day['unit_id'].values):
            
            waveform = info_day.iloc[unit_id_id]['waveform'][:,:,sensor_channels]
            waveform = np.reshape(waveform, (waveform.shape[0], -1), order='F')
            select_indices = np.argsort(np.sqrt(np.sum(np.square(waveform-templates),axis=1)))[:nb_points]
            print(select_indices.shape)
            waveform = waveform[select_indices,:]
            waveforms.append(waveform)
            unit_ids.append(unit_id*np.ones((waveform.shape[0],)).astype(int))
            day_ids.append(day_id*np.ones((waveform.shape[0],)).astype(int))
            
    waveforms = np.vstack(waveforms)
    unit_ids = np.hstack(unit_ids)
    day_ids = np.hstack(day_ids)
    
    if(method=='pca'):
        pca_ = PCA(n_components=n_components, svd_solver=svd_solver).fit(waveforms)
        waveforms_2d = pca_.transform(waveforms)
    
    else:
        mapper_ = umap.UMAP(n_neighbors=umap_params['n_neighbors'],
                            random_state=umap_params['random_state'], 
                            min_dist=umap_params['min_dist'], 
                            n_components=n_components, 
                            metric=umap_params['metric']).fit(waveforms)
        waveforms_2d = mapper_.transform(waveforms)
    
    data = {'waveform_pc1': waveforms_2d[:,0], 'waveform_pc2': waveforms_2d[:,1], 
            'day_id': day_ids, 'unit_id': unit_ids}
    df = pd.DataFrame(data=data)
    
    return df

def dimensionality_reduction_stability_plot(df, method='umap', figsize=(30, 20), shank_displace_factor=10, 
                                            day_displace_factor=20, distance=2, nb_points=20000, 
                                            azim = 270, elev = 15, dist = 7, colors=None, z_scale=1):
    """
    TBU
    """
    
    sensor_location = np.array([[160,0],[80,0],[160,80],[80,80],[160,160],
                                 [80,160],[160,240],[80,240],[160,320],[80,320],
                                 [160,400],[80,400],[160,480],[80,480],[160,560],
                                 [80,560],[240,0],[0,0],[240,80],[0,80],
                                 [240,160],[0,160],[240,240],[0,240],[240,320],
                                 [0,320],[240,400],[0,400],[240,480],[0,480],
                                 [240,560],[0,560]])

    sensor_channels = np.arange(32)
    
    # Choose label based on our dimension reduction method
    if(method=='pca'):
        labels=['PC 1', 'PC 2']
    else:
        labels=['UMAP 1', 'UMAP 2']
   
    # Set color scheme
    if(colors is None):
        cm = pylab.get_cmap('rainbow')
        colors = []
        NUM_COLORS = len(np.unique(df['unit_id'].values))
        for i in range(NUM_COLORS):
            colors.append(cm(1. * i / NUM_COLORS))
    
    day_ids = np.unique(df['day_id'].values) # The unique day_ids
    day_displace = np.arange(len(day_ids))*day_displace_factor

    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111, projection='3d')

    color_displace = 0
    
    unit_ids = np.unique(df['unit_id'].values) # The unique unit_ids
    centroids_all = {}
    for unit_id_id, unit_id in enumerate(unit_ids):
        df_unit = df.loc[df['unit_id']==unit_id]
        waveform_umap = df_unit[['waveform_pc1','waveform_pc2']].values
        centroid = np.mean(waveform_umap,axis=0)
        centroids_all[str(unit_id)] = centroid
    
    centroids_all_day = []
    for day_id_id, day_id in enumerate(day_ids):
        df_day = df.loc[df['day_id']==day_id] # info for a specific day_id
        unit_ids = np.unique(df_day['unit_id'].values) # the unique unit_ids for that day
        centroids = {}
        for unit_id_id, unit_id in enumerate(unit_ids):
            df_unit = df_day.loc[df_day['unit_id']==unit_id]
            waveform_umap = df_unit[['waveform_pc1','waveform_pc2']].values
            centroid_all = centroids_all[str(unit_id)]
            distance_matrix = np.linalg.norm(waveform_umap - centroid_all, axis=1)
            ind_points1 = np.random.permutation(len(distance_matrix))#[:2*nb_points]
            ind_points = distance_matrix[ind_points1] < distance*2
            waveform_umap_ = waveform_umap[ind_points1,:][ind_points,:]
            centroid = np.mean(waveform_umap_,axis=0)

            distance_matrix = np.linalg.norm(waveform_umap_ - centroid, axis=1)
            #keep closest points to centroid
            ind_points1 = np.random.permutation(len(distance_matrix))#[:nb_points]
            ind_points = distance_matrix[ind_points1] < distance
            waveform_umap_ = waveform_umap_[ind_points1,:][ind_points,:]

            centroid_new = np.mean(waveform_umap_,axis=0)
            centroids[str(unit_id)] = centroid_new
            color_plot = colors[unit_id_id]

            ax.scatter(waveform_umap_[:,1], waveform_umap_[:,0], 
                       day_displace[day_id_id]*np.ones((waveform_umap_.shape[0],))+0.1,
                       edgecolor= 'grey', s=10, linewidth=0.1,alpha=0.5,
                       color=color_plot, zorder=-1)
        
        centroids_all_day.append(centroids)
    
    for day_id_id, day_id in enumerate(day_ids):
        
        if day_id_id == 0:
            continue
        
        df_day_curr = df.loc[df['day_id']==day_id] # info for a specific day_id
        
        # These are the unit_ids for the current day
        centroids_unit_curr = np.unique(df_day_curr['unit_id'].values)

        # Right now these are all dictionaries where the key is the unit_id of the current day, the value is the centroid of the current day
        centroids_plot_curr = centroids_all_day[day_id_id].copy()
        centroids_plot_past = centroids_all_day[day_id_id].copy()
        
        centroids_daycount = centroids_all_day[day_id_id].copy()
        centroids_daycount = dict.fromkeys(centroids_daycount.keys(), day_id_id)
        
        print(centroids_daycount)
        
        # print(list(centroids_plot_curr.values())[0][0])
        
        # Now we look for the last time these unit_ids occurred, and their day and centroid
        for unit_id_id, unit_id in enumerate(centroids_unit_curr):
            for day_id_id_past, day_id_past in enumerate(day_ids[:day_id_id]):
                if unit_id in df.loc[df['day_id']==day_id_past]['unit_id'].values:
                    centroids_plot_past[str(unit_id)] = centroids_all_day[day_id_id_past][str(unit_id)]
                    centroids_daycount[str(unit_id)] = day_id_id_past

        centroids_plot_curr = list(centroids_plot_curr.values())
        centroids_plot_past = list(centroids_plot_past.values())
        centroids_daycount = list(centroids_daycount.values())
        
        print(centroids_plot_curr)
        print(centroids_plot_past)
        print(centroids_daycount)
        
        for j in range(len(centroids_plot_curr)):
            color = colors[int(j+color_displace)]
            ax.plot([centroids_plot_past[j][1], centroids_plot_curr[j][1]],
                    [centroids_plot_past[j][0], centroids_plot_curr[j][0]],
                    [day_displace[centroids_daycount[j]]+1, day_displace[day_id_id]+1],
                    color=color,
                    marker='o',
                    markersize=5,
                    markerfacecolor='None',
                    markeredgewidth=1,
                    markeredgecolor='black',
                    alpha=1.,
                    linewidth=3,
                    zorder=1000)
    
    color_displace += len(unit_ids)
    
    ax.set_xticks([])
    ax.set_yticks([])
    x_label = 'umap2'
    y_label = 'umap1'
    z_label = 'Mouse age (months)'
    x_scale, y_scale = 1, 1
    """Adjust these lines for axis angle"""
    # ax = set_ax_style(ax, x_label, y_label, z_label, x_scale, y_scale, z_scale)
    ax.azim = azim
    ax.elev = elev
    ax.dist = dist
   
    return ax

def set_ax_style(ax, x_label, y_label, z_label,
                x_scale, y_scale, z_scale, background=False):
    ax.set_xlabel(x_label, labelpad=20)
    ax.set_ylabel(y_label, labelpad=20)
    ax.set_zlabel(z_label, labelpad=20)
    if not background:
        #set background to white
        ax.xaxis.pane.set_edgecolor('w')
        ax.yaxis.pane.set_edgecolor('w')
        ax.zaxis.pane.set_edgecolor('w')
        ax.spines['right'].set_visible(False)
        ax.spines['top'].set_visible(False)
        ax.spines['left'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.axis('off')
        #ax.yaxis('off')
        #ax.zaxis('off')
        ax.xaxis.pane.fill = False
        ax.yaxis.pane.fill = False
        ax.zaxis.pane.fill = False
    #scale
    scale=np.diag([x_scale, y_scale, z_scale, 1.0])
    scale=scale*(1.0/scale.max())
    scale[3,3]=1.0 
    def short_proj():
        return np.dot(Axes3D.get_proj(ax), scale)
    ax.get_proj=short_proj
    return ax

In [None]:
df = dimensionality_reduction_cal(info_day_select, method='umap',  n_components=2, 
                                  umap_params={'n_neighbors': 20, 'random_state': 2, 'min_dist': 0.1, 'metric': 'euclidean'})

df.to_csv(pack_folder+'dimensionality_reduction_info.csv')

In [None]:
df= pd.read_csv(pack_folder+'dimensionality_reduction_info.csv')

cm = pylab.get_cmap('gist_rainbow')
NUM_COLORS = max(np.unique(df['unit_id'].values)) + 1
colors = []
for i in range(NUM_COLORS):
    colors.append(cm(1. * i / NUM_COLORS))
colors = np.array(colors)

In [None]:
colors_plot = colors[np.unique(df['unit_id'].values)]
ax = dimensionality_reduction_stability_plot(df, method='umap', figsize=(20, 20), shank_displace_factor=10, 
                                             day_displace_factor=40, distance=1, nb_points=20000, 
                                             azim = 270, elev = 15, dist = 7, colors=colors_plot, z_scale=2)
plt.savefig(pack_folder + f'umap_across_days.pdf', dpi=300)

In [None]:
colors_plot = colors[np.unique(df['unit_id'].values)]
ax = dimensionality_reduction_stability_plot(df, method='umap', figsize=(20, 20), shank_displace_factor=10, 
                                             day_displace_factor=40, distance=np.inf, nb_points=np.exp(100), 
                                             azim = 270, elev = 15, dist = 7, colors=colors_plot, z_scale=2)
plt.savefig(pack_folder + f'umap_across_days_not_confined.pdf', dpi=300)

# Waveform Metrics

In [None]:
def feature_cal(slice_curated_we, quality_scores_cal=True, points_nb=None, 
                umap_params = {'n_neighbors': 20, 'random_state': 2, 'min_dist': 0.1, 'metric': 'euclidean', 'n_components': 2}):
    """
    TBU
    """
    
    probe_groups = np.arange(0,32)
    qc_metric_names=['snr', 'firing_rate']
    we_feature_names=['peak_to_valley','peak_trough_ratio', 'halfwidth', 'repolarization_slope','recovery_slope']
    pc = st.compute_principal_components(slice_curated_we, load_if_exists=False, n_components=2, mode='by_channel_local')
    qc_metrics = st.compute_quality_metrics(slice_curated_we, metric_names=qc_metric_names)
    unit_ids = slice_curated_we.sorting.unit_ids
    extremum_channel_ids = st.get_template_extremum_channel(slice_curated_we, peak_sign='neg')
    sampling_frequency = slice_curated_we.sorting.get_sampling_frequency()
    
    templates = []
    amplitudes = []
    waveforms = []
    waveforms_all_channel = []
    waveform_labels = []
    shank_ids = []
    for unit_id in unit_ids:
        extremum_channel_id = extremum_channel_ids[unit_id]
        shank_id = probe_groups[extremum_channel_id]
        waveform = slice_curated_we.get_waveforms(unit_id=unit_id)
        template = slice_curated_we.get_template(unit_id=unit_id)
        amplitude = np.max(template,axis=0) - np.min(template,axis=0)
        select_channels = np.where(probe_groups==shank_id)[0]
        amplitude = amplitude[extremum_channel_id]
        waveform_shank = waveform[:,:,select_channels]
        waveform_mean = np.mean(waveform_shank,axis=0)

        if(points_nb is not None):
            select_indices = np.argsort(np.sum(np.sum(np.square((waveform_shank - waveform_mean)),
                                                      axis=1),axis=1))[:points_nb]
            waveform = waveform[select_indices,:,:]
        
        waveform_all_channel = np.reshape(waveform[:,:,select_channels], (waveform.shape[0],-1), order='F')
        waveform = waveform[:,:,extremum_channel_id]
        templates.append(template)
        amplitudes.append(amplitude)
        waveforms.append(waveform)
        waveforms_all_channel.append(waveform_all_channel)
        waveform_labels.append(np.ones((waveform.shape[0],))*unit_id)
        shank_ids.append(np.ones((waveform.shape[0],))*shank_id)

    templates = np.vstack(templates)
    waveforms = np.vstack(waveforms)
    waveforms_all_channel = np.vstack(waveforms_all_channel)
    waveform_labels = np.hstack(waveform_labels)
    shank_ids = np.hstack(shank_ids)
    
    we_metrics = features_5(waveforms, sampling_frequency, feature_names=we_feature_names)
    
    metrics = {}
    for we_feature_name in we_feature_names:
        feature_neurons = []
        feature = we_metrics[we_feature_name]
        for unit_id in unit_ids:
            feature_neuron = feature[waveform_labels==unit_id]
            feature_neuron_good = feature_neuron[~np.isnan(feature_neuron)]
            feature_neuron_good_mean  = np.mean(feature_neuron_good)
            feature_neuron_good_std = np.std(feature_neuron_good)
            right_indices = np.where(((feature_neuron_good<feature_neuron_good_mean+3*feature_neuron_good_std)*1 + 
                                     (feature_neuron_good>feature_neuron_good_mean-3*feature_neuron_good_std)*1)>1)[0]
            feature_neuron_good = np.mean(feature_neuron_good[right_indices])
            feature_neurons.append(feature_neuron_good)
        
        metrics[we_feature_name] = feature_neurons
        
    metrics['amplitude'] = np.array(amplitudes)
    metrics['snr'] = qc_metrics['snr'].values
    metrics['firing_rate'] = qc_metrics['firing_rate'].values
    
    if(quality_scores_cal==True):
        metrics['l_ratio'] = np.empty(unit_ids.shape)*np.nan
        metrics['sil_scores'] = np.empty(unit_ids.shape)*np.nan
        for shank_id in np.unique(shank_ids):
            idx = np.where(shank_ids==shank_id)[0]
            waveform_labels_shank = waveform_labels[idx]
            unit_ids_shank = np.unique(waveform_labels_shank)
            if(len(unit_ids_shank)>1):
#                 waveforms_shank = waveforms_all_channel[idx,:]
                unit_smallest_spike_num = np.min(np.array([np.sum(waveform_labels_shank==unit_id) 
                                                           for unit_id in unit_ids_shank]))
                waveforms_shank = np.vstack([waveforms_all_channel[waveform_labels==unit_id][:unit_smallest_spike_num]
                                            for unit_id in unit_ids_shank])
                waveform_labels_shank = np.hstack([waveform_labels[waveform_labels==unit_id][:unit_smallest_spike_num]
                                            for unit_id in unit_ids_shank])


                mapper_ = umap.UMAP(n_neighbors=umap_params['n_neighbors'],
                                            random_state=umap_params['random_state'], 
                                            min_dist=umap_params['min_dist'],
                                            n_components=umap_params['n_components'],
                                            metric=umap_params['metric']).fit(waveforms_shank)
                shank_waveforms_2d = mapper_.transform(waveforms_shank)

                sil_scores = silhouette_score(shank_waveforms_2d, waveform_labels_shank)

                for unit_id in unit_ids_shank:
                    pcs_for_this_unit = shank_waveforms_2d[waveform_labels_shank == unit_id,:]
                    pcs_for_other_units = shank_waveforms_2d[waveform_labels_shank != unit_id, :]
                    mean_value = np.expand_dims(np.mean(pcs_for_this_unit,0),0)
                    VI = np.linalg.inv(np.cov(pcs_for_this_unit.T))\
                    
                    mahalanobis_other = np.sort(cdist(mean_value, pcs_for_other_units, 'mahalanobis', VI = VI)[0])
                    mahalanobis_self = np.sort(cdist(mean_value, pcs_for_this_unit, 'mahalanobis', VI = VI)[0])

                    n = np.min([pcs_for_this_unit.shape[0], pcs_for_other_units.shape[0]]) # number of spikes
                    if n >= 2:
                        dof = pcs_for_this_unit.shape[1] # number of features
                        l_ratio = np.sum(1 - chi2.cdf(pow(mahalanobis_other,2), dof)) / \
                                mahalanobis_self.shape[0] # normalize by size of cluster, not number of other spikes
                    else:
                        l_ratio = np.nan

                    idx_unit = np.where(unit_ids==unit_id)[0]
                    metrics['l_ratio'][idx_unit] = l_ratio

                _, idx_unit = ismember(unit_ids_shank, unit_ids)
                metrics['sil_scores'][idx_unit] = np.ones((len(unit_ids_shank),))*sil_scores
    
    return metrics

def feature_dist_cal(slice_curated_we, quality_scores_cal=True, points_nb=None, 
                     umap_params = {'n_neighbors': 20, 'random_state': 2, 'min_dist': 0.1, 'metric': 'euclidean', 'n_components': 2}, 
                     save_folder='./'):
    """
    TBU
    """
    
    probe_groups = np.arange(0,32)
    we_feature_names=['peak_to_valley','peak_trough_ratio', 'halfwidth', 'repolarization_slope','recovery_slope']
    unit_ids = slice_curated_we.sorting.unit_ids
    extremum_channel_ids = st.get_template_extremum_channel(slice_curated_we, peak_sign='neg')
    sampling_frequency = slice_curated_we.sorting.get_sampling_frequency()
    
    templates = []
    amplitudes = []
    waveforms = []
    waveforms_all_channel = []
    waveform_labels = []
    shank_ids = []
    
    for unit_id in unit_ids:
        extremum_channel_id = extremum_channel_ids[unit_id]
        shank_id = probe_groups[extremum_channel_id]
        waveform = slice_curated_we.get_waveforms(unit_id=unit_id)
        template = slice_curated_we.get_template(unit_id=unit_id)
        amplitude = np.max(template,axis=0) - np.min(template,axis=0)
        select_channels = np.where(probe_groups==shank_id)[0]
        amplitude = amplitude[extremum_channel_id]
        waveform_shank = waveform[:,:,select_channels]
        waveform_mean = np.mean(waveform_shank,axis=0)

        if(points_nb is not None):
            select_indices = np.argsort(np.sum(np.sum(np.square((waveform_shank - waveform_mean)),
                                                      axis=1),axis=1))[:points_nb]
            waveform = waveform[select_indices,:,:]
        
        waveform_all_channel = np.reshape(waveform[:,:,select_channels], (waveform.shape[0],-1), order='F')
        waveform = waveform[:,:,extremum_channel_id]
        templates.append(template)
        amplitudes.append(amplitude)
        waveforms.append(waveform)
        waveforms_all_channel.append(waveform_all_channel)
        waveform_labels.append(np.ones((waveform.shape[0],))*unit_id)
        shank_ids.append(np.ones((waveform.shape[0],))*shank_id)

    templates = np.vstack(templates)
    waveforms = np.vstack(waveforms)
    waveforms_all_channel = np.vstack(waveforms_all_channel)
    waveform_labels = np.hstack(waveform_labels)
    shank_ids = np.hstack(shank_ids)
    
    we_metrics = features_5(waveforms, sampling_frequency, feature_names=we_feature_names)
    
    for we_feature_name in we_feature_names:
        feature_neurons = []
        feature = we_metrics[we_feature_name]
        
        fig, axs = plt.subplots(int(np.ceil(len(sorting.unit_ids)/4)), 4, figsize=(20, 5*np.ceil(len(sorting.unit_ids)/4)))
        
        for i, unit_id in enumerate(unit_ids):
            feature_neuron = feature[waveform_labels==unit_id]
            feature_neuron_good = feature_neuron[~np.isnan(feature_neuron)]
            feature_neuron_good_mean  = np.mean(feature_neuron_good)
            feature_neuron_good_std = np.std(feature_neuron_good)
            right_indices = np.where(((feature_neuron_good<feature_neuron_good_mean+3*feature_neuron_good_std)*1 + 
                                     (feature_neuron_good>feature_neuron_good_mean-3*feature_neuron_good_std)*1)>1)[0]
            feature_neuron_good = feature_neuron_good[right_indices]
            
            if int(np.ceil(len(unit_ids)/4)) > 1:
                ax = axs[int(np.floor(i/4)), int(np.mod(i,4))]
                ax.hist(feature_neuron_good, bins=20)
                ax.set_title(f'{we_feature_name} - Unit {unit_id}')
            else:
                ax = axs[int(np.mod(i,4))]
                ax.hist(feature_neuron_good, bins=20)
                ax.set_title(f'{we_feature_name} - Unit {unit_id}')
        
        plt.savefig(save_folder+'/waveform_metric_' + we_feature_name + '.pdf',dpi=600)
    
    return None

In [None]:
feature_names_all = ['snr', 'firing_rate', 'amplitude', 'peak_to_valley','peak_trough_ratio', 'halfwidth', 'repolarization_slope','recovery_slope', 'sil_scores', 'l_ratio']
quality_scores_cal = False
info_load_if_exists = False

In [None]:
"""Calculate waveform features for all-days"""

if(quality_scores_cal==True):
    feature_names = feature_names_all
else:
    feature_names = feature_names_all[:-2]

# Calculate the 'average' metrics for units (all-days)
umap_params = {'n_neighbors': 10, 'random_state': 2, 'min_dist': 0.05, 'metric': 'euclidean', 'n_components': 10}
metrics = feature_cal(we, quality_scores_cal=quality_scores_cal, points_nb=5000, umap_params=umap_params)

# Create histograms of select metrics for units (all-days)
feature_dist_cal(we, quality_scores_cal=quality_scores_cal, points_nb=5000, umap_params=umap_params, save_folder=pack_folder)

In [None]:
"""Calculate waveform features for individual days"""

metrics_coll = []

# Calculate the 'average' metrics for units (individual days)
umap_params = {'n_neighbors': 10, 'random_state': 2, 'min_dist': 0.05, 'metric': 'euclidean', 'n_components': 10}
for day_id, day_name in enumerate(date_id_all):
    
    # Get waveform object for the day
    data_folder_day = data_folder_all + day_name + '/'
    pack_folder_day = pack_folder + day_name + '/'
    waveform_folder_day = pack_folder_day + 'waveforms_merged/'
    we_day = WaveformExtractor.load_from_folder(waveform_folder_day)
    
    # Calculate 'average' metrics for the units in this day
    metrics_day = feature_cal(we_day, quality_scores_cal=quality_scores_cal, points_nb=5000, umap_params=umap_params)
    metrics_coll.append(metrics_day)
    
    # Create distribution of metrics for units in this day 
    feature_dist_cal(we_day, quality_scores_cal=quality_scores_cal, points_nb=5000, umap_params=umap_params, save_folder=pack_folder_day)

# Spectrogram

In [None]:
trace_cmr = recording_cmr.get_traces().T
sample_rate = recording_cmr.get_sampling_frequency()

In [None]:
"""Computes power spectrum using Welch's method to reduce noise"""

fig, axs = plt.subplots(8, 4, figsize=(20, 40))

for i in range(trace_cmr.shape[0]):    
    ax = axs[i//4, i%4] # iterate through subplots
    
    # Compute the power spectrum using Welch's method
    f, Pxx_spec = signal.welch(trace_cmr[i], sample_rate, 'flattop', 1024, scaling='spectrum')
    
    ax.semilogy(f, np.sqrt(Pxx_spec))
    ax.set_title(f'Shank {i+1}')

In [None]:
"""Computes power spectrum using Welch's method to reduce noise (individual days)"""

for day_id in range(len(date_id_all)):
    
    data_folder_day = data_folder_all + date_id_all[day_id] + '/'
    pack_folder_day = pack_folder + date_id_all[day_id] + '/'
    
    # Load recording object
    recording_save_path_day = data_folder_day + 'recordings/'
    recording_day = spikeinterface.core.base.BaseExtractor.load_from_folder(recording_save_path_day)
    recording_f_day = st.preprocessing.bandpass_filter(recording_day, freq_min=freq_min, freq_max=freq_max)
    recording_cmr_day = st.preprocessing.common_reference(recording_f_day, reference='global', operator='average')
    
    # Plot spectograms
    trace_cmr_day = recording_cmr_day.get_traces().T
    sample_rate_day = recording_cmr_day.get_sampling_frequency()
    
    fig, axs = plt.subplots(8, 4, figsize=(20, 40))

    for i in range(trace_cmr_day.shape[0]):    
        ax = axs[i//4, i%4] # iterate through subplots

        # Compute the power spectrum using Welch's method
        f, Pxx_spec = signal.welch(trace_cmr_day[i], sample_rate_day, 'flattop', 1024, scaling='spectrum')

        ax.plot(f, np.sqrt(Pxx_spec))
        ax.set_title(f'Shank {i+1}')
    
    plt.savefig(pack_folder_day+'/power_spectrum.pdf',dpi=600)