# Exploration

In [None]:
def evaluateSpikeDetectionPerformance(knownSpikes, predictedSpikes, windowLeading=-1, windowTrailing=35):
    # for each known spike, check if any index within the range index-windowLeading:index+windowTrailing is 
    # within the predicted spike index
    
    results = pd.DataFrame(index=predictedSpikes, columns=['predictionResult'])
    falseNegatives = 0
    falsePositives
    
    for known in knownSpikes:
        allowable = np.arange(known-windowLeading, known+windowTrailing)
        
        below = predictedSpikes < max(allowable)
        above = predictedSpikes > min(allowable)
        
        matches = predictedSpikes[above == below]
        
        if len(matches) == 1:
            results.iloc[matches[0]] = 'TP'
            continue
        elif len(matches) == 0:
            falseNegatives += 1

In [None]:
def classifySpikes(waveforms, nn):

    # Ensure data is of type pandas dataframe
    assert isinstance(waveforms, pd.Series)

    # Create an empty string to accumulate the count of correct predictions
    predictions = []

    # Iterate over each row in the data
    for waveform in waveforms:

        # Query the network
        outputs = nn.query(waveform.tolist())

        # Identify predicted label
        prediction = np.argmax(outputs)

        # Correct label predicted to account for non-zero counting of neuron types and append to list of classified action potentials
        predictions.append(prediction+1)

    return predictions

# Setup

In [None]:
import plotly.express as px
import numpy as np
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data.head(3)

Use labelled spikes to train the network, by first retrieving putative spike waveforms and passing it as input to NN. First we will split training data into training and validation.

In [None]:
data = spike_tools.joinSpikes(data, spikeLocations)

In [None]:
data.head(3)

# Detect spikes

### Detect spikes yourself

Filter signal

In [None]:
data['signalFiltered'] = spike_tools.bandPassFilter(data['signal'], lowCut=300, highCut=3000,order=1)
data.head()

Predict peaks

In [None]:
data, predictedPeakIndexes = spike_tools.detectPeaks(data)
data.head(3)

Get spike waveforms for predicted spikes

In [None]:
data = spike_tools.getSpikeWaveforms(predictedPeakIndexes, data)
data.head()

Get spike waveforms for known spikes

In [None]:
knownSpikeIndexes = data[data['knownSpike']==True].index
data = spike_tools.getSpikeWaveforms(knownSpikeIndexes, data)

Plot spikes overlapped on original signal

In [None]:
sample = data.iloc[1152030-2500:1152030+2500, :]
spike_tools.plotSpikes([sample['signal'], sample['signalFiltered']], [sample['knownSpike'], sample['predictedSpike']])

In [None]:
# waves = sample[sample['predictedSpike']==True]['waveform'].tolist()

# px.line(x=np.linspace(0,100, 101), y=waves)

---

create datasets ready to pass to neural network

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.splitData(data, predictedPeakIndexes, trainingShare=0.8)

# Training

### Set up results dict

In [None]:
results = utilities.createResultsRepo()

### Create and train NNs

In [None]:
lastTrainingSpike = len(data_training[data_training['predictedSpike']==True])

for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=101, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve, df1, df2 = batchTrain(data_training=data_training,
                                                              data_validation=data_validation,
                                                              spikeIndexes_training=spikeIndexes_training, 
                                                              spikeIndexes_validation=spikeIndexes_validation, 
                                                              nn=nn,
                                                              epochs=30,
                                                              plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve
    results[hid]['trainingData'] = df1
    results[hid]['validationData'] = df2

In [None]:
crappies = {}
for crappy in [54412, 87433, 165493, 232479, 299250, 312319, 339791, 472193, 980407]:
    crappies[str(crappy)] = data.loc[crappy-10: crappy+35]

# Classify

In [None]:
waveforms = data[data['predictedSpike']==True]['waveform']

In [None]:
nn = results['500']['nn']
assert isinstance(nn, NeuralNetwork)

In [None]:
data.loc[data['predictedSpike']==True, 'predictedClass'] = classifySpikes(waveforms, nn)

In [None]:
trues = pd.Series(data[data['predictedSpike']==True].index)
truesShiftL = trues + 2
truesShiftR = trues - 2
mask = sorted(trues.append([truesShiftL, truesShiftR]))

---

In [None]:
utilities.plotLearningCurves(results)