## Objective
Become acquainted with recordings and using them for training and testing DL model

## Imports

In [2]:
# region Set up notebook imports
%load_ext autoreload
%autoreload 2
# Reload a module after changes have been made
from importlib import reload
# endregion

# Imports
import numpy as np
import torch

from utils import random_seed
import data, meta

ModuleNotFoundError: No module named 'torch'

In [2]:
import sys
for p in sys.path:
    print(p)
    

/data/MEAprojects/DLSpikeSorter/src
/data/MEAprojects/DLSpikeSorter/src/analysis/recording
/home/mea/anaconda3/envs/dl_ss/lib/python38.zip
/home/mea/anaconda3/envs/dl_ss/lib/python3.8
/home/mea/anaconda3/envs/dl_ss/lib/python3.8/lib-dynload

/home/mea/anaconda3/envs/dl_ss/lib/python3.8/site-packages


In [1]:
%autoreload 
import data
data.test()

UsageError: Line magic function `%autoreload` not found.


## Load and format one recording

#### File system notes:
1. Delete spike_band_scaled_filtered.dat from spikesort_matlab4 folder because it is not needed and only takes up storage

In [2]:
from spikeinterface import load_extractor
from spikeinterface.extractors import BinaryRecordingExtractor

In [3]:
MOUSE_PATH = meta.SI_MOUSE_PATHS[0]
##
recording_si = load_extractor(MOUSE_PATH / "spike_band.si")  # type: BinaryRecordingExtractor
# np.save(MOUSE_PATH / "recording.npy", recording_si.get_traces(return_scaled=False).T)

In [None]:
# Format all SI mouse recordings
# from tdqm import tqdm
# for mouse_path in tqdm(meta.SI_MOUSE_PATHS[1:]):
#     recording_si = load_extractor(mouse_path / "spike_band.si")
#     np.save(mouse_path / "recording.npy", recording_si.get_traces(return_scaled=False).T)
#     (mouse_path / "spikesort_matlab4/spike_band_scaled_filtered.dat").unlink()

In [9]:
sorted = np.load(MOUSE_PATH / "spikesort_matlab4/results/sorted.npz", allow_pickle=True)
sorted["spike_times"].size

1473546

## Load MultiRecordingDataset

In [4]:
multi_rec = data.MultiRecordingDataset(
    samples_per_waveform=2, front_buffer=60, end_buffer=60,
    num_wfs_probs=[0, 1], isi_wf_min=3, isi_wf_max=None,
    data_paths=meta.SI_MOUSE_PATHS[:1], thresh_amp=0, thresh_std=9999,
    sample_size=300, start=0, ms_before=3, ms_after=3, 
    device="cpu", dtype=torch.float16, mmap_mode="r"
)

In [5]:
alpha_to_waveform_dict = multi_rec.get_alpha_to_waveform_dict()

In [9]:
sample = multi_rec[0]
# multi_rec.plot_sample(*sample, alpha_to_waveform_dict)

please work


AssertionError: 

## Load RecordingCrossVal

In [None]:
# Get the recording data
rec_cross_val = data.RecordingCrossVal(
    sample_size=200, front_buffer=40, end_buffer=40,
    num_wfs_probs=[0, 0, 1], isi_wf_min=5, isi_wf_max=None,
    thresh_amp=5, thresh_std=2,  #amp=1.2
    samples_per_waveform=2, data_root=ROOT,
    mmap_mode="r", device="cpu", as_datasets=True,
)
# rec_cross_val.summary()

In [None]:
# Get rec
FOLD = "2954"
##
rec, train, val = rec_cross_val[FOLD]
print(f"Val Recording: {rec} - Train: {len(train)} samples - Val: {len(val)} - Train/Val: {len(val)/len(train) * 100:.1f}%")
alpha_to_wf_dict = train.get_alpha_to_waveform_dict()

In [None]:
# Plot samples
NUM_SAMPLES = 1
plot.set_dpi(400)
random_seed(2)
##
sample_ind = np.random.choice(range(1, len(train), 2), NUM_SAMPLES)  # Only plot samples with waveforms
for i in sample_ind:
    train.plot_sample(*train[i], alpha_to_wf_dict)

In [None]:
# Plot template waveforms
wf_dataset = data.WaveformDataset(ROOT / "2954/sorted.npz", thresh_amp=1.2, thresh_std=2, plotting=True)
print(f"Number of waveforms: {len(wf_dataset)}")
wf_dataset.plot_waveforms()

In [None]:
from src.plot import set_dpi
from src.utils import FACTOR_UV
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.axes._axes import Axes
from matplotlib.figure import Figure
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm
def plot_waveform_channels(kilosort_npz_path, thresh_amp, thresh_std,
                           scale_v, scale_h, xlim=None, ylim=None,
                           curation_colors=("red", "black")):
    """
    Plot the mean waveforms (which are defined on different channels) of a unit at their spatial locations
    and color code them based on whether they are curated
    #
    :param kilosort_npz_path:
        spikesort_matlab4.py --> sorted.npz
    :param thresh_amp:
        Minimum amplitude (in arbitrary scaled down units)
    :param thresh_std:
        Minimum standard deviation in amplitude divided by amplitude
    :param scale_v:
        Vertical stretch factor of wf
    :param scale_h:
        Horizontal stretch factor of wf
    :param xlim:
        of plot
        None --> use auto-generated lim
    :param ylim:
        of plot
        None --> use auto-generated lim
    :param curation_colors:
        (failed color, passed color)
    """

    npz = np.load(kilosort_npz_path, allow_pickle=True)
    locations = npz["locations"]
    # Iterate through units
    for i_u, unit in enumerate(npz["units"]):
        # Set up plot
        fig, ax = plt.subplots(1)  # type: Figure, Axes
        ax.set_title(i_u)
        ax.set_aspect("equal")

        if xlim is not None:
            ax.set_xlim(*xlim)
        if ylim is not None:
            ax.set_ylim(*ylim)

        # Get waveform data
        templates = unit["template"].T   # (n_channels, n_samples)
        center_i = templates.shape[1] // 2
        is_curated = (unit["amplitudes"] >= thresh_amp) * (unit["std_norms"] <= thresh_std)  # type: np.ndarray

        # Plot waveforms
        for i, wf in enumerate(templates):
            color = curation_colors[int(is_curated[i])]

            wf *= scale_v
            loc = locations[i]

            x = np.arange(wf.size, dtype=float) - center_i
            x *= scale_h

            x += loc[0]
            wf += loc[1]
            ax.plot(x, wf, color=color)

        # Add scalebar
        fontprops = fm.FontProperties(size=18)
        scalebar = AnchoredSizeBar(ax.transData,
                                   FACTOR_UV * thresh_amp, '20 m', 'lower right',
                                   pad=0.1,
                                   color='black',
                                   frameon=False,
                                   size_vertical=1,
                                   fontproperties=fontprops)
        ax.add_artist(scalebar)

        plt.tight_layout()
        plt.show()
        break

plot_waveform_channels(
    kilosort_npz_path=ROOT / "2954/sorted.npz",
    thresh_amp=1.2, thresh_std=2,
    scale_v=1, scale_h=1,
    xlim=(700, 800), ylim=(250, 500),
    curation_colors=("red", "black")
)