### load spike times and EI positions

In [1]:
import h5py
import numpy as np

# --- Parameters ---
dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat'  # Update this
h5_in_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_kilosort_data001_spike_times.h5'  # from MATLAB export

sampling_rate = 20000  # Hz

# --- Load spike times from HDF5 ---
all_spikes = {}
with h5py.File(h5_in_path, 'r') as f:
    unit_ids = sorted(f['/spikes'].keys(), key=lambda x: int(x.split('_')[1]))
    for uid in unit_ids:
        unit_index = int(uid.split('_')[1])
        raw = f[f'/spikes/{uid}'][:]
        if raw.ndim == 1 and raw.shape[0] == 1:
            spikes_sec = np.array(raw[0]).flatten()
        else:
            spikes_sec = np.array(raw).flatten()
        spikes_samples = np.round(spikes_sec * sampling_rate).astype(np.int32)
        all_spikes[unit_index] = spikes_samples

    # Load electrode positions
    ei_positions = f['/ei_positions'][:].T  # shape becomes [512 x 2]



### load parasol list from mat file

In [2]:
import scipy.io

mat_path = '/Volumes/Lab/Users/alexth/axolotl/parasol_list_data001.mat'
mat_contents = scipy.io.loadmat(mat_path)

parasol_list = mat_contents['parasol_list'].squeeze()

### order parasol list by EI amplitude

In [3]:
import h5py
import numpy as np

h5_ei_path = '/Volumes/Lab/Users/alexth/axolotl/ks_eis_subset.h5'

ks_templates = {}
with h5py.File(h5_ei_path, 'r') as f:
    for k in f.keys():
        ks_templates[int(k.split('_')[1])] = f[k][:]
ks_unit_ids = list(ks_templates.keys())
ks_ei_stack = np.stack([ks_templates[k] for k in ks_unit_ids], axis=0)  # [N x 512 x 81]

# Compute peak-to-peak per channel, per unit
peak_to_peak = ks_ei_stack.max(axis=2) - ks_ei_stack.min(axis=2)  # shape: [N x 512]

# Get maximum value across all channels per unit
unit_max_amp = peak_to_peak.max(axis=1)  # shape: [N]

# Get sorting indices (descending)
sorted_indices = np.argsort(unit_max_amp)[::-1]

# Reorder unit IDs by amplitude
ks_unit_ids_sorted = [ks_unit_ids[i] for i in sorted_indices]
unit_max_amp_sorted = unit_max_amp[sorted_indices]

# Filter sorted KS IDs to keep only those in parasol_list
parasol_set = set(parasol_list)
parasol_list_ordered = [ks_id for ks_id in ks_unit_ids_sorted if ks_id in parasol_set]

### Generate EIs from original file for the parasol list

In [None]:
from extract_data_snippets import extract_snippets
import numpy as np
import h5py

# not baseline corrected!

# --- Parameters ---
window = (-60, 90)
n_channels = 512
dtype = 'int16'

# Path to raw data
dat_path_orig = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001.dat'

# Output file path
ei_save_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_parasol_eis.h5'

with h5py.File(ei_save_path, 'w') as f:
    for ks_id in parasol_list_ordered:
        spike_times = all_spikes[ks_id]

        # Remove ISI violations (refractory < 60 samples)
        cleaned_spikes = [spike_times[0]]
        for t in spike_times[1:]:
            if t - cleaned_spikes[-1] >= 60:
                cleaned_spikes.append(t)
        cleaned_spikes = np.array(cleaned_spikes)

        print(f"Removed {len(spike_times) - len(cleaned_spikes)} spikes from unit {ks_id}")

        snips = extract_snippets(dat_path_orig, cleaned_spikes, window=window, n_channels=n_channels, dtype=dtype)
        ei = np.mean(snips, axis=2)  # shape: [512 x T]
        f.create_dataset(f'unit_{ks_id}', data=ei)


Removed 1 spikes from unit 54
Removed 3 spikes from unit 187
Removed 6 spikes from unit 801
Removed 0 spikes from unit 265
Removed 8 spikes from unit 216
Removed 12 spikes from unit 648
Removed 2 spikes from unit 337
Removed 2 spikes from unit 124
Removed 1 spikes from unit 157
Removed 11 spikes from unit 440
Removed 0 spikes from unit 6
Removed 169 spikes from unit 179
Removed 3 spikes from unit 386
Removed 1 spikes from unit 856
Removed 0 spikes from unit 596
Removed 7 spikes from unit 190
Removed 2 spikes from unit 158
Removed 0 spikes from unit 766
Removed 4 spikes from unit 649
Removed 0 spikes from unit 297
Removed 0 spikes from unit 132
Removed 13 spikes from unit 534
Removed 0 spikes from unit 315
Removed 17 spikes from unit 584
Removed 9 spikes from unit 438
Removed 50 spikes from unit 253
Removed 2 spikes from unit 350
Removed 0 spikes from unit 196
Removed 4 spikes from unit 258
Removed 0 spikes from unit 716
Removed 0 spikes from unit 141
Removed 1 spikes from unit 97
Remov

### subtract one by one and plot results

In [5]:
import prepare_subtraction_templates
import importlib
importlib.reload(prepare_subtraction_templates)
import subtract_unit_from_raw
importlib.reload(subtract_unit_from_raw)
from extract_data_snippets import extract_snippets
import numpy as np
import h5py
import matplotlib.pyplot as plt
from plot_ei_waveforms import plot_ei_waveforms  # assuming this is already imported
import os

# --- Parameters ---s
window = (-60, 90)
n_channels = 512
dtype = 'int16'
template_window = (-20, 60)
subtraction_window = (-10, 30)
EI_threshold = 10
n_bins = 3
template_center = 20
total_samples = 36_000_000  # adjust as needed

dat_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_sub.dat'
ei_mod_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_parasol_eis_mod.h5'
ei_sub_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_parasol_eis_sub.h5'
ei_orig_path = '/Volumes/Lab/Users/alexth/axolotl/201703151_data001_parasol_eis.h5'
save_dir = '/Volumes/Lab/Users/alexth/axolotl/201703151_parasol_ei_WVs'
os.makedirs(save_dir, exist_ok=True)

with h5py.File(ei_mod_path, 'w') as f_mod, h5py.File(ei_sub_path, 'w') as f_sub:
    for ks_id in parasol_list_ordered:
        spike_times = all_spikes[ks_id]

        # Remove ISI violations (<60 samples apart)
        cleaned_spikes = [spike_times[0]]
        for t in spike_times[1:]:
            if t - cleaned_spikes[-1] >= 60:
                cleaned_spikes.append(t)
        cleaned_spikes = np.array(cleaned_spikes)
        print(f"Removed {len(spike_times) - len(cleaned_spikes)} spikes from unit {ks_id}")

        # --- EI before subtraction ---
        snips_mod = extract_snippets(dat_path, cleaned_spikes, window=window, n_channels=n_channels, dtype=dtype)
        ei_mod = np.mean(snips_mod, axis=2)
        ei_mod -= ei_mod[:, :5].mean(axis=1, keepdims=True)
        f_mod.create_dataset(f'unit_{ks_id}', data=ei_mod)

        print(ks_id)

        # Load original EI
        with h5py.File(ei_orig_path, 'r') as f_orig:
            ei_orig = f_orig[f'unit_{ks_id}'][:]
            ei_orig -= ei_orig[:, :5].mean(axis=1, keepdims=True)


        ei_ptp = ei_orig.max(axis=1) - ei_orig.min(axis=1)
        ref_channel = np.argmax(ei_ptp)
        ref_waveform = ei_orig[ref_channel]
        true_peak_sample = np.argmin(ref_waveform)  # sample index of negative deflection
        expected_center = 60  # center of (-60, 90) window
        align_shift = true_peak_sample - expected_center

        if align_shift != 0:
            print(f"Adjusting spike times by {align_shift} samples (soma peak at {true_peak_sample})")
            spike_times_aligned = cleaned_spikes.copy()
            spike_times_aligned = spike_times_aligned + align_shift
        else:
            spike_times_aligned = cleaned_spikes.copy()


        # --- Get templates ---
        ei_template, ei_channels, x_shifts, templates_by_chan, template_ids_by_chan, y_shifts_by_chan = prepare_subtraction_templates.prepare_subtraction_templates(
            dat_path=dat_path,
            spike_times=spike_times_aligned,
            EI_threshold=EI_threshold,
            window=template_window,
            max_lag=3,
            n_bins=n_bins,
            n_channels=n_channels,
            dtype=dtype,
            subtraction_window=(10, 50),
            save_dir=save_dir,
            ks_id=ks_id
        )

        # --- Subtract ---
        subtract_unit_from_raw.subtract_unit_from_raw(
            dat_path=dat_path,
            spike_times=spike_times_aligned,
            x_shifts=x_shifts,
            templates=templates_by_chan,
            template_ids=template_ids_by_chan,
            ei_waveform=ei_template,
            unit_id=ks_id,
            target_channels=ei_channels,
            start_sample=0,
            total_samples=total_samples,
            template_center=template_center,
            subtraction_window=subtraction_window,
            n_channels=n_channels,
            dtype=dtype
        )

        # --- EI after subtraction ---
        snips_sub = extract_snippets(dat_path, cleaned_spikes, window=window, n_channels=n_channels, dtype=dtype)
        ei_sub = np.mean(snips_sub, axis=2)
        ei_sub -= ei_sub[:, :5].mean(axis=1, keepdims=True)
        f_sub.create_dataset(f'unit_{ks_id}', data=ei_sub)



        # Plot and save
        plt.figure(figsize=(20, 20))
        plot_ei_waveforms([ei_orig, ei_mod, ei_sub], ei_positions,
                        colors=['black', 'red', 'blue'], scale=70, box_height=1.0, box_width=50)

        plt.title(f"Unit {ks_id} | Removed {len(spike_times) - len(cleaned_spikes)} spikes", fontsize=16)
        save_path = os.path.join(save_dir, f'unit_{ks_id}.png')
        plt.savefig(save_path, bbox_inches='tight', pad_inches=0.1, dpi=150)

        plt.close()


Removed 1 spikes from unit 54
54
Adjusting spike times by -8 samples (soma peak at 52)
23 18
Unit 54 subtraction complete.
Removed 3 spikes from unit 187
187
Adjusting spike times by -1 samples (soma peak at 59)
23 18
Unit 187 subtraction complete.
Removed 6 spikes from unit 801
801
Adjusting spike times by -14 samples (soma peak at 46)
23 18
Unit 801 subtraction complete.
Removed 0 spikes from unit 265
265
Adjusting spike times by -7 samples (soma peak at 53)
23 18
Unit 265 subtraction complete.
Removed 8 spikes from unit 216
216
Adjusting spike times by -1 samples (soma peak at 59)
23 18
Unit 216 subtraction complete.
Removed 12 spikes from unit 648
648
17 23
Unit 648 subtraction complete.
Removed 2 spikes from unit 337
337
Adjusting spike times by -1 samples (soma peak at 59)
24 18
Unit 337 subtraction complete.
Removed 2 spikes from unit 124
124
Adjusting spike times by -1 samples (soma peak at 59)
23 18
Unit 124 subtraction complete.
Removed 1 spikes from unit 157
157
Adjusting sp