# Imports

In [None]:
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities, KNN_spikes
from nn_spikes import getInputsAndTargets
import plotly.express as px

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.dataPreProcess(data, spikeLocations, waveformWindow=100)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=3)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=2)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=1)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=0)

# Run Neural Network Classifier

In [None]:
results = utilities.createResultsRepo([500])

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=61, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=30,
                                                              plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve

In [None]:
waveforms = data_validation[data_validation['predictedSpike']==True]['waveform']
predictionsMLP = spike_tools.classifySpikesMLP(waveforms, nn)

# Run KNN Classifier

In [None]:
for index in spikeIndexes_validation:
    _, _, label = getInputsAndTargets(data_validation.loc[index, 'waveform'], 4, data_validation.loc[index - 10:index + 5, 'knownClass'])
    data_validation.loc[index, 'knownClass'] = label + 1
    
for index in spikeIndexes_training:
    _, _, label = getInputsAndTargets(data_training.loc[index, 'waveform'], 4, data_training.loc[index - 10:index + 5, 'knownClass'])
    data_training.loc[index, 'knownClass'] = label + 1

Run KNN for a range of neighbors and components

In [None]:
results = {}
for n_neighbors in [3,4,5,6,7,8]:
    for n_components in [3,4,5,6,7,8,9,10]:
        predictionsKNN, componentsTraining, cs = KNN_spikes.KNN_classifier(data_training.loc[spikeIndexes_training, 'waveform'].to_list(), 
                                                         data_validation.loc[spikeIndexes_validation, 'waveform'].to_list(), 
                                                         data_training.loc[spikeIndexes_training, 'knownClass'].to_list(), n_components=int(n_components), n_neighbors=n_neighbors)

        results['comps:' + str(n_components) + '_neigh:' + str(n_neighbors)] = predictionsKNN

Calculate accuracies for all and store in new dict

In [None]:
accuracies = {}
for key in results.keys():
    data_validation.loc[spikeIndexes_validation, 'predictedClass'] = results[key]
    matches = sum(data_validation.loc[spikeIndexes_validation, 'knownClass'] == data_validation.loc[spikeIndexes_validation, 'predictedClass'])
    acc = matches/len(data_validation.loc[spikeIndexes_validation])
    accuracies[key] = acc

Convert dict into something plotly can plot

In [None]:
df = pd.DataFrame(data=accuracies, index=[0]).T
df.columns = ['accuracy']

In [None]:
px.scatter(df, y="accuracy")

In [None]:
df = pd.DataFrame(data=cs, index=np.arange(0, len(cs)), columns=['cumulative sum of explained variance'])
px.line(df, y="cumulative sum of explained variance")

In [None]:
data_validation.loc[spikeIndexes_validation, 'predictedClass'] = results['comps:8_neigh:6']

In [None]:
df = pd.DataFrame(data=(componentsTraining.T[0], componentsTraining.T[1]), index=['1st component', '2nd component'])

In [None]:
df.T

In [None]:
px.scatter(df.T, x="1st component", y="2nd component")

In [None]:
matches = sum(data_validation.loc[spikeIndexes_validation, 'knownClass'] == data_validation.loc[spikeIndexes_validation, 'predictedClass'])

In [None]:
acc = matches/len(data_validation.loc[spikeIndexes_validation])
acc

---

In [None]:
data_validation.loc[spikeIndexes_validation]

In [None]:
utilities.getConfusion(actual=data_validation.loc[spikeIndexes_validation, 'knownClass'], 
                       predicted=data_validation.loc[spikeIndexes_validation, 'predictedClass'])