# VLM Merge Analysis Tutorial

This tutorial demonstrates how to use Vision-Language Models (VLM) to analyze whether units should be merged.

## What is VLM Merge Analysis?

VLM merge analysis uses AI vision models to determine if pairs or groups of units should be merged into a single unit. This is useful when spike sorting creates multiple units from the same neuron.

## Requirements

- A `SortingAnalyzer` with computed extensions:
  - `waveforms` / `templates` (for waveform plots)
  - `spike_locations` (for spike location plots)
  - `spike_amplitudes` (for amplitude plots)
  - `principal_components` (for PCA clustering)
  - `correlograms` (for crosscorrelograms)

- API key for your chosen VLM provider (OpenAI, Anthropic, or Google)


In [None]:
import os
import sys

_root_dir = os.path.abspath(os.path.join(os.getcwd(), '..'))
sys.path.insert(0, os.path.join(_root_dir, 'src'))

import spikeinterface as si
import numpy as np
from spikeagent.app.tool.si_custom import create_merge_img_df
from spikeagent.curation.vlm_merge import run_vlm_merge, plot_merge_results
from spikeagent.app.tool.utils import get_model
from dotenv import load_dotenv

load_dotenv()

print("Imports successful!")


## Step 1: Create a SortingAnalyzer

We'll create a synthetic `SortingAnalyzer` for demonstration. In practice, you would load your own recording and sorting.


In [None]:
print("Creating synthetic recording and sorting...")
recording, sorting = si.generate_ground_truth_recording(
    durations=[10.0],
    num_channels=16,
    num_units=6,
    sampling_frequency=30000.0,
    noise_kwargs={'noise_levels': 5.0, 'strategy': 'on_the_fly'},
    seed=42
)

sorting_analyzer = si.create_sorting_analyzer(
    sorting=sorting,
    recording=recording,
    format="memory"
)

print(f"Created sorting_analyzer with {len(sorting_analyzer.unit_ids)} units")
print(f"Unit IDs: {list(sorting_analyzer.unit_ids)}")


## Step 2: Compute Required Extensions

We need to compute all the extensions that VLM merge analysis will use.


In [None]:
print("Computing extensions...")

sorting_analyzer.compute("random_spikes")
sorting_analyzer.compute("waveforms", n_jobs=1)
sorting_analyzer.compute("templates")
sorting_analyzer.compute("correlograms", window_ms=100.0, bin_ms=1.0)
sorting_analyzer.compute("spike_locations", method="center_of_mass")
sorting_analyzer.compute("spike_amplitudes")
sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_local')

available_exts = [ext for ext in sorting_analyzer.get_computable_extensions() 
                  if sorting_analyzer.has_extension(ext)]
print(f"Computed extensions: {available_exts}")


## Step 3: Find Potential Merge Candidates

We'll use SpikeInterface's `compute_merge_unit_groups` to automatically find units that might be merged based on template similarity. This is the same approach SpikeAgent uses.


In [None]:
from spikeinterface.curation import compute_merge_unit_groups

print("Computing potential merge candidates based on template similarity...")
steps = ["template_similarity"]
steps_params = {
    "template_similarity": {"template_diff_thresh": 0.2}
}

potential_merge_groups = compute_merge_unit_groups(
    sorting_analyzer,
    resolve_graph=False,
    steps_params=steps_params,
    steps=steps
)

num_groups = len(potential_merge_groups)
print(f"Found {num_groups} potential merge groups.")
if num_groups > 0:
    print("Potential merge groups:")
    for i, group in enumerate(potential_merge_groups[:10]):
        print(f"  Group {i}: {group}")
    if num_groups > 10:
        print(f"  ... and {num_groups - 10} more groups")
else:
    print("No potential merge groups found. You may want to adjust template_diff_thresh.")


## Step 4: Create Merge Image DataFrame

The VLM needs images comparing units within each group. We'll create a dataframe containing base64-encoded images for each merge group.


In [None]:
if len(potential_merge_groups) == 0:
    print("No merge candidates found. Skipping VLM merge analysis.")
    print("You can try adjusting template_diff_thresh or manually define merge groups.")
else:
    features = ["waveform_single", "amplitude_plot", "crosscorrelograms", "pca_clustering"]
    
    print(f"\nCreating merge image dataframe with features: {features}...")
    import tempfile
    with tempfile.TemporaryDirectory() as tmpdir:
        img_df = create_merge_img_df(
            sorting_analyzer,
            unit_groups=potential_merge_groups,
            features=features,
            load_if_exists=False,
            save_folder=tmpdir
        )
    
    print(f"Created merge image dataframe: {img_df.shape}")
    print(f"Merge groups: {len(img_df)}")
    print(f"Features: {list(img_df.columns)}")


## Step 5: Run VLM Merge Analysis

Now we'll use a Vision-Language Model to analyze whether each potential merge group should actually be merged.

**Note:** This requires an API key. Set it in your environment or `.env` file:
- `OPENAI_API_KEY` for GPT-4o
- `ANTHROPIC_API_KEY` for Claude
- `GOOGLE_API_KEY` for Gemini


In [None]:
if len(potential_merge_groups) == 0:
    print("Skipping VLM merge analysis - no merge candidates found.")
    results_df = None
else:
    model_name = "gpt-4o"
    model = get_model(model_name)
    
    print(f"Initialized model: {model_name}")
    
    print("\nRunning VLM merge analysis...")
    print("This may take a few minutes depending on the number of groups...")
    
    results_df = run_vlm_merge(
        model=model,
        merge_unit_groups=potential_merge_groups,
        img_df=img_df,
        features=features,
        good_merge_groups=[],
        bad_merge_groups=[],
        num_workers=10
    )
    
    print(f"\nVLM merge analysis complete!")
    print(f"\nResults summary:")
    print(f"Total groups analyzed: {len(results_df)}")
    print(f"Recommended merges: {len(results_df[results_df['merge_type'] == 'merge'])}")
    print(f"Recommended to keep separate: {len(results_df[results_df['merge_type'] == 'not merge'])}")


## Step 6: View Results

Let's examine the results and see which groups were recommended for merging.


In [None]:
if results_df is not None:
    print("VLM Merge Analysis Results:")
    print("=" * 60)
    print(results_df[['merge_type', 'merge_units']].to_string())
    
    merge_groups = results_df[results_df['merge_type'] == 'merge'].index.tolist()
    keep_separate_groups = results_df[results_df['merge_type'] == 'not merge'].index.tolist()
    
    print(f"\nGroups recommended for merging: {merge_groups}")
    print(f"Groups recommended to keep separate: {keep_separate_groups}")
    
    print("\nDetailed reasoning:")
    for group_idx in range(min(5, len(potential_merge_groups))):
        if group_idx in results_df.index:
            group = potential_merge_groups[group_idx]
            print(f"\nGroup {group_idx} (units {group}):")
            print(f"Decision: {results_df.loc[group_idx, 'merge_type']}")
            reasoning = results_df.loc[group_idx, 'reasoning']
            print(f"Reasoning: {reasoning[:300]}...")
else:
    print("No results to display - no merge candidates were found.")


## Step 7: Visualize Results

Plot the merge groups with their classification results.


In [None]:
if results_df is not None:
    plot_merge_results(results_df, img_df)
else:
    print("No results to plot - no merge candidates were found.")


## Step 8: Apply Merges

If merges were recommended, we can apply them using SpikeInterface's merge functionality.


In [None]:
if results_df is not None:
    merge_groups = results_df[results_df['merge_type'] == 'merge'].index.tolist()
    
    if len(merge_groups) > 0:
        from spikeinterface.curation.curation_tools import resolve_merging_graph
        
        merge_unit_pairs = [results_df.loc[group_idx, 'merge_units'] for group_idx in merge_groups]
        final_merge_groups = resolve_merging_graph(sorting_analyzer.sorting, merge_unit_pairs)
        
        print("Applying merges...")
        print(f"Final merge groups: {final_merge_groups}")
        
        if final_merge_groups and len(final_merge_groups) > 0:
            merged_analyzer = sorting_analyzer.merge_units(
                merge_unit_groups=final_merge_groups,
                sparsity_overlap=0
            )
            print(f"Created merged_analyzer with {len(merged_analyzer.unit_ids)} units")
            print(f"Original units: {len(sorting_analyzer.unit_ids)}")
            print(f"Merged units: {len(merged_analyzer.unit_ids)}")
        else:
            print("No merges to apply after resolving merge graph.")
    else:
        print("No merges recommended. All units should be kept separate.")
else:
    print("No results available - no merge candidates were found.")


## Summary

This tutorial demonstrated:
1. Creating a `SortingAnalyzer` with required extensions
2. Using `compute_merge_unit_groups` to automatically find merge candidates (same as SpikeAgent)
3. Generating merge image dataframes for VLM analysis
4. Running VLM merge analysis to classify merge groups
5. Viewing and visualizing results
6. Applying merges using `resolve_merging_graph` and `merge_units`

### Next Steps

- Try different features: `["waveform_multi", "spike_locations"]`
- Use few-shot learning by providing `good_merge_groups` and `bad_merge_groups`
- Adjust `template_diff_thresh` to find more or fewer merge candidates
- Save results: `results_df.to_csv('vlm_merge_results.csv')`
