# Imports

In [None]:
import pandas as pd
import numpy as np
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities, KNN_spikes
from nn_spikes import getInputsAndTargets
import plotly.express as px

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.dataPreProcess(data, spikeLocations, waveformWindow=100)

In [None]:
spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=3)

In [None]:
spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=2)

In [None]:
spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=1)

In [None]:
spike_tools.getAverageWaveforms(data_training, spikeIndexes_training[:200], classToPlot=0)

# Run Neural Network Classifier

In [None]:
results = utilities.createResultsRepo([1100])

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=len(data_training.loc[spikeIndexes_training[0], 'waveform']), 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=40,
                                                              plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve

Plot Learning curves

In [None]:
utilities.plotLearningCurves(results)

In [None]:
data_validation.loc[spikeIndexes_validation].head()

### Predict on validation dataset

In [None]:
waveforms = data_validation.loc[spikeIndexes_validation, 'waveform']
predictions = spike_tools.classifySpikesMLP(waveforms, results['1100']['nn'])
data_validation.at[spikeIndexes_validation, 'predictedClass'] = pd.Series(predictions).values

In [None]:
data_validation.loc[spikeIndexes_validation]

# Plot Confusion

In [None]:
actual = data_validation.loc[spikeIndexes_validation, 'assignedKnownClass'].values + 1
actual[:10]

In [None]:
predicted = data_validation.loc[spikeIndexes_validation, 'predictedClass'].values + 1
predicted[:10]

In [None]:
utilities.getConfusion(actual.tolist(), predicted.tolist())