# Playground2: Kilosort + Template Metrics

This notebook:
1. Runs Kilosort4 (if not already done)
2. Loads data and sorting results
3. Computes template metrics

In [None]:
from pathlib import Path
import spikeinterface.full as si

print(f"SpikeInterface version: {si.__version__}")

## Configuration

In [None]:
# Data paths
data_folder = Path("/Users/jf5479/Downloads/AL031_2019-12-02")
bin_file = data_folder / "AL031_2019-12-02_bank1_NatIm_g0_t0_bc_decompressed.imec0.ap.bin"
meta_file = data_folder / "AL031_2019-12-02_bank1_NatIm_g0_t0.imec0.ap.meta"

# Output paths
output_folder = data_folder / "spikeinterface_output"
kilosort_output = output_folder / "kilosort4_output"
analyzer_folder = output_folder / "sorting_analyzer"

# Job kwargs for parallel processing
job_kwargs = dict(n_jobs=-1, chunk_duration="1s", progress_bar=True)

print(f"Data folder: {data_folder}")
print(f"Bin file exists: {bin_file.exists()}")
print(f"Meta file exists: {meta_file.exists()}")

## 1. Load Recording

In [None]:
# The bin and meta files have different names, so we need to load manually
from neo.rawio.spikeglxrawio import read_meta_file
from spikeinterface.extractors.cbin_ibl import extract_stream_info
from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts
import probeinterface

# Read meta file
meta = read_meta_file(meta_file)
info = extract_stream_info(meta_file, meta)

# Get parameters
num_channels = info["num_chan"]
sampling_frequency = info["sampling_rate"]
channel_gains = info["channel_gains"]
channel_offsets = info["channel_offsets"]
channel_ids = info["channel_names"]

# Remove sync channel (last channel)
num_channels_no_sync = num_channels - 1
channel_gains_no_sync = channel_gains[:-1]
channel_offsets_no_sync = channel_offsets[:-1]
channel_ids_no_sync = channel_ids[:-1]

print(f"Sampling frequency: {sampling_frequency} Hz")
print(f"Number of channels (without sync): {num_channels_no_sync}")

# Load as binary recording
recording = si.read_binary(
    file_paths=bin_file,
    sampling_frequency=sampling_frequency,
    num_channels=num_channels,  # Include sync for reading, will remove later
    dtype="int16",
)

# Remove sync channel using select_channels
recording = recording.select_channels(channel_ids=recording.channel_ids[:-1])

# Set gains and offsets
recording.set_channel_gains(channel_gains_no_sync)
recording.set_channel_offsets(channel_offsets_no_sync)

# Load and attach probe from meta file
probe = probeinterface.read_spikeglx(meta_file)
recording = recording.set_probe(probe)

# Set inter_sample_shift property for phase correction (needed for Neuropixels)
ptype = probe.annotations.get("probe_type", 0)
if ptype in [21, 24]:  # NP2.0
    num_channels_per_adc = 16
else:  # NP1.0
    num_channels_per_adc = 12

sample_shifts = get_neuropixels_sample_shifts(recording.get_num_channels(), num_channels_per_adc)
recording.set_property("inter_sample_shift", sample_shifts)

print(f"Loaded recording: {recording}")

In [None]:
print(f"Duration: {recording.get_total_duration():.2f} s")
print(f"Probe: {recording.get_probe()}")
recording

## 2. Preprocessing

In [None]:
# High-pass filter
rec_filtered = si.highpass_filter(recording, freq_min=300.0)

# Detect and remove bad channels
bad_channel_ids, channel_labels = si.detect_bad_channels(rec_filtered)
print(f"Bad channels detected: {len(bad_channel_ids)}")
if len(bad_channel_ids) > 0:
    print(f"Bad channel IDs: {bad_channel_ids}")
    rec_clean = rec_filtered.remove_channels(bad_channel_ids)
else:
    rec_clean = rec_filtered

# Skip phase_shift - Kilosort4 handles this internally
# Common median reference
rec_preprocessed = si.common_reference(rec_clean, operator="median", reference="global")

print(f"Preprocessed recording: {rec_preprocessed}")

## 3. Run Kilosort4 (if not already done)

In [None]:
# Check if Kilosort output already exists
if kilosort_output.exists() and (kilosort_output / "spike_times.npy").exists():
    print(f"Kilosort output already exists at: {kilosort_output}")
    print("Loading existing sorting results...")
    sorting = si.read_sorter_folder(kilosort_output)
else:
    print(f"Running Kilosort4, output will be saved to: {kilosort_output}")
    print(f"Installed sorters: {si.installed_sorters()}")

    # Run Kilosort4
    sorting = si.run_sorter(
        sorter_name="kilosort4",
        recording=rec_preprocessed,
        folder=kilosort_output,
        verbose=True,
        remove_existing_folder=True,  # Remove any failed previous attempts
    )
    print("Kilosort4 completed!")

print(f"Sorting result: {sorting}")
print(f"Number of units: {len(sorting.unit_ids)}")

## 4. Create SortingAnalyzer

In [None]:
# Check if analyzer already exists
if analyzer_folder.exists():
    print(f"Loading existing analyzer from: {analyzer_folder}")
    analyzer = si.load_sorting_analyzer(analyzer_folder)
else:
    print(f"Creating new analyzer at: {analyzer_folder}")
    analyzer = si.create_sorting_analyzer(
        sorting=sorting,
        recording=rec_preprocessed,
        sparse=True,
        format="binary_folder",
        folder=analyzer_folder,
    )

analyzer

## 5. Compute Extensions for Template Metrics

In [None]:
# Random spikes selection
if not analyzer.has_extension("random_spikes"):
    print("Computing random_spikes...")
    analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500)

In [None]:
# Waveforms
if not analyzer.has_extension("waveforms"):
    print("Computing waveforms...")
    analyzer.compute("waveforms", ms_before=1.5, ms_after=2.0, **job_kwargs)

In [None]:
# Templates
if not analyzer.has_extension("templates"):
    print("Computing templates...")
    analyzer.compute("templates", operators=["average", "median", "std"])

In [None]:
# Noise levels
if not analyzer.has_extension("noise_levels"):
    print("Computing noise_levels...")
    analyzer.compute("noise_levels")

## 6. Compute Template Metrics

In [None]:
# Compute template metrics with multi-channel metrics included
if not analyzer.has_extension("template_metrics"):
    print("Computing template_metrics...")
    analyzer.compute(
        "template_metrics",
        include_multi_channel_metrics=True,
    )

# Get the metrics as a DataFrame
template_metrics = analyzer.get_extension("template_metrics").get_data()
template_metrics

## 7. Compute Quality Metrics (optional)

In [None]:
# Spike amplitudes
if not analyzer.has_extension("spike_amplitudes"):
    print("Computing spike_amplitudes...")
    analyzer.compute("spike_amplitudes", **job_kwargs)

In [None]:
# Correlograms
if not analyzer.has_extension("correlograms"):
    print("Computing correlograms...")
    analyzer.compute("correlograms")

In [None]:
# Quality metrics
if not analyzer.has_extension("quality_metrics"):
    print("Computing quality_metrics...")
    analyzer.compute("quality_metrics")

quality_metrics = analyzer.get_extension("quality_metrics").get_data()
quality_metrics

## 8. Summary

In [None]:
print(f"Total units: {len(sorting.unit_ids)}")
print(f"Analyzer saved to: {analyzer_folder}")
print(f"\nAvailable extensions:")
for ext_name in analyzer.get_loaded_extension_names():
    print(f"  - {ext_name}")

In [None]:
# Combine metrics
combined_metrics = template_metrics.join(quality_metrics, how="outer")
combined_metrics

In [None]:
# Save metrics to CSV
output_folder.mkdir(parents=True, exist_ok=True)
metrics_csv = output_folder / "combined_metrics.csv"
combined_metrics.to_csv(metrics_csv)
print(f"Metrics saved to: {metrics_csv}")