In [1]:
%load_ext autoreload
%autoreload 2

# Create 400 neuron dataset

In [None]:
import os

import colorcet as cc
import matplotlib.pyplot as plt
import numpy as np
import torch
from analysis.data_gen_utils import (
    all_units_except,
    combine_datasets,
    download_IBL,
    extract_IBL,
    make_dataset,
)
from analysis.projections import learn_manifold_umap, pca, pca_train
from ceed.models.ceed import CEED
from mpl_toolkits.mplot3d import Axes3D

In [None]:
pid_sess1 = 'dab512bd-a02d-4c1f-8dbc-9155a163efc0'
pid_sess2 = 'febb430e-2d50-4f83-87a0-b5ffbb9a4943'
save_folder_sess1 = "/media/cat/data/IBL_data_CEED/dab512bd-a02d-4c1f-8dbc-9155a163efc0"
save_folder_sess2 = "/media/cat/data/IBL_data_CEED/febb430e-2d50-4f83-87a0-b5ffbb9a4943"
t_window = [0, 1200]  # in seconds
overwrite = False
rec1, meta_file_sess1 = download_IBL(
    pid=pid_sess1, t_window=t_window, save_folder=save_folder_sess1, overwrite=overwrite
)
rec2, meta_file_sess2 = download_IBL(
    pid=pid_sess2, t_window=t_window, save_folder=save_folder_sess2, overwrite=overwrite
)

In [None]:
"""extract the all data needed to make CEED dataset
spike_idx_sess: spike_times, channels, neurons (if use_labels=True)
geom_sess: channels x 2
chan_idx_sess: waveform extraction channels for each channel
templates_sess: templates across all channels for all neurons
"""
recompute = True

if recompute:
    spike_idx_sess1, geom_sess1, chan_idx_sess1, templates_sess1 = extract_IBL(
        rec=rec1,
        meta_fp=meta_file_sess1,
        pid=pid_sess1,
        t_window=t_window,
        use_labels=True,
    )
    spike_idx_sess2, geom_sess2, chan_idx_sess2, templates_sess2 = extract_IBL(
        rec=rec2,
        meta_fp=meta_file_sess2,
        pid=pid_sess2,
        t_window=t_window,
        use_labels=True,
    )
    np.save("spike_idx_sess1.npy", spike_idx_sess1)
    np.save("geom_sess1.npy", geom_sess1)
    np.save("chan_idx_sess1.npy", chan_idx_sess1)
    np.save("templates_sess1.npy", templates_sess1)
    np.save("spike_idx_sess2.npy", spike_idx_sess2)
    np.save("geom_sess2.npy", geom_sess2)
    np.save("chan_idx_sess2.npy", chan_idx_sess2)
    np.save("templates_sess2.npy", templates_sess2)
else:
    spike_idx_sess1 = np.load("spike_idx_sess1.npy")
    geom_sess1 = np.load("geom_sess1.npy")
    chan_idx_sess1 = np.load("chan_idx_sess1.npy")
    templates_sess1 = np.load("templates_sess1.npy")
    spike_idx_sess2 = np.load("spike_idx_sess2.npy")
    geom_sess2 = np.load("geom_sess2.npy")
    chan_idx_sess2 = np.load("chan_idx_sess2.npy")
    templates_sess2 = np.load("templates_sess2.npy")

In [None]:
# session DYO16 units to get data from and dataset save path
dy016_unit_ids_path = os.path.join(
    os.getcwd(), "400neuron_unit_ids", "dy016_unit_ids.npy"
)
selected_units_sess1 = np.load(dy016_unit_ids_path)
dataset_folder_sess1 = save_folder_sess1 + "/ds"

# make first dataset for training with DY016 units
# will create a folder with the spike, probe channel number, and corresponding channel location datasets in the train, val, test splits
# optionally also saves out spatial and temporal noise covariance matrices
inference = False
train_num = 200
val_num = 0
test_num = 200
save_covs = False
num_chans_extract = 21
normalize = False  # True for cell-type dataset
shift = False
save_fewer = False
(
    train_set1,
    val_set1,
    test_set1,
    train_geom_locs1,
    val_geom_locs1,
    test_geom_locs1,
    train_max_chan1,
    val_max_chan1,
    test_max_chan1,
) = make_dataset(
    rec=rec1,
    spike_index=spike_idx_sess1,
    geom=geom_sess1,
    save_path=dataset_folder_sess1,
    chan_index=chan_idx_sess1,
    templates=templates_sess1,
    unit_ids=selected_units_sess1,
    train_num=train_num,
    val_num=val_num,
    test_num=test_num,
    save_covs=save_covs,
    num_chans_extract=num_chans_extract,
    normalize=normalize,
    shift=shift,
    inference=inference,
    save_fewer=save_fewer,
)

In [None]:
# session DYO16 units to get data from and dataset save path
dy009_unit_ids_path = os.path.join(
    os.getcwd(), "400neuron_unit_ids", "dy009_unit_ids.npy"
)
selected_units_sess2 = np.load(dy009_unit_ids_path)
dataset_folder_sess2 = save_folder_sess2 + "/ds"

# make second dataset for training with DY009 units
# will create a folder with the spike, probe channel number, and corresponding channel location datasets in the train, val, test splits
# optionally also saves out spatial and temporal noise covariance matrices
inference = False
train_num = 200
val_num = 0
test_num = 200
save_covs = False
num_chans_extract = 21
normalize = False  # True for cell-type dataset
shift = False
save_fewer = False
(
    train_set2,
    val_set2,
    test_set2,
    train_geom_locs2,
    val_geom_locs2,
    test_geom_locs2,
    train_max_chan2,
    val_max_chan2,
    test_max_chan2,
) = make_dataset(
    rec=rec2,
    spike_index=spike_idx_sess2,
    geom=geom_sess2,
    save_path=dataset_folder_sess2,
    chan_index=chan_idx_sess2,
    templates=templates_sess2,
    unit_ids=selected_units_sess2,
    train_num=train_num,
    val_num=val_num,
    test_num=test_num,
    save_covs=save_covs,
    num_chans_extract=num_chans_extract,
    normalize=normalize,
    shift=shift,
    inference=inference,
    save_fewer=save_fewer,
)

In [18]:
combined_ds_path = "/media/cat/data/IBL_data_CEED/400neuron_200spike_ds"

# combine the two training datasets into a larger one for more unit diversity
dataset_list = [dataset_folder_sess1, dataset_folder_sess2]
combine_datasets(dataset_list, combined_ds_path)

In [None]:
data_dir = combined_ds_path
exp_name = "spikesorting_CEED_400n_paper_experiment"
log_dir = data_dir + "/logs/"
ckpt_dir = data_dir + "/saved_models/"
batch_size = 512
num_extra_chans = 5  # 11 channels total
save_metrics = True
epochs = 400
aug_p_dict = {
    "collide": 0.4,
    "crop_shift": 0.5,
    "amp_jitter": 0.7,
    "temporal_jitter": 0.6,
    "smart_noise": (0.5, 1.0),
}
# subsample the 10 neuron dataset used in the paper from the 400 neurons
# will output the results on training metrics
test_units = [11, 13, 16, 69, 84, 89, 277, 267, 332, 343]

print(test_units)
# Train the 400 neuron, 200 spike, 11 channel model benchmarked in the supplement of the paper
# (very similar results to the 1200 spike version)
ceed_test = CEED(num_extra_chans=5)
ceed_test.train(
    data_dir=data_dir,
    exp_name=exp_name,
    log_dir=log_dir,
    epochs=epochs,
    ckpt_dir=ckpt_dir,
    batch_size=batch_size,
    save_metrics=save_metrics,
    aug_p_dict=aug_p_dict,
    units_list=test_units,
)

In [None]:
data_dir = '/media/cat/data/IBL_data_CEED/400neuron_200spike_ds'
ceed_test = CEED(num_extra_chans=5)
ceed_test.load('/media/cat/data/IBL_data_CEED/400neuron_200spike_ds/saved_models/spikesorting_CEED_400n_paper_experiment/')
fc_transformed_inference_data, fc_inference_labels = ceed_test.load_and_transform(
    data_dir=data_dir,units_list=test_units, file_split="test"
)

In [None]:
import numpy as np
from sklearn.metrics import adjusted_rand_score
from sklearn.mixture import GaussianMixture

# results are slightly different than paper with this generated dataset because the channel recording preprocessing/destriping
# process has moved from the IBL functions to analagous ones from SpikeInterface. In order to benchmark with the original 
# test dataset please download the dataset from this link: https://uchicago.box.com/v/CEED-data-storage. You can then 
# perform inference on the test dataset in the folder using the model checkpoint created in this notebook. 

covariance_type = "full"
n_clusters = 10
reps_train = fc_transformed_inference_data
reps_test = fc_transformed_inference_data
scores = []
for i in range(100):
    gmm = GaussianMixture(n_clusters, random_state=i, covariance_type=covariance_type).fit(
        reps_test
    )
    gmm_cont_test_labels = gmm.predict(reps_test)
    score = adjusted_rand_score(fc_inference_labels, gmm_cont_test_labels) * 100
    scores.append(score)
    print(f"num_comps: {fc_transformed_inference_data.shape[1]}, rand_score: {score}")