# This demonotebook, is a guide on the simplest way to run Unit Match in python

In [1]:
# %load_ext autoreload
# %autoreload 

import sys
from pathlib import Path

import UMPy.Bayes_fun as bf
import UMPy.utils as util
import UMPy.Overlord as ov
import numpy as np
import matplotlib.pyplot as plt
import UMPy.Save_utils as su
import UMPy.GUI as gui

  MainAx.set_xlabel('Xpos ($\mu$m)', size = 14)
  MainAx.set_ylabel('Ypos ($\mu$m)', size = 14)


#### Set params and give path to input data (typically a KiloSort directory with the extract RawWaveform folder)

In [None]:
#get default parameters, can add your own before or after!
param = util.get_default_param()

#Give the paths to the KS directories for each session
#If you don't have a dir with channel_positions.npy etc look at the detailed example for supplying paths separately 
KSdirs = [r'path/to/KSdir/Session1', r'Path/to/KSdir/Session2']
WavePaths, UnitLabelPaths, ChannelPos = util.paths_fromKS(KSdirs)

#### Run the Unit Match process
1. Extract parameters from the waveforms e.g Amplitudes, weighted average waveforms and Spatial Decay lengths
2. Calculate metrics/scores for matching e.g Amplitude Score and Waveform similarity
3. Using putative matches find a estimate of drift correction between session (can be done per shank for 2.0 probes)
4. Re-Calculate metrics/scores with the drift corrected metrics
5. Use a naive Bayes classifier to get suggested 'matches' and 'non'matches'
6. (Optional) run the GUIto curated the suggest matches and investigated the UnitMatch results

In [None]:
#read in data and select the good units and exact metadata
waveform, SessionID, SessionSwitch, WithinSession, GoodUnits, param = util.load_good_waveforms(WavePaths, UnitLabelPaths, param) # 1-step version of above

# create clusInfo, contains all unit id/session related info
ClusInfo = {'GoodUnits' : GoodUnits, 'SessionSwitch' : SessionSwitch, 'SessionID' : SessionID, 
            'OriginalID' : np.concatenate(GoodUnits) }

#Extract parameters from waveform
ExtractedWaveProperties = ov.extract_parameters(waveform, ChannelPos, ClusInfo, param)

#Extract metric scores
TotalScore, CandidatePairs, Scores2Include, Predictors  = ov.extract_metric_scores(ExtractedWaveProperties, SessionSwitch, WithinSession, param, niter  = 2)

#Probability analysis
priorMatch = 1 - (param['nExpectedMatches'] / param['nUnits']**2 ) # fredom of choose in prior prob?
Priors = np.array((priorMatch, 1-priorMatch))

labels = CandidatePairs.astype(int)
Cond = np.unique(labels)
ScoreVector = param['ScoreVector']
ParameterKernels = np.full((len(ScoreVector), len(Scores2Include), len(Cond)), np.nan)

ParameterKernels = bf.get_ParameterKernels(Scores2Include, labels, Cond, param, addone = 1)

Probability = bf.apply_naive_bayes(ParameterKernels, Priors, Predictors, param, Cond)

Output = Probability[:,1].reshape(param['nUnits'],param['nUnits'])

In [None]:
util.evaluate_output(Output, param, WithinSession, SessionSwitch, MatchThreshold = 0.75)

MatchThreshold = param['MatchThreshold']
#MatchThreshold = try different values here!

OutputThreshold = np.zeros_like(Output)
OutputThreshold[Output > MatchThreshold] = 1

plt.imshow(OutputThreshold, cmap = 'Greys')


In [None]:
Amplitude = ExtractedWaveProperties['Amplitude']
SpatialDecay = ExtractedWaveProperties['SpatialDecay']
AvgCentroid = ExtractedWaveProperties['AvgCentroid']
AvgWaveform = ExtractedWaveProperties['AvgWaveform']
AvgWaveformPerTP = ExtractedWaveProperties['AvgWaveformPerTP']
WaveIdx = ExtractedWaveProperties['WaveIdx']
MaxSite = ExtractedWaveProperties['MaxSite']
MaxSiteMean = ExtractedWaveProperties['MaxSiteMean']
gui.process_info_for_GUI(Output, MatchThreshold, Scores2Include, TotalScore, Amplitude, SpatialDecay,
                         AvgCentroid, AvgWaveform, AvgWaveformPerTP, WaveIdx, MaxSite, MaxSiteMean, 
                         waveform, WithinSession, ChannelPos, ClusInfo, param)


In [None]:
IsMatch, NotMatch, MatchesGUI = gui.run_GUI()

In [None]:
Matches = np.argwhere(MatchThreshold == 1)

SaveDir = r'Path/to/save/directory'
su.save_to_output(SaveDir, Scores2Include, Matches, Output, AvgCentroid, AvgWaveform, AvgWaveformPerTP, MaxSite,
                   TotalScore, OutputThreshold, ClusInfo, param, MatchesCurated = None, SaveMatchTable = True)