## Getting started with BombCell

## Install bombcell

Create a conda environment
```bash
conda create -n bombcell python=3.11
conda activate bombcell
```
Clone latest bombcell repository from github
```bash
git clone https://github.com/Julie-Fabre/bombcell.git
```
Install bombcell from local repository
```bash
cd bombcell/pyBombCell
# you could do `pip install .`, but uv is much quicker!
pip install uv
uv pip install . # or uv pip install -e . (-e for editable mode)
```

## Imports

In [20]:
import os, sys
from pathlib import Path
from pprint import pprint

import numpy as np
import pandas as pd
pd.set_option('display.max_columns', 50)
import matplotlib.pyplot as plt

In [21]:
# Optional

# # Add bombcell to Python path if NOT installed with pip
# # If notebook is running in bombcell repo:
# demo_dir = Path(os.getcwd())
# pyBombCell_dir = demo_dir.parent
# # Else:
# # pyBombCell_dir = "path/to/bombcell/repository/root"
# sys.path.append(str(pyBombCell_dir))

In [22]:
%load_ext autoreload
%autoreload 2

import bombcell as bc

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## Define data paths

By default: path to BombCell's toy dataset

In [23]:
# Replace with your kilosort directory
ks_dir = Path(os.getcwd()) / "toy_data" 

# Set bombcell's output directory
save_path = Path(ks_dir) / "bombcell"

In [24]:
## For Neuropixels probes, provide raw and meta files
# Leave 'None' if no raw data
raw_file_path = None # "path/to/rawdata.ap.bin"
meta_file_path = None # "path/to/metadata.ap.meta"

# Handle cases where raw data is compressed and fetch metadata uV conversion factor
ephys_raw_data, gain_to_uV = None, None
if raw_file_path is not None:
    ephys_raw_data = bc.manage_data_compression(Path(raw_file_path).parent)
if meta_file_path is not None:
    gain_to_uV = bc.get_gain_spikeglx(meta_file_path)

In [None]:
## For non-Neuropixels probes, specify conversion factor to uV
# ephys_raw_data = "" # .bin, .dat...
# gain_to_uV = None
# sampling_rate = 30_000
# nChannels = 385
# n_bytes_per_sample = 2 # 2 bytes if int16, 4 if int32...

## Get parameters

In [2]:
param = bc.get_default_parameters(ks_dir, 
                                  raw_file=ephys_raw_data,
                                  ephys_meta_dir=meta_file_path)

print("BombCell parameters:")
pprint(param)

NameError: name 'bc' is not defined

### Optionally customize parameters

In [None]:
# param["maxRPVviolations"] = 0.1
# param["compute_distance_metrics"] = 1
# param["computeDrift"] = 1
# param["compute_time_chunks"] = 0
# param['reextractRaw'] = False
# ...

## Run bombcell, get unit types and save results 

In [1]:
(
    quality_metrics,
    param,
    unit_type,
    unit_type_string,
) = bc.run_bombcell(
        ks_dir, save_path, param
)

NameError: name 'bc' is not defined

Reload quality metrics table

In [29]:
# quality metric values
quality_metrics_table = pd.DataFrame(quality_metrics)
quality_metrics_table.insert(0, 'Bombcell_unit_type', unit_type_string)
quality_metrics_table

Unnamed: 0,Bombcell_unit_type,phy_clusterID,nSpikes,nPeaks,nTroughs,waveformDuration_peakTrough,spatialDecaySlope,waveformBaselineFlatness,scndPeakToTroughRatio,mainPeakToTroughRatio,peak1ToPeak2Ratio,troughToPeak2Ratio,mainPeak_before_width,mainTrough_width,percentageSpikesMissing_gaussian,percentageSpikesMissing_symmetric,RPV_window_index,fractionRPVs,presenceRatio,maxDriftEstimate,cumDriftEstimate,rawAmplitude,signalToNoiseRatio,isolationDistance,Lratio,silhouetteScore,useTheseTimesStart,useTheseTimesStop,peak_channels
0,MUA,0,9705.0,1.0,1.0,600.0,-0.042352,0.021372,0.312463,0.312463,0.28365,11.282845,2.83236,7.31889,11.685531,27.709497,0.0,1.0,0.756757,24.972344,222.087749,,,,,,0.325367,4475.873933,0
1,NOISE,1,24343.0,1.0,2.0,400.0,-0.000496,1.0,0.980433,0.980433,0.001438,709.481628,,8.092349,100.0,31.48414,0.0,1.0,1.0,42.515827,383.897898,,,,,,0.325367,4475.873933,0
2,MUA,2,5906.0,1.0,1.0,733.333333,-0.025323,0.048061,0.431882,0.431882,0.111282,20.806959,0.948645,8.875602,100.0,22.074152,0.0,1.0,0.972973,73.389266,453.775656,,,,,,0.325367,4475.873933,2
3,MUA,3,1342.0,1.0,1.0,300.0,-0.040807,0.014625,0.234792,0.234792,0.021125,201.614166,,4.099461,7.891646,20.872642,0.0,1.0,0.810811,49.149593,632.523427,,,,,,0.325367,4475.873933,2
4,MUA,4,3784.0,2.0,1.0,600.0,-0.013389,0.057277,0.304664,0.304664,0.844419,3.887054,3.345621,8.822806,12.775031,41.360607,0.0,1.0,0.945946,65.015079,406.069652,,,,,,0.325367,4475.873933,4
5,MUA,5,6454.0,1.0,1.0,666.666667,-0.028866,0.044926,0.297345,0.297345,0.151091,22.258665,,6.191829,2.995993,2.330508,0.0,0.982238,0.527027,34.258558,362.233033,,,,,,0.325367,4475.873933,6
6,GOOD,6,13605.0,1.0,1.0,266.666667,-0.02093,0.026933,0.492256,0.492256,0.042574,47.715912,,4.136663,0.0,7.139932,0.0,0.032903,0.824324,52.689211,170.088576,,,,,,0.325367,4475.873933,6
7,MUA,7,171.0,1.0,1.0,600.0,-0.014845,0.019374,0.288174,0.288174,0.067231,51.614994,,4.139432,100.0,35.955056,0.0,1.0,0.689189,53.153589,702.260547,,,,,,0.325367,4475.873933,5
8,MUA,8,390.0,1.0,1.0,600.0,-0.015257,0.045225,0.303427,0.303427,0.149049,22.111492,,5.895327,0.220368,0.0,0.0,0.676498,0.878378,44.889622,566.454679,,,,,,0.325367,4475.873933,9
9,MUA,9,75.0,1.0,1.0,633.333333,-0.019839,0.029477,0.292315,0.292315,0.10084,33.924671,,5.698128,100.0,30.555556,0.0,1.0,0.337838,64.342559,393.136776,,,,,,0.325367,4475.873933,11


In [30]:
# boolean table, if quality metrics pass threshold given parameters
boolean_quality_metrics_table = bc.make_qm_table(
    quality_metrics, param, unit_type_string
)
boolean_quality_metrics_table

Unnamed: 0,unit_type,Original ID,NaN result,# peaks,# troughs,duration,baseline flatness,peak2 / trough,spatial decay,# spikes,% spikes missing,presence ratio,fraction RPVs,non somatic,peak(main) / trough,peak1 / peak2
0,MUA,0,False,False,False,False,False,False,False,False,False,False,True,False,False,False
1,NOISE,1,False,False,True,False,True,True,True,False,True,False,True,True,True,False
2,MUA,2,False,False,False,False,False,False,False,False,True,False,True,False,False,False
3,MUA,3,False,False,False,False,False,False,False,False,False,False,True,False,False,False
4,MUA,4,False,False,False,False,False,False,False,False,False,False,True,False,False,False
5,MUA,5,False,False,False,False,False,False,False,False,False,True,True,False,False,False
6,GOOD,6,False,False,False,False,False,False,False,False,False,False,False,False,False,False
7,MUA,7,False,False,False,False,False,False,False,True,True,True,True,False,False,False
8,MUA,8,False,False,False,False,False,False,False,False,False,False,True,False,False,False
9,MUA,9,False,False,False,False,False,False,False,True,True,True,True,False,False,False


Example: get all quality metrics for unit 10

## Compute ephys properties for cell type classification

In [ ]:
# Compute ephys properties (ACG, ISI, waveform morphology, etc.)
ephys_properties, ephys_param = bc.run_all_ephys_properties(
    ks_dir, param=param, save_path=save_path
)

print("Ephys properties computed:")
print(f"Number of units: {len(ephys_properties)}")
print(f"Properties per unit: {list(ephys_properties[0].keys()) if ephys_properties else 'None'}")

In [ ]:
# Create DataFrame with ephys properties for analysis
ephys_properties_table = pd.DataFrame(ephys_properties)
ephys_properties_table.insert(0, 'phy_clusterID', range(len(ephys_properties)))
ephys_properties_table.insert(1, 'Bombcell_unit_type', unit_type_string)

print("Ephys properties summary:")
ephys_properties_table.head()

### Cell type classification

Classify neurons based on waveform and firing properties. For striatum recordings, classify as MSN, FSI, TAN, or UIN. For cortex recordings, classify as narrow or wide-spiking neurons.

In [ ]:
# Classify cells - specify brain region for your recording
# Options: 'striatum' or 'cortex'
brain_region = 'striatum'  # Change this based on your recording location

if brain_region == 'striatum':
    cell_types = bc.classify_striatum_cells(ephys_properties, ephys_param)
    cell_type_names = ['MSN', 'FSI', 'TAN', 'UIN']
elif brain_region == 'cortex':
    cell_types = bc.classify_cortex_cells(ephys_properties, ephys_param)
    cell_type_names = ['Narrow-spiking', 'Wide-spiking']
else:
    print("Unknown brain region. Using generic classification.")
    cell_types = ['Unknown'] * len(ephys_properties)
    cell_type_names = ['Unknown']

# Add cell types to table
ephys_properties_table['cell_type'] = cell_types

print(f"Cell type classification for {brain_region}:")
print(ephys_properties_table[['phy_clusterID', 'Bombcell_unit_type', 'cell_type']].head(10))

In [ ]:
# Summary statistics of cell types
print("Cell type distribution:")
cell_type_counts = ephys_properties_table['cell_type'].value_counts()
print(cell_type_counts)

# Show cell types for GOOD units only
good_units = ephys_properties_table[ephys_properties_table['Bombcell_unit_type'] == 'GOOD']
if len(good_units) > 0:
    print("\nCell types for GOOD units:")
    good_cell_types = good_units['cell_type'].value_counts()
    print(good_cell_types)
else:
    print("\nNo GOOD units found in this dataset.")

### Visualize ephys properties and cell type classification

In [ ]:
# Plot key ephys properties colored by cell type
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# Define colors for each cell type
colors = {'MSN': 'blue', 'FSI': 'red', 'TAN': 'green', 'UIN': 'orange',
          'Narrow-spiking': 'red', 'Wide-spiking': 'blue', 'Unknown': 'gray'}

# Properties to plot
if brain_region == 'striatum':
    properties = [
        ('firing_rate_mean', 'Mean Firing Rate (Hz)'),
        ('waveform_duration_peak_trough', 'Waveform Duration (μs)'), 
        ('acg_pss_ratio', 'Post-Spike Suppression Ratio'),
        ('isi_cv', 'ISI CV'),
        ('waveform_half_width', 'Waveform Half Width (μs)'),
        ('acg_tau_rise', 'ACG Tau Rise (ms)')
    ]
else:
    properties = [
        ('firing_rate_mean', 'Mean Firing Rate (Hz)'),
        ('waveform_duration_peak_trough', 'Waveform Duration (μs)'),
        ('waveform_half_width', 'Waveform Half Width (μs)'),
        ('isi_cv', 'ISI CV'),
        ('acg_pss_ratio', 'Post-Spike Suppression Ratio'),
        ('firing_rate_std', 'Firing Rate Std (Hz)')
    ]

for i, (prop, label) in enumerate(properties):
    if i >= len(axes):
        break
    
    ax = axes[i]
    
    # Plot each cell type
    for cell_type in cell_type_names:
        mask = ephys_properties_table['cell_type'] == cell_type
        data = ephys_properties_table.loc[mask, prop]
        if len(data) > 0:
            ax.scatter(np.ones(len(data)) * i, data, 
                      c=colors.get(cell_type, 'gray'), 
                      label=cell_type, alpha=0.6)
    
    ax.set_ylabel(label)
    ax.set_xticks([])
    if i == 0:
        ax.legend()

plt.tight_layout()
plt.suptitle(f'Ephys Properties by Cell Type ({brain_region})', y=1.02)
plt.show()

In [ ]:
# Example: Plot ACG for different cell types
example_units = {}
for cell_type in cell_type_names:
    mask = ephys_properties_table['cell_type'] == cell_type
    if mask.any():
        example_units[cell_type] = ephys_properties_table.loc[mask, 'phy_clusterID'].iloc[0]

if example_units:
    fig, axes = plt.subplots(1, len(example_units), figsize=(4*len(example_units), 4))
    if len(example_units) == 1:
        axes = [axes]
    
    for i, (cell_type, unit_id) in enumerate(example_units.items()):
        # Get ACG for this unit
        unit_props = ephys_properties_table[ephys_properties_table['phy_clusterID'] == unit_id].iloc[0]
        
        # Plot a mock ACG (in real usage, you'd load the actual ACG data)
        ax = axes[i]
        ax.bar(range(-50, 51), np.random.exponential(1, 101), 
               color=colors.get(cell_type, 'gray'), alpha=0.7)
        ax.set_title(f'{cell_type}\\nUnit {unit_id}')
        ax.set_xlabel('Time lag (ms)')
        ax.set_ylabel('Normalized count')
        
        # Add some properties as text
        ax.text(0.05, 0.95, f'FR: {unit_props["firing_rate_mean"]:.1f} Hz\\n'
                             f'Duration: {unit_props["waveform_duration_peak_trough"]:.0f} μs',
                transform=ax.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.suptitle('Example Auto-Correlograms by Cell Type', y=1.02)
    plt.show()
else:
    print("No examples found for visualization.")

In [32]:
u = 12

units = quality_metrics_table.phy_clusterID
quality_metrics_table.loc[units == u, :]

Unnamed: 0,Bombcell_unit_type,phy_clusterID,nSpikes,nPeaks,nTroughs,waveformDuration_peakTrough,spatialDecaySlope,waveformBaselineFlatness,scndPeakToTroughRatio,mainPeakToTroughRatio,peak1ToPeak2Ratio,troughToPeak2Ratio,mainPeak_before_width,mainTrough_width,percentageSpikesMissing_gaussian,percentageSpikesMissing_symmetric,RPV_window_index,fractionRPVs,presenceRatio,maxDriftEstimate,cumDriftEstimate,rawAmplitude,signalToNoiseRatio,isolationDistance,Lratio,silhouetteScore,useTheseTimesStart,useTheseTimesStop,peak_channels
12,MUA,12,14750.0,1.0,1.0,300.0,-0.013697,0.077735,0.49462,0.49462,0.088362,22.880278,,6.52303,100.0,0.0,0.0,1.0,1.0,30.6275,191.151639,,,,,,0.325367,4475.873933,12
