# Imports

In [None]:
import pandas as pd
import numpy as np
from nn_spikes import NeuralNetwork, batchTrain
from spike_tools import classifySpikesMLP, getSpikeWaveforms
import plotly.express as px
from simulated_annealing import anneal

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

data_training = pd.read_csv('./datasources/spikes/dev/data_training_SA.csv')
data_training.set_index(data_training.columns[0], drop=True, inplace=True)
data_training.index.name='index'
data_training.head(3)

data_validation = pd.read_csv('./datasources/spikes/dev/data_validation_SA.csv')
data_validation.set_index(data_validation.columns[0], drop=True, inplace=True)
data_validation.index.name='index'
data_validation.head(3)

spikeIndexes_training = pd.read_csv('./datasources/spikes/dev/spikeIndexes_training_SA.csv')
spikeIndexes_training.set_index(spikeIndexes_training.columns[0], drop=True, inplace=True)
spikeIndexes_training = spikeIndexes_training.values.flatten()
spikeIndexes_training[:5]

spikeIndexes_validation = pd.read_csv('./datasources/spikes/dev/spikeIndexes_validation_SA.csv')
spikeIndexes_validation.set_index(spikeIndexes_validation.columns[0], drop=True, inplace=True)
spikeIndexes_validation = spikeIndexes_validation.values.flatten()
spikeIndexes_validation[:5]

# Run simulated annealing optimiser

In [None]:
# epochs, hidden_nodes, lr
solution = [15,500,0.2] 

# Simulated annealing optimisation
results = anneal(solution, spikeLocations, iterations=4, alpha=0.6, variation=0.3)

In [None]:
df = pd.DataFrame(results, columns=['Temperature', 'iteration', 'Solution', 'Error'])
df.set_index('Temperature', drop=True, inplace=True)

In [None]:
df

In [None]:
px.line(df['Error'], y="Error")

In [None]:
print("Best performance", 99.9-results[-1][-1])

### Predict on validation dataset

In [None]:
waveforms = data_validation.loc[spikeIndexes_validation, 'waveform']
predictions = spike_tools.classifySpikesMLP(waveforms, results['1100']['nn'])
data_validation.at[spikeIndexes_validation, 'predictedClass'] = pd.Series(predictions).values

In [None]:
data_validation.loc[spikeIndexes_validation]