# This demo notebook, is a detailed guide throught the Unit Match pipeline.

This notebook is only recomened if you want more detailed look at Unit Match, or have unconventional data

In [None]:
%load_ext autoreload
%autoreload 

import sys
from pathlib import Path
sys.path[0] = str(Path(sys.path[0]).parent)

import Param_fun as pf
import Metrics_fun as mf
import Bayes_fun as bf
import utils as util
import numpy as np
import matplotlib.pyplot as plt
import GUI as gui
import Save_utils as su

#### Load in necessary data, individually 
(Not recommend)

In [None]:
#Can supply paths to files directly, but reccomended to use the paths_fromKS function below
#get default parameters, can add your own before or after!
param = util.get_default_param()

#Load in the data
#The channel position i.e location of active channels
ChannelPos1 = np.load(r'Path\to\channel_positions.npy')
#This makes position 3-D by inserting an axis of all one's in the first axis, to allow easy extension to 3-D coords
ChannelPos1 = np.insert(ChannelPos1,0, np.ones(ChannelPos1.shape[0]), axis =1)

#path to the average waveforms for each session
WavePath1 = r'Path\to\RawWaveforms'
WavePath2 = r'Path\to\RawWaveforms'

#path to a tsv file, where the second column contains 'good' for all units you want to include
UnitLabelPath1 = r'Path\to\cluster_group.tsv'
UnitLabelPath2 = r'Path\to\cluster_group.tsv'

# need to put paths as a list
WavePaths = [WavePath1 , WavePath2]
UnitLabelPaths = [UnitLabelPath1, UnitLabelPath2]
ChannelPos = [ChannelPos1, ChannelPos1] # Want it per session, however usually will be the same 

#### Load data from a KiloSort directory
This directory needs to have a channel_positoins.npy, cluster_group.tsv and a RawWaveforms folder per session

In [None]:
#get default parameters, can add your own before or after!
param = util.get_default_param()

#Give the paths to the KS directories for each session (with a file 'RawWaveforms' )
KSdirs = [r'path/to/KSdir/Session1', r'Path/to/KSdir/Session2']
WavePaths, UnitLabelPaths, ChannelPos = util.paths_fromKS(KSdirs)

In [None]:
#read in data and select the good units and exact metadata

GoodUnits = util.get_good_units(UnitLabelPaths, good = True ) # good = False to load in ALL units
waveform, SessionID, SessionSwitch, WithinSession, param = util.load_good_units(GoodUnits, WavePaths, param)

#waveform, SessionID, SessionSwitch, WithinSession, GoodUnits, param = util.load_good_waveforms(WavePaths, UnitLabelPaths, param) # 1-step version of above

# create clusInfo, contains all unit id/session related info
ClusInfo = {'GoodUnits' : GoodUnits, 'SessionSwitch' : SessionSwitch, 'SessionID' : SessionID, 
            'OriginalID' : np.concatenate(GoodUnits) }

#### Run the Unit Match process
1. Extract parameters from the waveforms e.g Amplitudes, weighted average waveforms and Spatial Decay lengths
2. Calculate metrics/scores for matching e.g Amplitude Score and Waveform similarity
3. Using putative matches find a estimate of drit correction between session (canbe done per shank for 2.0 probes)
4. Re-Calculate metrics/scores with the drift corrected metrics
5. Use a naive Bayes classifier to get suggested 'matches' and 'non'matches'
6. (Optionall) run the GUIto currated the suggest matches and investigated the UnitMatch results

In [None]:
#Get parameters from the wavefunction

waveform = pf.detrend_waveform(waveform)

MaxSite, goodidx, goodpos, MaxSiteMean = pf.get_max_sites(waveform, ChannelPos,ClusInfo, param)

SpatialDecayFit , SpatialDecay,  d_10, AvgCentroid, AvgWaveform, PeakTime = pf.decay_and_average_Waveform(waveform,ChannelPos, goodidx, MaxSite, MaxSiteMean, ClusInfo, param)

Amplitude, waveform, AvgWaveform = pf.get_amplitude_shift_Waveform(waveform,AvgWaveform, PeakTime, param)

WaveformDuration, AvgWaveformPerTP, WaveIdx = pf.avg_Waveform_PerTP(waveform,ChannelPos, d_10, MaxSiteMean, Amplitude, AvgWaveform, ClusInfo, param)


In [None]:
#get Metrics/Scores from the extracted parameters

AmpScore = mf.get_simple_metric(Amplitude)
SpatialDecayScore = mf.get_simple_metric(SpatialDecay)
SpatialDecayFitScore = mf.get_simple_metric(SpatialDecayFit, outlier = True)
WVcorrScore = mf.get_WVcorr(AvgWaveform, param)
WFMSEscore = mf.get_WaveformMSE(AvgWaveform, param)

AvgWaveformPerTPFlip = mf.flip_dim(AvgWaveformPerTP, param)
EuclDist = mf.get_Euclidean_dist(AvgWaveformPerTPFlip,param)

CentroidDist, CentroidVar = mf.Centroid_metrics(EuclDist, param)

EuclDistRC = mf.get_recentered_Euclidean_dist(AvgWaveformPerTPFlip, AvgCentroid, param)

CentroidDistRecentered = mf.recentered_metrics(EuclDistRC)
TrajAngleScore, TrajDistScore = mf.dist_angle(AvgWaveformPerTPFlip, param)

In [None]:
#Collate themetrics and find the putative matches
# Average Euc Dist
EuclDist = np.nanmin(EuclDist[:,param['PeakLoc'] - param['waveidx'] == 0, :,:].squeeze(), axis = 1 )

# TotalScore
IncludeThesePairs = np.argwhere( EuclDist < param['MaxDist']) #array indices of pairs to include

# Make a dictionary of score to include
CentroidOverlordScore = (CentroidDistRecentered + CentroidVar) / 2
WaveformScore = (WVcorrScore + WFMSEscore) / 2
TrajectoryScore = (TrajAngleScore + TrajDistScore) / 2

Scores2Include = {'AmpScore' : AmpScore, 'SpatialDecayScore' : SpatialDecayScore, 'CentroidOverlord' : CentroidOverlordScore,
                  'CentroidDist' : CentroidDist, 'WaveformScore' : WaveformScore, 'TrajectoryScore': TrajectoryScore }

TotalScore, Predictors = mf.get_total_score(Scores2Include, param)

#Initial thresholding

ThrsOpt = mf.get_threshold(TotalScore, WithinSession, EuclDist, param, IsFirstPass = True)

param['nExpectedMatches'] = np.sum( (TotalScore > ThrsOpt).astype(int))
priorMatch = 1 - ( param['nExpectedMatches'] / len(IncludeThesePairs))
CandidatePairs = TotalScore > ThrsOpt


In [None]:
#drift
drifts, AvgCentroid, AvgWaveformPerTP = mf.drift_nSessions(CandidatePairs, SessionSwitch, AvgCentroid, AvgWaveformPerTP, TotalScore, param)

In [None]:
# re-do metric extraction with the drift corrected arrays

AvgWaveformPerTPFlip = mf.flip_dim(AvgWaveformPerTP, param)
EuclDist = mf.get_Euclidean_dist(AvgWaveformPerTPFlip,param)

CentroidDist, CentroidVar = mf.Centroid_metrics(EuclDist, param)

EuclDistRC = mf.get_recentered_Euclidean_dist(AvgWaveformPerTPFlip, AvgCentroid, param)

CentroidDistRecentered = mf.recentered_metrics(EuclDistRC)
TrajAngleScore, TrajDistScore = mf.dist_angle(AvgWaveformPerTPFlip, param)

# Average Euc Dist
EuclDist = np.nanmin(EuclDist[:,param['PeakLoc'] - param['waveidx'] == 0, :,:].squeeze(), axis = 1 )

# TotalScore
IncludeThesePairs = np.argwhere( EuclDist < param['MaxDist']) #array indices of pairs to include, in ML its IncludeThesePairs[:,1]
IncludeThesePairs_idx = np.zeros_like(TotalScore)
IncludeThesePairs_idx[EuclDist < param['MaxDist']] = 1 

# Make a dictionary of score to include
CentroidOverlordScore = (CentroidDistRecentered + CentroidVar) / 2
WaveformScore = (WVcorrScore + WFMSEscore) / 2
TrajectoryScore = (TrajAngleScore + TrajDistScore) / 2

Scores2Include = {'AmpScore' : AmpScore, 'SpatialDecayScore' : SpatialDecayScore, 'CentroidOverlord' : CentroidOverlordScore,
                  'CentroidDist' : CentroidDist, 'WaveformScore' : WaveformScore, 'TrajectoryScore': TrajectoryScore }

TotalScore, Predictors = mf.get_total_score(Scores2Include, param)
ThrsOpt = mf.get_threshold(TotalScore, WithinSession, EuclDist, param, IsFirstPass = False)


param['nExpectedMatches'] = np.sum( (TotalScore > ThrsOpt).astype(int))
priorMatch = 1 - ( param['nExpectedMatches'] / len(IncludeThesePairs))


In [None]:
# Set-up Bayes analysis
ThrsOpt = np.quantile(TotalScore[IncludeThesePairs_idx.astype(bool)], priorMatch)
CandidatePairs = TotalScore > ThrsOpt

priorMatch = 1 - (param['nExpectedMatches'] / param['nUnits']**2 ) #Can change value of priors 
Priors = np.array((priorMatch, 1-priorMatch))

labels = CandidatePairs.astype(int)
Cond = np.unique(labels)
ScoreVector = param['ScoreVector']
ParameterKernels = np.full((len(ScoreVector), len(Scores2Include), len(Cond)), np.nan)


In [None]:
# Run bayes analysis
ParameterKernels = bf.get_ParameterKernels(Scores2Include, labels, Cond, param, addone = 1)

Probability = bf.apply_naive_bayes(ParameterKernels, Priors, Predictors, param, Cond)

Output = Probability[:,1].reshape(param['nUnits'],param['nUnits'])

In [None]:
# Optional function tosummarise the output
util.evaluate_output(Output, param, WithinSession, SessionSwitch, MatchThreshold = 0.75)

Set a match threshold and look at the output

In [None]:
MatchThreshold = param['MatchThreshold']
#MatchThreshold = try different values here!
OutputThreshold = np.zeros_like(Output)
OutputThreshold[Output > MatchThreshold] = 1

plt.imshow(OutputThreshold, cmap = 'grays')
#plt.imshow(Output)
plt.colorbar()

In [None]:
# calcualte data and send data to the GUI
gui.process_info_for_GUI(Output, MatchThreshold, Scores2Include, TotalScore, Amplitude, SpatialDecay,
                         AvgCentroid, AvgWaveform, AvgWaveformPerTP, WaveIdx, MaxSite, MaxSiteMean, 
                         waveform, WithinSession, ChannelPos, ClusInfo, param)

#### Run the GUI
look at GUI_Reference_Guide.md for infomation on how to effectivley use the GUI!

In [None]:
#MatchesGUI is a list of 2 sets of matches for both CV 
#each array is symetric e.g will have (x,y) and (y,x) as a match
IsMatch, NotMatch, MatchesGUI = gui.run_GUI()

In [None]:
#all idx pairs where the proabilty is above the threshold
Matches = np.argwhere(OutputThreshold == 1) #include within session matches
matches = np.argwhere( ((OutputThreshold * WithinSession)) == True) #exclude within session macthes

#this function has 2 mode 'And' 'Or', which returns a matches if they appear in both or one cv pair
#then it will add all the matches selected as IsMaatch, then remove all matches in NotMatch
MatchesCurrated = util.currate_matches(MatchesGUI, IsMatch, NotMatch, Mode = 'And')

In [None]:
SaveDir = r'Path\to\Save\directory'

In [None]:
#save base output
su.save_to_output(SaveDir, Scores2Include, Matches, Output, AvgCentroid, AvgWaveform, AvgWaveformPerTP, MaxSite, TotalScore, OutputThreshold, ClusInfo, param, MatchesCurated = None, SaveMatchTable = True)

#save seperate CV output, option to save data so cross verification pairs are split up
#su.save_to_output_seperate_CV(SaveDir, Scores2Include, Matches, Output, AvgCentroid, AvgWaveform, AvgWaveformPerTP, MaxSite,
#                   TotalScore, MatchThreshold, ClusInfo, param, MatchesCurated = None, SaveMatchTable = True)
