# Imports

In [59]:
import os
import sys
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import warnings # JF: I think we want to remove this 
warnings.filterwarnings('ignore')

In [60]:
# Add bombcell to Python path if not installed with pip JF: this should be inside the bc.load_ephys_data() function
demo_dir = Path(os.getcwd())
pyBombCell_dir = demo_dir.parent
sys.path.append(str(pyBombCell_dir))

In [61]:
%load_ext autoreload 
%autoreload 2

import bombcell as bc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Define data paths

By default: path to BombCell's toy dataset

In [62]:
ks_dir = demo_dir / 'toy_data'  # Replace with your kilosort directory
raw_dir = None  # Leave 'None' if no raw data; eventually replace with path to your raw data
save_path = "~/Downloads/bombcell_plots"  # ~ is home directory, / work on Windows

# If a raw data directory with a meta folder is not given,
# please input the gain manually
gain_to_uV = np.nan

In [63]:
spike_times_samples, spike_templates, template_waveforms, template_amplitudes, \
           pc_features, pc_features_idx, channel_positions, good_channels = bc.load_ephys_data(ks_dir)

# JF: everything below this should in one function (the if statement)
if raw_dir != None:
    ephys_raw_data, meta_path = bc.manage_data_compression(raw_dir, decompressed_data_local = raw_dir)
    gain_to_uV = bc.get_gain_spikeglx(meta_path)
else:
    meta_path = None

In [64]:
param = bc.default_parameters(ks_dir, raw_dir, ephys_meta_dir = meta_path)

param['compute_time_chunks'] = 1;
param['compute_drift'] = 1;
param['compute_distance_metrics'] = 0;
param

{'show_detail_plots': False,
 'show_summary_plots': True,
 'verbose': True,
 're_extract_raw': False,
 'save_as_tsv': True,
 'unit_type_for_phy': True,
 'ephys_kilosort_path': PosixPath('/home/julie/Dropbox/MATLAB/bombcell/pyBombCell/Demos/toy_data'),
 'save_mat_file': False,
 'remove_duplicate_spike': True,
 'duplicate_spikes_window_s': 1e-05,
 'save_spike_without_duplicates': True,
 'recompute_duplicate_spike': False,
 'detrend_waveform': True,
 'n_raw_spikes_to_extract': 500,
 'save_multiple_raw': False,
 'decompress_data': False,
 'extract_raw_waveforms': True,
 'probe_type': 1,
 'tauR_values_min': 0.002,
 'tauR_values_max': 0.002,
 'tauR_values_steps': 0.0005,
 'tauC': 0.0001,
 'use_hill_method': 1,
 'compute_time_chunks': 1,
 'delta_time_chunk': 360,
 'presence_ratio_bin_size': 60,
 'drift_bin_size': 60,
 'compute_drift': 1,
 'min_thresh_detect_peaks_troughs': 0.2,
 'first_peak_ratio': 1.1,
 'normalize_spatial_decay': True,
 'min_width_first_peak': 4,
 'min_main_peak_to_trough_ra

In [65]:
# Extract or load in raw waveforms JF: this should be inside bc.get_all_quality_metrics
if raw_dir != None:
    raw_waveforms_full, raw_waveforms_peak_channel, SNR = bc.extract_raw_waveforms(
                        param,
                        spike_templates.squeeze(),
                        spike_times_samples.squeeze(),
                        param['re_extract_raw'],
                        save_path
                        )
else:
    raw_waveforms_full = None
    raw_waveforms_peak_channel = None
    SNR = None
    param['extract_raw_waveforms'] = False #No waveforms to extract!

In [66]:
# pre-load peak channels JF: this should be inside bc.get_all_quality_metrics
max_channels = bc.get_waveform_max_channel(template_waveforms)

# Remove duplicate spikes JF: this should be inside bc.get_all_quality_metrics
(non_empty_units,
 duplicate_spike_idx,
 spike_times_samples,
 spike_templates,
 template_amplitudes,
 pc_features,
 raw_waveforms_full,
 raw_waveforms_peak_channel,
 signal_to_noise_ratio,
 max_channels) = \
    bc.remove_duplicate_spikes(spike_times_samples,
                               spike_templates,
                               template_amplitudes,
                               max_channels,
                               save_path,
                               param,
                               pc_features = pc_features,
                               raw_waveforms_full = raw_waveforms_full,
                               raw_waveforms_peak_channel = raw_waveforms_peak_channel, 
                               signal_to_noise_ratio = SNR)


# Divide recording into time chunks JF: this should be inside bc.get_all_quality_metrics
spike_times_seconds = spike_times_samples / param['ephys_sample_rate']
if param['compute_time_chunks']:
    time_chunks = np.arange(np.min(spike_times_seconds), np.max(spike_times_seconds), param['delta_time_chunk'])
else:
    time_chunks = np.array((np.min(spike_times_seconds), np.max(spike_times_seconds)))

# Should be got as part of removing duplicate spikes!!! 
unique_templates = np.unique(spike_templates) 

In [67]:
# Initialize quality metrics dictionary JF: this should be inside bc.get_all_quality_metrics
n_units = unique_templates.size
quality_metrics = bc.create_quality_metrics_dict(n_units, snr = SNR)
quality_metrics['max_channels'] = max_channels

# Complete with remaining quality metrics
quality_metrics, times = bc.get_all_quality_metrics(unique_templates,
                                                    spike_times_seconds,
                                                    spike_templates,
                                                    template_amplitudes,
                                                    time_chunks,
                                                    pc_features,
                                                    pc_features_idx,
                                                    quality_metrics,
                                                    raw_waveforms_full,
                                                    channel_positions,
                                                    template_waveforms, param)

Computing BombCell quality metrics: 100%|██████████| 15/15 units


In [68]:
unit_type, unit_type_string = bc.get_quality_unit_type(param, quality_metrics) #JF: this should be inside bc.get_all_quality_metrics

qm_table = bc.make_qm_table(quality_metrics, param, unique_templates, unit_type_string)

qm_table.head()

Unnamed: 0,Original ID,NaN result,Peaks,Troughs,Waveform Min Length,Waveform Max Length,Baseline,Spatial Decay,Min Spikes,Missing Spikes,RPVs,Presence Ratio,Not Somatic,Good Unit
0,0,False,False,False,False,False,False,False,False,True,True,False,True,MUA
1,1,False,False,True,False,False,True,True,False,False,True,False,False,NOISE
2,2,False,False,False,False,False,False,False,False,False,True,False,True,MUA
3,3,False,False,False,False,False,False,False,True,False,False,False,True,MUA
4,4,False,False,False,False,False,False,False,False,False,True,False,True,MUA


In [69]:
bc.save_results(quality_metrics, unit_type_string, unique_templates, param, raw_waveforms_full, raw_waveforms_peak_channel, save_path) #JF: this should be inside bc.get_all_quality_metrics

ImportError: Unable to find a usable engine; tried using: 'pyarrow', 'fastparquet'.
A suitable version of pyarrow or fastparquet is required for parquet support.
Trying to import the above resulted in these errors:
 - Missing optional dependency 'pyarrow'. pyarrow is required for parquet support. Use pip or conda to install pyarrow.
 - Missing optional dependency 'fastparquet'. fastparquet is required for parquet support. Use pip or conda to install fastparquet.