# Imports

In [None]:
import pandas as pd
import numpy as np
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities
from nn_spikes import getInputsAndTargets
import plotly.express as px

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data, predictedSpikeIndexes = spike_tools.dataPreProcess(data, spikeLocations, waveformWindow=154, waveformSignalType='signalHPSavgol')

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.splitData(data, predictedSpikeIndexes)

# Run Neural Network Classifier

In [None]:
nn = NeuralNetwork(input_nodes=len(data_training.loc[spikeIndexes_training[0], 'waveform']), 
                       hidden_nodes=735, 
                       output_nodes=4, 
                       lr=0.2,
                       error_function='difference-squared')

nn, trainingCurve, validationCurve = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=40)

Plot Learning curves

In [None]:
from plotly import graph_objects as go

In [None]:
fig = go.Figure()

fig.add_trace(go.Scatter(
    y=trainingCurve,
    line=dict(width=1, dash='dash'),
    name='training'
))

fig.add_trace(go.Scatter(
    y=validationCurve,
    mode='lines',
    name='validation'
))

fig.show()

In [None]:
for i in [0,1,2,3]:
    spike_tools.getAverageWaveforms(data_training, spikeIndexes_training, classToPlot=3)

# Load submission data

In [None]:
dataSubmission = pd.read_csv('./datasources/spikes/submission_data.csv')
dataSubmission.columns = ['time (s)', 'signal']
dataSubmission.head()

In [None]:
dataSubmission, predictedSpikeIndexes = spike_tools.dataPreProcess(dataSubmission, spikeLocations, detectPeaksOn='signalHPSavgol', threshold=0.9, waveformWindow=154, submission=True, waveformSignalType='signalHPSavgol')

In [None]:
dataSubmission.loc[predictedSpikeIndexes]

In [None]:
sample = dataSubmission.iloc[25000:50000, :]
spike_tools.plotSpikes(signals=[sample['signal'], 
                                sample['signalSavgol'], 
                                sample['signalSavgolBP'], 
                                sample['signalHP'], 
                                sample['signalHPSavgol']], 
                       spikes=[sample['predictedSpike']])

In [None]:
# Create an empty string to accumulate the count of correct predictions
scorecard = []
predictions = []

# Iterate over each spike and query the trained neural network
for index in predictedSpikeIndexes[1:]:

    # Retrieve only the inputs (spike waveforms) to the network
    inputs, _ = getInputsAndTargets(dataSubmission.loc[index, 'waveform'], nn.output_nodes, 0)

    # Query the network to identify the predicted output for teh given spike waveform
    prediction = nn.query(inputs)
        
    predictions.append(prediction)

dataSubmission.loc[predictedSpikeIndexes, 'predictedClass'] = pd.Series(predictions, index=predictedSpikeIndexes[1:])

In [None]:
detectedSpikes = dataSubmission.loc[predictedSpikeIndexes]
detectedSpikes

In [None]:
for c in [0,1,2,3]:
    print("{}: {}".format(c, len(detectedSpikes[detectedSpikes['predictedClass'] == c])))

In [None]:
dataSubmission.loc[predictedSpikeIndexes[1:]].to_csv('./datasources/spikes/results.csv')

In [None]:
correctClasses = dataSubmission.loc[predictedSpikeIndexes[1:], 'predictedClass'] + 1
correctClasses.head()

In [None]:
import scipy.io as spio
# Store the submission data according to the selected classifier
Name = "13243.mat"
spio.savemat(Name, {"Index":predictedSpikeIndexes[1:],"Class":correctClasses.values})

---

In [None]:
# Retrieve a dataframe containing only spike entries and select the waveform extracts for a given class
detectedSpikes = dataSubmission.loc[predictedSpikeIndexes]
classWaveforms = detectedSpikes[detectedSpikes['predictedClass'] == 3]['waveform']

# Create vertical stack of all waveform values for that class
stack = np.vstack(classWaveforms.values)

# Create new list ready to store average values
avgs = []

# Loop over each column in stacked waveform values. This is equivalent to going point by point through the waveforms and taking
# the averages of values at that point for all waveforms in that class
for col in range(stack.shape[1]):
    colAvg = np.average(stack[:, col])
    # Store average of that point in a list. List will be of same length that the window is when extracting the waveforms
    avgs.append(colAvg)

# Store list of averages by casting to a series and appending at start of original store of waveforms
# (this is to make indexing it straightforward as classes will contain different number of waveforms)
classWaveforms = pd.Series([avgs]).append(classWaveforms)

# Create new Plotly graph objects figure
fig = go.Figure()

# Plot all waveforms on the same figure, with 10% opacity. Then plot the average waveform in full opacity on top.
for trace in classWaveforms[1:]:
    fig.add_trace(go.Scatter(x=np.linspace(0, 100, 101),
                             y=trace,
                             mode='lines',
                             line=dict(color='black'),
                             opacity=0.1,
                             ))

fig.add_trace(go.Scatter(x=np.linspace(0, 100, 101),
                         y=classWaveforms[0],
                         mode='lines', ))

fig.show()