# Bombcell unit labelling

With this notebook you can:
- load a SortingAnalyzer
- compute required extensions
- label units based on quality thresholds
- generating and save summary plots
- save metrics and results

In [None]:
from pathlib import Path

import spikeinterface as si
from spikeinterface.curation import (
    bombcell_get_default_thresholds,
    bombcell_label_units,
    save_thresholds,
    load_thresholds,
)
from spikeinterface.widgets import plot_unit_labelling_all

#### load a SortingAnalyzer

In [None]:
# Change this to your analyzer path - you need to have already generated a sorting analyzer. see quickstart.py for how to do this
analyzer_path = "/Users/jf5479/Downloads/M25_D18/kilosort4_sa"
output_folder = Path(analyzer_path) / "bombcell"

analyzer = si.load_sorting_analyzer(analyzer_path)
analyzer

#### compute required extensions

In [None]:
# Templates (required for template_metrics)
if not analyzer.has_extension("templates"):
    analyzer.compute("templates")

In [None]:
# Template metrics
if not analyzer.has_extension("template_metrics"):
    analyzer.compute("template_metrics")

In [None]:
# Quality metrics (and dependencies)
if not analyzer.has_extension("spike_amplitudes"):
    analyzer.compute("spike_amplitudes")

if not analyzer.has_extension("noise_levels"):
    analyzer.compute("noise_levels")

if not analyzer.has_extension("quality_metrics"):
    analyzer.compute("quality_metrics")

#### get metrics

In [None]:
qm = analyzer.get_extension("quality_metrics").get_data()
tm = analyzer.get_extension("template_metrics").get_data()

print(f"Quality metrics: {list(qm.columns)}")
print(f"Template metrics: {list(tm.columns)}")

#### set labelling thresholds

In [None]:
# Use default thresholds
thresholds = bombcell_get_default_thresholds()

# Or load from file:
# thresholds = load_thresholds("my_thresholds.json")

thresholds

In [None]:
# Optionally modify thresholds
# thresholds["amplitude_median"]["min"] = 50  # stricter
# thresholds["rp_contamination"]["max"] = 0.05  # stricter

In [None]:
# Optionally set and load thresholds from a JSON file 
# Load thresholds from saved JSON
thresholds = load_thresholds(output_folder / "thresholds.json")

The JSON file format looks like:
```json
{
    "amplitude_median": {"min": 40, "max": null},
    "num_positive_peaks": {"min": null, "max": 2},
    "peak_to_trough_duration": {"min": 0.0001, "max": 0.00115}
}
```
`null` in JSON becomes `np.nan` (threshold disabled)

#### label units

In [None]:
unit_type, unit_type_string = bombcell_label_units(
    quality_metrics=qm,
    template_metrics=tm,
    thresholds=thresholds,
    label_non_somatic=True,
    split_non_somatic_good_mua=False,
)

#### generate summary plots

In [None]:
plots = plot_unit_labelling_all(
    analyzer,
    unit_type,
    unit_type_string,
    quality_metrics=qm,
    template_metrics=tm,
    thresholds=thresholds,
    save_folder=output_folder,
)

#### save labelling thresholds

In [None]:
save_thresholds(thresholds, output_folder / "thresholds.json")

print(f"Results saved to: {output_folder.absolute()}")
print("\nFiles:")
for f in sorted(output_folder.glob("*")):
    print(f"  - {f.name}")