## Processing playground 
### tracking neurons across days 

In [12]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Enable autoreload like in the demo
%load_ext autoreload
%autoreload 2

# Import bombcell like in the demo
import bombcell as bc

print("Available bombcell functions:")
print([attr for attr in dir(bc) if not attr.startswith('_')])

# UnitMatch imports
import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import UnitMatchPy.save_utils as su
import UnitMatchPy.GUI as gui
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Available bombcell functions:


In [None]:
## Step 1: Set up file paths
# KiloSort directories - USE SAME SESSION TWICE for merge/split testing
# loop for animals in cta_backwards directory starting with calca_2 ot calca_3. use two last day folders inside. 
KS_dirs = [r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4',
           r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4']  # Same session repeated

# BombCell output directories - USE SAME SESSION TWICE  
custom_bombcell_paths = [r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf',
                         r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf']  # Same session repeated

# Output directory for saving results - add testing suffix
save_dir = r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/unitmatch_output_testing_jf'
os.makedirs(save_dir, exist_ok=True)

print(f"KiloSort directories: {KS_dirs}")
print(f"BombCell directories: {custom_bombcell_paths}")
print(f"Output directory: {save_dir}")
print("")
print("NOTE: Using the SAME session twice to test for merges/splits within the session")
print("This is correct for validating UnitMatch merge detection capabilities")

KiloSort directories: ['/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4', '/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4']
BombCell directories: ['/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf', '/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf']
Output directory: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/unitmatch_output_testing_jf

NOTE: Using the SAME session twice to test for merges/splits within the session
This is correct for validating UnitMatch merge detection capabilities


In [15]:
## Step 2: Run BombCell quality metrics and extract raw waveforms
print("Starting BombCell processing...")

# Process each session with BombCell
bombcell_results = {}
for i, session_dir in enumerate(KS_dirs):
    print(f"Processing session {i+1}: {session_dir}")
    
    # Find raw data file (.bin) and meta file (.meta) 
    session_path = Path(session_dir).parent
    raw_files = list(session_path.glob("*.ap.bin"))
    meta_files = list(session_path.glob("*.ap.meta"))
    
    raw_file = str(raw_files[0]) if raw_files else None
    meta_file = str(meta_files[0]) if meta_files else None
    
    print(f"  Raw file: {raw_file}")
    print(f"  Meta file: {meta_file}")
    
    # Get UnitMatch-optimized BombCell parameters (includes saveMultipleRaw=True)
    param = bc.default_parameters.get_unit_match_parameters(session_dir, 
                                                           raw_file=raw_file,  # Provide raw file path
                                                           meta_file=meta_file,  # Provide meta file path
                                                           kilosort_version=4)  # Adjust based on your KS version
    
    # Speed optimizations but keep 1000 spikes
    param['computeDistanceMetrics'] = False  # Disable expensive metrics
    param['computeDrift'] = False
    param['saveAsTSV'] = True  # Save results in phy-compatible format
    param['plotGlobal'] = False  # Disable plotting for speed
    param['plotDetails'] = False  # Disable detailed plots
    param['nRawSpikesToExtract'] = 100 #(default from get_unit_match_parameters)
    
    # Verify raw waveform extraction is enabled
    print(f"  Raw data file in param: {param.get('raw_data_file', 'None')}")
    print(f"  extractRaw: {param.get('extractRaw', False)}")
    print(f"  saveMultipleRaw: {param.get('saveMultipleRaw', False)}")
    print(f"  nRawSpikesToExtract: {param.get('nRawSpikesToExtract', 'Unknown')}")
    
    # Set BombCell output directory with testing suffix
    bc_output_dir = Path(session_dir).parent / 'bombcell_testing_jf'
    
    try:
        # Run BombCell - the function should be imported at top level
        (quality_metrics, param, unit_type, unit_type_string) = bc.run_bombcell(
            session_dir, bc_output_dir, param
        )
        
        # Check for NaNs in saved raw waveforms
        print(f"  Checking for NaNs in raw waveforms...")
        raw_waveforms_dir = bc_output_dir / 'RawWaveforms'
        if raw_waveforms_dir.exists():
            npy_files = list(raw_waveforms_dir.glob('*.npy'))
            nan_files = []
            total_files = len(npy_files)
            
            for npy_file in npy_files:
                try:
                    data = np.load(npy_file)
                    if np.any(np.isnan(data)):
                        nan_files.append(npy_file.name)
                except Exception as e:
                    print(f"    Error loading {npy_file.name}: {e}")
            
            print(f"  Raw waveform files: {total_files}")
            print(f"  Files with NaNs: {len(nan_files)}")
            if nan_files:
                print(f"  NaN files: {nan_files[:5]}...")  # Show first 5
        else:
            print(f"  ❌ RawWaveforms directory not found!")
        
        bombcell_results[f'session_{i+1}'] = {
            'quality_metrics': quality_metrics,
            'unit_type': unit_type,
            'unit_type_string': unit_type_string,
            'param': param,
            'session_dir': session_dir,
            'bc_output_dir': str(bc_output_dir),
            'nan_files_count': len(nan_files) if 'nan_files' in locals() else 0
        }
        
        print(f"BombCell processing complete for session {i+1}")
        print(f"  - Total units: {len(quality_metrics['phy_clusterID'])}")
        print(f"  - Good units: {sum(np.array(unit_type_string) == 'GOOD')}")
        print(f"  - Results saved to: {bc_output_dir}")
        
    except Exception as e:
        print(f"Error processing session {i+1}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("BombCell processing completed for all sessions.")

# Summary of NaN detection
print("\n=== NaN Detection Summary ===")
for session, result in bombcell_results.items():
    nan_count = result.get('nan_files_count', 0)
    print(f"{session}: {nan_count} files with NaNs")

Starting BombCell processing...
Processing session 1: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4
  Raw file: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/cz_npxl_g0_tcat.imec0.ap.bin
  Meta file: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/cz_npxl_g0_tcat.imec0.ap.meta
Using raw data cz_npxl_g0_t0.imec0.ap.bin.
  Raw data file in param: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/cz_npxl_g0_t0.imec0.ap.bin
  extractRaw: True
  saveMultipleRaw: True
  nRawSpikesToExtract: 100
🚀 Starting BombCell quality metrics pipeline...
📁 Processing data from: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4
Results will be saved to: /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf

Loading eph

Computing bombcell quality metrics:   0%|          | 0/657 units

KeyboardInterrupt: 

In [9]:
## Step 3: Prepare data for UnitMatch
print("Preparing data for UnitMatch...")

# Since we know the raw waveforms are saved, let's directly use the paths
print("Using predefined BombCell paths...")
bc_output_dirs = custom_bombcell_paths

# Check if directories exist and have RawWaveforms
for i, bc_dir in enumerate(bc_output_dirs):
    raw_waveforms_dir = Path(bc_dir) / 'RawWaveforms'
    if raw_waveforms_dir.exists():
        npy_files = list(raw_waveforms_dir.glob('*.npy'))
        print(f"  Session {i+1}: {len(npy_files)} .npy files found in {raw_waveforms_dir}")
    else:
        print(f"  Session {i+1}: RawWaveforms directory not found at {raw_waveforms_dir}")

# Get default UnitMatch parameters
param = default_params.get_default_param()

# Set up paths for UnitMatch - using the directories defined above
param['KS_dirs'] = KS_dirs

print(f"BombCell output directories: {bc_output_dirs}")

# Manually construct the correct paths since util.paths_from_KS isn't working properly
wave_paths = []
unit_label_paths = []
channel_pos = []

for i, (ks_dir, bc_dir) in enumerate(zip(KS_dirs, bc_output_dirs)):
    # Raw waveforms path
    wave_path = Path(bc_dir) / 'RawWaveforms'
    wave_paths.append(str(wave_path))
    
    # Unit label path - point to the specific TSV file
    unit_label_path = Path(bc_dir) / 'cluster_bc_unitType.tsv'
    unit_label_paths.append(str(unit_label_path))
    
    # Channel positions from KiloSort
    channel_pos_path = Path(ks_dir) / 'channel_positions.npy'
    if channel_pos_path.exists():
        channel_pos.append(np.load(channel_pos_path))
    else:
        print(f"WARNING: channel_positions.npy not found in {ks_dir}")

print(f"Raw waveform paths: {wave_paths}")
print(f"Unit label paths: {unit_label_paths}")
print(f"Channel positions loaded: {len(channel_pos)} sessions")

# Verify the unit label files exist
for i, unit_label_path in enumerate(unit_label_paths):
    if Path(unit_label_path).exists():
        print(f"  Session {i+1}: Unit label file found: {unit_label_path}")
    else:
        print(f"  Session {i+1}: Unit label file NOT found: {unit_label_path}")

# Get probe geometry
if len(channel_pos) > 0:
    param = util.get_probe_geometry(channel_pos[0], param)
    print("Data preparation for UnitMatch complete.")
else:
    print("ERROR: No channel positions loaded")

Preparing data for UnitMatch...
Using predefined BombCell paths...
  Session 1: 657 .npy files found in /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf/RawWaveforms
  Session 2: 657 .npy files found in /home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf/RawWaveforms
BombCell output directories: ['/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf', '/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf']
Raw waveform paths: ['/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf/RawWaveforms', '/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf/RawWaveforms']
Unit label paths: ['/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-0

In [10]:
## Step 4: Run UnitMatch - Data Loading and Parameter Extraction

# Check if previous step completed successfully
if 'wave_paths' not in locals():
    print("ERROR: wave_paths not defined - UnitMatch preparation failed")
    print("Please run the previous cell (Step 3) successfully first")
    print("Current local variables:", [var for var in locals().keys() if not var.startswith('_')])
else:
    print("Starting UnitMatch processing...")

    # STEP 0 -- Data preparation
    print("Loading good waveforms...")
    waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(
        wave_paths, unit_label_paths, param, good_units_only=True
    ) 

    # Fix channel positions - convert 2D to 3D if needed
    for i, ch_pos in enumerate(channel_pos):
        if ch_pos.shape[1] == 2:
            # Add z=0 dimension for 2D Neuropixels probes
            print(f"Converting session {i+1} channel positions from 2D to 3D")
            channel_pos[i] = np.column_stack([ch_pos, np.zeros(ch_pos.shape[0])])
        print(f"Session {i+1} channel positions shape: {channel_pos[i].shape}")

    # You may need to set peak location if it's not automatically detected correctly
    # param['peak_loc'] = # set as a value if the peak location is NOT ~ half the spike width

    # Create clus_info containing all unit id/session related info
    clus_info = {
        'good_units': good_units, 
        'session_switch': session_switch, 
        'session_id': session_id, 
        'original_ids': np.concatenate(good_units)
    }

    print(f"Total number of good units: {param['n_units']}")
    print(f"Number of sessions: {len(KS_dirs)}")
    
    # Check for NaNs in waveform data before processing
    total_nans = np.sum(np.isnan(waveform))
    total_elements = waveform.size
    print(f"Waveform data quality check: {total_nans}/{total_elements} NaN values ({100*total_nans/total_elements:.2f}%)")
    
    if total_nans > 0:
        print("WARNING: Found NaN values in waveform data - this may cause metric calculation warnings")
        # Find which units have NaNs
        units_with_nans = []
        for i in range(waveform.shape[0]):
            if np.any(np.isnan(waveform[i])):
                units_with_nans.append(i)
        print(f"Units with NaN waveforms: {len(units_with_nans)} out of {waveform.shape[0]}")

    # STEP 1 -- Extract parameters from waveforms
    print("Extracting waveform parameters...")
    print("(RuntimeWarnings about NaN slices are expected and handled by UnitMatch)")
    extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)
    print("Parameter extraction complete.")

Starting UnitMatch processing...
Loading good waveforms...
Converting session 1 channel positions from 2D to 3D
Session 1 channel positions shape: (384, 3)
Converting session 2 channel positions from 2D to 3D
Session 2 channel positions shape: (384, 3)
Total number of good units: 292
Number of sessions: 2
Waveform data quality check: 0/13679616 NaN values (0.00%)
Extracting waveform parameters...
Parameter extraction complete.


In [11]:
## Step 5: UnitMatch - Metric Calculation and Drift Correction
print("Calculating similarity metrics and applying drift correction...")

# Temporarily patch the UnitMatch overlord function to debug the quantile issue
import UnitMatchPy.overlord as ov_original
import UnitMatchPy.metric_functions as mf

def patched_extract_metric_scores(extracted_wave_properties, session_switch, within_session, param, niter=2):
    """Patched version with debugging for quantile issue"""
    import numpy as np
    
    # Unpack needed arrays from the ExtractedWaveProperties dictionary
    amplitude = extracted_wave_properties['amplitude']
    spatial_decay = extracted_wave_properties['spatial_decay']
    spatial_decay_fit = extracted_wave_properties['spatial_decay_fit']
    avg_waveform = extracted_wave_properties['avg_waveform']
    avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
    avg_centroid = extracted_wave_properties['avg_centroid']

    # These scores are NOT affected by the drift correction
    amp_score = mf.get_simple_metric(amplitude)
    spatial_decay_score = mf.get_simple_metric(spatial_decay)
    spatial_decay_fit_score = mf.get_simple_metric(spatial_decay_fit, outlier=True)
    wave_corr_score = mf.get_wave_corr(avg_waveform, param)
    wave_mse_score = mf.get_waveforms_mse(avg_waveform, param)

    # Affected by drift
    for i in range(niter):
        avg_waveform_per_tp_flip = mf.flip_dim(avg_waveform_per_tp, param)
        euclid_dist = mf.get_Euclidean_dist(avg_waveform_per_tp_flip, param)

        centroid_dist, centroid_var = mf.centroid_metrics(euclid_dist, param)

        euclid_dist_rc = mf.get_recentered_euclidean_dist(avg_waveform_per_tp_flip, avg_centroid, param)

        centroid_dist_recentered = mf.recentered_metrics(euclid_dist_rc)
        traj_angle_score, traj_dist_score = mf.dist_angle(avg_waveform_per_tp_flip, param)

        # Average Euc Dist
        euclid_dist = np.nanmin(euclid_dist[:,param['peak_loc'] - param['waveidx'] == 0, :,:].squeeze(), axis=1)

        # TotalScore
        include_these_pairs = np.argwhere(euclid_dist < param['max_dist'])  # array indices of pairs to include
        include_these_pairs_idx = np.zeros_like(euclid_dist)
        include_these_pairs_idx[euclid_dist < param['max_dist']] = 1

        # Make a dictionary of score to include
        centroid_overlord_score = (centroid_dist_recentered + centroid_var) / 2
        waveform_score = (wave_corr_score + wave_mse_score) / 2
        trajectory_score = (traj_angle_score + traj_dist_score) / 2

        scores_to_include = {'amp_score': amp_score, 'spatial_decay_score': spatial_decay_score, 'centroid_overlord_score': centroid_overlord_score,
                            'centroid_dist': centroid_dist, 'waveform_score': waveform_score, 'trajectory_score': trajectory_score}

        total_score, predictors = mf.get_total_score(scores_to_include, param)

        # Initial thresholding
        if (i < niter - 1):
            # Get the threshold for a match
            thrs_opt = mf.get_threshold(total_score, within_session, euclid_dist, param, is_first_pass=True)

            param['n_expected_matches'] = np.sum((total_score > thrs_opt).astype(int))
            prior_match = 1 - (param['n_expected_matches'] / len(include_these_pairs))
            candidate_pairs = total_score > thrs_opt

            drifts, avg_centroid, avg_waveform_per_tp = mf.drift_n_sessions(candidate_pairs, session_switch, avg_centroid, avg_waveform_per_tp, total_score, param)

    # Final threshold calculation with debugging
    thrs_opt = mf.get_threshold(total_score, within_session, euclid_dist, param, is_first_pass=False)
    param['n_expected_matches'] = np.sum((total_score > thrs_opt).astype(int))
    prior_match = 1 - (param['n_expected_matches'] / len(include_these_pairs))
    
    # Debug: Let's examine the total_score and prior_match values
    print(f"Debug patched function:")
    print(f"include_these_pairs length: {len(include_these_pairs)}")
    print(f"n_expected_matches: {param['n_expected_matches']}")
    print(f"prior_match calculation: 1 - ({param['n_expected_matches']} / {len(include_these_pairs)}) = {prior_match}")
    print(f"prior_match value: {prior_match}")
    print(f"prior_match type: {type(prior_match)}")
    print(f"Is prior_match finite: {np.isfinite(prior_match)}")
    print(f"total_score shape: {total_score.shape}")
    print(f"include_these_pairs_idx sum: {np.sum(include_these_pairs_idx.astype(bool))}")
    print(f"total_score[include_these_pairs_idx.astype(bool)] shape: {total_score[include_these_pairs_idx.astype(bool)].shape}")
    
    if len(include_these_pairs) > 0:
        print(f"total_score min: {np.min(total_score[include_these_pairs_idx.astype(bool)])}")
        print(f"total_score max: {np.max(total_score[include_these_pairs_idx.astype(bool)])}")
    
    # Original problematic line that we're trying to debug
    try:
        thrs_opt = np.quantile(total_score[include_these_pairs_idx.astype(bool)], prior_match)
        print(f"Quantile calculation succeeded: {thrs_opt}")
    except Exception as e:
        print(f"Quantile calculation failed: {e}")
        # Let's try to understand what's wrong
        if not (0 <= prior_match <= 1):
            print("prior_match is outside [0,1] range!")
            if prior_match < 0:
                print("prior_match is negative - using 0.0 instead")
                prior_match = 0.0
            elif prior_match > 1:
                print("prior_match is greater than 1 - using 1.0 instead")
                prior_match = 1.0
            
            try:
                thrs_opt = np.quantile(total_score[include_these_pairs_idx.astype(bool)], prior_match)
                print(f"Quantile calculation with corrected prior_match succeeded: {thrs_opt}")
            except Exception as e2:
                print(f"Quantile calculation still failed: {e2}")
                thrs_opt = np.median(total_score[include_these_pairs_idx.astype(bool)])
                print(f"Using median as fallback: {thrs_opt}")
    
    candidate_pairs = total_score > thrs_opt
    return total_score, candidate_pairs, scores_to_include, predictors

# Use the patched function
try:
    total_score, candidate_pairs, scores_to_include, predictors = patched_extract_metric_scores(
        extracted_wave_properties, session_switch, within_session, param, niter=2
    )

    print(f"SUCCESS: Number of candidate pairs: {np.sum(candidate_pairs)}")
    print(f"Scores included: {list(scores_to_include.keys())}")
    print("Metric calculation and drift correction complete.")
    
except Exception as e:
    print(f"ERROR in patched function: {e}")
    import traceback
    traceback.print_exc()

Calculating similarity metrics and applying drift correction...


  ang = np.abs( x1[dim_id1,:,:,:,:] - x2[dim_id1,:,:,:,:]) / np.abs(x1[dim_id2,:,:,:,:] - x2[dim_id2,:,:,:,:])
  return np.nanmean(a, axis, out=out, keepdims=keepdims)
  new_vals = np.nanmin(tmpdat, axis =1, keepdims=True) + np.nanmax(tmpdat, axis = 1, keepdims=True) - tmpdat


Done drift correction per shank for session pair 1 and 2


  centroid_dist = np.nanmin( euclid_dist[:,new_peak_loc - waveidx ==0,:,:].squeeze(), axis =1 ).squeeze()
  centroid_var = np.nanmin( np.nanvar(euclid_dist, axis = 1, ddof = 1 ).squeeze(), axis =1 ).squeeze()
  centroid_var = np.nanmin( np.nanvar(euclid_dist, axis = 1, ddof = 1 ).squeeze(), axis =1 ).squeeze()


Debug patched function:
include_these_pairs length: 2968
n_expected_matches: 3595
prior_match calculation: 1 - (3595 / 2968) = -0.21125336927223715
prior_match value: -0.21125336927223715
prior_match type: <class 'numpy.float64'>
Is prior_match finite: True
total_score shape: (292, 292)
include_these_pairs_idx sum: 2968
total_score[include_these_pairs_idx.astype(bool)] shape: (2968,)
total_score min: 0.1306966085751271
total_score max: 1.0
Quantile calculation failed: Quantiles must be in the range [0, 1]
prior_match is outside [0,1] range!
prior_match is negative - using 0.0 instead
Quantile calculation with corrected prior_match succeeded: 0.1306966085751271
SUCCESS: Number of candidate pairs: 84688
Scores included: ['amp_score', 'spatial_decay_score', 'centroid_overlord_score', 'centroid_dist', 'waveform_score', 'trajectory_score']
Metric calculation and drift correction complete.


  centroid_dist_recentered = np.nanmin( np.nanmean(euclid_dist_2, axis =1), axis =1)
  centroid_dist_recentered = np.nanmin( np.nanmean(euclid_dist_2, axis =1), axis =1)
  euclid_dist = np.nanmin(euclid_dist[:,param['peak_loc'] - param['waveidx'] == 0, :,:].squeeze(), axis=1)


In [None]:
## Step 6: UnitMatch - Naive Bayes Classification
print("Running Naive Bayes classification...")

# STEP 5 -- Probability analysis
# Get prior probability of being a match
prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2)
priors = np.array((prior_match, 1-prior_match))

print(f"Prior probability of match: {prior_match:.4f}")

# Construct distributions (kernels) for Naive Bayes Classifier
labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)

parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one=1)

# Get probability of each pair being a match
probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, cond)

# Reshape probability matrix
output_prob_matrix = probability[:,1].reshape(param['n_units'], param['n_units'])

print("Naive Bayes classification complete.")
print(f"Probability matrix shape: {output_prob_matrix.shape}")

In [None]:
## Step 7: Evaluate Results and Apply Threshold
print("Evaluating UnitMatch results...")

# Evaluate output with different thresholds
util.evaluate_output(output_prob_matrix, param, within_session, session_switch, match_threshold=0.75)

# Set match threshold (you can experiment with different values)
match_threshold = param['match_threshold']  # or set your own value, e.g., 0.75

# Apply threshold to create binary match matrix
output_threshold = np.zeros_like(output_prob_matrix)
output_threshold[output_prob_matrix > match_threshold] = 1

# Visualize the thresholded matches
plt.figure(figsize=(10, 8))
plt.imshow(output_threshold, cmap='Greys')
plt.title(f'Unit Matches (threshold = {match_threshold})')
plt.xlabel('Unit Index')
plt.ylabel('Unit Index')
plt.colorbar()
plt.show()

# Count matches
n_matches = np.sum(output_threshold) // 2  # Divide by 2 because matrix is symmetric
print(f"Number of putative matches found: {n_matches}")
print(f"Match threshold used: {match_threshold}")

In [None]:
## Step 8: Prepare and Launch GUI (Optional)
print("Preparing data for GUI...")

# Format data for GUI
amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']

# Process info for GUI
gui.process_info_for_GUI(
    output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
    avg_centroid, avg_waveform, avg_waveform_per_tp, wave_idx, max_site, max_site_mean, 
    waveform, within_session, channel_pos, clus_info, param
)

print("GUI data preparation complete.")
print("To launch the GUI, run the next cell.")

In [None]:
## Step 9: Launch GUI for Manual Curation (Optional)
# Uncomment the lines below to run the GUI for manual curation

# print("Launching UnitMatch GUI...")
# is_match, not_match, matches_GUI = gui.run_GUI()

# # If you ran the GUI, curate the matches
# matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode='And')
# print(f"Manual curation complete. Curated matches: {len(matches_curated)}")

print("GUI section ready. Uncomment the lines above to run manual curation.")

In [None]:
## Step 10: Save Results
print("Saving UnitMatch results...")

# Get matches from thresholded matrix
matches = np.argwhere(output_threshold == 1)

# Assign unique IDs to matched units
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

# Create output directory with testing suffix
unitmatch_output_dir = os.path.join(save_dir, 'unitmatch_results_testing_jf')
os.makedirs(unitmatch_output_dir, exist_ok=True)

# Save results
# NOTE: Change 'matches' to 'matches_curated' if you performed manual curation with the GUI
su.save_to_output(
    unitmatch_output_dir, 
    scores_to_include, 
    matches,  # Use matches_curated if you did manual curation
    output_prob_matrix, 
    avg_centroid, 
    avg_waveform, 
    avg_waveform_per_tp, 
    max_site,
    total_score, 
    output_threshold, 
    clus_info, 
    param, 
    UIDs=UIDs, 
    matches_curated=None,  # Set to matches_curated if you did manual curation
    save_match_table=True
)

print(f"Results saved to: {unitmatch_output_dir}")
print(f"Number of matches saved: {len(matches)}")
print(f"Unique IDs assigned to {len(UIDs)} units")

# Print summary
print("\n=== PROCESSING SUMMARY ===")
print(f"BombCell processed {len(KS_dirs)} sessions")
print(f"UnitMatch analyzed {param['n_units']} good units")
print(f"Found {n_matches} putative matches")
print(f"Results saved to: {unitmatch_output_dir}")
print("Processing complete!")