In [None]:
import numpy as np
import os
import h5py
import time
from scipy.spatial import KDTree
import json
from compare_eis import compare_eis


# --- Path and recording setup ---
dat_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat"
n_channels = 512
dtype = np.int16

# --- Get total number of samples ---
file_size_bytes = os.path.getsize(dat_path)
total_samples = file_size_bytes // (np.dtype(dtype).itemsize * n_channels)

# --- Load entire file into RAM as int16 ---
raw_data = np.fromfile(dat_path, dtype=dtype, count=total_samples * n_channels)
raw_data = raw_data.reshape((total_samples, n_channels))  # shape: [T, C]


# --- Parameters ---
n_channels = 512
dtype = 'int16'
max_units = 1500
amplitude_threshold = 15
window = (-20, 60)
peak_window = 30
total_samples=36_000_000
fit_offsets = (-5, 10)

do_pursuit = 0


h5_in_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5'  # from MATLAB export, to get EI positions
h5_out_path = '/Volumes/Lab/Users/alexth/axolotl/results_pipeline_0528.h5' # where to save data

debug_folder = "/Volumes/Lab/Users/alexth/axolotl/debug"

with h5py.File(h5_in_path, 'r') as f:
    # Load electrode positions
    ei_positions = f['/ei_positions'][:].T  # shape becomes [512 x 2]
    ks_vision_ids = f['/vision_ids'][:]  # shape: (N_units,)

import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

save_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001_baseline_and_artifacts.json"

if os.path.exists(save_path):
    print(f"Loading baselines")
    with open(save_path, 'r') as f:
        data = json.load(f)
    baselines = np.array(data['baselines'], dtype=np.float32)
else:
    print(f"Computing baselines")
    baselines = axolotl_utils_ram.compute_baselines_int16(raw_data, segment_len=100_000) # shape (512, 360)

    with open(save_path, 'w') as f:
        json.dump({
            'baselines': baselines.tolist(),
        }, f)


# get KS EIs
ks_ei_path = '/Volumes/Lab/Users/alexth/axolotl/ks_eis_subset.h5'
ks_templates = {}
ks_n_spikes = {}

with h5py.File(ks_ei_path, 'r') as f:
    for k in f.keys():
        unit_id = int(k.split('_')[1])-1
        ks_templates[unit_id] = f[k][:]
        ks_n_spikes[unit_id] = f[k].attrs.get('n_spikes', -1)  # fallback if missing

ks_unit_ids = list(ks_templates.keys())
ks_ei_stack = np.stack([ks_templates[k] for k in ks_unit_ids], axis=0)  # [N x 512 x 81]

unit_id = 0

print(f"\n=== Starting unit {unit_id} ===")

while True:


    start_time = time.time()

    # could cache scores on channels to pre-identify next one
    ref_channel = axolotl_utils_ram.find_dominant_channel_ram(
            raw_data = raw_data,
            segment_len = 100_000,
            n_segments = 10,
            peak_window = 30,
            top_k_neg = 20,
            top_k_events = 5,
            seed = 42
        )

    threshold, spike_times = axolotl_utils_ram.estimate_spike_threshold_ram(
        raw_data=raw_data,
        ref_channel=ref_channel,
        window = 30,
        total_samples_to_read = total_samples,
        refractory = 30,
        top_n = 100
    )

    print(f"Channel: {ref_channel}, Threshold: {-threshold:.1f}, Initial spikes: {len(spike_times)}")

    snips, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spike_times,
        window=window,
        selected_channels=np.arange(n_channels)
    )

    ei = np.mean(snips, axis=2)
    ei -= ei[:, :5].mean(axis=1, keepdims=True)

    spikes_for_plot_pre = valid_spike_times

    # Step 6–7: Cluster and select dominant unit
    clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, cluster_eis_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips, ei, k_start=3,return_debug=True)

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    spikes_init = spike_times[spikes_idx]

    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]

    # Extract baseline-subtracted waveforms for ref_channel
    ref_snips = snips_baselined[ref_channel, :, :]  # shape: (81, N)

    # Mean waveform over all spikes
    ref_mean = ref_snips.mean(axis=1)  # shape: (81,)
    # Negative peak (should be near index 20)
    ref_peak_amp = np.abs(ref_mean[-window[0]])  # scalar

    # Threshold at 0.66× of mean waveform peak
    threshold_ampl = 0.66 * ref_peak_amp

    # Get all actual spike values at sample 20
    spike_amplitudes = np.abs(ref_snips[20, :])  # shape: (N,)

    # Flag bad spikes: too small
    bad_inds = np.where(spike_amplitudes < threshold_ampl)[0]

    # Create mask to keep only good spikes
    keep_mask = np.ones(spike_amplitudes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    # Get original traces for bad_spike_traces
    snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
        dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
        spike_times=spikes[bad_inds],
        ref_channel=ref_channel,
        window=window,
        n_channels=512,
        dtype='int16'
    )

    segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
    snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes


    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.


    if len(spikes)<100:
        pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
        labels_post = np.array([0])                    # just one fake cluster label
        sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
        ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
        selected_index_post = 0                        # only one cluster, so index is 0
        cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
        spikes_for_plot_post = np.array([0])           # placeholder spike time
        spike_counts_post = [len(snips)]               # use actual number of spikes
        matches = []                                # no matches
        # `snips_baselined` is [C x T x N]
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    else:
        clusters_post, pcs_post, labels_post, sim_matrix_post, cluster_eis_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

        # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
        ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)


        spikes = spikes[final_spike_inds]  # convert to absolute spike times
        snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

        p2p_threshold = 30
        ei_p2p = ei.max(axis=1) - ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

        #print("reclustered pursuit\n")

        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]



    # DIAGNOSTIC PLOTS

    axolotl_utils_ram.plot_unit_diagnostics(
        output_path=debug_folder,
        unit_id=unit_id,

        # --- From first call to cluster_spike_waveforms
        pcs_pre=pcs_pre,
        labels_pre=labels_pre,
        sim_matrix_pre=sim_matrix_pre,
        cluster_eis_pre = cluster_eis_pre,
        spikes_for_plot_pre = spikes_for_plot_pre,

        # --- From ei_pursuit
        mean_score=mean_score,
        valid_score=valid_score,
        mean_scores_at_spikes=mean_scores_at_spikes,
        valid_scores_at_spikes=valid_scores_at_spikes,
        mean_thresh=mean_thresh,
        valid_thresh=valid_thresh,

        # --- Lag estimation and bad spike filtering
        lags=lags,
        bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
        good_mean_trace=good_mean_trace,
        threshold_ampl=-threshold_ampl,
        ref_channel=ref_channel,
        snips_bad=snips_bad,

        # --- From second clustering
        pcs_post=pcs_post,
        labels_post=labels_post,
        sim_matrix_post=sim_matrix_post,
        cluster_eis_post = cluster_eis_post,
        spikes_for_plot_post = spikes_for_plot_post,

        # --- For axis labels etc.
        window=(-20, 60),

        ei_positions=ei_positions,
        selected_channels_count=len(selected_channels),

        spikes = spikes, 
        orig_threshold = threshold,
        ks_matches = matches
    )


    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")



    if len(spikes)>=100:
        snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    if unit_id >= max_units:
        print("Reached unit limit.")
        break



In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(25, 4))
plt.plot(raw_data[:5000, 39])
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.title('Channel 39: First 5,000 samples')
plt.grid(True)
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --- Parameters ---
dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat'  # ← replace with actual path
n_channels = 512
channel = 39
n_samples = 5000
dtype = np.int16

# --- Read from disk ---
with open(dat_path, 'rb') as f:
    # Read first 5000 timepoints (i.e., 5000 * n_channels values)
    raw = np.fromfile(f, dtype=dtype, count=n_samples * n_channels)
    raw = raw.reshape((n_samples, n_channels))  # [time, channel]

# --- Plot ---
plt.figure(figsize=(25, 4))
plt.plot(raw[:5000, 39])
plt.xlabel('Sample')
plt.ylabel('Amplitude')
plt.title('Channel 39: First 5,000 samples')
plt.grid(True)
plt.show()


### Test - development

In [None]:
import numpy as np
import os
import h5py
import time
from scipy.spatial import KDTree
import json
from compare_eis import compare_eis


# --- Path and recording setup ---
dat_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat"
n_channels = 512
dtype = np.int16

# --- Get total number of samples ---
file_size_bytes = os.path.getsize(dat_path)
total_samples = file_size_bytes // (np.dtype(dtype).itemsize * n_channels)

# --- Load entire file into RAM as int16 ---
raw_data = np.fromfile(dat_path, dtype=dtype, count=total_samples * n_channels)
raw_data = raw_data.reshape((total_samples, n_channels))  # shape: [T, C]


# --- Parameters ---
n_channels = 512
dtype = 'int16'
max_units = 1500
amplitude_threshold = 15
window = (-20, 60)
peak_window = 30
total_samples=36_000_000
fit_offsets = (-5, 10)

do_pursuit = 0


h5_in_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5'  # from MATLAB export, to get EI positions
h5_out_path = '/Volumes/Lab/Users/alexth/axolotl/results_pipeline_0528.h5' # where to save data

debug_folder = "/Volumes/Lab/Users/alexth/axolotl/debug/0607"

with h5py.File(h5_in_path, 'r') as f:
    # Load electrode positions
    ei_positions = f['/ei_positions'][:].T  # shape becomes [512 x 2]
    ks_vision_ids = f['/vision_ids'][:]  # shape: (N_units,)

import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

save_path = "/Volumes/Lab/Users/alexth/axolotl/201703151_data001_baseline_derivative.json"

if os.path.exists(save_path):
    print(f"Loading baselines")
    with open(save_path, 'r') as f:
        data = json.load(f)
    baselines = np.array(data['baselines'], dtype=np.float32)
else:
    print(f"Computing baselines")
    baselines = axolotl_utils_ram.compute_baselines_int16_deriv_robust(raw_data, segment_len=100_000, diff_thresh=10, trim_fraction=0.15) # shape (512, 360)

    with open(save_path, 'w') as f:
        json.dump({
            'baselines': baselines.tolist(),
        }, f)


# get KS EIs
ks_ei_path = '/Volumes/Lab/Users/alexth/axolotl/ks_eis_subset.h5'
ks_templates = {}
ks_n_spikes = {}

with h5py.File(ks_ei_path, 'r') as f:
    for k in f.keys():
        unit_id = int(k.split('_')[1])-1
        ks_templates[unit_id] = f[k][:]
        ks_n_spikes[unit_id] = f[k].attrs.get('n_spikes', -1)  # fallback if missing

ks_unit_ids = list(ks_templates.keys())
ks_ei_stack = np.stack([ks_templates[k] for k in ks_unit_ids], axis=0)  # [N x 512 x 81]


In [None]:
unit_id = 0
print(unit_id)
debug_folder = "/Volumes/Lab/Users/alexth/axolotl/debug/0615"

import sys

class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, obj):
        for f in self.files:
            f.write(obj)
    def flush(self):
        for f in self.files:
            f.flush()

# Create the log file
log_file = open("/Volumes/Lab/Users/alexth/axolotl/debug/0615/processing_log.txt", "w")

# Redirect stdout to both notebook and file
sys.stdout = Tee(sys.__stdout__, log_file)


In [None]:
import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)
from diptest import diptest
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from scipy.signal import find_peaks
from sklearn.cluster import KMeans

# unit_id = 1004


last_ref_channel = None
remaining_candidates = []

do_second_clustering = False

# ================================================================
# ---- 0.  One-time initialisation
# ================================================================
if 'rejected_spike_log' not in globals():
    rejected_spike_log = {
        'times'   : [],   # absolute sample indices
        'channel' : [],   # ref_channel for each rejected snip
        'unit_id' : []    # which unit’s refinement produced the reject
    }

# ref_channel = 51

# if 1:
while True:

    print(f"\n=== Starting unit {unit_id} ===\n")

    start_time = time.time()

    if not remaining_candidates:
        seed=int(time.time())

        remaining_candidates,top_amplitudes = axolotl_utils_ram.find_dominant_channel_ram(
            raw_data=raw_data,
            positions=ei_positions,
            segment_len=200_000, # at 20k samples per s, this is 10s snippet, and we have 10 of them
            n_segments=20,
            peak_window=30,
            top_k_neg=40,   # number of negative peaks to keep per channel per segment
            top_k_events=100, # number of top amplitudes to average per channel
            seed=seed,
            use_negative_peak=True,
            top_n = 10,
            min_spacing = 150
        )
        channel_source='fresh'
        last_ref_channel = None
        formatted_amps = ", ".join(f"{amp:.1f}" for amp in top_amplitudes)
        print(f"New amplitudes: {formatted_amps}, on channels {remaining_candidates}")

        

    candidate = remaining_candidates.pop(0)
    ref_channel = candidate

    if last_ref_channel is not None: # From cache
        channel_source='cache'

    # Update tracker
    # ref_channel = 29
    # channel_source = 'fresh'
    last_ref_channel = ref_channel



    threshold, spike_times = axolotl_utils_ram.estimate_spike_threshold_ram(
        raw_data=raw_data,
        ref_channel=ref_channel,
        window = 30,
        total_samples_to_read = total_samples,
        refractory = 10,
        top_n = 100,
        threshold_scale = 0.5
    )

    elapsed_thresholding = time.time() - start_time 
    print(f"Threshold: {elapsed_thresholding:.1f} s")

    # spike_times = unmatched_spikes_2

    print(f"Channel: {ref_channel} (from {channel_source}), Threshold: {-threshold:.1f}, Initial spikes: {len(spike_times)}")


    snips, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spike_times,
        window=window,
        selected_channels=np.arange(n_channels)
    )

    segment_len = 100_000
    snips_baselined = snips.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = valid_spike_times // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


    ei = np.mean(snips_baselined, axis=2)
    #ei -= ei[:, :5].mean(axis=1, keepdims=True)

    spikes_for_plot_pre = valid_spike_times

    # Step 6–7: Cluster and select dominant unit
    k_start = min(8, 3 + (len(spike_times) - 1) // 3000)

    clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, n_bad_channels_pre, cluster_eis_pre, cluster_to_merged_group_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips_baselined, ei, k_start=k_start,return_debug=True)


    elapsed_clustering = time.time() - start_time 
    print(f"Clustering: {elapsed_clustering-elapsed_thresholding:.1f} s, with k={k_start}")

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    contributing_original_ids_pre = [
        orig_id for orig_id, merged_idx in cluster_to_merged_group_pre.items()
        if merged_idx == selected_cluster_index_pre
    ]
    contributing_original_ids_pre = np.array(contributing_original_ids_pre)

    spikes_init = spike_times[spikes_idx]





    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    elapsed_lags = time.time() - start_time 
    print(f"Lags: {elapsed_lags-elapsed_clustering:.1f} s")

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    elapsed_snippet_extraction = time.time() - start_time 
    print(f"Snippet: {elapsed_snippet_extraction-elapsed_lags:.1f} s")

    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]



    # --------------------------------------------------------------------
    # 0.  Gather ref-channel waveforms (baseline-subtracted)
    # --------------------------------------------------------------------
    ref_snips     = snips_baselined[ref_channel, :, :].copy()
    ref_snips     = ref_snips.T          # [N × 81]
    wave_window   = slice(15, 40)                                 # focus on peak
    waveforms     = ref_snips[:, wave_window]                     # [N × 25]

    # --------------------------------------------------------------------
    # 1.  PCA(2)  (no z-score yet)
    # --------------------------------------------------------------------
    pcs_raw = PCA(n_components=2, svd_solver='full').fit_transform(waveforms)  # [N × 2]

    # --------------------------------------------------------------------
    # 2.  Hartigan’s dip test on rotated projections
    # --------------------------------------------------------------------
    angle_step = 10
    angles = np.deg2rad(np.arange(0, 180, angle_step))

    best_p = 1.0
    best_proj = None
    for theta in angles:
        proj = pcs_raw[:, 0] * np.cos(theta) + pcs_raw[:, 1] * np.sin(theta)
        _, p = diptest(proj)
        if p < best_p:
            best_p, best_proj = p, proj

    discard_inds_bimodal = np.empty(0, dtype=int)

    if best_p < 0.05:                                   # suspected bimodality
        # histogram → first two peaks as initial centroids
        hist, bin_edges = np.histogram(best_proj, bins=30)
        peaks, _ = find_peaks(hist)
        if len(peaks) >= 2:
            centroids = bin_edges[peaks[:2]].reshape(-1, 1)
            km = KMeans(n_clusters=2, init=centroids, n_init=1, random_state=42)
            labels = km.fit_predict(best_proj[:, None])

            # keep larger lobe, discard smaller one
            counts = np.bincount(labels)
            keep_label = counts.argmax()
            discard_inds_bimodal = np.where(labels != keep_label)[0]

    # ---------------------------------------------------------------
    # 3.  Re-project kept spikes with a fresh PCA if split clusters
    # ---------------------------------------------------------------

    if discard_inds_bimodal.size:
        keep_mask  = np.ones(len(pcs_raw), dtype=bool)
        keep_mask[discard_inds_bimodal] = False
        # Apply to real data and snips_baselined
        snips_baselined = snips_baselined[:, :, keep_mask]
        snips_full = snips_full[:, :, keep_mask]
        valid_spike_times = valid_spike_times[keep_mask]
        spikes = spikes[keep_mask]
        waveforms     = waveforms[keep_mask, :]
        pcs_raw   = PCA(n_components=2, svd_solver='full').fit_transform(waveforms)
        print(f"   Split the final cluster according to Hartigan test! Started with {len(keep_mask)} spikes, ended with {len(spikes)}")

    # ---------------------------------------------------------------
    # 4.  Scale, Mahalanobis
    # ---------------------------------------------------------------

    pcs_z      = StandardScaler().fit_transform(pcs_raw)
    pcs = pcs_z
    
    d2 = np.sum(pcs_z**2, axis=1)

    from scipy.stats import chi2
    thr_vis  = chi2.ppf(0.999,  df=2)     # illustration
    thr_cut  = chi2.ppf(0.9999, df=2)     # discard

    final_outlier_inds      = np.where(d2 > thr_vis)[0]
    final_outlier_inds_max  = np.where(d2 > thr_cut)[0]

    # --------------------------------------------------------------------
    # 5.  Local Outlier Factor
    # --------------------------------------------------------------------
        # ----------------- build robust inlier core -------------------------
    core_mask = d2 < chi2.ppf(0.999, df=2)       # Mahalanobis core (~0.1 % trimmed)

    # ----------------- semi-supervised LOF on the core ------------------
    from sklearn.neighbors import LocalOutlierFactor
    lof = LocalOutlierFactor(n_neighbors=20, novelty=True)
    lof.fit(pcs[core_mask])                    # train ONLY on good spikes
    lof_pred   = lof.predict(pcs)              # −1 = outlier w.r.t. core
    lof_scores = lof.negative_outlier_factor_      # more negative = more outlying
    lof_inds   = np.where(lof_pred == -1)[0]

    final_outlier_inds_max  = np.union1d(final_outlier_inds_max, lof_inds)




    # ------------------------------------------------------------------
    # 1.  Build mean waveform from *accepted* spikes on ref channel
    # ------------------------------------------------------------------
    pre, post        = window                     # snippet definition
    n_channels, snip_len, n_spikes = snips_baselined.shape
    accepted_mask    = np.ones(n_spikes, dtype=bool)
    accepted_mask[final_outlier_inds_max] = False
    accepted_inds    = np.where(accepted_mask)[0]


    if accepted_inds.size:                         # normal case
        mean_ref = snips_baselined[ref_channel, :, accepted_inds].copy()
    else:                                          # degenerate (all rejected)
        mean_ref = snips_baselined[ref_channel, :, :].copy()

    mean_ref = mean_ref.T
    mean_ref = np.mean(mean_ref, axis=1)
    mean_ref = mean_ref-mean_ref[0]

    abs_ref = np.abs(mean_ref)

    # --- thresholds in μV ------------------------------------------------
    max_abs = abs_ref.max()
    low_thr = 0.01 * max_abs
    high_thr = 0.1 * max_abs

    # first index where |WF| > 5 µV  (fall back to 0)
    try:
        idx_start = int(np.where(abs_ref > low_thr)[0][0])
    except IndexError:
        idx_start = 0

    # last index where |WF| > 20 µV  (fall back to end)
    cands = np.where(abs_ref > high_thr)[0]
    idx_end = int(cands[-1]) if cands.size else snippet_len - 1

    print(f"Template masking for reject spikes: {idx_start} to {idx_end}")


    # convert waveform indices → sample offsets relative to spike centre
    # global_sample = spike_time + (wave_idx + pre)
    offset_start = idx_start + pre            # pre negative
    offset_end   = idx_end   + pre            # inclusive
    if offset_start > offset_end:             # safety
        offset_start, offset_end = pre, post

    # ------------------------------------------------------------------
    # 2.  Stash rejects & apply baseline on ref channel
    # ------------------------------------------------------------------
    segment_len = 100_000                      # already used elsewhere

    reject_times = spikes[final_outlier_inds_max]
    n_rej        = len(reject_times)

    for t in reject_times:
        # ---- (a)  log entry ------------------------------------------
        rejected_spike_log['times']  .append(int(t))
        rejected_spike_log['channel'].append(ref_channel)
        rejected_spike_log['unit_id'].append(unit_id)

        # # ---- (b)  baseline value for this spike ----------------------
        # seg_id        = int(t // segment_len)
        # baseline_val  = baselines[ref_channel, seg_id]

        # # ---- (c)  overwrite ref-channel samples with baseline --------
        # t0 = int(t + offset_start)
        # t1 = int(t + offset_end + 1)           # slice end is non-inclusive
        # # bounds check
        # if t0 < 0:            t0 = 0
        # if t1 > raw_data.shape[0]:
        #     t1 = raw_data.shape[0]
        # raw_data[t0:t1, ref_channel] = baseline_val


    # --------------------------------------------------------------------
    # 2.  Mahalanobis distance  (robust covariance)
    # --------------------------------------------------------------------
    # from sklearn.covariance import MinCovDet
    # from scipy.stats import chi2

    # mcd      = MinCovDet().fit(pcs_z)
    # d2       = mcd.mahalanobis(pcs_z)                           # squared distance
    # df       = pcs_z.shape[1]                                   # degrees of freedom
    # thr_maha = chi2.ppf(0.999999, df=df)    # Illustration only
    # final_outlier_inds = np.where(d2 > thr_maha)[0]

    # thr_maha = chi2.ppf(0.9999999, df=df)    # real discards: 0.1 % most extreme  → tweak if desired
    # final_outlier_inds_max = np.where(d2 > thr_maha)[0]

    # print(f"Maha rejects: {len(maha_inds)}")

    # # --------------------------------------------------------------------
    # # 3.  Local Outlier Factor
    # # --------------------------------------------------------------------
    #     # ----------------- build robust inlier core -------------------------
    # core_mask = d2 < chi2.ppf(0.999999, df=df)       # Mahalanobis core (~0.1 % trimmed)

    # # ----------------- semi-supervised LOF on the core ------------------
    # from sklearn.neighbors import LocalOutlierFactor
    # lof = LocalOutlierFactor(n_neighbors=len(maha_inds), novelty=True)
    # lof.fit(pcs_all[core_mask])                    # train ONLY on good spikes
    # lof_pred   = lof.predict(pcs_all)              # −1 = outlier w.r.t. core
    # lof_scores = lof.negative_outlier_factor_      # more negative = more outlying
    # lof_inds   = np.where(lof_pred == -1)[0]


    # print(f"LOF rejects: {len(lof_inds)}")
    


    # --------------------------------------------------------------------
    # 4.  Assemble the two masks you asked for
    # --------------------------------------------------------------------
    # (A) illustration set – only one metric (Mahalanobis here)
    # final_outlier_inds      = maha_inds

    # (B) aggressive discard – any metric
    # final_outlier_inds_max  = np.union1d(maha_inds, lof_inds)

    # --------------------------------------------------------------------
    # 5.  Optional: visualisation helpers
    # --------------------------------------------------------------------
    # pcs_plot = np.column_stack((pc1_z, pc2_z))          # for scatter plots
    # Now you can colour by           : 
    #    * black  = inliers
    #    * red    = final_outlier_inds
    #    * green  = extra LOF-only outliers (final_outlier_inds_max \ final_outlier_inds)


    bad_inds = final_outlier_inds_max #outlier_inds

    # Create mask to keep only good spikes
    keep_mask = np.ones(spikes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces_easy = snips_baselined[ref_channel, :, final_outlier_inds]  # shape: (n_bad, T)
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    if do_second_clustering:
        # Get original traces for bad_spike_traces
        snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
            dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
            spike_times=spikes[bad_inds],
            ref_channel=ref_channel,
            window=window,
            n_channels=512,
            dtype='int16'
        )

        segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
        snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes

    final_spike_inds = np.where(keep_mask)[0]


    elapsed_bad = time.time() - start_time 
    print(f"Distortion handling: {elapsed_bad - elapsed_snippet_extraction:.1f} s, stashed {len(bad_inds)} spikes")

    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.

    if do_second_clustering:

        if len(spikes)<100:
            pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
            labels_post = np.array([0])                    # just one fake cluster label
            sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
            ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
            selected_index_post = 0                        # only one cluster, so index is 0
            cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
            spikes_for_plot_post = np.array([0])           # placeholder spike time
            spike_counts_post = [len(snips)]               # use actual number of spikes
            matches = []                                # no matches
            # `snips_baselined` is [C x T x N]
            # We only subtract on the referent channel to avoid distortion
            template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
            residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

            # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
            # Transpose to match expected shape: (n_spikes, snip_len)
            # force key and lookup to match normal case: np.int64
            ref_channel = np.int64(ref_channel)
            selected_channels = np.array([ref_channel], dtype=np.int64)
            residuals_per_channel = {
                ref_channel: residuals_fallback.T.astype(np.int16)
            }

        else:
            clusters_post, pcs_post, labels_post, sim_matrix_post, n_bad_channels_post, cluster_eis_post, cluster_to_merged_group_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

            # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
            ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)
            
            contributing_original_ids_post = [
                orig_id for orig_id, merged_idx in cluster_to_merged_group_post.items()
                if merged_idx == selected_cluster_index_post
            ]
            contributing_original_ids_post = np.array(contributing_original_ids_post)


            spikes = spikes[final_spike_inds]  # convert to absolute spike times
            snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

            p2p_threshold = 30
            ei_p2p = ei.max(axis=1) - ei.min(axis=1)
            selected_channels = np.where(ei_p2p > p2p_threshold)[0]
            selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

            snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
            #print("reclustered pursuit\n")

            # check for matching KS units
            results = []
            lag = 20
            ks_sim_threshold = 0.75

            # Run comparison
            sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
            matches = [
                {
                    "unit_id": ks_unit_ids[i],
                    "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                    "similarity": float(sim[i]),
                    "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
                }
                for i in np.where(sim > ks_sim_threshold)[0]
            ]
        elapsed_post_clustering= time.time() - start_time 
        print(f"Post-clustering: {elapsed_post_clustering-elapsed_bad:.1f} s")
    else:
        final_ei = np.mean(snips_baselined, axis=2)

        p2p_threshold = 30
        ei_p2p = final_ei.max(axis=1) - final_ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]
        snips_full = snips_full[selected_channels, :, :]
        
        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, final_ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]

        elapsed_post_clustering= time.time() - start_time 
        print(f"Matches: {elapsed_post_clustering-elapsed_bad:.1f} s")


    
    # DIAGNOSTIC PLOTS

    import axolotl_utils_ram
    import importlib
    importlib.reload(axolotl_utils_ram)

    if do_second_clustering:

        axolotl_utils_ram.plot_unit_diagnostics(
            output_path=debug_folder,
            unit_id=unit_id,

            # --- From first call to cluster_spike_waveforms
            pcs_pre=pcs_pre,
            labels_pre=labels_pre,
            sim_matrix_pre=sim_matrix_pre,
            cluster_eis_pre = cluster_eis_pre,
            spikes_for_plot_pre = spikes_for_plot_pre,
            n_bad_channels_pre = n_bad_channels_pre,
            contributing_original_ids_pre = contributing_original_ids_pre,

            # --- From ei_pursuit
            mean_score=mean_score,
            valid_score=valid_score,
            mean_scores_at_spikes=mean_scores_at_spikes,
            valid_scores_at_spikes=valid_scores_at_spikes,
            mean_thresh=mean_thresh,
            valid_thresh=valid_thresh,

            # --- Lag estimation and bad spike filtering
            lags=lags,
            bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
            good_mean_trace=good_mean_trace,
            threshold_ampl=-threshold_ampl,
            ref_channel=ref_channel,
            snips_bad=snips_bad,

            # --- From second clustering
            pcs_post=pcs_post,
            labels_post=labels_post,
            sim_matrix_post=sim_matrix_post,
            cluster_eis_post = cluster_eis_post,
            spikes_for_plot_post = spikes_for_plot_post,
            n_bad_channels_post = n_bad_channels_post,
            contributing_original_ids_post = contributing_original_ids_post,

            # --- For axis labels etc.
            window=(-20, 60),

            ei_positions=ei_positions,
            selected_channels_count=len(selected_channels),

            spikes = spikes, 
            orig_threshold = threshold,
            ks_matches = matches
        )
    else:
        axolotl_utils_ram.plot_unit_diagnostics_single_cluster(
            output_path=debug_folder,
            unit_id=unit_id,

            # --- From first call to cluster_spike_waveforms
            pcs_pre=pcs_pre,
            labels_pre=labels_pre,
            sim_matrix_pre=sim_matrix_pre,
            cluster_eis_pre = cluster_eis_pre,
            spikes_for_plot_pre = spikes_for_plot_pre,
            n_bad_channels_pre = n_bad_channels_pre,
            contributing_original_ids_pre = contributing_original_ids_pre,

            # --- Lag estimation and bad spike filtering
            lags=lags,
            bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
            bad_spike_traces_easy=bad_spike_traces_easy,  # shape: (n_bad, T)
            pcs = pcs,
            outlier_inds_easy = final_outlier_inds,
            outlier_inds = final_outlier_inds_max,
            good_mean_trace=good_mean_trace,
            ref_channel=ref_channel,

            final_ei = final_ei,

            # --- For axis labels etc.

            ei_positions=ei_positions,

            spikes = spikes, 
            orig_threshold = threshold,
            ks_matches = matches
        )




    elapsed_diagnostic= time.time() - start_time 
    print(f"Diagnostics: {elapsed_diagnostic-elapsed_post_clustering:.1f} s")

    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")


    elapsed_saving = time.time() - start_time 
    print(f"Saving: {elapsed_saving-elapsed_diagnostic:.1f} s")



    if len(spikes)>=100:
        
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }


    elapsed_residual = time.time() - start_time 
    print(f"Residual: {elapsed_residual-elapsed_saving:.1f} s")

    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )

    for t in reject_times:
        # ---- (b)  baseline value for this spike ----------------------
        seg_id        = int(t // segment_len)
        baseline_val  = baselines[ref_channel, seg_id]

        # ---- (c)  overwrite ref-channel samples with baseline --------
        t0 = int(t + offset_start)
        t1 = int(t + offset_end + 1)           # slice end is non-inclusive
        # bounds check
        if t0 < 0:            t0 = 0
        if t1 > raw_data.shape[0]:
            t1 = raw_data.shape[0]
        raw_data[t0:t1, ref_channel] = baseline_val
    
    end_time = time.time()
    elapsed = end_time - start_time

    print(f"Subtraction: {elapsed -elapsed_residual:.1f} s")
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")
    # print(f"timing. threshold: {elapsed_thresholding:.1f}, cluster: {elapsed_clustering-elapsed_thresholding:.1f},lags: {elapsed_lags-elapsed_thresholding:.1f},post: {elapsed_post_clustering-elapsed_lags:.1f},diag: {elapsed_diagnostic-elapsed_post_clustering:.1f},.\n")


    # Step 13: Repeat until done
    unit_id += 1
    # if unit_id >= max_units:
    #     print("Reached unit limit.")
    #     break




In [None]:
import matplotlib.pyplot as plt

from scipy.stats import chi2
thr_vis  = chi2.ppf(0.9,  df=2)     # illustration
thr_cut  = chi2.ppf(0.9, df=2)     # discard

final_outlier_inds      = np.where(d2 > thr_vis)[0]
final_outlier_inds_max  = np.where(d2 > thr_cut)[0]




print(f"LOF rejects: {len(lof_inds)}")

final_outlier_inds      = lof_inds


# Assume labels_kmeans contains the 0/1 cluster assignments from k-means on best_proj
# And pcs is your [N, 2] PC1/PC2 array (z-scored)
pcs = pcs_z
plt.figure(figsize=(4, 4))
plt.scatter(pcs[:, 0], pcs[:, 1], s=5, alpha=0.7)
plt.scatter(pcs[final_outlier_inds_max, 0], pcs[final_outlier_inds_max, 1], s=15, color="green",alpha=1)
plt.scatter(pcs[final_outlier_inds, 0], pcs[final_outlier_inds, 1], s=3, color="red",alpha=1)
plt.xlabel("PC1 (z-score)")
plt.ylabel("PC2 (z-score)")
plt.title(f"PC1 vs PC2 scatter\Hartigan p = {p:.3f}")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(6, 5))
plt.scatter(pc1_z, pc2_z, s=2, alpha=0.5)
plt.title("PC1 z-scores for Ref Channel Waveforms")
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.grid(True)
plt.tight_layout()
plt.show()

thresh_pc = 6
final_outlier_mask_max = (np.abs(pc1_z) > thresh_pc) | (np.abs(pc2_z) > thresh_pc) # or -2.5, tune this
final_outlier_inds_max = np.where(final_outlier_mask_max)[0]

print(len(final_outlier_inds_max))
# print(final_outlier_inds_max)

if final_outlier_inds_max.size > 0:
    from scipy.spatial.distance import cdist

    # Combine PC1 and PC2 z-scores
    pcs_all = np.column_stack((pc1_z, pc2_z))

    # Split into bad and good
    bad_pcs = pcs_all[final_outlier_inds_max]
    good_mask = np.ones(len(pcs_all), dtype=bool)
    good_mask[final_outlier_inds_max] = False
    good_pcs = pcs_all[good_mask]

    N = min(20, len(good_pcs))

    bad_good_means = []
    good_good_means = []
    ratios = []

    for bad in bad_pcs:
        # Distance to all good spikes
        dists = np.linalg.norm(good_pcs - bad, axis=1)
        nearest_inds = np.argsort(dists)[:N]
        nearest_good = good_pcs[nearest_inds]

        # Mean distance bad -> nearest N good
        mean_bad_good = np.mean(dists[nearest_inds])
        bad_good_means.append(mean_bad_good)

        # Mean pairwise distance among those N good spikes
        gg_dists = cdist(nearest_good, nearest_good)
        mean_good_good = np.mean(gg_dists[np.triu_indices_from(gg_dists, k=1)])
        good_good_means.append(mean_good_good)

        # Ratio
        ratios.append(mean_bad_good / (mean_good_good + 1e-8))

    bad_good_means = np.array(bad_good_means)
    good_good_means = np.array(good_good_means)
    ratios = np.array(ratios)
    # Summary
    print(f"Mean bad-good distance: {np.mean(bad_good_means):.4f}")
    print(f"Mean good-good distance: {np.mean(good_good_means):.4f}")
    print(f"Mean ratio (bad-good / good-good): {np.mean(ratios):.4f}")

    # Optional: print distribution
    # for i, (bg, gg, r) in enumerate(zip(bad_good_means, good_good_means, ratios)):
    #     print(f"Bad spike {i}: bad-good = {bg:.4f}, good-good = {gg:.4f}, ratio = {r:.2f}")
    # final_outlier_inds_max = final_outlier_inds_max[ratios>3.5]

N = snips_baselined.shape[2] # or however you define total spike count
labels = np.zeros(N, dtype=int)
thresh_pc = 2
final_outlier_mask = (np.abs(pc1_z) > thresh_pc) | (np.abs(pc2_z) > thresh_pc) # or -2.5, tune this
final_outlier_inds = np.where(final_outlier_mask)[0]

labels[final_outlier_inds] = 1

thresh_pc = 6
final_outlier_mask_max = (np.abs(pc1_z) > thresh_pc) | (np.abs(pc2_z) > thresh_pc) # or -2.5, tune this
final_outlier_inds_max = np.where(final_outlier_mask_max)[0]

labels[final_outlier_inds_max] = 2

merged_clusters, sim, n_bad_channels = axolotl_utils_ram.merge_similar_clusters(snips_baselined, labels, max_lag=3, p2p_thresh=30.0, amp_thresh=-20, cos_thresh=0.9)

print(sim)
print(n_bad_channels)
print(N)
print(len(merged_clusters))
for i, arr in enumerate(merged_clusters):
    print(f"Array {i} length: {len(arr)}")



In [None]:
from verify_cluster import verify_cluster

# parameters for verify_cluster
params = {
    'window': (-20, 60),
    'min_spikes': 100,
    'ei_sim_threshold': 0.75,
    'k_start': 10,
    'k_refine': 4
}

# choose cell ID (this is index of the cell, not vision ID)
spike_times = spikes

# run recursive clustering
clusters = verify_cluster(
    spike_times=spike_times,
    dat_path=dat_path,
    params=params
)

print(f"Returned {len(clusters)} clean subclusters")
for i, cl in enumerate(clusters):
    print(f"  Cluster {i}: {len(cl['inds'])} spikes")

print("Success")

In [None]:
import analyze_clusters
import importlib
importlib.reload(analyze_clusters)

# plot EI (Vision style), ISI, firing rate, time course (from one pixel) and STA (single frame with strongest pixel)
analyze_clusters.analyze_clusters(clusters,
                 spike_times=spike_times,
                 sampling_rate=20000,
                 dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
                 h5_path='/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5',
                 triggers_mat_path='/Volumes/Lab/Users/alexth/axolotl/trigger_in_samples_201703151.mat',
                 cluster_ids=None,
                 lut=None,
                 sta_depth=30,
                 sta_offset=0,
                 sta_chunk_size=1000,
                 sta_refresh=2,
                 ei_scale=3,
                 ei_cutoff=0.08)

### END of main pipeline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from plot_ei_waveforms import plot_ei_waveforms

# --- PC1 vs PC2 scatter plot ---
plt.figure(figsize=(6, 5))
unique_labels = np.unique(labels_pre)
print(unique_labels)
colors = plt.cm.tab10.colors  # or any colormap you like

for i, label in enumerate(unique_labels):
    mask = labels_pre == label
    color = colors[i % len(colors)]
    plt.scatter(pcs_pre[mask, 0], pcs_pre[mask, 1], s=10, color=color, alpha=0.7, label=f"Cluster {label}")

plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("Cluster PCA Scatter")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# --- Plot EIs for each cluster ---
for i, ei in enumerate(cluster_eis_pre):
    plt.figure(figsize=(15,5))
    plot_ei_waveforms(
        ei=ei,                 # list of EIs
        positions=ei_positions,
        ref_channel=ref_channel,
        scale=70.0,
        box_height=1.5,
        box_width=50,
        linewidth=0.5,
        alpha=0.9,
        colors='black'
    )


In [None]:
sampling_rate = 20000

# --- Load spike times and electrode positions from HDF5 ---
all_spikes = {}
with h5py.File(h5_in_path, 'r') as f:
    unit_ids = sorted(f['/spikes'].keys(), key=lambda x: int(x.split('_')[1]))
    for uid in unit_ids:
        unit_index = int(uid.split('_')[1])
        raw = f[f'/spikes/{uid}'][:]
        if raw.ndim == 1 and raw.shape[0] == 1:
            spikes_sec = np.array(raw[0]).flatten()
        else:
            spikes_sec = np.array(raw).flatten()
        spikes_samples = np.round(spikes_sec * sampling_rate).astype(np.int32)
        all_spikes[unit_index] = spikes_samples


In [None]:
ks_spikes = all_spikes[56]
print(len(ks_spikes))

import numpy as np
import matplotlib.pyplot as plt


ref_channel = 51

# Define window around each spike
pre, post = 20, 60
snip_len = pre + post + 1

# Allocate array to hold all snippets
snippets = []

for s in ss:#ks_spikes[:1000]:
    if s - pre >= 0 and s + post < raw_data.shape[0]:
        snippet = raw_data[s - pre : s + post + 1, ref_channel]
        snippets.append(snippet)

snippets = np.array(snippets)  # shape: [n_spikes, snip_len]

# Plot all snippets
plt.figure(figsize=(10, 9))
for i in range(snippets.shape[0]):
    plt.plot(snippets[i], color='black', alpha=0.1, linewidth=0.5)

plt.title(f"Ref channel {ref_channel} waveforms at {len(snippets)} spikes")
plt.xlabel("Sample index (relative to spike)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
ks_spikes = all_spikes[56]

import numpy as np
import matplotlib.pyplot as plt


ref_channel = 51
amps = raw_data[ks_spikes, ref_channel]
plt.hist(amps, bins=100)
plt.axvline(-50, color='red', linestyle='--')



In [None]:
ks_spikes = all_spikes[56]
# ks_spikes = ks_spikes[amps > -50]

print(len(ks_spikes))

ref_channel = 51
fs = 20000  # sampling rate in Hz
segment_len = fs  # 1 second = 20,000 samples

trace = raw_data[:, ref_channel].astype(np.float32)
n_samples = len(trace)

# Create baseline-corrected version of the trace
trace_corrected = np.empty_like(trace)

n_segments = (n_samples + segment_len - 1) // segment_len  # ceil division
for i in range(n_segments):
    start = i * segment_len
    end = min(start + segment_len, n_samples)
    segment = trace[start:end]
    trace_corrected[start:end] = segment - np.mean(segment)

# Now extract snippets from corrected trace
pre, post = 20, 60
snip_len = pre + post + 1
snippets = []

for s in ks_spikes[:1000]:
    if s - pre >= 0 and s + post < trace_corrected.shape[0]:
        snippet = trace_corrected[s - pre : s + post + 1]
        snippets.append(snippet)

snippets = np.array(snippets)  # shape: [n_spikes, snip_len]

# Plot
plt.figure(figsize=(10, 9))
for i in range(snippets.shape[0]):
    plt.plot(snippets[i], color='black', alpha=0.1, linewidth=0.5)

plt.title(f"Ref channel {ref_channel} waveforms at {len(snippets)} spikes\n(1s segment-wise baseline subtraction)")
plt.xlabel("Sample index (relative to spike)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
amps = trace_corrected[ks_spikes]
plt.hist(amps, bins=100)
plt.axvline(-50, color='red', linestyle='--')


In [None]:
import numpy as np
import h5py

unit_id_1 = 1011
max_diff = 10 # in samples

# --- Load spike times ---
with h5py.File(h5_out_path, 'r') as h5:
    spikes_1 = h5[f'unit_{unit_id_1}']['spike_times'][:]


spikes_2 = ks_spikes
# --- Sort spike times (for efficiency) ---
spikes_1 = np.sort(spikes_1)
spikes_2 = np.sort(spikes_2)

# --- Match spikes within max_diff ---
matches = []
j_start = 0

for i, t1 in enumerate(spikes_1):
    while j_start < len(spikes_2) and spikes_2[j_start] < t1 - max_diff:
        j_start += 1

    j = j_start
    while j < len(spikes_2) and spikes_2[j] <= t1 + max_diff:
        if abs(t1 - spikes_2[j]) <= max_diff:
            matches.append((i, j))
        j += 1

# --- Output ---
print(f"Found {len(matches)} matched spikes between unit {unit_id_1} and ks")
for i, j in matches:
    print(f"Spike1: {spikes_1[i]}, Spike2: {spikes_2[j]}")

# --- Build set of matched indices in spikes_2 ---
matched_j = set(j for _, j in matches)

# --- Extract unmatched spike times from spikes_2 ---
unmatched_spikes_2 = [spikes_2[j] for j in range(len(spikes_2)) if j not in matched_j]

print(f"Found {len(unmatched_spikes_2)} unmatched spikes in unit ks")
unmatched_spikes_2 = np.array(unmatched_spikes_2)


### Read from h5 results

In [None]:
import h5py

unit_id = 40 # or whatever unit you're checking

with h5py.File(h5_out_path, 'r') as h5:
    group = h5[f'unit_{unit_id}']
    
    spike_times = group['spike_times'][:]
    ei = group['ei'][:]
    selected_channels = group['selected_channels'][:]
    
    peak_channel = group.attrs['peak_channel']

# Check shapes or values
print("Spike times:", spike_times.shape)
#print("EI shape:", ei.shape)
# print("Selected channels:", selected_channels)
# print(179 in selected_channels)
#print("Peak channel:", peak_channel)
print(spike_times)


unit_id = 11 # or whatever unit you're checking

with h5py.File(h5_out_path, 'r') as h5:
    group = h5[f'unit_{unit_id}']
    
    spike_times1 = group['spike_times'][:]
    ei = group['ei'][:]
    selected_channels = group['selected_channels'][:]
    
    peak_channel = group.attrs['peak_channel']

# Check shapes or values
print("Spike times:", spike_times1.shape)

### plot one channel

In [None]:
# Plot
# ei = np.mean(snips_baselined, axis=2)
plt.figure(figsize=(7, 4))
plt.plot(ei[125,:], color='black', linewidth=1)
plt.plot(ei1[125,:], color='red', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt


ref_channel = 51

# Define window around each spike
pre, post = 20, 60
snip_len = pre + post + 1

# Allocate array to hold all snippets
snippets = []

for s in spikes_1[:50]:
    if s - pre >= 0 and s + post < raw_data.shape[0]:
        snippet = raw_data[s - pre : s + post + 1, ref_channel]
        snippets.append(snippet)

snippets = np.array(snippets)  # shape: [n_spikes, snip_len]

# Plot all snippets
plt.figure(figsize=(10, 5))
for i in range(snippets.shape[0]):
    plt.plot(snippets[i], color='black', alpha=0.1, linewidth=0.5)


snippets = []

for s in unmatched_spikes_2[:50]:
    if s - pre >= 0 and s + post < raw_data.shape[0]:
        snippet = raw_data[s - pre : s + post + 1, ref_channel]
        snippets.append(snippet)

snippets = np.array(snippets)  # shape: [n_spikes, snip_len]

# Plot all snippets
for i in range(snippets.shape[0]):
    plt.plot(snippets[i], color='red', alpha=0.1, linewidth=1)

plt.title(f"Ref channel {ref_channel} waveforms at {len(snippets)} spikes")
plt.xlabel("Sample index (relative to spike)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
snips, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
    raw_data=raw_data,
    spike_times=ks_spikes[mean_scores_at_spikes<20000],
    window=window,
    selected_channels=np.arange(512)
)

segment_len = 100_000
snips_baselined = snips.copy()  # shape (n_channels, 81, N)
n_channels, snip_len, n_spikes = snips_baselined.shape

# Determine segment index for each spike
segment_indices = valid_spike_times // segment_len  # shape: (n_spikes,)

# Loop through channels and subtract baseline per spike
for ch in range(n_channels):
    snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


ei = np.mean(snips_baselined, axis=2)
#ei -= ei[:, :5].mean(axis=1, keepdims=True)

from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plot_ei_waveforms(
    ei=ei,
    positions=ei_positions,
    scale=70.0,
    box_height=1.0,
    box_width=50,
    linewidth=0.5,
    alpha=0.9,
    colors='black'
)
plt.title("EI")
plt.tight_layout()
plt.show()

In [None]:
import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

(
spikes,
mean_score,
valid_score,
mean_scores_at_spikes,
valid_scores_at_spikes,
mean_thresh,
valid_thresh
) = axolotl_utils_ram.ei_pursuit_ram(
    raw_data=raw_data,
    spikes=ks_spikes,                     # absolute sample times
    ei_template=ei,                    # EI from selected cluster
    save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
    alignment_offset = -window[0],
    fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
    sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
    return_debug=True, 

)

In [None]:
print(len(spikes))
ss=ks_spikes[mean_scores_at_spikes<20000]
print(len(ss))

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 5))
plt.hist(mean_scores_at_spikes, bins=200, alpha=0.5, label='KS spike scores', color='red')
plt.xlabel("Mean EI Match Score")
plt.ylabel("Count")
plt.title("Mean EI Scores: Global vs. KS-aligned")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
print(f"NaNs in scores: {np.isnan(scores).sum()}")


In [None]:
from scipy.stats import norm
import numpy as np

fit_percentile = 40
sigma_thresh = 5.0

scores = mean_scores_at_spikes  # KS spike scores

clean_scores = mean_scores_at_spikes[~np.isnan(mean_scores_at_spikes)]

# 1. Determine percentile cutoff
cutoff = np.percentile(clean_scores, fit_percentile, method='nearest')
print(cutoff)

# 2. Select left tail
left_tail = clean_scores[clean_scores <= cutoff]

# 3. Fit normal distribution to tail
mu, sigma = norm.fit(left_tail)

# 4. Compute final threshold
threshold = mu - sigma_thresh * sigma

print(f"Fitted mu = {mu:.3f}, sigma = {sigma:.3f}")
print(f"Computed threshold = {threshold:.3f}")
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.hist(left_tail, bins=100, density=True, alpha=0.5, label="Left tail")

# Overlay fitted Gaussian
x = np.linspace(left_tail.min(), left_tail.max(), 200)
pdf = norm.pdf(x, mu, sigma)
plt.plot(x, pdf, 'r-', label=f"Fit: μ={mu:.2f}, σ={sigma:.2f}")

plt.axvline(threshold, color='red', linestyle='--', label=f"Threshold = {threshold:.2f}")
plt.xlabel("Score")
plt.ylabel("Density")
plt.title("Fit to Left Tail of Scores (KS Spikes)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import h5py

unit_id_1 = 1000
unit_id_2 = 11
max_diff = 10 # in samples

# --- Load spike times ---
with h5py.File(h5_out_path, 'r') as h5:
    spikes_1 = h5[f'unit_{unit_id_1}']['spike_times'][:]
    spikes_2 = h5[f'unit_{unit_id_2}']['spike_times'][:]

# --- Sort spike times (for efficiency) ---
spikes_1 = np.sort(spikes_1)
spikes_2 = np.sort(spikes_2)

# --- Match spikes within max_diff ---
matches = []
j_start = 0

for i, t1 in enumerate(spikes_1):
    while j_start < len(spikes_2) and spikes_2[j_start] < t1 - max_diff:
        j_start += 1

    j = j_start
    while j < len(spikes_2) and spikes_2[j] <= t1 + max_diff:
        if abs(t1 - spikes_2[j]) <= max_diff:
            matches.append((i, j))
        j += 1

# --- Output ---
print(f"Found {len(matches)} matched spikes between unit {unit_id_1} and {unit_id_2}")
for i, j in matches:
    print(f"Spike1: {spikes_1[i]}, Spike2: {spikes_2[j]}")

# --- Build set of matched indices in spikes_2 ---
matched_j = set(j for _, j in matches)

# --- Extract unmatched spike times from spikes_2 ---
unmatched_spikes_2 = [spikes_2[j] for j in range(len(spikes_2)) if j not in matched_j]

print(f"Found {len(unmatched_spikes_2)} unmatched spikes in unit {unit_id_2}")
unmatched_spikes_2 = np.array(unmatched_spikes_2)


### plot EI

In [None]:
from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plot_ei_waveforms(
    ei=ei,
    positions=ei_positions,
    scale=70.0,
    box_height=1.0,
    box_width=50,
    linewidth=0.5,
    alpha=0.9,
    colors='black'
)
plt.title("EI")
plt.tight_layout()
plt.show()

In [None]:
from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

for i, cluster in enumerate(clusters_pre):
    ei = cluster['ei']
    ref_ch = cluster['channels'][np.argmax(np.ptp(ei[cluster['channels'], :], axis=1))]
    ei_p2p = np.ptp(ei[ref_ch, :])
    n_spikes = len(cluster['inds'])

    plt.figure(figsize=(15, 5))
    plot_ei_waveforms(
        ei=ei,
        positions=ei_positions,
        scale=70.0,
        box_height=1.0,
        box_width=50,
        linewidth=0.5,
        alpha=0.9,
        colors='black'
    )
    plt.title(f"Cluster {i} EI — Spikes: {n_spikes}, P2P on Ref Ch ({ref_ch}): {ei_p2p:.1f}")
    plt.tight_layout()
    plt.show()


In [None]:
import axolotl_utils_ram
import importlib
importlib.reload(axolotl_utils_ram)

clusters_pre, pcs_pre, labels_pre, sim_matrix_pre, cluster_eis_pre  = axolotl_utils_ram.cluster_spike_waveforms(snips_baselined, ei, k_start=3,return_debug=True)


In [None]:
# Plot
plt.figure(figsize=(7, 4))
plt.plot(cluster_eis_pre[2][148,:], color='black', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:

ei_a = clusters_pre[0]['ei']
ei_b = clusters_pre[1]['ei']


result = axolotl_utils_ram.compare_ei_subtraction(ei_a, ei_b, max_lag=3, p2p_thresh=30.0)

res = np.array(result['per_channel_residuals'])
cos_sim = np.mean(result['per_channel_cosine_sim'])

neg_inds = np.where(res < -10)[0]

print(len(neg_inds))
print(cos_sim)



In [None]:

ei_a = cluster_eis_pre[0]
ei_b = cluster_eis_pre[2]


result = axolotl_utils_ram.compare_ei_subtraction(ei_a, ei_b, max_lag=3, p2p_thresh=30.0)

res = np.array(result['per_channel_residuals'])
cos_sim = np.mean(result['per_channel_cosine_sim'])

neg_inds = np.where(res < -10)[0]

print(len(neg_inds))
print(cos_sim)


In [None]:
from plot_ei_waveforms import plot_ei_waveforms
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plot_ei_waveforms(
    ei=cluster_eis_pre[2],
    positions=ei_positions,
    scale=70.0,
    box_height=1.0,
    box_width=50,
    linewidth=0.5,
    alpha=0.9,
    colors='black'
)
plt.title(f"Cluster {i} EI — Spikes: {n_spikes}, P2P on Ref Ch ({ref_ch}): {ei_p2p:.1f}")
plt.tight_layout()
plt.show()

In [None]:
# Plot
plt.figure(figsize=(7, 4))
plt.plot(ei_a[148,:], color='black', linewidth=1)
plt.plot(ei_b[148,:], color='red', linewidth=1)
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:

global_cosine_sim = np.mean(result['per_channel_cosine_sim'])

amp_threshold = -10
cos_threshold = 0.9

res = result['per_channel_residuals']
neg_inds = np.where(np.array(res) < amp_threshold)[0]
if global_cosine_sim < cos_threshold or len(neg_inds) > 0:
    print('two units')
else:
    print('same unit')
    

In [None]:

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.bar(result['good_channels'], result['per_channel_residuals'], color='gray')
plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
plt.xlabel("Channel ID")
plt.ylabel("Mean Residual (B - A)")
plt.title("Per-Channel Residuals (Masked Subtraction)")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()


plt.figure(figsize=(10, 4))
plt.bar(result['good_channels'], result['per_channel_cosine_sim'], color='gray')
plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
plt.xlabel("Channel ID")
plt.ylabel("Mean Residual (B - A)")
plt.title("Per-Channel cosine_sim (Masked Subtraction)")
plt.grid(True, axis='y')
plt.tight_layout()
plt.show()

# plt.figure(figsize=(10, 4))
# plt.bar(result['good_channels'], result['p2p'], color='gray')
# plt.axhline(0, color='black', linewidth=0.8, linestyle='--')
# plt.xlabel("Channel ID")
# plt.ylabel("Mean Residual (B - A)")
# plt.title("Per-Channel amplitude")
# plt.grid(True, axis='y')
# plt.tight_layout()
# plt.show()


In [None]:

    ei, spikes_idx, selected_channels, selected_cluster_index_pre = axolotl_utils_ram.select_cluster_with_largest_waveform(clusters_pre, ref_channel)

    spikes_init = spike_times[spikes_idx]

    if do_pursuit:
        (
        spikes,
        mean_score,
        valid_score,
        mean_scores_at_spikes,
        valid_scores_at_spikes,
        mean_thresh,
        valid_thresh
        ) = axolotl_utils_ram.ei_pursuit_ram(
            raw_data=raw_data,
            spikes=spikes_init,                     # absolute sample times
            ei_template=ei,                    # EI from selected cluster
            save_prefix='/Volumes/Lab/Users/alexth/axolotl/ei_scan_unit0',  # set uniquely per unit
            alignment_offset = -window[0],
            fit_percentile = 40,                # how many (percentile) spikes to take to fit Gaussian for threshold determination (left-hand side of already found spikes)
            sigma_thresh = 5.0,                  # how many Gaussian sigmas to take for threshold
            return_debug=True, 

        )
    else:
        spikes = spikes_init
        mean_score=None
        valid_score=None
        mean_scores_at_spikes=spikes
        valid_scores_at_spikes=None
        mean_thresh=None
        valid_thresh=None

    # Step 9a: Extract full snippets from final spike times

    snips_ref_channel, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.array([ref_channel]),
        window=window,
    )

    snips_ref_channel = snips_ref_channel.transpose(2, 0, 1)


    lags = axolotl_utils_ram.estimate_lags_by_xcorr_ram(
        snippets=snips_ref_channel,                # shape [N x C x T]
        peak_channel_idx=0,                 # 0 because the only channel that gets passed is the referent channel
        window=(-5, 10),                  # optional, relative to peak
        max_lag=6,                        # optional, max xcorr shift
    )

    spikes = spikes+lags

    snips_full, valid_spike_times = axolotl_utils_ram.extract_snippets_ram(
        raw_data=raw_data,
        spike_times=spikes,
        selected_channels=np.arange(n_channels),
        window=window,
    )


    segment_len = 100_000
    snips_baselined = snips_full.copy()  # shape (n_channels, 81, N)
    n_channels, snip_len, n_spikes = snips_baselined.shape

    # Determine segment index for each spike
    segment_indices = spikes // segment_len  # shape: (n_spikes,)

    # Loop through channels and subtract baseline per spike
    for ch in range(n_channels):
        snips_baselined[ch, :, :] -= baselines[ch, segment_indices][None, :]


    # Extract baseline-subtracted waveforms for ref_channel
    ref_snips = snips_baselined[ref_channel, :, :]  # shape: (81, N)

    # Mean waveform over all spikes
    ref_mean = ref_snips.mean(axis=1)  # shape: (81,)
    # Negative peak (should be near index 20)
    ref_peak_amp = np.abs(ref_mean[-window[0]])  # scalar

    # Threshold at 0.75× of mean waveform peak
    threshold_ampl = 0.75 * ref_peak_amp

    # Get all actual spike values at sample 20
    spike_amplitudes = np.abs(ref_snips[20, :])  # shape: (N,)

    # Flag bad spikes: too small
    bad_inds = np.where(spike_amplitudes < threshold_ampl)[0]

    # Create mask to keep only good spikes
    keep_mask = np.ones(spike_amplitudes.shape[0], dtype=bool)
    keep_mask[bad_inds] = False

    # --- Extract bad spike traces for plotting
    bad_spike_traces = snips_baselined[ref_channel, :, bad_inds]  # shape: (n_bad, T)

    # Get original traces for bad_spike_traces
    snips_bad = axolotl_utils_ram.extract_snippets_single_channel(
        dat_path='/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat',
        spike_times=spikes[bad_inds],
        ref_channel=ref_channel,
        window=window,
        n_channels=512,
        dtype='int16'
    )

    segment_indices = spikes[bad_inds] // segment_len  # shape: (n_spikes,)
    snips_bad[0, :, :] -= baselines[ref_channel, segment_indices][None, :]


    # Apply to real data and snips_baselined
    snips_baselined = snips_baselined[:, :, keep_mask]
    good_mean_trace = np.mean(snips_baselined[ref_channel, :, :], axis=1)
    snips_full = snips_full[:, :, keep_mask]
    valid_spike_times = valid_spike_times[keep_mask]
    spikes = spikes[keep_mask]

    spikes_for_plot_post = spikes

    final_spike_inds = np.where(keep_mask)[0]


        


In [None]:
params = {
    'window': (-20, 60),
    'min_spikes': 100,
    'ei_sim_threshold': 0.75,
    'k_start': 4,
    'k_refine': 2
}

from verify_cluster import verify_cluster

spike_times = spikes
clusters = verify_cluster(
    spike_times=spike_times,
    dat_path=snips_baselined,
    params=params
)

print(f"Returned {len(clusters)} clean subclusters")
for i, cl in enumerate(clusters):
    print(f"  Cluster {i}: {len(cl['inds'])} spikes")

In [None]:

import analyze_clusters
import importlib
importlib.reload(analyze_clusters)


analyze_clusters.analyze_clusters(clusters,
                 spike_times=spikes,
                 sampling_rate=20000,
                 dat_path=snips_baselined,
                 h5_path='/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5',
                 triggers_mat_path='/Volumes/Lab/Users/alexth/axolotl/trigger_in_samples_201703151.mat',
                 cluster_ids=None,
                 lut=None,
                 sta_depth=30,
                 sta_offset=0,
                 sta_chunk_size=1000,
                 sta_refresh=2,
                 ei_scale=3,
                 ei_cutoff=0.08)

In [None]:
tmp = snips_baselined[ref_channel, 20, :].copy()
import matplotlib.pyplot as plt
plt.plot(tmp)

final_spike_inds = np.arange(len(spikes))


In [None]:
plt.figure(figsize=(6, 4))
plt.hist(tmp, bins=50, color='gray', edgecolor='black')
plt.title("Histogram of tmp values")
plt.xlabel("Amplitude")
plt.ylabel("Count")
plt.grid(True)
plt.show()


In [None]:
inds = np.where(tmp > -500)[0]
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 4))
plt.plot(snips_baselined[39, :, inds].T, alpha=1)
plt.plot(snips_baselined[39, :, :6], alpha=1)
plt.title(f"Overlay of {len(inds)} selected snippets on channel 39")
plt.xlabel("Time (samples)")
plt.ylabel("Amplitude")
plt.grid(True)
plt.show()


In [None]:


    # Step 9b: Recluster - choose k. snips_full is all channels, baselined - relevant cahnnels will be subselected in the function.


    if len(spikes)<100:
        pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
        labels_post = np.array([0])                    # just one fake cluster label
        sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
        ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
        selected_index_post = 0                        # only one cluster, so index is 0
        cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
        spikes_for_plot_post = np.array([0])           # placeholder spike time
        spike_counts_post = [len(snips)]               # use actual number of spikes
        matches = []                                # no matches
        # `snips_baselined` is [C x T x N]
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }

    else:
        clusters_post, pcs_post, labels_post, sim_matrix_post, cluster_eis_post  = axolotl_utils_ram.cluster_spike_waveforms(snips=snips_baselined, ei=ei, k_start=2,return_debug=True)

        # Step 9c: choose the best cluster - choose similarity threshold. EI is all channels, baselined
        ei, final_spike_inds, selected_channels, selected_cluster_index_post = axolotl_utils_ram.select_cluster_by_ei_similarity_ram(clusters=clusters_post,reference_ei=ei,similarity_threshold=0.95)


        spikes = spikes[final_spike_inds]  # convert to absolute spike times
        snips_baselined = snips_baselined[:,:,final_spike_inds] # cut only the ones that survived

        p2p_threshold = 30
        ei_p2p = ei.max(axis=1) - ei.min(axis=1)
        selected_channels = np.where(ei_p2p > p2p_threshold)[0]
        selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

        #print("reclustered pursuit\n")

        # check for matching KS units
        results = []
        lag = 20
        ks_sim_threshold = 0.75

        # Run comparison
        sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
        matches = [
            {
                "unit_id": ks_unit_ids[i],
                "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
                "similarity": float(sim[i]),
                "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
            }
            for i in np.where(sim > ks_sim_threshold)[0]
        ]



In [None]:
p2p_threshold = 30
ei_p2p = ei.max(axis=1) - ei.min(axis=1)
selected_channels = np.where(ei_p2p > p2p_threshold)[0]
selected_channels = selected_channels[np.argsort(ei_p2p[selected_channels])[::-1]]

#print("reclustered pursuit\n")

# check for matching KS units
results = []
lag = 20
ks_sim_threshold = 0.75

# Run comparison
sim = compare_eis(ks_ei_stack, ei, lag).squeeze() # shape: (num_KS_units,)
matches = [
    {
        "unit_id": ks_unit_ids[i],
        "vision_id": int(ks_vision_ids[ks_unit_ids[i]].item()),
        "similarity": float(sim[i]),
        "n_spikes": int(ks_n_spikes[ks_unit_ids[i]])
    }
    for i in np.where(sim > ks_sim_threshold)[0]
]

pcs_post = np.zeros((1, 2))                    # shape: (N_spikes, 2 PCs)
labels_post = np.array([0])                    # just one fake cluster label
sim_matrix_post = np.zeros((1, 1))             # fake 1×1 similarity matrix
ei_clusters_post = [np.zeros((512, 81))]       # fake EI for the “post” cluster
selected_index_post = 0                        # only one cluster, so index is 0
cluster_eis_post = [np.zeros((512, 81))]       # same dummy EI
spikes_for_plot_post = np.array([0])           # placeholder spike time
spike_counts_post = [len(snips)]               # use actual number of spikes

In [None]:


    # DIAGNOSTIC PLOTS

    axolotl_utils_ram.plot_unit_diagnostics(
        output_path=debug_folder,
        unit_id=unit_id,

        # --- From first call to cluster_spike_waveforms
        pcs_pre=pcs_pre,
        labels_pre=labels_pre,
        sim_matrix_pre=sim_matrix_pre,
        cluster_eis_pre = cluster_eis_pre,
        spikes_for_plot_pre = spikes_for_plot_pre,

        # --- From ei_pursuit
        mean_score=mean_score,
        valid_score=valid_score,
        mean_scores_at_spikes=mean_scores_at_spikes,
        valid_scores_at_spikes=valid_scores_at_spikes,
        mean_thresh=mean_thresh,
        valid_thresh=valid_thresh,

        # --- Lag estimation and bad spike filtering
        lags=lags,
        bad_spike_traces=bad_spike_traces,  # shape: (n_bad, T)
        good_mean_trace=good_mean_trace,
        threshold_ampl=-threshold_ampl,
        ref_channel=ref_channel,
        snips_bad=snips_bad,

        # --- From second clustering
        pcs_post=pcs_post,
        labels_post=labels_post,
        sim_matrix_post=sim_matrix_post,
        cluster_eis_post = cluster_eis_post,
        spikes_for_plot_post = spikes_for_plot_post,

        # --- For axis labels etc.
        window=(-20, 60),

        ei_positions=ei_positions,
        selected_channels_count=len(selected_channels),

        spikes = spikes, 
        orig_threshold = threshold,
        ks_matches = matches
    )


    # Step 10: Save unit metadata
    try:
        with h5py.File(h5_out_path, 'a') as h5:
            group = h5.require_group(f'unit_{unit_id}')

            for name, data in [
                ('spike_times', spikes.astype(np.int32)),
                ('ei', ei.astype(np.float32)), # EI is already baselined
                ('selected_channels', selected_channels.astype(np.int32))
            ]:
                if name in group:
                    del group[name]
                group.create_dataset(name, data=data)

            group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))
            # group.create_dataset('spike_times', data=spikes.astype(np.int32))
            # group.create_dataset('ei', data=ei.astype(np.float32))
            # group.create_dataset('selected_channels', data=selected_channels.astype(np.int32))
            # group.attrs['peak_channel'] = int(np.argmax(np.ptp(ei, axis=1)))

        #print(f"Exported unit_{unit_id} with {len(spikes)} spikes.")

    except KeyboardInterrupt:
        print("\nKeyboard interrupt detected — exiting safely before write completes.")

    except Exception as e:
        print(f"\nUnexpected error while saving unit_{unit_id}: {e}")



In [None]:


    if len(spikes)>=100:
        snips_full = snips_full[np.ix_(selected_channels, np.arange(snips_full.shape[1]), final_spike_inds)]
        snips_full = snips_full.transpose(2, 0, 1) # [C × T × N] → [N × C × T]

            # --- Setup ---
        residuals_per_channel = {}
        cluster_ids_per_channel = {}
        scale_factors_per_channel = {}

        for ch_idx, ch in enumerate(selected_channels):
            # Slice data for this channel
            ch_snips = snips_full[:, ch_idx, :]  # shape: (n_spikes, snip_len)
            ch_baselines = baselines[ch, :]    # shape: (n_segments,)

            # Subtract PCA cluster means
            residuals, cluster_ids, scale_factors = axolotl_utils_ram.subtract_pca_cluster_means_ram(
                snippets=ch_snips,
                baselines=ch_baselines,
                spike_times=spikes,
                segment_len=100_000,  # must match what was used to generate baselines
                n_clusters=5,
                offset_window=(-10,40)
            )

            # Store results
            residuals_per_channel[ch] = residuals
            cluster_ids_per_channel[ch] = cluster_ids
            scale_factors_per_channel[ch] = scale_factors
    else:
        
        # We only subtract on the referent channel to avoid distortion
        template_fallback = np.mean(snips_baselined[ref_channel], axis=1)  # shape: (T,)
        residuals_fallback = snips_baselined[ref_channel] - template_fallback[:, None]  # shape: (T, N)

        # Assume residuals_fallback is (T, N) from previous step (template-subtracted waveforms)
        # Transpose to match expected shape: (n_spikes, snip_len)
        # force key and lookup to match normal case: np.int64
        ref_channel = np.int64(ref_channel)
        selected_channels = np.array([ref_channel], dtype=np.int64)
        residuals_per_channel = {
            ref_channel: residuals_fallback.T.astype(np.int16)
        }


    # end_time = time.time()
    # elapsed = end_time - start_time 
    # print(f"Finished preprocessing, starting edits. Elapsed: {elapsed:.1f} seconds.")
    # Step 12: edit raw data
    write_locs = spikes + window[0]
    axolotl_utils_ram.apply_residuals(
        raw_data=raw_data,
        dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat',
        residual_snips_per_channel=residuals_per_channel,
        write_locs=write_locs,
        selected_channels=selected_channels,
        total_samples=raw_data.shape[0],
        dtype = np.int16,
        n_channels = n_channels,
        is_ram=True,
        is_disk=False
    )
    end_time = time.time()
    elapsed = end_time - start_time
    print(f"Processed unit {unit_id} with {len(spikes)} final spikes in {elapsed:.1f} seconds.\n")


    # Step 13: Repeat until done
    unit_id += 1
    # if unit_id >= max_units:
    #     print("Reached unit limit.")
    #     break

