# VLM Curation Tutorial

This tutorial demonstrates how to use Vision-Language Models (VLM) to curate spike sorting results.

## What is VLM Curation?

VLM curation uses AI vision models (like GPT-4o or Claude) to classify units as "Good" or "Bad" based on visual features like waveforms, autocorrelograms, and spike locations. This provides an automated, AI-assisted approach to quality control.

## Requirements

- A `SortingAnalyzer` with computed extensions:
  - `waveforms` (for waveform plots)
  - `templates` (for multi-channel waveforms)
  - `correlograms` (for autocorrelograms)
  - `spike_locations` (for spike location plots)
  - `spike_amplitudes` (for amplitude plots)
  - `quality_metrics` (optional, for including quantitative metrics)

- API key for your chosen VLM provider (OpenAI, Anthropic, or Google)


In [None]:
import os
import sys

_root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, os.path.join(_root_dir, 'src'))

import spikeinterface as si
import numpy as np
from spikeagent.app.tool.si_custom import create_unit_img_df
from spikeagent.curation.vlm_curation import run_vlm_curation, plot_spike_images_with_result
from spikeagent.app.tool.utils import get_model
from dotenv import load_dotenv

load_dotenv()

print("Imports successful!")


## Step 1: Create a SortingAnalyzer

We'll create a synthetic `SortingAnalyzer` for demonstration. In practice, you would load your own recording and sorting.


In [None]:
print("Creating synthetic recording and sorting...")
recording, sorting = si.generate_ground_truth_recording(
    durations=[10.0],
    num_channels=16,
    num_units=8,
    sampling_frequency=30000.0,
    noise_kwargs={'noise_levels': 5.0, 'strategy': 'on_the_fly'},
    seed=42
)

sorting_analyzer = si.create_sorting_analyzer(
    sorting=sorting,
    recording=recording,
    format="memory"
)

print(f"Created sorting_analyzer with {len(sorting_analyzer.unit_ids)} units")
print(f"Unit IDs: {list(sorting_analyzer.unit_ids)}")


## Step 2: Compute Required Extensions

We need to compute all the extensions that VLM curation will use to generate images.


In [None]:
print("Computing extensions...")

sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("noise_levels")
sorting_analyzer.compute("waveforms", n_jobs=1)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("correlograms", window_ms=100.0, bin_ms=1.0)
sorting_analyzer.compute("spike_locations", method="center_of_mass")
sorting_analyzer.compute("spike_amplitudes")
sorting_analyzer.compute("quality_metrics")

available_exts = [ext for ext in sorting_analyzer.get_computable_extensions() 
                  if sorting_analyzer.has_extension(ext)]
print(f"Computed extensions: {available_exts}")


## Step 3: Create Image DataFrame

The VLM needs images of each unit's features. We'll create a dataframe containing base64-encoded images.


In [None]:
features = ["waveform_single", "autocorr", "spike_locations"]

print(f"Creating image dataframe with features: {features}...")
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
    img_df = create_unit_img_df(
        sorting_analyzer,
        unit_ids=None,
        features=features,
        load_if_exists=False,
        save_folder=tmpdir
    )

print(f"Created image dataframe: {img_df.shape}")
print(f"Units: {list(img_df.index)}")
print(f"Features: {list(img_df.columns)}")


## Step 4: Run VLM Curation

Now we'll use a Vision-Language Model to classify each unit as "Good" or "Bad".

**Note:** This requires an API key. Set it in your environment or `.env` file:
- `OPENAI_API_KEY` for GPT-4o
- `ANTHROPIC_API_KEY` for Claude
- `GOOGLE_API_KEY` for Gemini


In [None]:
model_name = "gpt-4o"
model = get_model(model_name)

print(f"Initialized model: {model_name}")

print("\nRunning VLM curation...")
print("This may take a few minutes depending on the number of units...")

results_df = run_vlm_curation(
    model=model,
    sorting_analyzer=sorting_analyzer,
    img_df=img_df,
    features=features,
    good_ids=[],
    bad_ids=[],
    with_metrics=True,
    unit_ids=None,
    num_workers=10
)

print(f"\nVLM curation complete!")
print(f"\nResults summary:")
print(f"Total units analyzed: {len(results_df)}")
print(f"Good units: {len(results_df[results_df['final_classification'] == 'Good'])}")
print(f"Bad units: {len(results_df[results_df['final_classification'] == 'Bad'])}")


## Step 5: View Results

Let's examine the results and see which units were classified as good or bad.


In [None]:
print("VLM Curation Results:")
print("=" * 60)
print(results_df[['average_score', 'final_classification']].to_string())

good_units = results_df[results_df['final_classification'] == 'Good'].index.tolist()
bad_units = results_df[results_df['final_classification'] == 'Bad'].index.tolist()

print(f"\nGood units: {good_units}")
print(f"Bad units: {bad_units}")

print("\nSample reasoning:")
for unit_id in list(sorting_analyzer.unit_ids)[:3]:
    if unit_id in results_df.index:
        print(f"\nUnit {unit_id}:")
        print(f"Classification: {results_df.loc[unit_id, 'final_classification']}")
        print(f"Score: {results_df.loc[unit_id, 'average_score']:.3f}")
        reasoning = results_df.loc[unit_id, 'combined_reasoning']
        print(f"Reasoning: {reasoning[:200]}...")


## Step 6: Visualize Results

Plot the units with their classification results.


In [None]:
plot_spike_images_with_result(results_df, img_df, feature="waveform_single")


## Step 7: Apply Curation

Finally, create a curated analyzer containing only the "Good" units.


In [None]:
if good_units:
    curated_analyzer = sorting_analyzer.select_units(good_units)
    print(f"Created curated analyzer with {len(curated_analyzer.unit_ids)} units")
    print(f"Original units: {len(sorting_analyzer.unit_ids)}")
    print(f"Curated units: {len(curated_analyzer.unit_ids)}")
    print(f"Removed: {len(sorting_analyzer.unit_ids) - len(curated_analyzer.unit_ids)} units")
else:
    print("No 'Good' units found. Consider reviewing the results.")
