# Imports

In [1]:
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities, KNN_spikes
from nn_spikes import getInputsAndTargets
import plotly.express as px

In [2]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [3]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.dataPreProcess(data, spikeLocations, waveformWindow=100)

3528 peaks detected.


In [4]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=3)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=2)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=1)

In [None]:
#spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=0)

# Run Neural Network Classifier

In [None]:
results = utilities.createResultsRepo([500])

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=61, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=30,
                                                              plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve

In [None]:
waveforms = data_validation[data_validation['predictedSpike']==True]['waveform']
predictionsMLP = spike_tools.classifySpikesMLP(waveforms, nn)

# Run KNN Classifier

In [5]:
for index in spikeIndexes_validation:
    _, _, label = getInputsAndTargets(data_validation.loc[index, 'waveform'], 4, data_validation.loc[index - 10:index + 5, 'knownClass'])
    data_validation.loc[index, 'knownClass'] = label + 1
    
for index in spikeIndexes_training:
    _, _, label = getInputsAndTargets(data_training.loc[index, 'waveform'], 4, data_training.loc[index - 10:index + 5, 'knownClass'])
    data_training.loc[index, 'knownClass'] = label + 1

Run KNN for a range of neighbors and components

In [52]:
results = {}
for n_neighbors in [3,4,5,6,7,8]:
    for n_components in [3,4,5,6,7,8,9,10]:
        predictionsKNN, componentsTraining, cs = KNN_spikes.KNN_classifier(data_training.loc[spikeIndexes_training, 'waveform'].to_list(), 
                                                         data_validation.loc[spikeIndexes_validation, 'waveform'].to_list(), 
                                                         data_training.loc[spikeIndexes_training, 'knownClass'].to_list(), n_components=int(n_components), n_neighbors=n_neighbors)

        results['comps:' + str(n_components) + '_neigh:' + str(n_neighbors)] = predictionsKNN

Total Variance Explained:  0.7202922304590138
Total Variance Explained:  0.8074610921536692
Total Variance Explained:  0.8859467544485391
Total Variance Explained:  0.9255395477758733
Total Variance Explained:  0.9513512530071864
Total Variance Explained:  0.9700928919026987
Total Variance Explained:  0.9800149427646909
Total Variance Explained:  0.9858643597331648
Total Variance Explained:  0.7202922304590139
Total Variance Explained:  0.8074610921536698
Total Variance Explained:  0.8859467544485412
Total Variance Explained:  0.9255395477758742
Total Variance Explained:  0.9513512530071874
Total Variance Explained:  0.9700928919026983
Total Variance Explained:  0.9800149427646911
Total Variance Explained:  0.9858643597331657
Total Variance Explained:  0.7202922304590139
Total Variance Explained:  0.8074610921536707
Total Variance Explained:  0.8859467544485404
Total Variance Explained:  0.9255395477758731
Total Variance Explained:  0.9513512530071866
Total Variance Explained:  0.97009

Calculate accuracies for all and store in new dict

In [53]:
accuracies = {}
for key in results.keys():
    data_validation.loc[spikeIndexes_validation, 'predictedClass'] = results[key]
    matches = sum(data_validation.loc[spikeIndexes_validation, 'knownClass'] == data_validation.loc[spikeIndexes_validation, 'predictedClass'])
    acc = matches/len(data_validation.loc[spikeIndexes_validation])
    accuracies[key] = acc

Convert dict into something plotly can plot

In [54]:
df = pd.DataFrame(data=accuracies, index=[0]).T
df.columns = ['accuracy']

In [51]:
px.scatter(df, y="accuracy")

In [60]:
df = pd.DataFrame(data=cs, index=np.arange(0, len(cs)), columns=['cumulative sum of explained variance'])
px.line(df, y="cumulative sum of explained variance")

In [61]:
data_validation.loc[spikeIndexes_validation, 'predictedClass'] = results['comps:8_neigh:6']

In [62]:
df = pd.DataFrame(data=(componentsTraining.T[0], componentsTraining.T[1]), index=['1st component', '2nd component'])

In [63]:
df.T

Unnamed: 0,1st component,2nd component
0,-2.012972,9.068497
1,-7.835550,7.907335
2,7.631584,3.126444
3,0.445331,-1.652373
4,6.664038,1.585088
...,...,...
2820,0.952042,-3.071266
2821,-8.814219,6.978024
2822,4.719762,12.564016
2823,0.558915,-2.554398


In [64]:
px.scatter(df.T, x="1st component", y="2nd component")

In [65]:
matches = sum(data_validation.loc[spikeIndexes_validation, 'knownClass'] == data_validation.loc[spikeIndexes_validation, 'predictedClass'])

In [66]:
acc = matches/len(data_validation.loc[spikeIndexes_validation])
acc

0.9388335704125178

---

In [None]:
data_validation.loc[spikeIndexes_validation]

In [67]:
utilities.getConfusion(actual=data_validation.loc[spikeIndexes_validation, 'knownClass'], 
                       predicted=data_validation.loc[spikeIndexes_validation, 'predictedClass'])

              precision    recall  f1-score   support

           1       0.95      0.97      0.96       172
           2       0.93      0.89      0.91       187
           3       0.91      0.96      0.94       188
           4       0.97      0.94      0.95       156

    accuracy                           0.94       703
   macro avg       0.94      0.94      0.94       703
weighted avg       0.94      0.94      0.94       703

