# Exploration

# Setup

In [1]:
import plotly.express as px
import numpy as np
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities

In [2]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [3]:
data.head(3)

Unnamed: 0,time (s),signal
0,0.0,-0.091413
1,4e-05,0.371792
2,8e-05,0.093167


# Function definition

In [None]:
def splitData(data, spikes):
    
    # Get split index
    splitSpike = int(len(spikes)*3/4)
    splitIndex = spikes.iloc[splitSpike]['index']
    
    if not 'knownSpike' in data.columns:
        # Create 2 new columns for spike data and prepare 2 additional columns for predicted spike data
        data.insert(len(data.columns), 'knownSpike', False)
        data.insert(len(data.columns), 'knownClass', 0)
        data.insert(len(data.columns), 'predictedSpike', False)
        data.insert(len(data.columns), 'predictedClass', 0)

        # Store spike data in new columns
        data.loc[spikes['index'], 'knownSpike'] = True
        data.loc[spikes['index'], 'knownClass'] = spikes['class'].values
    else:
        print("New columns already exist, just splitting data.")
    
    # Return training and validation data
    return data.iloc[:splitIndex], data.iloc[splitIndex:]

---

Use labelled spikes to train the network, by first retrieving putative spike waveforms and passing it as input to NN. First we will split training data into training and validation.

In [None]:
data_training, data_validation = splitData(data, spikeLocations)

In [None]:
data_training.head(3)

In [None]:
data_validation.head(3)

---

# Detect spikes

### Detect spikes yourself

Filter signal

In [None]:
data['signalFiltered'] = spike_tools.bandPassFilter(data['signal'])
data.head()

Predict peaks

In [None]:
data, predictedPeakIndexes = spike_tools.detectPeaks(data)
data.head(3)

### Get spike waveforms

In [None]:
data = spike_tools.getSpikeWaveforms(predictedPeakIndexes, data)
spikes_training.head()

In [None]:
xRange = np.linspace(0,z, z+1)
sample = spikes_training.iloc[20:40, 4].tolist()

px.line(x=xRange, y=sample)

# Training

### Set up results dict

In [None]:
results = utilities.createResultsRepo(hiddenNodes=[200, 500, 700])
results

### Create and train NNs

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=101, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve, df1, df2 = batchTrain(data_training=spikes_training,
                                                                                    data_validation=spikes_validation,
                                                                                    nn=nn,
                                                                                    epochs=20,
                                                                                    plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve
    results[hid]['trainingData'] = df1
    results[hid]['validationData'] = df2

In [None]:
utilities.plotLearningCurves(results)

In [None]:
utilities.getConfusion(results['500']['validationData']['class'], results['500']['validationData']['classPrediction'])