# Setup

In [None]:
import plotly.express as px
import numpy as np
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities
import pickle

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.dataPreProcess(data, spikeLocations=predictedPeakIndexes)

# Training

### Set up results dict

In [None]:
results = utilities.createResultsRepo([500])

### Create and train NNs

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=101, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=30,
                                                              plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve

# Load Submission Data

In [None]:
data_submission = pd.read_csv('./datasources/spikes/submission_data.csv')
data_submission.columns = ['time (s)', 'signal']

In [None]:
data_submission, predictedSpikeIndexes_submission = spike_tools.dataPreProcess(data_submission, threshold=1.25, submission=True)

Extract wavs

In [None]:
waveforms = data_submission[data_submission['predictedSpike']==True]['waveform']

Make predictions

In [None]:
predictions = spike_tools.classifySpikesMLP(waveforms.iloc[1:], nn)

---

In [None]:
data_submission

In [None]:
import plotly.graph_objects as go
signals = [data_submission['signal'][2000:8000], data_submission['signalFiltered'][2000:8000]]

In [None]:
fig = go.Figure()

for signal in signals:
    fig.add_trace(go.Scatter(
        x=signal.index,
        y=signal,
        mode='lines',
        name=signal.name,
        opacity=0.5,
    ))
        
fig.show()