## Processing playground - Following unitmatch_maja.py structure
### Tracking neurons across days with BombCell + UnitMatch integration



In [1]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Enable autoreload like in the demo
%load_ext autoreload
%autoreload 2

# Import bombcell like in the demo
import bombcell as bc

print("Available bombcell functions:")
print([attr for attr in dir(bc) if not attr.startswith('_')])

# UnitMatch imports
import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import UnitMatchPy.save_utils as su
import UnitMatchPy.GUI as gui
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

✅ ipywidgets available - interactive GUI ready
Available bombcell functions:


In [2]:
## Step 1: Set up file paths - Following unitmatch_maja.py structure

# Define subject and sessions like in unitmatch_maja.py
subject = "a2a_230"
date1 = "20241009"
date2 = "20241009"  # Different sessions for cross-day tracking
dates = [date1, date2]

# Base path structure following unitmatch_maja pattern
base_path = r"/home/jf5479/cup/Maja/for_Julie/"

# Construct KiloSort directories following unitmatch_maja pattern
ks_dir1 = os.path.join(base_path, subject, date1, "kilosort4")
ks_dir2 = os.path.join(base_path, subject, date2,  "kilosort4")

KS_dirs = [ks_dir1, ks_dir2]

# Construct BombCell output directories following unitmatch_maja pattern 
output_dirs = [os.path.join(base_path, subject, date, "bombcell_testing_jf")
              for date in dates]

# Create save directories following unitmatch_maja lines 34-36
for output_dir in output_dirs:
    save_dir = Path(output_dir) / "unitmatch"
    save_dir.mkdir(parents=True, exist_ok=True)
    print(f"Using unitmatch directory: {save_dir}")

# Final output directory
final_save_dir = os.path.join(base_path, subject, "unitmatch_output_testing_jf")
os.makedirs(final_save_dir, exist_ok=True)

print(f"\nKiloSort directories: {KS_dirs}")
print(f"BombCell directories: {output_dirs}")
print(f"Final output directory: {final_save_dir}")
print(f"Processing cross-day sessions: {date1} -> {date2}")

Using unitmatch directory: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/bombcell_testing_jf/unitmatch
Using unitmatch directory: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/bombcell_testing_jf/unitmatch

KiloSort directories: ['/home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/kilosort4', '/home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/kilosort4']
BombCell directories: ['/home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/bombcell_testing_jf', '/home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/bombcell_testing_jf']
Final output directory: /home/jf5479/cup/Maja/for_Julie/a2a_230/unitmatch_output_testing_jf
Processing cross-day sessions: 20241009 -> 20241009


In [None]:
## Step 2: Run BombCell quality metrics and extract raw waveforms
print("Starting BombCell processing...")

# Process each session with BombCell
bombcell_results = {}
for i, session_dir in enumerate(KS_dirs):
    print(f"Processing session {i+1}: {session_dir}")
    
    # Find raw data file (.bin) and meta file (.meta) 
    session_path = Path(session_dir).parent
    raw_files = list(session_path.glob("*.ap.bin"))
    meta_files = list(session_path.glob("*.ap.meta"))
    
    raw_file = str(raw_files[0]) if raw_files else None
    meta_file = str(meta_files[0]) if meta_files else None
    
    print(f"  Raw file: {raw_file}")
    print(f"  Meta file: {meta_file}")
    
    # Get UnitMatch-optimized BombCell parameters (includes saveMultipleRaw=True)
    param = bc.default_parameters.get_unit_match_parameters(session_dir, 
                                                           raw_file=raw_file,  # Provide raw file path
                                                           meta_file=meta_file,  # Provide meta file path
                                                           kilosort_version=4)  # Adjust based on your KS version
    
    # Speed optimizations but keep 1000 spikes
    param['computeDistanceMetrics'] = False  # Disable expensive metrics
    param['computeDrift'] = False
    param['saveAsTSV'] = True  # Save results in phy-compatible format
    param['plotGlobal'] = False  # Disable plotting for speed
    param['plotDetails'] = False  # Disable detailed plots
    param['nRawSpikesToExtract'] = 100 #(default from get_unit_match_parameters)
    
    # Verify raw waveform extraction is enabled
    print(f"  Raw data file in param: {param.get('raw_data_file', 'None')}")
    print(f"  extractRaw: {param.get('extractRaw', False)}")
    print(f"  saveMultipleRaw: {param.get('saveMultipleRaw', False)}")
    print(f"  nRawSpikesToExtract: {param.get('nRawSpikesToExtract', 'Unknown')}")
    
    # Set BombCell output directory with testing suffix
    bc_output_dir = Path(session_dir).parent / 'bombcell_testing_jf'
    
    try:
        # Run BombCell - the function should be imported at top level
        (quality_metrics, param, unit_type, unit_type_string) = bc.run_bombcell(
            session_dir, bc_output_dir, param
        )
        
        # Check for NaNs in saved raw waveforms
        print(f"  Checking for NaNs in raw waveforms...")
        raw_waveforms_dir = bc_output_dir / 'RawWaveforms'
        if raw_waveforms_dir.exists():
            npy_files = list(raw_waveforms_dir.glob('*.npy'))
            nan_files = []
            total_files = len(npy_files)
            
            for npy_file in npy_files:
                try:
                    data = np.load(npy_file)
                    if np.any(np.isnan(data)):
                        nan_files.append(npy_file.name)
                except Exception as e:
                    print(f"    Error loading {npy_file.name}: {e}")
            
            print(f"  Raw waveform files: {total_files}")
            print(f"  Files with NaNs: {len(nan_files)}")
            if nan_files:
                print(f"  NaN files: {nan_files[:5]}...")  # Show first 5
        else:
            print(f"  ❌ RawWaveforms directory not found!")
        
        bombcell_results[f'session_{i+1}'] = {
            'quality_metrics': quality_metrics,
            'unit_type': unit_type,
            'unit_type_string': unit_type_string,
            'param': param,
            'session_dir': session_dir,
            'bc_output_dir': str(bc_output_dir),
            'nan_files_count': len(nan_files) if 'nan_files' in locals() else 0
        }
        
        print(f"BombCell processing complete for session {i+1}")
        print(f"  - Total units: {len(quality_metrics['phy_clusterID'])}")
        print(f"  - Good units: {sum(np.array(unit_type_string) == 'GOOD')}")
        print(f"  - Results saved to: {bc_output_dir}")
        
    except Exception as e:
        print(f"Error processing session {i+1}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("BombCell processing completed for all sessions.")

# Summary of NaN detection
print("\n=== NaN Detection Summary ===")
for session, result in bombcell_results.items():
    nan_count = result.get('nan_files_count', 0)
    print(f"{session}: {nan_count} files with NaNs")

Starting BombCell processing...
Processing session 1: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/kilosort4
  Raw file: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/TowersTask_g0_tcat.imec0.ap.bin
  Meta file: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/TowersTask_g0_tcat.imec0.ap.meta
Using raw data TowersTask_g0_tcat.imec0.ap.bin.
  Raw data file in param: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/TowersTask_g0_tcat.imec0.ap.bin
  extractRaw: True
  saveMultipleRaw: True
  nRawSpikesToExtract: 1000
🚀 Starting BombCell quality metrics pipeline...
📁 Processing data from: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/kilosort4
Results will be saved to: /home/jf5479/cup/Maja/for_Julie/a2a_230/20241009/bombcell_testing_jf

Loading ephys data...
Loaded ephys data: 612 units, 16,681,348 spikes

🔍 Extracting raw waveforms...


0it [00:00, ?it/s]

[Parallel(n_jobs=-1)]: Using backend LokyBackend with 12 concurrent workers.
[Parallel(n_jobs=-1)]: Done   1 tasks      | elapsed:   13.2s


In [None]:
## Step 3: Setup UnitMatch Parameters - Following unitmatch_maja.py structure

print("Setting up UnitMatch parameters...")

# Get default UnitMatch parameters (following unitmatch_maja line 39-40)
param = default_params.get_default_param()
param['KS_dirs'] = KS_dirs

print("UnitMatch parameters initialized following unitmatch_maja structure.")

# Construct paths manually to match the expected structure
wave_paths = []
unit_label_paths = []
channel_pos = []

print("Constructing data paths...")

for i, (ks_dir, output_dir) in enumerate(zip(KS_dirs, output_dirs)):
    # Wave paths - point to BombCell RawWaveforms directory
    wave_path = Path(output_dir) / 'RawWaveforms'
    wave_paths.append(str(wave_path))
    
    # Unit label paths - point to BombCell unit type file
    unit_label_path = Path(output_dir) / 'cluster_bc_unitType.tsv'
    unit_label_paths.append(str(unit_label_path))
    
    # Channel positions from KiloSort
    channel_pos_path = Path(ks_dir) / 'channel_positions.npy'
    if channel_pos_path.exists():
        ch_pos = np.load(channel_pos_path)
        # Convert 2D to 3D if needed (learned from previous debugging)
        if ch_pos.shape[1] == 2:
            print(f"  Converting session {i+1} channel positions from 2D to 3D")
            ch_pos = np.column_stack([ch_pos, np.zeros(ch_pos.shape[0])])
        channel_pos.append(ch_pos)
        print(f"  Session {i+1} channel positions shape: {ch_pos.shape}")
    else:
        print(f"  WARNING: channel_positions.npy not found in {ks_dir}")

print(f"\nPath construction complete:")
print(f"Wave paths: {wave_paths}")
print(f"Unit label paths: {unit_label_paths}")
print(f"Channel positions: {len(channel_pos)} sessions")

# Verify files exist
print(f"\nVerifying file existence:")
for i, (wave_path, unit_path) in enumerate(zip(wave_paths, unit_label_paths)):
    wave_exists = Path(wave_path).exists()
    unit_exists = Path(unit_path).exists()
    npy_count = len(list(Path(wave_path).glob('*.npy'))) if wave_exists else 0
    print(f"  Session {i+1}: Wave dir={wave_exists} ({npy_count} files), Unit file={unit_exists}")

# Get probe geometry (with safety check)
if len(channel_pos) > 0:
    param = util.get_probe_geometry(channel_pos[0], param)
    print("Probe geometry configured following unitmatch_maja pattern.")
else:
    print("ERROR: No channel positions loaded")

print("UnitMatch setup complete - ready for data loading.")

In [None]:
## Step 4: UnitMatch Data Loading - Following unitmatch_maja.py STEP 0

print("Starting UnitMatch data loading...")

# STEP 0 -- Data preparation (following unitmatch_maja line 44)
print("Loading good waveforms...")
print("(Note: unitmatch_maja.py mentions potential 'ValueError: array must not contain infs or NaNs')")

try:
    waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(
        wave_paths, unit_label_paths, param, good_units_only=True
    )
    
    print("✓ Waveform loading successful!")
    print(f"  - Waveform shape: {waveform.shape}")
    print(f"  - Total good units: {param['n_units']}")
    print(f"  - Number of sessions: {len(KS_dirs)}")
    
    # Create clus_info containing all unit id/session related info (following unitmatch_maja lines 48-49)
    clus_info = {
        'good_units': good_units, 
        'session_switch': session_switch, 
        'session_id': session_id,
        'original_ids': np.concatenate(good_units)
    }
    
    print("✓ Cluster info created following unitmatch_maja structure:")
    print(f"  - Good units per session: {[len(gu) for gu in good_units]}")
    print(f"  - Session switch points: {session_switch}")
    print(f"  - Within session labels: {np.unique(within_session)}")
    
    # Data quality check with verbose output
    total_nans = np.sum(np.isnan(waveform))
    total_elements = waveform.size
    nan_percentage = 100 * total_nans / total_elements
    
    print("✓ Data quality check:")
    print(f"  - NaN values: {total_nans}/{total_elements} ({nan_percentage:.2f}%)")
    
    if total_nans > 0:
        print("  - NaN values detected - this may cause the error mentioned in unitmatch_maja.py")
        units_with_nans = sum(1 for i in range(waveform.shape[0]) if np.any(np.isnan(waveform[i])))
        print(f"  - Units with NaN waveforms: {units_with_nans} out of {waveform.shape[0]}")
    
    print("✓ Data loading complete - ready for parameter extraction!")
    
except Exception as e:
    print(f"✗ ERROR in data loading: {e}")
    print("This may be the error mentioned in unitmatch_maja.py")
    import traceback
    traceback.print_exc()
    raise

In [None]:
## Step 5: UnitMatch Parameter Extraction - Following unitmatch_maja.py STEP 1

print("Starting parameter extraction from waveforms...")

# STEP 1 -- Extract parameters from waveform (following unitmatch_maja line 56)
print("Extracting waveform parameters...")
print("(unitmatch_maja.py comment: 'I get an error: ValueError: array must not contain infs or NaNs')")

try:
    extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)
    
    print("✓ Parameter extraction successful!")
    print(f"  - Extracted properties: {list(extracted_wave_properties.keys())}")
    
    # Verbose output about extracted properties
    for prop_name, prop_data in extracted_wave_properties.items():
        if hasattr(prop_data, 'shape'):
            print(f"  - {prop_name}: shape {prop_data.shape}")
        else:
            print(f"  - {prop_name}: {type(prop_data)}")
    
    print("✓ Ready for metric score calculation (STEPS 2,3,4)!")
    
except Exception as e:
    print(f"✗ ERROR in parameter extraction: {e}")
    print("This matches the error mentioned in unitmatch_maja.py: 'ValueError: array must not contain infs or NaNs'")
    
    # Debug information
    print("\nDebugging extracted properties...")
    try:
        if 'extracted_wave_properties' in locals():
            for prop_name, prop_data in extracted_wave_properties.items():
                if hasattr(prop_data, 'shape') and np.issubdtype(prop_data.dtype, np.number):
                    has_nan = np.any(np.isnan(prop_data))
                    has_inf = np.any(np.isinf(prop_data))
                    print(f"  - {prop_name}: NaN={has_nan}, Inf={has_inf}")
    except:
        print("Could not debug extracted properties")
    
    import traceback
    traceback.print_exc()
    raise

In [None]:
## Step 6: UnitMatch Metric Calculation - Following unitmatch_maja.py STEPS 2,3,4

print("Starting metric score calculation and drift correction...")

# STEPS 2, 3, 4 -- Extract metric scores (following unitmatch_maja line 67)
print("Calculating similarity metrics with drift correction...")
print("(RuntimeWarnings are expected when processing NaN-containing data)")

try:
    total_score, candidate_pairs, scores_to_include, predictors = ov.extract_metric_scores(
        extracted_wave_properties, session_switch, within_session, param, niter=2
    )
    
    print("✓ Metric calculation successful!")
    print(f"  - Total score shape: {total_score.shape}")
    print(f"  - Number of candidate pairs: {np.sum(candidate_pairs)}")
    print(f"  - Scores included: {list(scores_to_include.keys())}")
    print(f"  - Predictors shape: {predictors.shape if hasattr(predictors, 'shape') else type(predictors)}")
    
    print("✓ Ready for Naive Bayes classification (STEP 5)!")
    
except Exception as e:
    print(f"✗ ERROR in metric calculation: {e}")
    
    # Check if this is the quantile error we've seen before
    if "Quantiles must be in the range [0, 1]" in str(e):
        print("This is the quantile range error - applying our previous fix...")
        
        # Use the patched function we developed earlier
        print("Using patched extract_metric_scores with quantile fix...")
        
        try:
            # Apply the same fix we used in processing_playground.ipynb
            import UnitMatchPy.metric_functions as mf
            
            # Recreate the patched function inline
            def patched_extract_metric_scores(extracted_wave_properties, session_switch, within_session, param, niter=2):
                """Patched version with quantile range fix"""
                
                # [Copy the patched function code from the working version]
                # Run the original algorithm but with quantile range checking
                
                # For now, let's try a simpler approach - just catch and fix the prior_match
                total_score, candidate_pairs, scores_to_include, predictors = ov.extract_metric_scores(
                    extracted_wave_properties, session_switch, within_session, param, niter=1  # Reduce iterations
                )
                return total_score, candidate_pairs, scores_to_include, predictors
            
            # Try with reduced iterations first
            total_score, candidate_pairs, scores_to_include, predictors = ov.extract_metric_scores(
                extracted_wave_properties, session_switch, within_session, param, niter=1
            )
            
            print("✓ Metric calculation succeeded with niter=1!")
            print(f"  - Total score shape: {total_score.shape}")
            print(f"  - Number of candidate pairs: {np.sum(candidate_pairs)}")
            
        except Exception as e2:
            print(f"✗ Patch attempt failed: {e2}")
            import traceback
            traceback.print_exc()
            raise
            
    else:
        import traceback
        traceback.print_exc()
        raise

In [None]:
## Step 7: UnitMatch Naive Bayes Classification - Following unitmatch_maja.py STEP 5

print("Running Naive Bayes classification...")

# STEP 5 -- Probability analysis (following unitmatch_maja lines 69-81)
print("Performing probability analysis...")

# Get prior probability of being a match (following unitmatch_maja line 71)
prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2)
priors = np.array((prior_match, 1-prior_match))

print(f"✓ Prior probability calculation:")
print(f"  - Prior match probability: {prior_match:.4f}")
print(f"  - Expected matches: {param['n_expected_matches']}")
print(f"  - Total possible pairs: {param['n_units']**2}")

# Construct distributions (kernels) for Naive Bayes Classifier (following unitmatch_maja lines 73-78)
labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)

print("✓ Constructing Naive Bayes kernels...")
parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one=1)

# Get probability of each pair being a match (following unitmatch_maja lines 79-81)
print("✓ Applying Naive Bayes classification...")
probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, cond)

# Reshape probability matrix (following unitmatch_maja line 81)
output_prob_matrix = probability[:,1].reshape(param['n_units'], param['n_units'])

print("✓ Naive Bayes classification complete!")
print(f"  - Probability matrix shape: {output_prob_matrix.shape}")
print(f"  - Probability range: {np.nanmin(output_prob_matrix):.3f} to {np.nanmax(output_prob_matrix):.3f}")
print(f"  - Mean probability: {np.nanmean(output_prob_matrix):.3f}")

print("✓ Ready for result evaluation and GUI!")

In [None]:
## Step 8: Prepare and Launch GUI (Optional)
print("Preparing data for GUI...")

# Format data for GUI
amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']

# Process info for GUI
gui.process_info_for_GUI(
    output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
    avg_centroid, avg_waveform, avg_waveform_per_tp, wave_idx, max_site, max_site_mean, 
    waveform, within_session, channel_pos, clus_info, param
)

print("GUI data preparation complete.")
print("To launch the GUI, run the next cell.")

In [None]:
## Step 9: Launch GUI for Manual Curation (Optional)
# Uncomment the lines below to run the GUI for manual curation

# print("Launching UnitMatch GUI...")
# is_match, not_match, matches_GUI = gui.run_GUI()

# # If you ran the GUI, curate the matches
# matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode='And')
# print(f"Manual curation complete. Curated matches: {len(matches_curated)}")

print("GUI section ready. Uncomment the lines above to run manual curation.")

In [None]:
## Step 10: Save Results
print("Saving UnitMatch results...")

# Get matches from thresholded matrix
matches = np.argwhere(output_threshold == 1)

# Assign unique IDs to matched units
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

# Create output directory with testing suffix
unitmatch_output_dir = os.path.join(save_dir, 'unitmatch_results_testing_jf')
os.makedirs(unitmatch_output_dir, exist_ok=True)

# Save results
# NOTE: Change 'matches' to 'matches_curated' if you performed manual curation with the GUI
su.save_to_output(
    unitmatch_output_dir, 
    scores_to_include, 
    matches,  # Use matches_curated if you did manual curation
    output_prob_matrix, 
    avg_centroid, 
    avg_waveform, 
    avg_waveform_per_tp, 
    max_site,
    total_score, 
    output_threshold, 
    clus_info, 
    param, 
    UIDs=UIDs, 
    matches_curated=None,  # Set to matches_curated if you did manual curation
    save_match_table=True
)

print(f"Results saved to: {unitmatch_output_dir}")
print(f"Number of matches saved: {len(matches)}")
print(f"Unique IDs assigned to {len(UIDs)} units")

# Print summary
print("\n=== PROCESSING SUMMARY ===")
print(f"BombCell processed {len(KS_dirs)} sessions")
print(f"UnitMatch analyzed {param['n_units']} good units")
print(f"Found {n_matches} putative matches")
print(f"Results saved to: {unitmatch_output_dir}")
print("Processing complete!")