## Processing playground 
### tracking neurons across days 

In [ ]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# Enable autoreload like in the demo
%load_ext autoreload
%autoreload 2

# Import bombcell like in the demo
import bombcell as bc

print("Available bombcell functions:")
print([attr for attr in dir(bc) if not attr.startswith('_')])

# UnitMatch imports
import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import UnitMatchPy.save_utils as su
import UnitMatchPy.GUI as gui
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

In [ ]:
## Step 1: Set up file paths
# KiloSort directories
KS_dirs = [r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4',
           r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-20/cz_npxl_g0/cz_npxl_g0_imec0/kilosort4']

# BombCell output directories - add testing suffix
custom_bombcell_paths = [r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-19/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf',
                         r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/2023-04-20/cz_npxl_g0/cz_npxl_g0_imec0/bombcell_testing_jf']

# Output directory for saving results - add testing suffix
save_dir = r'/home/jf5479/cup/Chris/data/cta_backwards/calca_302/unitmatch_output_testing_jf'
os.makedirs(save_dir, exist_ok=True)

print(f"KiloSort directories: {KS_dirs}")
print(f"BombCell directories: {custom_bombcell_paths}")
print(f"Output directory: {save_dir}")

In [ ]:
## Step 2: Run BombCell quality metrics and extract raw waveforms
print("Starting BombCell processing...")

# Process each session with BombCell
bombcell_results = {}
for i, session_dir in enumerate(KS_dirs):
    print(f"Processing session {i+1}: {session_dir}")
    
    # Find raw data file (.bin) and meta file (.meta) 
    session_path = Path(session_dir).parent
    raw_files = list(session_path.glob("*.ap.bin"))
    meta_files = list(session_path.glob("*.ap.meta"))
    
    raw_file = str(raw_files[0]) if raw_files else None
    meta_file = str(meta_files[0]) if meta_files else None
    
    print(f"  Raw file: {raw_file}")
    print(f"  Meta file: {meta_file}")
    
    # Get default BombCell parameters for this session
    param = bc.default_parameters.get_default_parameters(session_dir, 
                                                        raw_file=raw_file,  # Provide raw file path
                                                        meta_file=meta_file,  # Provide meta file path
                                                        kilosort_version=4)  # Adjust based on your KS version
    
    # Customize parameters for UnitMatch integration
    param['extractRaw'] = True  # Ensure raw waveforms are extracted for UnitMatch
    param['computeDistanceMetrics'] = False  # Disable expensive metrics
    param['computeDrift'] = False
    param['saveAsTSV'] = True  # Save results in phy-compatible format
    
    # Set BombCell output directory with testing suffix
    bc_output_dir = Path(session_dir).parent / 'bombcell_testing_jf'
    
    try:
        # Run BombCell - the function should be imported at top level
        (quality_metrics, param, unit_type, unit_type_string) = bc.run_bombcell(
            session_dir, bc_output_dir, param
        )
        
        bombcell_results[f'session_{i+1}'] = {
            'quality_metrics': quality_metrics,
            'unit_type': unit_type,
            'unit_type_string': unit_type_string,
            'param': param,
            'session_dir': session_dir,
            'bc_output_dir': str(bc_output_dir)
        }
        
        print(f"BombCell processing complete for session {i+1}")
        print(f"  - Total units: {len(quality_metrics['phy_clusterID'])}")
        print(f"  - Good units: {sum(np.array(unit_type_string) == 'GOOD')}")
        print(f"  - Results saved to: {bc_output_dir}")
        
    except Exception as e:
        print(f"Error processing session {i+1}: {e}")
        import traceback
        traceback.print_exc()
        continue

print("BombCell processing completed for all sessions.")

In [ ]:
## Step 3: Prepare data for UnitMatch
print("Preparing data for UnitMatch...")

# Get default UnitMatch parameters
param = default_params.get_default_param()

# Set up paths for UnitMatch - using the directories defined above
param['KS_dirs'] = KS_dirs

# Use BombCell output directories for waveforms and labels
bc_output_dirs = [bombcell_results[f'session_{i+1}']['bc_output_dir'] for i in range(len(KS_dirs))]

# Get paths from BombCell outputs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(
    KS_dirs, 
    custom_bombcell_paths=bc_output_dirs
)

# Get probe geometry
param = util.get_probe_geometry(channel_pos[0], param)

print(f"Raw waveform paths: {wave_paths}")
print(f"Unit label paths: {unit_label_paths}")
print(f"Channel positions loaded: {len(channel_pos)} sessions")
print("Data preparation for UnitMatch complete.")

In [None]:
## Step 4: Run UnitMatch - Data Loading and Parameter Extraction
print("Starting UnitMatch processing...")

# STEP 0 -- Data preparation
print("Loading good waveforms...")
waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(
    wave_paths, unit_label_paths, param, good_units_only=True
) 

# You may need to set peak location if it's not automatically detected correctly
# param['peak_loc'] = # set as a value if the peak location is NOT ~ half the spike width

# Create clus_info containing all unit id/session related info
clus_info = {
    'good_units': good_units, 
    'session_switch': session_switch, 
    'session_id': session_id, 
    'original_ids': np.concatenate(good_units)
}

print(f"Total number of good units: {param['n_units']}")
print(f"Number of sessions: {len(KS_dirs)}")

# STEP 1 -- Extract parameters from waveforms
print("Extracting waveform parameters...")
extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)
print("Parameter extraction complete.")

In [None]:
## Step 5: UnitMatch - Metric Calculation and Drift Correction
print("Calculating similarity metrics and applying drift correction...")

# STEPS 2, 3, 4 -- Extract metric scores with drift correction
total_score, candidate_pairs, scores_to_include, predictors = ov.extract_metric_scores(
    extracted_wave_properties, session_switch, within_session, param, niter=2
)

print(f"Number of candidate pairs: {len(candidate_pairs)}")
print(f"Scores included: {scores_to_include}")
print("Metric calculation and drift correction complete.")

In [None]:
## Step 6: UnitMatch - Naive Bayes Classification
print("Running Naive Bayes classification...")

# STEP 5 -- Probability analysis
# Get prior probability of being a match
prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2)
priors = np.array((prior_match, 1-prior_match))

print(f"Prior probability of match: {prior_match:.4f}")

# Construct distributions (kernels) for Naive Bayes Classifier
labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)

parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one=1)

# Get probability of each pair being a match
probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, cond)

# Reshape probability matrix
output_prob_matrix = probability[:,1].reshape(param['n_units'], param['n_units'])

print("Naive Bayes classification complete.")
print(f"Probability matrix shape: {output_prob_matrix.shape}")

In [None]:
## Step 7: Evaluate Results and Apply Threshold
print("Evaluating UnitMatch results...")

# Evaluate output with different thresholds
util.evaluate_output(output_prob_matrix, param, within_session, session_switch, match_threshold=0.75)

# Set match threshold (you can experiment with different values)
match_threshold = param['match_threshold']  # or set your own value, e.g., 0.75

# Apply threshold to create binary match matrix
output_threshold = np.zeros_like(output_prob_matrix)
output_threshold[output_prob_matrix > match_threshold] = 1

# Visualize the thresholded matches
plt.figure(figsize=(10, 8))
plt.imshow(output_threshold, cmap='Greys')
plt.title(f'Unit Matches (threshold = {match_threshold})')
plt.xlabel('Unit Index')
plt.ylabel('Unit Index')
plt.colorbar()
plt.show()

# Count matches
n_matches = np.sum(output_threshold) // 2  # Divide by 2 because matrix is symmetric
print(f"Number of putative matches found: {n_matches}")
print(f"Match threshold used: {match_threshold}")

In [None]:
## Step 8: Prepare and Launch GUI (Optional)
print("Preparing data for GUI...")

# Format data for GUI
amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']

# Process info for GUI
gui.process_info_for_GUI(
    output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
    avg_centroid, avg_waveform, avg_waveform_per_tp, wave_idx, max_site, max_site_mean, 
    waveform, within_session, channel_pos, clus_info, param
)

print("GUI data preparation complete.")
print("To launch the GUI, run the next cell.")

In [None]:
## Step 9: Launch GUI for Manual Curation (Optional)
# Uncomment the lines below to run the GUI for manual curation

# print("Launching UnitMatch GUI...")
# is_match, not_match, matches_GUI = gui.run_GUI()

# # If you ran the GUI, curate the matches
# matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode='And')
# print(f"Manual curation complete. Curated matches: {len(matches_curated)}")

print("GUI section ready. Uncomment the lines above to run manual curation.")

In [ ]:
## Step 10: Save Results
print("Saving UnitMatch results...")

# Get matches from thresholded matrix
matches = np.argwhere(output_threshold == 1)

# Assign unique IDs to matched units
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

# Create output directory with testing suffix
unitmatch_output_dir = os.path.join(save_dir, 'unitmatch_results_testing_jf')
os.makedirs(unitmatch_output_dir, exist_ok=True)

# Save results
# NOTE: Change 'matches' to 'matches_curated' if you performed manual curation with the GUI
su.save_to_output(
    unitmatch_output_dir, 
    scores_to_include, 
    matches,  # Use matches_curated if you did manual curation
    output_prob_matrix, 
    avg_centroid, 
    avg_waveform, 
    avg_waveform_per_tp, 
    max_site,
    total_score, 
    output_threshold, 
    clus_info, 
    param, 
    UIDs=UIDs, 
    matches_curated=None,  # Set to matches_curated if you did manual curation
    save_match_table=True
)

print(f"Results saved to: {unitmatch_output_dir}")
print(f"Number of matches saved: {len(matches)}")
print(f"Unique IDs assigned to {len(UIDs)} units")

# Print summary
print("\n=== PROCESSING SUMMARY ===")
print(f"BombCell processed {len(KS_dirs)} sessions")
print(f"UnitMatch analyzed {param['n_units']} good units")
print(f"Found {n_matches} putative matches")
print(f"Results saved to: {unitmatch_output_dir}")
print("Processing complete!")