## DeepUnitMatch training notebook

This demo notebook will take you through training DeepUnitMatch on your own data. This is not necessary, as the neural network is able to generalise to unseen datasets, but training on your own data is very likely to improve performance.

You should only proceed if you have a substantial amount of Neuropixels 2.0 data you wish to use for training.


In [None]:
import torch
import os, sys
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), 'DeepUnitMatch'))
from DeepUnitMatch.utils import AE_npdataset, npdataset
from DeepUnitMatch.train.train_AE import run_training
from DeepUnitMatch.train.train_finetune import run_finetune
from DeepUnitMatch.utils import param_fun
from DeepUnitMatch.testing import test
import UnitMatchPy.default_params as default_params
import UnitMatchPy.utils as util
import matplotlib.pyplot as plt
import threading

### Data loading

You should have a list of Kilosort directories you want to use for training. If you have run Bombcell on this data (recommended) then these directories should also contain the bombcell results.

In [None]:
# Getting the data the same way as UnitMatch

# Get default parameters, can add your own before or after!
param = default_params.get_default_param()

# Give the paths to the KS directories for each session
# If you don't have a dir with channel_positions.npy etc look at the detailed example for supplying paths separately
KS_dirs = [r'path/to/KSdir/Session1', r'path/to/KSdir/Session2', r'path/to/KSdir/Session3', r'path/to/KSdir/Session4']

# KS_dirs should contain a large number of paths

param['KS_dirs'] = KS_dirs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(KS_dirs)
param = util.get_probe_geometry(channel_pos[0], param)

# STEP 0 from the UMPy example notebook
waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(wave_paths, unit_label_paths, param, good_units_only = True)
param['good_units'] = good_units

In [None]:
save_path = r"path/to/save_directory"
# This is where the waveform snippets used in DeepUnitMatch will be saved, under a folder called "processed_waveforms"
# Make sure there is no existing data at save_path/processed_waveforms to avoid overwriting!

snippets, positions = param_fun.get_snippets(waveform, channel_pos, session_id, save_path=save_path)

### Training

Training takes place in two steps: autoencoder training and finetuning. 

In [None]:
# Autoencoder pretraining

EXPERIMENT_NAME = "DeepUnitMatch_DEMO"              # This is the name for the whole experiment (both training steps)

AEdataset = AE_npdataset.AE_NeuropixelsDataset(save_path, batch_size=32)
training_thread = threading.Thread(target=run_training, args=(EXPERIMENT_NAME, AEdataset), kwargs={'lr':1e-5, 'save_freq':1, 'total_epoch':3, 'cont':False, 'batchsize':32})        # recommend 300 epochs
training_thread.start()

In [None]:
# Contrastive learning finetuning

CLdataset = npdataset.NeuropixelsDataset(save_path, batch_size=32, mode='train')

if training_thread is not None and training_thread.is_alive():
    print("Wait for autoencoder training to finish...")
else:
    training_thread = threading.Thread(target=run_finetune, args=(EXPERIMENT_NAME, CLdataset), kwargs={'lr_enc':2*1e-5, 'lr_proj': 1.1*1e-4, 'save_freq':1, 'total_epoch':3, 'cont':False, 'batchsize':32})         # recommend 50 epochs
    training_thread.start()

### Load the trained model for inference

In [None]:
read_path = r"path/to/your/saved/model/checkpoint"  # Path to your trained model checkpoint

model = test.load_trained_model(device="cpu", read_path=read_path)
sim_matrix = test.inference(model, os.path.join(save_path, 'processed_waveforms'))
plt.imshow(sim_matrix, cmap='viridis', aspect='auto')
plt.colorbar()

From here, you can follow the standard demo notebook to use your trained model in the full DeepUnitMatch pipeline