# Imports

In [1]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# import warnings
# warnings.filterwarnings('ignore')

In [2]:
#Add bombcell to Python path if not installed with pip
demo_dir = Path(os.getcwd())
pyBombCell_dir = demo_dir.parent
sys.path.append(str(pyBombCell_dir))

In [3]:
%load_ext autoreload
%autoreload 2

import bombcell as bc

# Define data paths

By default: path to BombCell's toy dataset

In [4]:
ks_dir = demo_dir / 'toy_data'  # Replace with your kilosort directory
raw_dir = None  # Leave 'None' if no raw data; eventually replace with path to your raw data
save_path = "~/Downloads/bombcell_plots"  # ~ is home directory, / work on Windows

# If a raw data directory with a meta folder is not given,
# please input the gain manually
gain_to_uV = np.nan

In [5]:
spike_times_samples, spike_templates, template_waveforms, template_amplitudes, \
           pc_features, pc_features_idx, channel_positions, good_channels = bc.load_ephys_data(ks_dir)

#ephys_raw_data and gain_to_uv will be None if no raw_dir given
ephys_raw_data, meta_path, gain_to_uV = bc.manage_if_raw_data(raw_dir)


In [6]:
param = bc.default_parameters(ks_dir, raw_dir, ephys_meta_dir = meta_path)

In [7]:
# Extract or load in raw waveforms
if raw_dir != None:
    raw_waveforms_full, raw_waveforms_peak_channel, SNR = bc.extract_raw_waveforms(
                        param,
                        spike_templates.squeeze(),
                        spike_times_samples.squeeze(),
                        param['re_extract_raw'],
                        save_path
                        )
else:
    raw_waveforms_full = None
    raw_waveforms_peak_channel = None
    SNR = None
    param['extract_raw_waveforms'] = False #No waveforms to extract!

In [8]:
# pre-load peak channels
max_channels = bc.get_waveform_max_channel(template_waveforms)

# Remove duplicate spikes
(non_empty_units,
 duplicate_spike_idx,
 spike_times_samples,
 spike_templates,
 template_amplitudes,
 pc_features,
 raw_waveforms_full,
 raw_waveforms_peak_channel,
 signal_to_noise_ratio,
 max_channels) = \
    bc.remove_duplicate_spikes(spike_times_samples,
                               spike_templates,
                               template_amplitudes,
                               max_channels,
                               save_path,
                               param,
                               pc_features = pc_features,
                               raw_waveforms_full = raw_waveforms_full,
                               raw_waveforms_peak_channel = raw_waveforms_peak_channel, 
                               signal_to_noise_ratio = SNR)


# Divide recording into time chunks
spike_times_seconds = spike_times_samples / param['ephys_sample_rate']
if param['compute_time_chunks']:
    time_chunks = np.arange(np.min(spike_times_seconds), np.max(spike_times_seconds), param['delta_time_chunk'])
else:
    time_chunks = np.array((np.min(spike_times_seconds), np.max(spike_times_seconds)))

# Should be got as part of removing duplicate spikes!!! 
unique_templates = np.unique(spike_templates) 

In [9]:
# Initialize quality metrics dictionary
n_units = unique_templates.size
quality_metrics = bc.create_quality_metrics_dict(n_units, snr = SNR)
quality_metrics['max_channels'] = max_channels
param['use_hill_method'] = True # use the old method for RPVs
param['compute_time_chunks'] = False

# Complete with remaining quality metrics
quality_metrics, times = bc.get_all_quality_metrics(unique_templates,
                                                    spike_times_seconds,
                                                    spike_templates,
                                                    template_amplitudes,
                                                    time_chunks,
                                                    pc_features,
                                                    pc_features_idx,
                                                    quality_metrics,
                                                    raw_waveforms_full,
                                                    channel_positions,
                                                    template_waveforms, param)

Computing BombCell quality metrics:   0%|          | 0/15 units

  fit_params = curve_fit(gaussian_cut, amp_bin_gaussian, spike_counts_per_amp_bin_gaussian, p0 = p0, ftol = 1e-3, xtol = 1e-3, maxfev = 10000)[0]
  spike_depth_in_channels = np.sum(pc_channel_pos_weights[np.newaxis, :] * pc_features_pc1 ** 2, axis = 1) / np.nansum(pc_features_pc1 ** 2, axis = 1)
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  p_missing = ((surrogate_area - np.sum(spike_counts_per_amp_bin) * bin_step) / surrogate_area) * 100


In [10]:
unit_type, unit_type_string = bc.get_quality_unit_type(param, quality_metrics)

qm_table = bc.make_qm_table(quality_metrics, param, unique_templates, unit_type_string)
qm_table.head()

Unnamed: 0,Original ID,NaN result,Peaks,Troughs,Waveform Min Length,Waveform Max Length,Baseline,Spatial Decay,Min Spikes,Missing Spikes,RPVs,Presence Ratio,Not Somatic,Good Unit
0,0,False,False,False,False,False,False,False,False,True,True,False,True,MUA
1,1,False,False,True,False,False,True,True,False,False,True,False,False,NOISE
2,2,False,False,False,False,False,False,False,False,False,True,False,True,MUA
3,3,False,False,False,False,False,False,False,True,False,False,False,True,MUA
4,4,False,False,False,False,False,False,False,False,False,True,False,True,MUA


To save results as a parquet either PyArrow or FastParquet needs to be installed

In [11]:
bc.save_results(quality_metrics, unit_type_string, unique_templates, param, raw_waveforms_full, raw_waveforms_peak_channel, save_path)