# This demo notebook, is a detailed guide throught the Unit Match pipeline.

This notebook is only recomened if you want more detailed look at Unit Match, or have unconventional data

In [None]:
%load_ext autoreload
%autoreload 

import sys
from pathlib import Path
import UnitMatchPy.param_functions as pf
import UnitMatchPy.metric_functions as mf
import UnitMatchPy.bayes_functions as bf
import UnitMatchPy.utils as util
import numpy as np
import matplotlib.pyplot as plt
import UnitMatchPy.GUI as gui
import UnitMatchPy.save_utils as su
import UnitMatchPy.assign_unique_id as aid
import UnitMatchPy.default_params as default_params

#### Load in necessary data, individually 
(Not recommend)

In [None]:
#Can supply paths to files directly, but recommended to use the paths_fromKS function below
#get default parameters, can add your own before or after!
param = default_params.get_default_param()

#Load in the data
#The channel position i.e location of active channels
channel_pos1 = np.load(r'Path\to\channel_positions.npy')
#This makes position 3-D by inserting an axis of all one's in the first axis, to allow easy extension to 3-D coords
channel_pos1 = np.insert(channel_pos1,0, np.ones(channel_pos1.shape[0]), axis =1)

#path to the average waveforms for each session
wave_path1 = r'Path\to\RawWaveforms'
wave_path2 = r'Path\to\RawWaveforms'

#path to a tsv file, where the second column contains 'good' for all units you want to include
unit_label_path1 = r'Path\to\cluster_group.tsv'
unit_label_path2 = r'Path\to\cluster_group.tsv'

# need to put paths as a list
wave_paths = [wave_path1 , wave_path2]
unit_label_paths = [unit_label_path1, unit_label_path2]
channel_pos = [channel_pos1, channel_pos1] # Want it per session, however usually will be the same 

#### Load data from a KiloSort directory
This directory needs to have a channel_positoins.npy, cluster_group.tsv and a RawWaveforms folder per session

In [None]:
#get default parameters, can add your own before or after!
param = default_params.get_default_param()


#Give the paths to the KS directories for each session (with a file 'RawWaveforms' )
KS_dirs = [r'path/to/KSdir/Session1', r'Path/to/KSdir/Session2']
#KS_dirs = [r'C:\Users\Experiment\Data\EB019-large data test\2022-07-21', r'C:\Users\Experiment\Data\EB019-large data test\2022-07-22']

param['KS_dirs'] = KS_dirs
wave_paths, unit_label_paths, channel_pos = util.paths_from_KS(KS_dirs)
param = util.get_probe_geometry(channel_pos[0], param)

In [None]:
#read in data and select the good units and exact metadata

good_units = util.get_good_units(unit_label_paths, good = True ) # good = False to load in ALL units
waveform, session_id, session_switch, within_session, param = util.load_good_units(good_units, wave_paths, param)

#waveform, session_id, session_switch, within_session, param = util.load_good_waveforms(wave_paths, unit_label_paths, param) # 1-step version of above

# create clus_info, contains all unit id/session related info
clus_info = {'good_units' : good_units, 'session_switch' : session_switch, 'session_id' : session_id, 
            'original_ids' : np.concatenate(good_units) }

#### Run the Unit Match process
1. Extract parameters from the waveforms e.g Amplitudes, weighted average waveforms and Spatial Decay lengths
2. Calculate metrics/scores for matching e.g Amplitude Score and Waveform similarity
3. Using putative matches find a estimate of drit correction between session (canbe done per shank for 2.0 probes)
4. Re-Calculate metrics/scores with the drift corrected metrics
5. Use a naive Bayes classifier to get suggested 'matches' and 'non'matches'
6. (Optionall) run the GUIto currated the suggest matches and investigated the UnitMatch results

In [None]:
#Get parameters from the wavefunction

waveform = pf.detrend_waveform(waveform)

max_site, good_idx, good_pos, max_site_mean = pf.get_max_sites(waveform, channel_pos, clus_info, param)

spatial_decay_fit , spatial_decay,  d_10, avg_centroid, avg_waveform, peak_time = pf.decay_and_average_waveform(waveform, channel_pos, good_idx, max_site, max_site_mean, clus_info, param)

amplitude, waveform, avg_waveform = pf.get_amplitude_shift_waveform(waveform, avg_waveform, peak_time, param)

waveform_duration, avg_waveform_per_tp, good_wave_idxs = pf.get_avg_waveform_per_tp(waveform, channel_pos, d_10, max_site_mean, amplitude, avg_waveform, clus_info, param)



In [None]:
#get Metrics/Scores from the extracted parameters

amp_score = mf.get_simple_metric(amplitude)
spatial_decay_score = mf.get_simple_metric(spatial_decay)
spatial_decay_fit_score = mf.get_simple_metric(spatial_decay_fit, outlier = True)
wave_corr_score = mf.get_wave_corr(avg_waveform, param)
wave_mse_score = mf.get_waveforms_mse(avg_waveform, param)

avg_waveform_per_tp_flip = mf.flip_dim(avg_waveform_per_tp, param)
euclid_dist = mf.get_Euclidean_dist(avg_waveform_per_tp_flip, param)

centroid_dist, centroid_var = mf.centroid_metrics(euclid_dist, param)

euclid_dist_rc = mf.get_recentered_euclidean_dist(avg_waveform_per_tp_flip, avg_centroid, param)

centroid_dist_recentered = mf.recentered_metrics(euclid_dist_rc)
traj_angle_score, traj_dist_score = mf.dist_angle(avg_waveform_per_tp_flip, param)

In [None]:
#Collate themetrics and find the putative matches
# Average Euc Dist
euclid_dist = np.nanmin(euclid_dist[:,param['peak_loc'] - param['waveidx'] == 0, :,:].squeeze(), axis = 1 )

# TotalScore
include_these_pairs = np.argwhere( euclid_dist < param['max_dist']) #array indices of pairs to include

# Make a dictionary of score to include
centroid_overlord_score = (centroid_dist_recentered + centroid_var) / 2
waveform_score = (wave_corr_score + wave_mse_score) / 2
trajectory_score = (traj_angle_score + traj_dist_score) / 2

scores_to_include = {'amp_score' : amp_score, 'spatial_decay_score' : spatial_decay_score, 'centroid_overlord_score' : centroid_overlord_score,
                'centroid_dist' : centroid_dist, 'waveform_score' : waveform_score, 'trajectory_score': trajectory_score }

total_score, predictors = mf.get_total_score(scores_to_include, param)

#Initial thresholding

thrs_opt = mf.get_threshold(total_score, within_session, euclid_dist, param, is_first_pass = True)

param['nExpectedMatches'] = np.sum( (total_score > thrs_opt).astype(int))
prior_match = 1 - ( param['nExpectedMatches'] / len(include_these_pairs))
candidate_pairs = total_score > thrs_opt


In [None]:
#drift
drifts, avg_centroid, avg_waveform_per_tp = mf.drift_n_sessions(candidate_pairs, session_switch, avg_centroid, avg_waveform_per_tp, total_score, param)

In [None]:
# re-do metric extraction with the drift corrected arrays

avg_waveform_per_tp_flip = mf.flip_dim(avg_waveform_per_tp, param)
euclid_dist = mf.get_Euclidean_dist(avg_waveform_per_tp_flip,param)

centroid_dist, centroid_var = mf.centroid_metrics(euclid_dist, param)

euclid_dist_rc = mf.get_recentered_euclidean_dist(avg_waveform_per_tp_flip, avg_centroid, param)

centroid_dist_recentered = mf.recentered_metrics(euclid_dist_rc)
traj_angle_score, traj_dist_score = mf.dist_angle(avg_waveform_per_tp_flip, param)

# Average Euc Dist
euclid_dist = np.nanmin(euclid_dist[:,param['peak_loc'] - param['waveidx'] == 0, :,:].squeeze(), axis = 1 )

# TotalScore
include_these_pairs = np.argwhere( euclid_dist < param['max_dist']) #array indices of pairs to include, in ML its IncludeThesePairs[:,1]
include_these_pairs_idx = np.zeros_like(total_score)
include_these_pairs_idx[euclid_dist < param['max_dist']] = 1 

# Make a dictionary of score to include
centroid_overlord_score = (centroid_dist_recentered + centroid_var) / 2
waveform_score = (wave_corr_score + wave_mse_score) / 2
trajectory_score = (traj_angle_score + traj_dist_score) / 2

scores_to_include = {'amp_score' : amp_score, 'spatial_decay_score' : spatial_decay_score, 'centroid_overlord_score' : centroid_overlord_score,
                'centroid_dist' : centroid_dist, 'waveform_score' : waveform_score, 'trajectory_score': trajectory_score }

total_score, predictors = mf.get_total_score(scores_to_include, param)
thrs_opt = mf.get_threshold(total_score, within_session, euclid_dist, param, is_first_pass = False)


param['n_expected_matches'] = np.sum( (total_score > thrs_opt).astype(int))
prior_match = 1 - ( param['n_expected_matches'] / len(include_these_pairs))


In [None]:
# Set-up Bayes analysis
thrs_opt = np.quantile(total_score[include_these_pairs_idx.astype(bool)], prior_match)
candidate_pairs = total_score > thrs_opt

prior_match = 1 - (param['n_expected_matches'] / param['n_units']**2 ) #Can change value of priors 
Priors = np.array((prior_match, 1-prior_match))

labels = candidate_pairs.astype(int)
cond = np.unique(labels)
score_vector = param['score_vector']
parameter_kernels = np.full((len(score_vector), len(scores_to_include), len(cond)), np.nan)


In [None]:
# Run bayes analysis
parameter_kernels = bf.get_parameter_kernels(scores_to_include, labels, cond, param, add_one = 1)

probability = bf.apply_naive_bayes(parameter_kernels, Priors, predictors, param, cond)

output_prob_matrix = probability[:,1].reshape(param['n_units'],param['n_units'])

In [None]:
# Optional function to summarise the output
match_threshold = param['match_threshold']
#match_threshold = try different values here!

util.evaluate_output(output_prob_matrix, param, within_session, session_switch, match_threshold = 0.75)

Set a match threshold and look at the output

In [None]:
match_threshold = param['match_threshold']
#match_threshold = try different values here!
output_threshold = np.zeros_like(output_prob_matrix)
output_threshold[output_prob_matrix > match_threshold] = 1

plt.imshow(output_threshold, cmap = 'Greys')
plt.colorbar()

In [None]:
# calculate data and send data to the GUI
gui.process_info_for_GUI(output_prob_matrix, match_threshold, scores_to_include, total_score, amplitude, spatial_decay,
                         avg_centroid, avg_waveform, avg_waveform_per_tp, good_wave_idxs, max_site, max_site_mean, 
                         waveform, within_session, channel_pos, clus_info, param)

#### Run the GUI
look at GUI_Reference_Guide.md for infomation on how to effectivley use the GUI!

In [None]:
#MatchesGUI is a list of 2 sets of matches for both CV 
#each array is symmetric e.g will have (x,y) and (y,x) as a match
is_match, not_match, matches_GUI = gui.run_GUI()

In [None]:
#all idx pairs where the probability is above the threshold
matches_within_session = np.argwhere(output_threshold == 1) #include within session matches
matches = np.argwhere( ((output_threshold * within_session)) == True) #exclude within session matches

#this function has 2 mode 'and' 'or', which returns a matches if they appear in both or one cv pair
#then it will add all the matches selected as IsMaatch, then remove all matches in NotMatch
matches_curated = util.curate_matches(matches_GUI, is_match, not_match, mode = 'and')

In [None]:
UIDs = aid.assign_unique_id(output_prob_matrix, param, clus_info)

save_dir = r'Path/to/save/directory'
#NOTE - change to matches to matches_curated if done manual curation with the GUI
su.save_to_output(save_dir, scores_to_include, matches # matches_curated
                  , output_prob_matrix, avg_centroid, avg_waveform, avg_waveform_per_tp, max_site,
                   total_score, output_threshold, clus_info, param, UIDs = UIDs, matches_curated = None, save_match_table = True)
#save separate CV output, option to save data so cross verification pairs are split up
#su.save_to_output_seperate_CV(save_dir, scores_to_include, matches, output_prob_matrix, avg_centroid, avg_waveform, avg_waveform_per_tp, max_site,
#                   total_score, output_threshold, clus_info, param, UIDs = UIDs, matches_curated = None, save_match_table = True)
