In [None]:
# Plot positions of selected electrodes
plt.figure(figsize=(6, 6))
plt.scatter(ei_positions[selected_channels_plot, 0], ei_positions[selected_channels_plot, 1], c='m', s=80)
for i in selected_channels_plot:
    x, y = ei_positions[i]
    plt.text(x + 5, y, str(i), fontsize=8)
plt.title(f"Selected electrodes for unit {unit_idx}")
plt.xlabel("x")
plt.ylabel("y")
plt.axis("equal")
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Select a few example channels to plot (e.g., first 5)
channels_to_plot = selected_by_ei[:5]

plt.figure(figsize=(10, 6))

for i, ch in enumerate(channels_to_plot):
    ei_idx = np.where(selected_by_ei == ch)[0][0]  # find row in ei_waveforms

    wf = ei_waveforms[ei_idx, :]
    mask = ei_masks[ei_idx, :]

    plt.subplot(len(channels_to_plot), 1, i + 1)
    plt.plot(wf, label=f'Electrode {ch}', color='black')
    plt.plot(np.where(mask)[0], wf[mask], 'ro', markersize=3, label='masked')
    plt.axhline(0, color='gray', linestyle='--', linewidth=0.5)
    plt.legend(loc='upper right')

plt.suptitle("EI waveforms (masked in red)")
plt.tight_layout()
plt.show()


In [None]:
print(f"Detected {len(final_spike_times)} spike candidates.")

print("Total local peaks:", len(peaks))
print("Above threshold:", np.sum(mean_score[peaks] > score_threshold))
print("Valid channel count > 3:", np.sum(valid_score[peaks] > 3))
print("Final spike count:", len(final_spike_times))

print("score_threshold:", score_threshold)

print("Num spikes:", len(final_spike_times))
print("Min:", np.min(final_spike_times))
print("Max:", np.max(final_spike_times))
print("First 20:", final_spike_times[:20])

In [None]:
import scipy.io as sio

sio.savemat('/Volumes/Lab/Users/alexth/axolotl/score_debug.mat', {
    'mean_score': mean_score.astype(np.float32),
    'valid_score': valid_score.astype(np.int32)
})


In [None]:
import scipy.io as sio
import numpy as np

sio.savemat('/Volumes/Lab/Users/alexth/axolotl/final_spike_times.mat', {
    'final_spike_times': np.array(final_spike_times, dtype=np.int64)
})

In [None]:
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans

# Subset to selected channels
snips_sel = snips[selected_by_ei, :, :]  # [C x T x N]
n_spikes = snips_sel.shape[2]

# Reshape to [N x (C*T)] for PCA
snips_flat = snips_sel.transpose(2, 0, 1).reshape(n_spikes, -1)

# Run PCA
pca = PCA(n_components=10)
pcs = pca.fit_transform(snips_flat)

# Run k-means clustering
k = 8
kmeans = KMeans(n_clusters=k, n_init=10, random_state=1)
labels = kmeans.fit_predict(pcs)


In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 6))
plt.scatter(pcs[:, 0], pcs[:, 1], s=5, alpha=0.6)
plt.xlabel("PC1")
plt.ylabel("PC2")
plt.title("PCA: PC1 vs PC2")
plt.grid(True)
plt.show()


In [None]:
import torch

snips_torch = torch.from_numpy(snips)  # shape [C x T x N]

ei_per_cluster = []
for i in range(k):
    inds = torch.where(torch.tensor(labels) == i)[0]
    if len(inds) == 0:
        ei = torch.zeros_like(torch.from_numpy(ei_template))
    else:
        snips_i = snips_torch[:, :, inds]
        ei = torch.mean(snips_i, dim=2)
        ei = ei - ei[:, :5].mean(dim=1, keepdim=True)
    ei_per_cluster.append(ei.numpy())


In [None]:
import matplotlib.pyplot as plt

n_clusters = len(ei_per_cluster)
top_channels = selected_by_ei[:5]  # just a few for visual clarity

fig, axs = plt.subplots(n_clusters, len(top_channels), figsize=(15, 2.5 * n_clusters), sharex=True, sharey=True)

for i, ei in enumerate(ei_per_cluster):
    for j, ch in enumerate(top_channels):
        axs[i, j].plot(ei[ch, :])
        axs[i, j].set_title(f"Cl {i} | Ch {ch}")
plt.tight_layout()
plt.show()

In [None]:
from save_eis_for_matlab import save_eis_for_matlab

save_path = '/Volumes/Lab/Users/alexth/axolotl/eis_for_matlab.mat'
save_eis_for_matlab(ei_per_cluster, save_path)

In [None]:
from plot_ei_python import plot_ei_python

ei = ei_per_cluster[0]
plot_ei_python(ei, ei_positions, label=selected_by_ei, scale=25, neg_color='red', pos_color='black')


In [None]:
import matplotlib.pyplot as plt
from plot_ei_python import plot_ei_python

n_eis = 8
rows = 4
cols = 2

fig, axs = plt.subplots(rows, cols, figsize=(15, 2.5 * n_eis))
axs = axs.flatten()

for i, ax in enumerate(axs[:n_eis]):
    ei = ei_per_cluster[i]
    title = f"Cluster {i}, {np.sum(labels == i)} spikes"
    plot_ei_python(ei, ei_positions, label=selected_by_ei, scale=25,
                   neg_color='red', pos_color='black', ax=ax)
    ax.set_title(title, fontsize=10)

plt.tight_layout()
plt.show()


In [None]:
from compare_eis import compare_eis

similarity_matrix = compare_eis(ei_per_cluster, ei_template)

np.set_printoptions(precision=2, suppress=True)
print(similarity_matrix)

In [None]:
import refine_cluster
import importlib
importlib.reload(refine_cluster)

In [None]:
from plot_ei_python import plot_ei_python

plot_ei_python(ei_numpy, ei_positions, label=selected_by_ei, scale=25, neg_color='red', pos_color='black')

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.plot(mean_score[:10_000])
plt.title("First 10,000 values of mean_score")
plt.xlabel("Sample index")
plt.ylabel("Template match score")
plt.grid(True)
plt.tight_layout()
plt.show()

In [None]:
from run_template_scoring_gpu import run_template_scoring_gpu

# Dummy inputs (just to test reading and baseline calculation)
block_megabytes = 1  # will read ~1000 samples = 1MB
selected_channels = selected_by_ei  # from earlier
start_sample = 0
total_samples = 36_000_000  # small test value

# Use your real template + mask + norm arrays (already built)
mean_score, max_score, valid_score = run_template_scoring_gpu(
    dat_path,
    ei_template=ei_template[selected_channels, :],
    ei_masks=ei_masks,
    ei_norms=ei_norms,
    selected_channels=selected_channels,
    start_sample=start_sample,
    total_samples=total_samples,
    dtype='int16',
    block_size=None
)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.plot(valid_score)
plt.title("Template matching score across time")
plt.xlabel("Timepoint (sample index)")
plt.ylabel("Mean dot product")
plt.grid(True)
plt.show()