In [None]:
import numpy as np
import os
import h5py
import time
from scipy.spatial import KDTree
import json
from compare_eis import compare_eis


# --- Path and recording setup ---
dat_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat"
n_channels = 512
dtype = np.int16

# --- Get total number of samples ---
file_size_bytes = os.path.getsize(dat_path)
total_samples = file_size_bytes // (np.dtype(dtype).itemsize * n_channels)

# --- Load entire file into RAM as int16 ---
raw_data = np.fromfile(dat_path, dtype=dtype, count=total_samples * n_channels)
raw_data = raw_data.reshape((total_samples, n_channels))  # shape: [T, C]


# --- Parameters ---
n_channels = 512
dtype = 'int16'
max_units = 1500
amplitude_threshold = 15
window = (-20, 60)
peak_window = 30
total_samples=36_000_000
fit_offsets = (-5, 10)

do_pursuit = 0


h5_in_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5'  # from MATLAB export, to get EI positions
h5_out_path = '/Volumes/Lab/Users/alexth/axolotl/results_pipeline_0528.h5' # where to save data

debug_folder = "/Volumes/Lab/Users/alexth/axolotl/debug"

with h5py.File(h5_in_path, 'r') as f:
    # Load electrode positions
    ei_positions = f['/ei_positions'][:].T  # shape becomes [512 x 2]
    ks_vision_ids = f['/vision_ids'][:]  # shape: (N_units,)

import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

save_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001_baseline_and_artifacts.json"

if os.path.exists(save_path):
    print(f"Loading baselines")
    with open(save_path, 'r') as f:
        data = json.load(f)
    baselines = np.array(data['baselines'], dtype=np.float32)
else:
    print(f"Computing baselines")
    baselines = axolotl_utils_ram.compute_baselines_int16(raw_data, segment_len=100_000) # shape (512, 360)

    with open(save_path, 'w') as f:
        json.dump({
            'baselines': baselines.tolist(),
        }, f)


# get KS EIs
ks_ei_path = '/Volumes/Lab/Users/alexth/axolotl/ks_eis_subset.h5'
ks_templates = {}
ks_n_spikes = {}

with h5py.File(ks_ei_path, 'r') as f:
    for k in f.keys():
        unit_id = int(k.split('_')[1])-1
        ks_templates[unit_id] = f[k][:]
        ks_n_spikes[unit_id] = f[k].attrs.get('n_spikes', -1)  # fallback if missing

ks_unit_ids = list(ks_templates.keys())
ks_ei_stack = np.stack([ks_templates[k] for k in ks_unit_ids], axis=0)  # [N x 512 x 81]

unit_id = 0

print(f"\n=== Starting unit {unit_id} ===")

while True:


    start_time = time.time()

    # could cache scores on channels to pre-identify next one
    ref_channel = axolotl_utils_ram.find_dominant_channel_ram(
            raw_data = raw_data,
            segment_len = 100_000,
            n_segments = 10,
            peak_window = 30,
            top_k_neg = 20,
            top_k_events = 5,
            seed = 42
        )

    threshold, spike_times = axolotl_utils_ram.estimate_spike_threshold_ram(
        raw_data=raw_data,
        ref_channel=ref_channel,
        window = 30,
        total_samples_to_read = total_samples,
        refractory = 30,
        top_n = 100
    )

    print(f"Channel: {ref_channel}, Threshold: {-threshold:.1f}, Initial spikes: {len(spike_times)}")

    snips, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spike_times,
        window=window,
        selected_channels=np.arange(n_channels)
    )

    ei = np.mean(snips, axis=2)
    ei -= ei[:, :5].mean(axis=1, keepdims=True)

    spikes_for_plot_pre = valid_spike_times

    # Step 6–7: Cluster and select dominant unit
    clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, cluster_eis_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips, ei, k_start=3,return_debug=True)

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    spikes_init = spike_times[spikes_idx]

    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]

    # Extract baseline-subtracted waveforms for ref_channel
    ref_snips = snips_baselined[ref_channel, :, :]  # shape: (81, N)

    # Mean waveform over all spikes
    ref_mean = ref_snips.mean(axis=1)  # shape: (81,)
    # Negative peak (should be near index 20)
    ref_peak_amp = np.abs(ref_mean[-window[0]])  # scalar

    # Threshold at 0.66× of mean waveform peak
    threshold_ampl = 0.66 * ref_peak_amp

    # Get all actual spike values at sample 20
    spike_amplitudes = np.abs(ref_snips[20, :])  # shape: (N,)

    # Flag bad spikes: too small
    bad_inds = np.where(spike_amplitudes < threshold_ampl)[0]

    # Create mask to keep only good spikes
    keep_mask = np.ones(spike_amplitudes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    # Get original traces for bad_spike_traces
    snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
        dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
        spike_times=spikes[bad_inds],
        ref_channel=ref_channel,
        window=window,
        n_channels=512,
        dtype='int16'
    )

    segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
    snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes


    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.


    if len(spikes)<100:
        pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
        labels_post = np.array([0])                    # just one fake cluster label
        sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
        ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
        selected_index_post = 0                        # only one cluster, so index is 0
        cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
        spikes_for_plot_post = np.array([0])           # placeholder spike time
        spike_counts_post = [len(snips)]               # use actual number of spikes
        matches = []                                # no matches
        # `snips_baselined` is [C x T x N]
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    else:
        clusters_post, pcs_post, labels_post, sim_matrix_post, cluster_eis_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

        # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
        ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)


        spikes = spikes[final_spike_inds]  # convert to absolute spike times
        snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

        p2p_threshold = 30
        ei_p2p = ei.max(axis=1) - ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

        #print("reclustered pursuit\n")

        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]



    # DIAGNOSTIC PLOTS

    axolotl_utils_ram.plot_unit_diagnostics(
        output_path=debug_folder,
        unit_id=unit_id,

        # --- From first call to cluster_spike_waveforms
        pcs_pre=pcs_pre,
        labels_pre=labels_pre,
        sim_matrix_pre=sim_matrix_pre,
        cluster_eis_pre = cluster_eis_pre,
        spikes_for_plot_pre = spikes_for_plot_pre,

        # --- From ei_pursuit
        mean_score=mean_score,
        valid_score=valid_score,
        mean_scores_at_spikes=mean_scores_at_spikes,
        valid_scores_at_spikes=valid_scores_at_spikes,
        mean_thresh=mean_thresh,
        valid_thresh=valid_thresh,

        # --- Lag estimation and bad spike filtering
        lags=lags,
        bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
        good_mean_trace=good_mean_trace,
        threshold_ampl=-threshold_ampl,
        ref_channel=ref_channel,
        snips_bad=snips_bad,

        # --- From second clustering
        pcs_post=pcs_post,
        labels_post=labels_post,
        sim_matrix_post=sim_matrix_post,
        cluster_eis_post = cluster_eis_post,
        spikes_for_plot_post = spikes_for_plot_post,

        # --- For axis labels etc.
        window=(-20, 60),

        ei_positions=ei_positions,
        selected_channels_count=len(selected_channels),

        spikes = spikes, 
        orig_threshold = threshold,
        ks_matches = matches
    )


    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")



    if len(spikes)>=100:
        snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    if unit_id >= max_units:
        print("Reached unit limit.")
        break



In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(25, 4))
plt.plot(raw_data[:5000, 39])
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.title('Channel 39: First 5,000 samples')
plt.grid(True)
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Parameters ---
dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat'  # ← replace with actual path
n_channels = 512
channel = 39
n_samples = 5000
dtype = np.int16

# --- Read from disk ---
with open(dat_path, 'rb') as f:
    # Read first 5000 timepoints (i.e., 5000 * n_channels values)
    raw = np.fromfile(f, dtype=dtype, count=n_samples * n_channels)
    raw = raw.reshape((n_samples, n_channels))  # [time, channel]

# --- Plot ---
plt.figure(figsize=(25, 4))
plt.plot(raw[:5000, 39])
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.title('Channel 39: First 5,000 samples')
plt.grid(True)
plt.show()


### Test verify_clusters

In [None]:
import numpy as np
import os
import h5py
import time
from scipy.spatial import KDTree
import json
from compare_eis import compare_eis


# --- Path and recording setup ---
dat_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat"
n_channels = 512
dtype = np.int16

# --- Get total number of samples ---
file_size_bytes = os.path.getsize(dat_path)
total_samples = file_size_bytes // (np.dtype(dtype).itemsize * n_channels)

# --- Load entire file into RAM as int16 ---
raw_data = np.fromfile(dat_path, dtype=dtype, count=total_samples * n_channels)
raw_data = raw_data.reshape((total_samples, n_channels))  # shape: [T, C]


# --- Parameters ---
n_channels = 512
dtype = 'int16'
max_units = 1500
amplitude_threshold = 15
window = (-20, 60)
peak_window = 30
total_samples=36_000_000
fit_offsets = (-5, 10)

do_pursuit = 0


h5_in_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5'  # from MATLAB export, to get EI positions
h5_out_path = '/Volumes/Lab/Users/alexth/axolotl/results_pipeline_0528.h5' # where to save data

debug_folder = "/Volumes/Lab/Users/alexth/axolotl/debug"

with h5py.File(h5_in_path, 'r') as f:
    # Load electrode positions
    ei_positions = f['/ei_positions'][:].T  # shape becomes [512 x 2]
    ks_vision_ids = f['/vision_ids'][:]  # shape: (N_units,)

import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

save_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001_baseline_and_artifacts.json"

if os.path.exists(save_path):
    print(f"Loading baselines")
    with open(save_path, 'r') as f:
        data = json.load(f)
    baselines = np.array(data['baselines'], dtype=np.float32)
else:
    print(f"Computing baselines")
    baselines = axolotl_utils_ram.compute_baselines_int16(raw_data, segment_len=100_000) # shape (512, 360)

    with open(save_path, 'w') as f:
        json.dump({
            'baselines': baselines.tolist(),
        }, f)


# get KS EIs
ks_ei_path = '/Volumes/Lab/Users/alexth/axolotl/ks_eis_subset.h5'
ks_templates = {}
ks_n_spikes = {}

with h5py.File(ks_ei_path, 'r') as f:
    for k in f.keys():
        unit_id = int(k.split('_')[1])-1
        ks_templates[unit_id] = f[k][:]
        ks_n_spikes[unit_id] = f[k].attrs.get('n_spikes', -1)  # fallback if missing

ks_unit_ids = list(ks_templates.keys())
ks_ei_stack = np.stack([ks_templates[k] for k in ks_unit_ids], axis=0)  # [N x 512 x 81]


In [None]:

unit_id = 0

print(f"\n=== Starting unit {unit_id} ===")

last_ref_channel = None
remaining_candidates = []


while True:

    start_time = time.time()

    if not remaining_candidates:
        remaining_candidates = axolotl_utils_ram.find_dominant_channel_ram(
            raw_data=raw_data,
            segment_len=100_000,
            n_segments=10,
            peak_window=30,
            top_k_neg=20,
            top_k_events=5,
            seed=42
        )

        # Try candidates until one is far enough from previous

    while remaining_candidates:
        candidate = remaining_candidates.pop(0)
        if last_ref_channel is None:
            ref_channel = candidate
            channel_source='fresh'
            break
        dist = np.linalg.norm(ei_positions[candidate] - ei_positions[last_ref_channel])
        if dist >= 150:
            ref_channel = candidate
            channel_source='cache'
            break
    else:
        # All candidates were too close → recompute and restart this block
        remaining_candidates = axolotl_utils_ram.find_dominant_channel_ram(
            raw_data=raw_data,
            segment_len=100_000,
            n_segments=10,
            peak_window=30,
            top_k_neg=20,
            top_k_events=5,
            seed=42
        )
        # Re-run selection from top of block
        candidate = remaining_candidates.pop(0)
        ref_channel = candidate
        channel_source='fresh'

    # Update tracker
    last_ref_channel = ref_channel

    # # could cache scores on channels to pre-identify next one
    # ref_channel = axolotl_utils_ram.find_dominant_channel_ram(
    #         raw_data = raw_data,
    #         segment_len = 100_000,
    #         n_segments = 10,
    #         peak_window = 30,
    #         top_k_neg = 20,
    #         top_k_events = 5,
    #         seed = 42
    #     )

    threshold, spike_times = axolotl_utils_ram.estimate_spike_threshold_ram(
        raw_data=raw_data,
        ref_channel=ref_channel,
        window = 30,
        total_samples_to_read = total_samples,
        refractory = 30,
        top_n = 100
    )

    print(f"Channel: {ref_channel} (from {channel_source}), Threshold: {-threshold:.1f}, Initial spikes: {len(spike_times)}")

    snips, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spike_times,
        window=window,
        selected_channels=np.arange(n_channels)
    )

    segment_len = 100_000
    snips_baselined = snips.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = valid_spike_times // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


    ei = np.mean(snips_baselined, axis=2)
    #ei -= ei[:, :5].mean(axis=1, keepdims=True)

    spikes_for_plot_pre = valid_spike_times

    # Step 6–7: Cluster and select dominant unit
    clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, n_bad_channels_pre, cluster_eis_pre, cluster_to_merged_group_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips_baselined, ei, k_start=3,return_debug=True)

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    contributing_original_ids_pre = [
        orig_id for orig_id, merged_idx in cluster_to_merged_group_pre.items()
        if merged_idx == selected_cluster_index_pre
    ]
    contributing_original_ids_pre = np.array(contributing_original_ids_pre)

    spikes_init = spike_times[spikes_idx]

    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


    # Extract baseline-subtracted waveforms for ref_channel
    ref_snips = snips_baselined[ref_channel, :, :]  # shape: (81, N)

    # Mean waveform over all spikes
    ref_mean = ref_snips.mean(axis=1)  # shape: (81,)
    # Negative peak (should be near index 20)
    ref_peak_amp = np.abs(ref_mean[-window[0]])  # scalar

    # Threshold at 0.75× of mean waveform peak
    threshold_ampl = 0.75 * ref_peak_amp

    # Get all actual spike values at sample 20
    spike_amplitudes = np.abs(ref_snips[20, :])  # shape: (N,)

    # Flag bad spikes: too small
    bad_inds = np.where(spike_amplitudes < threshold_ampl)[0]

    # Create mask to keep only good spikes
    keep_mask = np.ones(spike_amplitudes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    # Get original traces for bad_spike_traces
    snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
        dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
        spike_times=spikes[bad_inds],
        ref_channel=ref_channel,
        window=window,
        n_channels=512,
        dtype='int16'
    )

    segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
    snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes

    final_spike_inds = np.where(keep_mask)[0]



    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.


    if len(spikes)<100:
        pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
        labels_post = np.array([0])                    # just one fake cluster label
        sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
        ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
        selected_index_post = 0                        # only one cluster, so index is 0
        cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
        spikes_for_plot_post = np.array([0])           # placeholder spike time
        spike_counts_post = [len(snips)]               # use actual number of spikes
        matches = []                                # no matches
        # `snips_baselined` is [C x T x N]
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    else:
        clusters_post, pcs_post, labels_post, sim_matrix_post, n_bad_channels_post, cluster_eis_post, cluster_to_merged_group_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

        # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
        ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)


        spikes = spikes[final_spike_inds]  # convert to absolute spike times
        snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

        p2p_threshold = 30
        ei_p2p = ei.max(axis=1) - ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

        #print("reclustered pursuit\n")

        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]




    # DIAGNOSTIC PLOTS

    axolotl_utils_ram.plot_unit_diagnostics(
        output_path=debug_folder,
        unit_id=unit_id,

        # --- From first call to cluster_spike_waveforms
        pcs_pre=pcs_pre,
        labels_pre=labels_pre,
        sim_matrix_pre=sim_matrix_pre,
        cluster_eis_pre = cluster_eis_pre,
        spikes_for_plot_pre = spikes_for_plot_pre,
        n_bad_channels_pre = n_bad_channels_pre,
        contributing_original_ids_pre = contributing_original_ids_pre,

        # --- From ei_pursuit
        mean_score=mean_score,
        valid_score=valid_score,
        mean_scores_at_spikes=mean_scores_at_spikes,
        valid_scores_at_spikes=valid_scores_at_spikes,
        mean_thresh=mean_thresh,
        valid_thresh=valid_thresh,

        # --- Lag estimation and bad spike filtering
        lags=lags,
        bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
        good_mean_trace=good_mean_trace,
        threshold_ampl=-threshold_ampl,
        ref_channel=ref_channel,
        snips_bad=snips_bad,

        # --- From second clustering
        pcs_post=pcs_post,
        labels_post=labels_post,
        sim_matrix_post=sim_matrix_post,
        cluster_eis_post = cluster_eis_post,
        spikes_for_plot_post = spikes_for_plot_post,

        # --- For axis labels etc.
        window=(-20, 60),

        ei_positions=ei_positions,
        selected_channels_count=len(selected_channels),

        spikes = spikes, 
        orig_threshold = threshold,
        ks_matches = matches
    )


    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")




    if len(spikes)>=100:
        snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }


    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    if unit_id >= max_units:
        print("Reached unit limit.")
        break



In [None]:

    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    # if unit_id >= max_units:
    #     print("Reached unit limit.")
    #     break



In [None]:
# Plot
ei = np.mean(snips_baselined, axis=2)
plt.figure(figsize=(7, 4))
plt.plot(ei[148,:], color='black', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

for i, cluster in enumerate(clusters_pre):
    ei = cluster['ei']
    ref_ch = cluster['channels'][np.argmax(np.ptp(ei[cluster['channels'], :], axis=1))]
    ei_p2p = np.ptp(ei[ref_ch, :])
    n_spikes = len(cluster['inds'])

    plt.figure(figsize=(15, 5))
    plot_ei_waveforms(
        ei=ei,
        positions=ei_positions,
        scale=70.0,
        box_height=1.0,
        box_width=50,
        linewidth=0.5,
        alpha=0.9,
        colors='black'
    )
    plt.title(f"Cluster {i} EI — Spikes: {n_spikes}, P2P on Ref Ch ({ref_ch}): {ei_p2p:.1f}")
    plt.tight_layout()
    plt.show()


In [None]:
import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, cluster_eis_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips_baselined, ei, k_start=3,return_debug=True)


In [None]:
# Plot
plt.figure(figsize=(7, 4))
plt.plot(cluster_eis_pre[2][148,:], color='black', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:

ei_a = clusters_pre[0]['ei']
ei_b = clusters_pre[1]['ei']


result = axolotl_utils_ram.compare_ei_subtraction(ei_a, ei_b, max_lag=3, p2p_thresh=30.0)

res = np.array(result['per_channel_residuals'])
cos_sim = np.mean(result['per_channel_cosine_sim'])

neg_inds = np.where(res < -10)[0]

print(len(neg_inds))
print(cos_sim)



In [None]:

ei_a = cluster_eis_pre[0]
ei_b = cluster_eis_pre[2]


result = axolotl_utils_ram.compare_ei_subtraction(ei_a, ei_b, max_lag=3, p2p_thresh=30.0)

res = np.array(result['per_channel_residuals'])
cos_sim = np.mean(result['per_channel_cosine_sim'])

neg_inds = np.where(res < -10)[0]

print(len(neg_inds))
print(cos_sim)


In [None]:
from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plot_ei_waveforms(
    ei=cluster_eis_pre[2],
    positions=ei_positions,
    scale=70.0,
    box_height=1.0,
    box_width=50,
    linewidth=0.5,
    alpha=0.9,
    colors='black'
)
plt.title(f"Cluster {i} EI — Spikes: {n_spikes}, P2P on Ref Ch ({ref_ch}): {ei_p2p:.1f}")
plt.tight_layout()
plt.show()

In [None]:
# Plot
plt.figure(figsize=(7, 4))
plt.plot(ei_a[148,:], color='black', linewidth=1)
plt.plot(ei_b[148,:], color='red', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:

global_cosine_sim = np.mean(result['per_channel_cosine_sim'])

amp_threshold = -10
cos_threshold = 0.9

res = result['per_channel_residuals']
neg_inds = np.where(np.array(res) < amp_threshold)[0]
if global_cosine_sim < cos_threshold or len(neg_inds) > 0:
    print('two units')
else:
    print('same unit')
    

In [None]:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.bar(result['good_channels'], result['per_channel_residuals'], color='gray')
plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
plt.xlabel("Channel ID")
plt.ylabel("Mean Residual (B - A)")
plt.title("Per-Channel Residuals (Masked Subtraction)")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()


plt.figure(figsize=(10, 4))
plt.bar(result['good_channels'], result['per_channel_cosine_sim'], color='gray')
plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
plt.xlabel("Channel ID")
plt.ylabel("Mean Residual (B - A)")
plt.title("Per-Channel cosine_sim (Masked Subtraction)")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

# plt.figure(figsize=(10, 4))
# plt.bar(result['good_channels'], result['p2p'], color='gray')
# plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
# plt.xlabel("Channel ID")
# plt.ylabel("Mean Residual (B - A)")
# plt.title("Per-Channel amplitude")
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()


In [None]:

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    spikes_init = spike_times[spikes_idx]

    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


    # Extract baseline-subtracted waveforms for ref_channel
    ref_snips = snips_baselined[ref_channel, :, :]  # shape: (81, N)

    # Mean waveform over all spikes
    ref_mean = ref_snips.mean(axis=1)  # shape: (81,)
    # Negative peak (should be near index 20)
    ref_peak_amp = np.abs(ref_mean[-window[0]])  # scalar

    # Threshold at 0.75× of mean waveform peak
    threshold_ampl = 0.75 * ref_peak_amp

    # Get all actual spike values at sample 20
    spike_amplitudes = np.abs(ref_snips[20, :])  # shape: (N,)

    # Flag bad spikes: too small
    bad_inds = np.where(spike_amplitudes < threshold_ampl)[0]

    # Create mask to keep only good spikes
    keep_mask = np.ones(spike_amplitudes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    # Get original traces for bad_spike_traces
    snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
        dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
        spike_times=spikes[bad_inds],
        ref_channel=ref_channel,
        window=window,
        n_channels=512,
        dtype='int16'
    )

    segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
    snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes

    final_spike_inds = np.where(keep_mask)[0]


        


In [None]:
params = {
    'window': (-20, 60),
    'min_spikes': 100,
    'ei_sim_threshold': 0.75,
    'k_start': 4,
    'k_refine': 2
}

from verify_cluster import verify_cluster

spike_times = spikes
clusters = verify_cluster(
    spike_times=spike_times,
    dat_path=snips_baselined,
    params=params
)

print(f"Returned {len(clusters)} clean subclusters")
for i, cl in enumerate(clusters):
    print(f"  Cluster {i}: {len(cl['inds'])} spikes")

In [None]:

import analyze_clusters
import importlib
importlib.reload(analyze_clusters)


analyze_clusters.analyze_clusters(clusters,
                 spike_times=spikes,
                 sampling_rate=20000,
                 dat_path=snips_baselined,
                 h5_path='/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5',
                 triggers_mat_path='/Volumes/Lab/Users/alexth/axolotl/trigger_in_samples_201703151.mat',
                 cluster_ids=None,
                 lut=None,
                 sta_depth=30,
                 sta_offset=0,
                 sta_chunk_size=1000,
                 sta_refresh=2,
                 ei_scale=3,
                 ei_cutoff=0.08)

In [None]:
tmp = snips_baselined[ref_channel, 20, :].copy()
import matplotlib.pyplot as plt
plt.plot(tmp)

final_spike_inds = np.arange(len(spikes))


In [None]:
plt.figure(figsize=(6, 4))
plt.hist(tmp, bins=50, color='gray', edgecolor='black')
plt.title("Histogram of tmp values")
plt.xlabel("Amplitude")
plt.ylabel("Count")
plt.grid(True)
plt.show()


In [None]:
inds = np.where(tmp > -500)[0]
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.plot(snips_baselined[39, :, inds].T, alpha=1)
plt.plot(snips_baselined[39, :, :6], alpha=1)
plt.title(f"Overlay of {len(inds)} selected snippets on channel 39")
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.show()


In [None]:


    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.


    if len(spikes)<100:
        pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
        labels_post = np.array([0])                    # just one fake cluster label
        sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
        ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
        selected_index_post = 0                        # only one cluster, so index is 0
        cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
        spikes_for_plot_post = np.array([0])           # placeholder spike time
        spike_counts_post = [len(snips)]               # use actual number of spikes
        matches = []                                # no matches
        # `snips_baselined` is [C x T x N]
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    else:
        clusters_post, pcs_post, labels_post, sim_matrix_post, cluster_eis_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

        # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
        ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)


        spikes = spikes[final_spike_inds]  # convert to absolute spike times
        snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

        p2p_threshold = 30
        ei_p2p = ei.max(axis=1) - ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

        #print("reclustered pursuit\n")

        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]



In [None]:
p2p_threshold = 30
ei_p2p = ei.max(axis=1) - ei.min(axis=1)
selected_channels = np.where(ei_p2p > p2p_threshold)[0]
selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

#print("reclustered pursuit\n")

# check for matching KS units
results = []
lag = 20
ks_sim_threshold = 0.75

# Run comparison
sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
matches = [
    {
        "unit_id": ks_unit_ids[i],
        "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
        "similarity": float(sim[i]),
        "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
    }
    for i in np.where(sim > ks_sim_threshold)[0]
]

pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
labels_post = np.array([0])                    # just one fake cluster label
sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
selected_index_post = 0                        # only one cluster, so index is 0
cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
spikes_for_plot_post = np.array([0])           # placeholder spike time
spike_counts_post = [len(snips)]               # use actual number of spikes

In [None]:


    # DIAGNOSTIC PLOTS

    axolotl_utils_ram.plot_unit_diagnostics(
        output_path=debug_folder,
        unit_id=unit_id,

        # --- From first call to cluster_spike_waveforms
        pcs_pre=pcs_pre,
        labels_pre=labels_pre,
        sim_matrix_pre=sim_matrix_pre,
        cluster_eis_pre = cluster_eis_pre,
        spikes_for_plot_pre = spikes_for_plot_pre,

        # --- From ei_pursuit
        mean_score=mean_score,
        valid_score=valid_score,
        mean_scores_at_spikes=mean_scores_at_spikes,
        valid_scores_at_spikes=valid_scores_at_spikes,
        mean_thresh=mean_thresh,
        valid_thresh=valid_thresh,

        # --- Lag estimation and bad spike filtering
        lags=lags,
        bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
        good_mean_trace=good_mean_trace,
        threshold_ampl=-threshold_ampl,
        ref_channel=ref_channel,
        snips_bad=snips_bad,

        # --- From second clustering
        pcs_post=pcs_post,
        labels_post=labels_post,
        sim_matrix_post=sim_matrix_post,
        cluster_eis_post = cluster_eis_post,
        spikes_for_plot_post = spikes_for_plot_post,

        # --- For axis labels etc.
        window=(-20, 60),

        ei_positions=ei_positions,
        selected_channels_count=len(selected_channels),

        spikes = spikes, 
        orig_threshold = threshold,
        ks_matches = matches
    )


    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")



In [None]:


    if len(spikes)>=100:
        snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }


    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    # if unit_id >= max_units:
    #     print("Reached unit limit.")
    #     break

