In [None]:
%load_ext autoreload
%autoreload 2


import os 
import numpy as np
import pandas as pd
#Hack if downloading file and not pip install
current_dir = os.getcwd()
BombCell_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
os.chdir(BombCell_dir)
import bombcell.extract_raw_waveforms as erw
import bombcell.load_ephys_data as led
import bombcell.default_parameters as params
import matplotlib.pyplot as plt
import bombcell.quality_metrics as qm
import time
import bombcell.helper_functions as hf
from pathlib import Path
import bombcell.save_utils as su

In [43]:
ks_dir = r'c:\Users\Experiment\Data\BC_dev_data\JF093_2023-03-06_site1\kilosort2\site1'
raw_dir = r'c:\Users\Experiment\Data\BC_dev_data\JF093_2023-03-06_site1\site1'
save_path = r'c:\Users\Experiment\Data\BC_dev_data\JF093_2023-03-06_site1\Python results'

#If a raw data directory with a meta folder is not given,
#please input the gain manually
gain_to_uV = np.nan 

In [None]:
spike_times_samples, spike_templates, template_waveforms, template_amplitudes, \
           pc_features, pc_features_idx, channel_positions, good_channels = led.load_ephys_data(ks_dir)

pc_features = pc_features.squeeze()
if raw_dir != None:
    ephys_raw_data, meta_path = erw.manage_data_compression(raw_dir, decompressed_data_local = raw_dir)
    gain_to_uV = led.get_gain_spikeglx(meta_path)

In [45]:
param = params.default_parameters(ks_dir, ephys_raw_data, ephys_meta_dir = meta_path)

In [None]:
max_channels = qm.get_waveform_max_channel(template_waveforms)

quality_metrics = {}
GUI_data = {}
quality_metrics['max_channels'] = max_channels

#re extract or load in raw waveforms
raw_waveforms_full, raw_waveforms_peak_channel, SNR = erw.extract_raw_waveforms(
                    param, spike_templates.squeeze(), spike_times_samples.squeeze(), param['re_extract_raw'], save_path)

In [7]:
non_empty_units, duplicate_spike_idx, spike_times_samples, spike_templates, template_amplitudes, \
pc_features, raw_waveforms_full, raw_waveforms_peak_channel, signal_to_noise_ratio, max_channels= \
    qm.remove_duplicate_spikes(spike_times_samples, spike_templates, template_amplitudes, max_channels, save_path, param,
            pc_features = pc_features, raw_waveforms_full = raw_waveforms_full, raw_waveforms_peak_channel = raw_waveforms_peak_channel, 
            signal_to_noise_ratio = SNR)

pc_features = pc_features.squeeze()
# time chunks
spike_times_seconds = spike_times_samples / param['ephys_sample_rate']
if param['compute_time_chunks']:
    time_chunks = np.arange(np.min(spike_times_seconds), np.max(spike_times_seconds), param['delta_time_chunk'])
else:
    time_chunks = np.array((np.min(spike_times_seconds), np.max(spike_times_seconds)))

#should be got as part of removing duplicate spikes!!! 
unique_templates = np.unique(spike_templates) 

In [None]:

n_units = unique_templates.size
quality_metrics = hf.create_quality_metrics_dict(n_units, snr = SNR)
quality_metrics['max_channels'] = max_channels
param['use_hill_method'] = True # use the old method for RPVs
param['compute_time_chunks'] = False

quality_metrics, times = hf.get_all_quality_metrics(unique_templates, spike_times_seconds, spike_templates, template_amplitudes, time_chunks,
                            pc_features, pc_features_idx, quality_metrics, raw_waveforms_full, channel_positions, template_waveforms, param)



In [9]:
# classify noise
nan_result = np.isnan(quality_metrics['n_peaks'])

too_many_peaks = quality_metrics['n_peaks']  > param['max_n_peaks']

too_many_troughs = quality_metrics['n_troughs'] > param['max_n_troughs']

too_short_waveform = quality_metrics['waveform_duration_peak_trough'] < param['min_wave_duration']

too_long_waveform = quality_metrics['waveform_duration_peak_trough'] > param['max_wave_duration']

too_noisy_baseline = quality_metrics['waveform_baseline'] > param['max_wave_baseline_fraction']

##
too_shallow_decay =quality_metrics['exp_decay'] > param['min_spatial_decay_slope']
to_steap_decay = quality_metrics['exp_decay'] < param['max_spatial_decay_slope']
# classify as mua
#ALL or ANY?

too_few_total_spikes = quality_metrics['n_spikes'] < param['min_num_spikes_total']

too_many_spikes_missing = quality_metrics['percent_missing_gaussian'] > param['max_perc_spikes_missing']

too_low_presence_ratio = quality_metrics['presence_ratio'] < param['min_presence_ratio']

too_many_RPVs = quality_metrics['fraction_RPVs']> param['max_RPV']

if param['extract_raw_waveforms']:
    too_small_amplitude = quality_metrics['raw_amplitude'] < param['min_amplitude'] 

    too_small_SNR =  quality_metrics['signal_to_noise_ratio'] < param['min_SNR'] 

if param['compute_drift']:
    too_large_drift = quality_metrics['max_drift_estimate'] > param['max_drift']

# determine if ALL unit is somatic or non-somatic
param['non_somatic_trough_peak_ratio'] = 1.25
param['non_somatic_peak_before_to_after_ratio'] = 1.2
#somatic == 0, non_somatic == 1
is_somatic = np.ones(unique_templates.size)

is_somatic[(quality_metrics['trough'] / np.max((quality_metrics['main_peak_before'] , quality_metrics['main_peak_after']), axis = 0)) < param['non_somatic_trough_peak_ratio']] = 0

is_somatic[(quality_metrics['main_peak_before'] / quality_metrics['main_peak_after'])  > param['non_somatic_peak_before_to_after_ratio']] = 0

is_somatic[(quality_metrics['main_peak_before'] * param['first_peak_ratio'] > quality_metrics['main_peak_after']) & (quality_metrics['width_before'] < param['min_width_first_peak']) \
    & (quality_metrics['main_peak_before'] * param['min_main_peak_to_trough_ratio'] > quality_metrics['trough']) & (quality_metrics['trough_width'] < param['min_width_main_trough'])] = 0


#is_somatic[np.isnan(quality_metrics['trough'])] = 0
quality_metrics['is_somatic_new'] = is_somatic

not_somatic = is_somatic == 1

In [10]:
unit_type, unit_type_string = hf.get_quality_unit_type(param, quality_metrics)

In [11]:
qm_table_array = np.array((nan_result, too_many_peaks, too_many_troughs, too_short_waveform, too_long_waveform, too_noisy_baseline, too_shallow_decay, \
                           too_few_total_spikes, too_many_spikes_missing, too_many_RPVs, too_low_presence_ratio, not_somatic))

qm_table_array = np.vstack((qm_table_array, unit_type))
qm_table_array = np.vstack((unique_templates, qm_table_array))
#DO this for the optional params
qm_table = pd.DataFrame(qm_table_array, index = ['Original ID', 'NaN result', 'Peaks', 'Troughs', 'Waveform Min Length', 'Waveform Max Length', 'Baseline', 'Spatial Decay', \
                                                 'Min Spikes', 'Missing Spikes', 'RPVs', 'Presence Ratio', 'Not Somatic', 'Good Unit']).T


In [None]:
qm_table

In [40]:
su.save_results(quality_metrics, unit_type_string, unique_templates, param, raw_waveforms_full, raw_waveforms_peak_channel, save_path)