In [None]:
from analysis.data_gen_utils import download_IBL, extract_IBL, make_dataset, combine_datasets, all_units_except
import numpy as np
import matplotlib.pyplot as plt

### Download and destripe AP data from an IBL session by its PID.

In [None]:
pid_sess1 = 'dab512bd-a02d-4c1f-8dbc-9155a163efc0'
pid_sess2 = 'febb430e-2d50-4f83-87a0-b5ffbb9a4943'
save_folder_sess1 = '/media/cat/data/IBL_data_CEED/dab512bd-a02d-4c1f-8dbc-9155a163efc0'
save_folder_sess2 = '/media/cat/data/IBL_data_CEED/febb430e-2d50-4f83-87a0-b5ffbb9a4943'
t_window = [0, 1000] #in seconds
overwrite = False
rec1, meta_file_sess1 = download_IBL(pid=pid_sess1, t_window=t_window, save_folder=save_folder_sess1, overwrite=overwrite)
rec2, meta_file_sess2 = download_IBL(pid=pid_sess2, t_window=t_window, save_folder=save_folder_sess2, overwrite=overwrite)

In [None]:
'''extract the all data needed to make CEED dataset
spike_idx_sess: spike_times, channels, neurons (if use_labels=True)
geom_sess: channels x 2
chan_idx_sess: waveform extraction channels for each channel
templates_sess: templates across all channels for all neurons
'''
recompute = False

if recompute:
    spike_idx_sess1, geom_sess1, chan_idx_sess1, templates_sess1 = extract_IBL(rec=rec1, 
                                                                               meta_fp=meta_file_sess1, 
                                                                               pid=pid_sess1, 
                                                                               t_window=t_window, 
                                                                               use_labels=True)
    spike_idx_sess2, geom_sess2, chan_idx_sess2, templates_sess2 = extract_IBL(rec=rec2, 
                                                                               meta_fp=meta_file_sess2, 
                                                                               pid=pid_sess2, 
                                                                               t_window=t_window,
                                                                               use_labels=True)
    np.save('spike_idx_sess1.npy', spike_idx_sess1)
    np.save('geom_sess1.npy', geom_sess1)
    np.save('chan_idx_sess1.npy', chan_idx_sess1)
    np.save('templates_sess1.npy', templates_sess1)
    np.save('spike_idx_sess2.npy', spike_idx_sess2)
    np.save('geom_sess2.npy', geom_sess2)
    np.save('chan_idx_sess2.npy', chan_idx_sess2)
    np.save('templates_sess2.npy', templates_sess2)
else:
    spike_idx_sess1 = np.load('spike_idx_sess1.npy')
    geom_sess1 = np.load('geom_sess1.npy')
    chan_idx_sess1 = np.load('chan_idx_sess1.npy')
    templates_sess1 = np.load('templates_sess1.npy')
    spike_idx_sess2 = np.load('spike_idx_sess2.npy')
    geom_sess2 = np.load('geom_sess2.npy')
    chan_idx_sess2 = np.load('chan_idx_sess2.npy')
    templates_sess2 = np.load('templates_sess2.npy')

In [None]:
#plot a denoised template on the 40 extracted channels (not all 40 used to train CEED).
template_id = 454
template = templates_sess1[template_id]
max_channel = np.argmax(np.ptp(template,0))
plt.plot(templates_sess1[template_id][:,chan_idx_sess1[max_channel]]);

In [None]:
save_fewer = True

In [None]:
# session 1 units to get data from and dataset save path
# can look at templates of neural units to gauge whether to extract spikes from them for a CEED dataset
selected_units_sess1 = np.arange(400)
dataset_folder_sess1 = save_folder_sess1 + '/ds'

# make a dataset for training
# will create a folder with the spike, probe channel number, and corresponding channel location datasets in the train, val, test splits
# optionally also saves out spatial and temporal noise covariance matrices
inference = False
train_num = 400
val_num = 0
test_num = 100
save_covs = True
num_chans_extract = 21
normalize = False #True for cell-type dataset
shift = True
train_set1, val_set1, test_set1, train_geom_locs1, val_geom_locs1, \
test_geom_locs1, train_max_chan1, val_max_chan1, test_max_chan1 = make_dataset(rec=rec1, 
                                                                               spike_index=spike_idx_sess1,
                                                                               geom=geom_sess1, 
                                                                               save_path=dataset_folder_sess1, 
                                                                               chan_index=chan_idx_sess1, 
                                                                               templates=templates_sess1, 
                                                                               unit_ids=selected_units_sess1, 
                                                                               train_num=train_num, 
                                                                               val_num=val_num, 
                                                                               test_num=test_num, 
                                                                               save_covs=save_covs, 
                                                                               num_chans_extract=num_chans_extract, 
                                                                               normalize=normalize, 
                                                                               shift=shift, 
                                                                               inference=inference,
                                                                               save_fewer=save_fewer)

In [None]:
# suppose for session 2 we want a larger dataset comprised of most of the units in the recording, 
# but excluding some units that might have poor spikes in the recording or which kilosort did a poor job of recognizing. 
# then we can save out all of the units in the recording to the dataset and plot some spikes from the units (with plot=True)
# and then we can go through and look at which units look like they have good, useful spikes for training. for this we'll
# save out all units at first by specifying no unit ids.
selected_units_sess2 = None
dataset_folder_sess2 = save_folder_sess2 + '/ds'

# make a dataset for training
# will create a folder with the spike, probe channel number, and corresponding channel location datasets in the train, val, test splits
# optionally also saves out spatial and temporal noise covariance matrices
inference = False
train_num = 400
val_num = 0
test_num = 100
save_covs = True
num_chans_extract = 21
normalize = False
shift = True

train_set2, val_set2, test_set2, train_geom_locs, val_geom_locs2, \
test_geom_locs2, train_max_chan2, val_max_chan2, test_max_chan2 = make_dataset(rec=rec2, 
                                                                               spike_index=spike_idx_sess2,
                                                                               geom=geom_sess2, 
                                                                               save_path=dataset_folder_sess2, 
                                                                               chan_index=chan_idx_sess2, 
                                                                               templates=templates_sess2, 
                                                                               unit_ids=selected_units_sess2, 
                                                                               train_num=train_num, 
                                                                               val_num=val_num, 
                                                                               test_num=test_num, 
                                                                               save_covs=save_covs, 
                                                                               num_chans_extract=num_chans_extract, 
                                                                               plot=True,
                                                                               normalize=normalize, 
                                                                               shift=shift, 
                                                                               inference=inference,
                                                                               save_fewer=save_fewer)

In [None]:
# after looking at all of the plotted units in session 2's dataset folder (under wf_plots) we decide to exclude certain units.
# we can also decide to make selected_units_sess2 out of the units we decide are good rather than exclude units 
# if we find too many bad units (we would just define it as a list of units in this latter case).
bad_units_sess2 = [130, 250, 340]
selected_units_sess2 = all_units_except(spike_index=spike_idx_sess2, exclude_units=bad_units_sess2)

# now we can create the same dataset, but overwrite the previous files with the new ones that exclude the unit ids we don't want
train_set2, val_set2, test_set2, train_geom_locs, val_geom_locs2, \
test_geom_locs2, train_max_chan2, val_max_chan2, test_max_chan2 = make_dataset(rec=rec2, 
                                                                               spike_index=spike_idx_sess2,
                                                                               geom=geom_sess2, 
                                                                               save_path=dataset_folder_sess2, 
                                                                               chan_index=chan_idx_sess2, 
                                                                               templates=templates_sess2, 
                                                                               unit_ids=selected_units_sess2, 
                                                                               train_num=train_num, 
                                                                               val_num=val_num, 
                                                                               test_num=test_num, 
                                                                               save_covs=save_covs, 
                                                                               num_chans_extract=num_chans_extract, 
                                                                               plot=False,
                                                                               normalize=normalize, 
                                                                               shift=shift, 
                                                                               inference=inference,
                                                                               save_fewer=save_fewer)

In [None]:
combined_ds_path = '/media/cat/data/IBL_data_CEED/combined'

# combine the two training datasets into a larger one for more unit diversity
dataset_list = [dataset_folder_sess1, dataset_folder_sess2]
combine_datasets(dataset_list, combined_ds_path)

In [None]:
# session 1 units to get data from and dataset save path
selected_units_inference = np.arange(400)
inference_ds_path = save_folder_sess1 + '/ds_inference'

# make a dataset for training
# will create a folder with the spike, probe channel number, and corresponding channel location datasets in the train, val, test splits
# optionally also saves out spatial and temporal noise covariance matrices
inference = True
test_num = 500
save_covs = True
num_chans_extract = 21
normalize = False
shift = True

test_set, test_geom_locs, test_max_chan = make_dataset(rec=rec1, 
                                                       spike_index=spike_idx_sess1,
                                                       geom=geom_sess1, 
                                                       save_path=inference_ds_path, 
                                                       chan_index=chan_idx_sess1, 
                                                       templates=templates_sess1, 
                                                       unit_ids=selected_units_inference, 
                                                       train_num=train_num, 
                                                       val_num=val_num, 
                                                       test_num=test_num, 
                                                       save_covs=save_covs, 
                                                       num_chans_extract=num_chans_extract, 
                                                       normalize=normalize, 
                                                       shift=shift, 
                                                       inference=inference,
                                                       save_fewer=save_fewer)

In [None]:
## spikes_test = np.load(inference_ds_path + '/spikes_test.npy') 
labels_test = np.load(inference_ds_path + '/labels_test.npy') 

template_id = 145
plt.plot(np.mean(test_set[labels_test==template_id],0).T);