## Objective 
Save data in form needed for semi-artificial (sometimes called synthetic or artificial) data

### Save traces
Stored as .npy and NOT scaled to uV

In [1]:
# Setup
import h5py
from multiprocessing import Pool
import numpy as np

from tqdm import tqdm


def _save_traces_mea_old(task):
    rec_path, save_path, start_frame, chan_ind, chunk_start, chunk_size, gain = task
    sig = h5py.File(rec_path, 'r')['sig']
    traces = sig[chan_ind, chunk_start:chunk_start+chunk_size].astype("float16") # * gain
    saved_traces = np.load(save_path, mmap_mode="r+")
    saved_traces[:, chunk_start-start_frame:chunk_start-start_frame+chunk_size] = traces
        
def save_traces_mea_old(rec_path, save_path,
                        start_ms=0, end_ms=None, samp_freq=20,  # kHz
                        default_gain=1,
                        chunk_size=100000,
                        num_processes=16):
    """
    This only works for the old format of Maxwell MEA .h5 files
    """
    
    start_frame = round(start_ms * samp_freq)

    recording = h5py.File(rec_path, 'r')

    if end_ms is None:
        end_frame = recording['sig'].shape[1]
    else:
        end_frame = round(end_ms * samp_freq)

    chan_ind = []
    for mapping in recording['mapping']:  # (chan_idx, elec_id, x_cord, y_cord)
        if mapping[1] != -1:
            chan_ind.append(mapping[0])
    if 'lsb' in recording['settings']:
        gain = recording['settings']['lsb'][0] * 1e6    
    else:
        gain = default_gain
        print(f"'lsb' not found in 'settings'. Setting gain to uV to {gain}")

    print("Alllocating memory for traces ...")
    traces = np.zeros((len(chan_ind), end_frame-start_frame), dtype="float16")
    np.save(save_path, traces)
    del traces
    
    print("Extracting traces ...")
    tasks = [(rec_path, save_path, start_frame, chan_ind, chunk_start, chunk_size, gain) 
             for chunk_start in range(start_frame, end_frame, chunk_size)]
    with Pool(processes=num_processes) as pool:
        for _ in tqdm(pool.imap_unordered(_save_traces_mea_old, tasks), total=len(tasks)):
            pass

In [2]:
for rec in ['2950', '2953', '2954', '2957', '5116', '5118']:
    save_traces_mea_old(f"/data/MEAprojects/DLSpikeSorter/data/{rec}/data.raw.h5", f"/data/MEAprojects/DLSpikeSorter/data/{rec}/traces.npy")

Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 36/36 [00:32<00:00,  1.12it/s]


Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 36/36 [00:34<00:00,  1.05it/s]


Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 36/36 [00:48<00:00,  1.34s/it]


Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 36/36 [00:35<00:00,  1.01it/s]


Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 36/36 [00:35<00:00,  1.01it/s]


Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 37/37 [00:34<00:00,  1.08it/s]


### Find TRAINING_MEDIAN needed to run DL model on real recordings
For run_dl_model() in RT-Sort

In [36]:
# For MEA model
SAMP_FREQ = 20  # kHz
FIRST_MS = 50  # Used to estimate median of first ms
##
"""
Using IQR because it is ilke standardizing by dividing STD, except IQR is less affected by spikes in first 50ms
"""

import numpy as np
import scipy.stats
all_medians = []
for rec in ['2950', '2953', '2954', '2957', '5116', '5118']:
    traces = np.load(f"/data/MEAprojects/DLSpikeSorter/data/{rec}/traces.npy", mmap_mode="r")
    window = traces[:, :round(FIRST_MS*SAMP_FREQ)]
    iqrs = scipy.stats.iqr(window, axis=1)
    median = np.median(iqrs)
    all_medians.append(median)
print(f"Inference base scaling: {np.mean(all_medians)}")

Inference base scaling: 12.67


In [3]:
# For neuropixels model
SAMP_FREQ = 30  # kHz
FIRST_MS = 50  # Used to estimate median of first ms
GAIN_TO_UV = 0.195
##
"""
Using IQR because it is ilke standardizing by dividing STD, except IQR is less affected by spikes in first 50ms
"""
import numpy as np
import scipy.stats

all_medians = []
for rec in [
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592315",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592318",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592320",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592324",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592328",
    "/data/MEAprojects/buzsaki/SiegleJ/AllenInstitute_744912849/session_766640955/probe_773592330",
]:
    traces = np.load(f"{rec}/traces.npy", mmap_mode="r")
    window = traces[:, :round(FIRST_MS*SAMP_FREQ)] * GAIN_TO_UV
    iqrs = scipy.stats.iqr(window, axis=1)
    median = np.median(iqrs)
    all_medians.append(median)
print(f"Inference base scaling: {np.mean(all_medians)}")
print("TRUNCATE THIS NUMBER SO THAT THERE IS ONLY ONE DECIMAL PLACE (DO NOT ROUND)")
print("Otherwise, RT-Sort's performance on /data/MEAprojects/primary_mouse/patch_ground_truth/200724/2602/patch_rec_cell7.raw.h5 decreases dramatically (not sure whyy)")

Inference base scaling: 15.45375


### Scale sorted.npz to uV
Unit templates need to be in uV in sortd.npz, but older versions have them in the MEA's arbitrary units

In [20]:
GAIN_TO_UV = 6.29425
##
for rec in ['2950', '2953', '2954', '2957', '5116', '5118']:
    npz = np.load(f"/data/MEAprojects/DLSpikeSorter/data/{rec}/sorted.npz", allow_pickle=True)
    npz = dict(npz)
    for unit in npz['units']:
        unit['template'] *= GAIN_TO_UV
        unit['amplitudes'] *= GAIN_TO_UV
        # Not need to change 'std_norms' since change to std is cancelled by change to amplitude normalizing the std
    np.savez(f"/data/MEAprojects/DLSpikeSorter/data/{rec}/sorted.npz", **npz)
    