# VLM Spike Curation Tutorial

This tutorial will guide you through using AI (Vision Language Models) to help curate spike sorting results by automatically identifying which units should be merged together.

## What is Spike Curation?

After spike sorting, you often have multiple units that actually represent the same neuron. **Curation** is the process of identifying and merging these duplicate units to get a cleaner, more accurate representation of your neural data.

## What This Tutorial Does

This tutorial uses AI to:
1. **Find potential merge candidates** - Automatically identifies units that might be duplicates
2. **Analyze visual features** - Uses AI to examine crosscorrelograms and amplitude plots
3. **Make merge decisions** - The AI decides which units should be merged based on visual analysis
4. **Save results** - Creates a log of all merge decisions for review

## Prerequisites

- A spikeinterface SortingAnalyzer object (from a completed spike sorting run)
- An OpenAI API key (for GPT-4o) or other supported model
- The required Python packages installed

Let's get started!

# ## Step 0: Set Up Your API Keys
# 
# **IMPORTANT:** Before starting this tutorial, you must configure your API keys. These keys allow the AI models to process and analyze your spike sorting data. Based on our experience, GPT-4o is the fastest model for spike curation tasks.
# 
# ### Which API Key Do You Need?
# 
# - **Recommended (GPT-4o):** Set your `OPENAI_API_KEY`
# 
# ### How to Obtain Your API Keys
# 
# - **OpenAI:** https://platform.openai.com/api-keys
# - **Anthropic:** https://console.anthropic.com/
# - **Google:** https://makersuite.google.com/app/apikey
# 
# ### Configuring Your `.env` File
# 
# Run the cell below to create or update your `.env` file with your API keys. This file will be automatically loaded by the tutorial.


In [None]:
import os
from dotenv import load_dotenv

# Create or update .env file with your API keys
# Replace the values below with your actual API keys

env_content = """# API Keys for Spike Agent
# Replace the values below with your actual API keys
# You can leave keys empty if you're not using that provider

# OpenAI API Key (required for GPT-4o, GPT-4o-mini, etc.)
OPENAI_API_KEY=your_openai_key_here

"""

# Write .env file
with open('.env', 'w') as f:
    f.write(env_content)

print(".env file created!")
print("\n   IMPORTANT: Please edit the .env file and replace:")
print("   - 'your_openai_key_here' with your actual OpenAI API key")
print("\n  You can edit the .env file in your file explorer or text editor.")
print("   Make sure there are NO spaces around the = sign!")
print("\nAfter updating the .env file, run the next cell to verify your keys are loaded.")


In [None]:
# Load environment variables from .env file
load_dotenv()

# Only check for OpenAI (GPT-4o) API key
print("Checking for OpenAI GPT-4o API key...")
print("-" * 50)

openai_key = os.getenv("OPENAI_API_KEY", "")

if openai_key and openai_key != "your_openai_key_here":
    print("OPENAI_API_KEY: Found (ready to use GPT-4o)")
    print("-" * 50)
    print("\nGreat! GPT-4o is ready to use.")
    print("   You can proceed with the tutorial.")
    print("\nAvailable model:")
    print("   - GPT-4o")
else:
    print("OPENAI_API_KEY: Not found or not updated")
    print("-" * 50)
    print("\nERROR: OpenAI API key not found!")
    print("   Please edit the .env file and add your OpenAI API key.")
    print("   Then re-run this cell to verify.")


**Once you've verified your API keys are loaded above, you can proceed to the next steps!**

The tutorial will use these keys automatically when you call `get_model()`. Make sure you have at least one key set up before continuing.


## Step 1: Import Required Libraries

First, we need to import the necessary libraries. We'll use:
- **spikeinterface** - For loading and working with spike sorting data
- **matplotlib** - For plotting and visualization
- **PIL** - For image processing
- Custom tools from this package for VLM merge analysis

The `si.set_global_job_kwargs(n_jobs=-1)` sets spikeinterface to use all available CPU cores for faster processing.


In [None]:
import os
import spikeinterface.full as si
si.set_global_job_kwargs(n_jobs=-1)

## Step 2: Load Your Sorting Analyzer

**Important:** Replace `<path_to_analyzer_folder>` with the actual path to your SortingAnalyzer folder.

The SortingAnalyzer contains all the information about your spike sorting results, including:
- Spike times for each unit
- Waveforms
- Quality metrics
- Template information

**Example path:** `"/path/to/your/sorting_analyzer_folder"`


In [None]:
sorting_folder = "<path_to_analyzer_folder>"
analyzer_to_merge = si.load_sorting_analyzer(sorting_folder)

## Step 3: Find Potential Merge Candidates

This step automatically identifies units that might be duplicates and should be merged.

### How It Works

The algorithm uses **template similarity** to find units with similar waveforms. Units with very similar templates (above 90% similarity by default) are grouped together as potential merge candidates.

### Parameters Explained

- **`template_diff_thresh: 0.9`** - Units must be at least 90% similar to be considered for merging
  - Lower values (e.g., 0.8) = more aggressive merging (more units grouped)
  - Higher values (e.g., 0.95) = more conservative (fewer units grouped)

- **`resolve_graph=False`** - We're just finding candidates, not merging yet

### Output

The code will print how many potential merge groups were found and show the first 15 groups. Each group is a list of unit IDs that might represent the same neuron.


In [None]:
from spikeinterface.curation import compute_merge_unit_groups
print("--- Step 1: Computing potential merge candidates ---")
steps = ["template_similarity"]
steps_params = {
    "template_similarity": {"template_diff_thresh": 0.9}
}
potential_merge_groups = compute_merge_unit_groups(
    analyzer_to_merge,
    resolve_graph=False,
    steps_params=steps_params,
    steps=steps
)

num_groups = len(potential_merge_groups)
print(f"Found {num_groups} potential merge groups.")
if num_groups > 15:
    print("Showing only the first 15 groups:")
for i, group in enumerate(potential_merge_groups[:15]):
    print(f"  Group {i}: {group}")

## Step 4: Helper Functions for Image Processing

These functions help prepare images for the AI to analyze:

- **`concat_images_horizontally()`** - Combines multiple images side-by-side so the AI can compare them
- **`plot_concat_images()`** - Displays the combined images

These are utility functions used internally by the VLM merge process. You don't need to modify them, but they're here if you want to visualize the images yourself.


In [None]:
import matplotlib.pyplot as plt
from PIL import Image
import io
import base64
import pandas as pd

def concat_images_horizontally(b64_images, resize_height=300):
    images = [Image.open(io.BytesIO(base64.b64decode(b64))) for b64 in b64_images]

    if resize_height is not None:
        images = [
            img.resize((int(img.width * resize_height / img.height), resize_height))
            for img in images
        ]
    
    total_width = sum(img.width for img in images)
    max_height = max(img.height for img in images)
    
    new_img = Image.new("RGB", (total_width, max_height), (255, 255, 255))
    
    x_offset = 0
    for img in images:
        new_img.paste(img, (x_offset, 0))
        x_offset += img.width
    
    return new_img

def plot_concat_images(b64_images):
    img = concat_images_horizontally(b64_images)
    fig, ax = plt.subplots(1, 1, figsize=(15, 4))
    ax.imshow(img)
    ax.axis("off")

## Step 5: Prepare for VLM Analysis

This step sets up the AI model and prepares the images for analysis.

### What Happens Here

1. **Model Selection** - Choose any supported vision model (OpenAI, Anthropic, or Google)
2. **API Key Verification** - Automatically checks that you have the correct API key for your chosen model
3. **Folder Setup** - Determines where to save merge images and results
4. **Image Generation** - Creates visualization images for each potential merge group

### Available Models

You can use any of these models by changing the `model_name` variable:

**OpenAI Models:**
- `"gpt-4o"` (recommended, best vision capabilities)
- `"gpt-4o-mini"` (faster, cheaper)
- `"gpt-4-turbo"`
- `"gpt-3.5-turbo"`

**Anthropic Models:**
- `"claude_4_sonnet"`
- `"claude_3_5_sonnet"`
- `"claude_3_opus"`
- `"claude_3_haiku"` (fastest)

**Google Models:**
- `"gemini-2.5-pro"`
- `"gemini-1.5-flash"` (fastest)
- `"gemini-1.5-pro"`

### Important Notes

- **Model Selection**: Change `model_name = "gpt-4o"` to any model from the list above
- **API Key**: The code automatically checks for the correct API key based on your model choice
- **Sorting Folder**: The code automatically finds your sorting folder, but you can set it manually if needed
- **Image Creation**: This may take a few minutes depending on how many merge candidates you have

The `merge_img_df` contains all the images that will be sent to the AI for analysis.


In [None]:
# --- Part 2: Run VLM Merge Decision ---
import os
from dotenv import load_dotenv

# Ensure .env file is loaded (in case notebook was restarted)
load_dotenv()

import tool.si_custom as sic
from tool.vlm_merge import run_vlm_merge, plot_merge_results
from tool.utils import get_model
from spikeinterface.curation.curation_tools import resolve_merging_graph

if len(potential_merge_groups) == 0:
    print("NO_MERGE_CANDIDATES::No potential merge groups found. Skipping VLM merge analysis.")
    # Set the output to be the same as the input if no merges are performed
    merged_analyzer = analyzer_to_merge
else:
    print("\n--- Step 2: Running VLM to decide on merges ---")
    
    # Choose which model to use (you can change this to any supported model)
    # Available models:
    # OpenAI: "gpt-4o", "gpt-4o-mini", "gpt-4-turbo", "gpt-3.5-turbo"
    # Anthropic: "claude_4_sonnet", "claude_3_5_sonnet", "claude_3_opus", "claude_3_haiku"
    # Google: "gemini-2.5-pro", "gemini-1.5-flash", "gemini-1.5-pro"
    model_name = "gpt-4o"  # Change this to your preferred model
    
    # Verify API key is available for the selected model
    openai_key = os.getenv("OPENAI_API_KEY", "")
    anthropic_key = os.getenv("ANTHROPIC_API_KEY", "")
    google_key = os.getenv("GOOGLE_API_KEY", "")
    
    # Check which provider the model belongs to
    openai_models = ["gpt-4o", "gpt-4o-mini", "gpt-4.1", "o1", "gpt-4-turbo", "gpt-3.5-turbo"]
    anthropic_models = ["claude_4_sonnet", "claude_4_opus", "claude_3_7_sonnet", "claude_3_5_sonnet", "claude_3_opus", "claude_3_haiku", "claude_3_sonnet"]
    google_models = ["gemini-2.5-pro", "gemini-2.0-flash-exp", "gemini-1.5-flash", "gemini-1.5-flash-8b", "gemini-1.5-pro"]
    
    if model_name in openai_models:
        if not openai_key or openai_key == "your_openai_key_here":
            raise ValueError(
                f"ERROR: OPENAI_API_KEY not found for model '{model_name}'!\n"
                "Please go back to Step 0 and set up your OpenAI API key in the .env file.\n"
                "The .env file should contain: OPENAI_API_KEY=your_actual_key_here"
            )
    elif model_name in anthropic_models:
        if not anthropic_key or anthropic_key == "your_anthropic_key_here":
            raise ValueError(
                f"ERROR: ANTHROPIC_API_KEY not found for model '{model_name}'!\n"
                "Please go back to Step 0 and set up your Anthropic API key in the .env file.\n"
                "The .env file should contain: ANTHROPIC_API_KEY=your_actual_key_here"
            )
    elif model_name in google_models:
        if not google_key or google_key == "your_google_key_here":
            raise ValueError(
                f"ERROR: GOOGLE_API_KEY not found for model '{model_name}'!\n"
                "Please go back to Step 0 and set up your Google API key in the .env file.\n"
                "The .env file should contain: GOOGLE_API_KEY=your_actual_key_here"
            )
    else:
        raise ValueError(f"ERROR: Unknown model '{model_name}'. Please check the model name.")
    
    print(f"API key loaded. Initializing model: {model_name}...")
    model = get_model(model_name=model_name)
    # The `sorting_folder` should be defined from a previous step, but we fall back gracefully.
    print(f"sorting_folder: {sorting_folder}") # The agent should check if the sorting_folder is defined first before falling back to current directory
    if 'sorting_folder' not in globals() or sorting_folder is None:
        if analyzer_to_merge.folder:
            sorting_folder = os.path.dirname(analyzer_to_merge.folder)
        else:
            sorting_folder = os.getcwd() # Fallback to current directory
            print(f"Warning: `sorting_folder` not found. Defaulting to current directory: {sorting_folder}")

    merge_img_df = sic.create_merge_img_df(analyzer_to_merge, unit_groups=potential_merge_groups, load_if_exists=False, save_folder=sorting_folder)


## Step 6: Run AI Merge Analysis

This is the main step where the AI analyzes each potential merge group and decides whether units should be merged.

### How the AI Makes Decisions

The AI looks at two types of visual features:

1. **Crosscorrelograms** - Shows the temporal relationship between spikes from different units
   - If units are the same neuron, their crosscorrelogram should show a sharp peak at time 0
   - If units are different neurons, the crosscorrelogram should be flat

2. **Amplitude Plots** - Shows the distribution of spike amplitudes over time
   - Same neuron = similar amplitude distributions
   - Different neurons = different amplitude patterns

### Parameters Explained

- **`features`** - Which visual features to analyze
  - `["crosscorrelograms", "amplitude_plot"]` - Analyze both
  - `["crosscorrelograms"]` - Only crosscorrelograms (faster)
  - `["amplitude_plot"]` - Only amplitude plots

- **`good_merge_groups=[1]`** - Example groups that are known good merges (for training/calibration)
- **`bad_merge_groups=[0]`** - Example groups that are known bad merges (for training/calibration)
- **`num_workers=50`** - Number of parallel workers (adjust based on your system)

### Output

The function returns a DataFrame with:
- Merge group IDs
- AI's decision (merge or don't merge)
- Reasoning/explanation for each decision
- Confidence scores

The results are automatically saved to `vlm_merge_reasoning.csv` in your sorting folder for review.

### What to Expect

- **Processing Time**: This can take 10-30 minutes depending on the number of merge candidates
- **API Costs**: Each merge group requires an API call. Monitor your usage if using paid APIs
- **Results**: Review the CSV file to see the AI's reasoning before applying merges


In [None]:
merge_results_df = run_vlm_merge(
    model=model,
    merge_unit_groups=potential_merge_groups,
    img_df=merge_img_df,
    features=["crosscorrelograms", "amplitude_plot"], # available options are: ["crosscorrelograms", "amplitude_plot"]
    good_merge_groups=[1], 
    bad_merge_groups=[0],
    num_workers=50
)

# Save CSV to deterministic path derived from analyzer
print(f"Using sorting_folder for merge outputs: {sorting_folder}")
merge_csv_path = os.path.join(sorting_folder,'vlm_merge_reasoning.csv')
merge_results_df.to_csv(merge_csv_path)
print(f"VLM merge reasoning log saved to: {merge_csv_path}")

## Next Steps

After running the analysis, you should:

1. **Review the Results** - Open `vlm_merge_reasoning.csv` and check the AI's decisions
2. **Verify Merge Decisions** - Look at a few examples manually to ensure the AI is making good decisions
3. **Apply Merges** - If you're satisfied, you can apply the merges using spikeinterface's merge functions
4. **Re-analyze** - After merging, you may want to re-run quality metrics and create a new SortingAnalyzer

## Tips for Best Results

- **Start with conservative thresholds** (0.9-0.95) to avoid false positives
- **Review a sample** of merge decisions before applying all of them
- **Adjust parameters** based on your data quality and needs
- **Use good/bad examples** if you have known merge cases to improve AI accuracy

## Troubleshooting

- **No merge candidates found**: Try lowering `template_diff_thresh` or check your data quality
- **AI making poor decisions**: Provide more good/bad examples or adjust the features being analyzed
- **Slow processing**: Reduce `num_workers` or analyze fewer features at once
- **API errors**: Check your API key and rate limits

Good luck with your curation!
