In [2]:
import os, sys
sys.path.insert(0, os.path.dirname(os.getcwd()))
sys.path.insert(0, os.path.dirname(os.path.dirname(os.getcwd())))
sys.path.insert(0, os.path.join(os.path.dirname(os.getcwd()), 'DeepUnitMatch'))
import UnitMatchPy.default_params as default_params
import UnitMatchPy.utils as util
import UnitMatchPy.overlord as ov
import UnitMatchPy.bayes_functions as bf
import matplotlib.pyplot as plt
import numpy as np
from DeepUnitMatch.utils import param_fun
from DeepUnitMatch.testing import test
from DeepUnitMatch.preprocess import split_units
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.save_utils as su
import UnitMatchPy.metric_functions as mf
from DeepUnitMatch.testing import test
from DeepUnitMatch.utils import helpers

In [3]:
# ======================
# User inputs (edit me)
# ======================

# Paths to the Kilosort output directories (one per session)
KS_dirs = [
    r'H:\FigShare_UnitMatch\Mouse1\2019-11-21\Probe0\1',
    r'H:\FigShare_UnitMatch\Mouse1\2019-11-22\Probe0\1',
]
# KS_dirs = [r'path/to/KSdir/Session1', r'path/to/KSdir/Session2']

# Where to save final results
save_dir = r"path\\to\\save\\results"  # specify your save path here.
save_dir = r"C:\\Users\\EnnyB\\Documents\\TMP\\results"  # specify your save path here.

# Where to write/read DeepUnitMatch preprocessed HDF5s (creates `processed_waveforms/`)
save_path = r"C:\\Users\\EnnyB\\Documents\\TMP"  # specify your save path here.
# save_path = r"C:\\path\\to\\save\\processed_waveforms"

# Model/inference settings
device = "cpu"

# Matching threshold (lower threshold -> more matches but also more false positives) - recommended 0.5
thresh = 0.5


In [4]:
# Getting the data the same way as UnitMatch

# Get default parameters, can add your own before or after!
param = default_params.get_default_param()

# KS_dirs is defined in the "User inputs" cell above
param['KS_dirs'] = KS_dirs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(KS_dirs)
param = util.get_probe_geometry(channel_pos[0], param)

# STEP 0 from the UMPy example notebook
waveform, session_id, session_switch, within_session, good_units, param = util.load_good_waveforms(wave_paths, unit_label_paths, param, good_units_only = True)
param['good_units'] = good_units

Using cluster_group.tsv
Using cluster_group.tsv


In [None]:
# Preprocess the DeepUnitMatch way and save as HDF5 files for each session in 'processed_waveforms'.

# save_path is defined in the "User inputs" cell above
unit_ids = np.concatenate(param["good_units"]).squeeze()  # cluster IDs in the same order as `waveform`
snippets, positions = param_fun.get_snippets(
    waveform, channel_pos, session_id, save_path=save_path, unit_ids=unit_ids
)

# Load the neural net
model = test.load_trained_model(device=device)

# We have stored the preprocessed data here (from the get_snippets function)
data_dir = os.path.join(save_path, 'processed_waveforms')

# Pass the preprocessed data through the neural net
sim_matrix = test.inference(model, data_dir)

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 1 dimensions. The detected shape was (2,) + inhomogeneous part.

In [None]:
# Visualise the similarity matrix
plt.imshow(sim_matrix, cmap='viridis', aspect='auto')
plt.colorbar()

In [None]:
# Use the same Naive Bayes as in UnitMatchPy

clus_info = {'good_units' : param['good_units'], 'session_switch' : session_switch, 'session_id' : session_id, 
            'original_ids' : np.concatenate(param['good_units']) }
extracted_wave_properties = ov.extract_parameters(waveform, channel_pos, clus_info, param)                  # contains spatial locations
within_session = 1 - (session_id[:, None] == session_id).astype(int)
sessions = np.unique(session_id)
match_dfs = []
probs = np.zeros(sim_matrix.shape)
distance_matrix = np.zeros(sim_matrix.shape)

for r1 in sessions:
    for r2 in sessions:
        if r1 >= r2:
            continue
        
        mask = np.isin(session_id, [r1, r2])
        sim_mat = sim_matrix[mask][:, mask]
        n = np.sum(mask)
        indices = np.where(np.isin(session_id, [r1, r2]))[0]
        df = helpers.create_dataframe([param['good_units'][r1], param['good_units'][r2]], sim_mat, session_list=[r1, r2])
        matches = test.get_matches(df, sim_mat, session_id[indices], data_dir, dist_thresh=50)

        labels = np.eye(sim_mat.shape[0])
        subsessionid = np.array([r1] * len(param['good_units'][r1]) + [r2] * len(param['good_units'][r2]))
        for (recses1, recses2), group in matches.groupby(by=['RecSes1', 'RecSes2']):
            asmatrix = group['match'].values.reshape(len(param['good_units'][recses1]), len(param['good_units'][recses2])).astype(int)
            labels[np.ix_(subsessionid == recses1, subsessionid == recses2)] = asmatrix

        # Visualize centroid distances for this session pair
        plt.figure(figsize=(10,4))
        plt.imshow(labels, cmap='viridis', aspect='auto')
        plt.title(f'labels')
        plt.colorbar()
        plt.tight_layout()

        avg_centroid, avg_waveform_per_tp = extracted_wave_properties['avg_centroid'][:, mask, :], extracted_wave_properties['avg_waveform_per_tp'][:, mask, :, :]
        avg_waveform_per_tp = mf.drift_correct_session_pair(labels.astype(bool), session_switch, avg_centroid, avg_waveform_per_tp, r1, param)
        avg_waveform_per_tp_flip = mf.flip_dim(avg_waveform_per_tp, param, n)
        euclid_dist = mf.get_Euclidean_dist(avg_waveform_per_tp_flip, param, n)
        centroid_dist,_ = mf.centroid_metrics(euclid_dist, param)
        # Visualize centroid distances for this session pair
        plt.figure(figsize=(10,4))
        plt.subplot(1,2,1)
        plt.imshow(centroid_dist, cmap='viridis', aspect='auto')
        plt.title(f'Centroid dist: sessions {r1}-{r2}')
        plt.colorbar()
        plt.tight_layout()
        scores_to_incl = {
            'similarity': sim_mat,
            'distance': centroid_dist,
        }

        n_units = int(np.sqrt(len(df)))
        priors = np.array([1 - 2 / n_units, 2 / n_units])
        parameter_kernels = bf.get_parameter_kernels(scores_to_incl, labels, np.unique(labels), param)
        predictors = np.stack([scores for scores in scores_to_incl.values()], axis=2)
        probability = bf.apply_naive_bayes(parameter_kernels, priors, predictors, param, np.unique(labels))
        prob_matrix = probability[:,1].reshape(n_units, n_units)
        # Debug: verify shapes match before assignment
        target_shape = np.ix_(mask, mask)
        target_rows = np.where(mask)[0]
        if prob_matrix.shape != centroid_dist.shape:
            print(f"  WARNING: Shape mismatch! prob_matrix={prob_matrix.shape} vs centroid_dist={centroid_dist.shape}")
        probs[np.ix_(mask, mask)] = prob_matrix
        distance_matrix[np.ix_(mask, mask)] = centroid_dist

In [None]:
# Visualise the output probability matrix

plt.imshow(probs, cmap='viridis', aspect='auto')
plt.colorbar()

In [None]:
# UnitMatchPy evaluation function

util.evaluate_output(probs, param, within_session, session_switch, match_threshold = 0.5)

In [None]:
# Process the output probability matrix to get final set of matches (across sessions)
# thresh is defined in the "User inputs" cell above
final_matches = test.directional_filter(probs, session_id, thresh)
plt.imshow(final_matches, cmap='viridis')
plt.colorbar()

# Divide final number of matches by 2 to account for double counting in the matrix
print(f" Found {np.sum(final_matches) // 2} matches in these sessions using the threshold of {thresh}.")

In [None]:
# Now we can check performance using the AUC. This tests the agreement between DeepUnitMatch matches and functional scores (in this case, ISI histogram correlations).

isicorr = test.ISI_correlations(param)
auc = test.AUC(final_matches, isicorr, session_id)
print(f"AUC for DeepUnitMatch matches: {auc:.3f}")

In [None]:
# Finally, we can do tracking using the UnitMatch tracking function. This assigns unique IDs to the neurons that will persist across many sessions.

UIDs = aid.assign_unique_id(probs, param, clus_info)

# Save the results
# save_dir is defined in the "User inputs" cell above
amplitude = extracted_wave_properties['amplitude']
spatial_decay = extracted_wave_properties['spatial_decay']
avg_centroid = extracted_wave_properties['avg_centroid']
avg_waveform = extracted_wave_properties['avg_waveform']
avg_waveform_per_tp = extracted_wave_properties['avg_waveform_per_tp']
wave_idx = extracted_wave_properties['good_wave_idxs']
max_site = extracted_wave_properties['max_site']
max_site_mean = extracted_wave_properties['max_site_mean']
su.save_to_output(save_dir, {"distance": distance_matrix}, np.argwhere(final_matches), probs, avg_centroid, avg_waveform, avg_waveform_per_tp, max_site,
                   distance_matrix, final_matches, clus_info, param, UIDs = UIDs, matches_curated = None, save_match_table = True)