## Objective
RT-Sort manuscript figure 1B

In [1]:
import matplotlib.pyplot as plt
import numpy as np

from spikeinterface.extractors import MaxwellRecordingExtractor

from tqdm import tqdm

REC_PATH = "/data/MEAprojects/RT-Sort/figures/data/2953/data.raw.h5"
# Should be shape (201, 1020) = (n_samples, n_elecs) and scaled to uV
TEMPLATES_AVERAGE_PATH = "/data/MEAprojects/RT-Sort/figures/data/2953/templates_average.npy"
TEMPLATES_STD_PATH = "/data/MEAprojects/RT-Sort/figures/data/2953/templates_std.npy"

In [2]:
recording = MaxwellRecordingExtractor(REC_PATH)
locations = recording.get_channel_locations()
# Need to invert y-cords
locations[:, 1] = 2100 - locations[:, 1]

In [3]:
# Find unit based on max-amp channel by specifying location boundaries of where it could be
X_MIN = 400
X_MAX = 450
Y_MIN = 1500  
Y_MAX = 1600  
##
all_templates = np.load(TEMPLATES_AVERAGE_PATH, mmap_mode="r")
for i in range(all_templates.shape[0]):
    chan = np.argmin(np.min(all_templates[i], axis=0))
    x, y = locations[chan]
    if X_MIN <= x <= X_MAX and Y_MIN <= y <= Y_MAX:
        print(i, locations[chan])

92 [ 437.5 1557.5]
93 [ 437.5 1557.5]
97 [ 402.5 1522.5]
111 [ 437.5 1557.5]


In [30]:
# XLIM = (190, 550)  # when using 2ms before and after
XLIM = (175, 560)  # when using 1ms before and 2ms after
YLIM = (1400, 1700)
AMP_THRESH = 19
NORM_STD_THRESH = 0.6
UNIT_IDX = 111

SAVE_PATH = "/data/MEAprojects/RT-Sort/figures/1B_1ms_before_2ms_after.svg"
##

# Plot waveforms
templates_mean = np.load(TEMPLATES_AVERAGE_PATH, mmap_mode="r")[UNIT_IDX].T  # peak is at frame 100
templates_std = np.load(TEMPLATES_STD_PATH, mmap_mode="r")[UNIT_IDX].T
arange = np.arange(templates_mean.shape[0])
peaks = np.argmax(np.abs(templates_mean), axis=1)
amps = np.abs(templates_mean[arange, peaks])
curation = (amps >= AMP_THRESH) * (templates_std[arange, peaks] / amps <= NORM_STD_THRESH)
# templates_mean = templates_mean[:, 60:-60]  # 2ms before and after
templates_mean = templates_mean[:, 80:-60]  # 1ms before and 2ms after

fig, ax = plt.subplots(1)
for chan in tqdm(range(len(curation))):
    x, y = locations[chan]
    if not (XLIM[0] < x < XLIM[1] and YLIM[0] < y < YLIM[1]): continue

    color = "#7542ff" if curation[chan] else "black"

    # if color == "black":
    #     print(x, y, amps[chan])

    temp = templates_mean[chan]

    temp = temp * 3 / 6.295
    loc = locations[chan]

    # # When using 2ms before and after
    # x = np.arange(temp.size, dtype=float) - 100
    # x *= 0.35  
    
    # When using 1ms before and 2ms after
    x = np.arange(temp.size, dtype=float) - 20
    x *= 0.5

    x += loc[0]
    temp += loc[1]

    # for wf in waveforms[:, :, chan]:
    #     ax.plot(x, wf*3/6.295 + loc[1], color="#bbbbbb")
    ax.plot(x, temp, color=color, alpha=1)

ax.set_ylim(YLIM)
ax.set_xlim(XLIM)

# Hide top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

# Increase thickness of the bottom and left spines
ax.spines["bottom"].set_linewidth(2.5)
ax.spines["left"].set_linewidth(2.5)

# Increase thickness of tick marks
ax.tick_params(axis='both', direction='out', length=6, width=2.5, colors='black')

# Hide labels
ax.set_title("")
ax.set_xlabel("")
ax.set_ylabel("")

ax.set_xticks([])
ax.set_yticks([])

plt.savefig(SAVE_PATH, format="svg")
plt.close()

plt.show()

100%|██████████| 1020/1020 [00:00<00:00, 34752.65it/s]
