# Exploration

# Setup

In [None]:
import scipy.io as spio
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from scipy.signal import butter, lfilter
from nn_spikes import NeuralNetwork, batchTrain, test
import utilities

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data.head()

# Function Definition

### Filter

In [None]:
def bandPassFilter(signal, lowCut=300.00, highCut=3000.00, sampleRate=25000, order=1):
    
    # TODO: Calculate something
    nyq = 0.5 * sampleRate
    low = lowCut / nyq
    high = highCut / nyq
    
    # Generate filter coefficients for butterworth filter
    b, a = butter(order, [low, high], btype='bandpass')

    signalFiltered = lfilter(b, a, signal)
    return signalFiltered

### Detect spikes

In [None]:
def detectPeaks(data, threshold=1.0):
    df = data.loc[data['signalFiltered'] > threshold]
    
    valleys = df[(df['signalFiltered'].shift(1) > df['signalFiltered']) &
                 (df['signalFiltered'].shift(-1) > df['signalFiltered'])]
    
    peaks = df[(df['signalFiltered'].shift(1) < df['signalFiltered']) &
               (df['signalFiltered'].shift(-1) < df['signalFiltered'])]
    
    return peaks.index

### Get putative spike waveforms

In [None]:
def getSpikeWaveform(spikes, data, window=100):
    
    if 'waveform' not in spikes.columns:
        spikes.insert(len(spikes.columns), 'waveform', None)
    
    for index in spikes.index:
        
        waveform = data.loc[index-int(window/4):index+int(3/4*window), 'signal'].tolist()
        waveformSmooth = bandPassFilter(waveform)
        spikes.at[index, 'waveform'] = waveformSmooth
        
    return spikes

### Get signal plot

In [None]:
def plotSignal(signal, peaks):
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        y=signal,
        mode='lines',
        name='Signal'
    ))

    fig.add_trace(go.Scatter(
        x=peaks,
        y=[signal[j] for j in peaks],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
            symbol='cross'
        ),
        name='Detected Peaks'
    ))

    fig.show()

In [None]:
def plotLearningCurves(results):
    
    fig = go.Figure()
    
    for key in results.keys():
        fig.add_trace(go.Scatter(
            y=results[key]['trainingCurve'],
            line=dict(width=1, dash='dash'),
            name=str(key)+' training'
        ))
        
        fig.add_trace(go.Scatter(
            y=results[key]['validationCurve'],
            mode='lines',
            name=str(key)+' validation'
        ))

    fig.show()

In [None]:
def createResultsRepo(hiddenNodes=[200,500,700,900]):
    results = {}
    for depth in hiddenNodes:
        results[str(depth)] = {'trainingCurve':[], 
                               'validationCurve':[],
                               'trainingData':None, 
                               'validationData':None,
                               'nn':None}
        
    return results

In [None]:
def plotConfusion(matrix, x=[0,1,2,3], y=[0,1,2,3]):
    # change each element of z to type string for annotations
    matrixText = [[str(y) for y in x] for x in matrix]

    # set up figure 
    fig = ff.create_annotated_heatmap(matrix, x=x, y=y, annotation_text=matrixText, colorscale='Viridis')

    # add title
    fig.update_layout(title_text='<i><b>Confusion matrix</b></i>',
                      #xaxis = dict(title='x'),
                      #yaxis = dict(title='x')
                     )

    # add custom xaxis title
    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=0.5,
                            y=-0.15,
                            showarrow=False,
                            text="Predicted value",
                            xref="paper",
                            yref="paper"))

    # add custom yaxis title
    fig.add_annotation(dict(font=dict(color="black",size=14),
                            x=-0.35,
                            y=0.5,
                            showarrow=False,
                            text="Real value",
                            textangle=-90,
                            xref="paper",
                            yref="paper"))

    # adjust margins to make room for yaxis title
    fig.update_layout(margin=dict(t=50, l=200))

    # add colorbar
    fig['data'][0]['showscale'] = True
    fig.show()

---

Use labelled spikes to train the network, by first retrieving putative spike waveforms and passing it as input to NN. First we will split training data into training and validation.

### Split data into training and validation

In [None]:
splitSpike = int(len(spikeLocations)*3/4)
splitIndex = spikeLocations.iloc[splitSpike]['index']
print("Split index: {}\nSplit spike: {}".format(splitIndex, splitSpike))

In [None]:
data_training = data.iloc[:splitIndex]
data_validation = data.iloc[splitIndex:]
print("training size: {}\nvalidation size: {}".format(data_training.shape[0], data_validation.shape[0]))

In [None]:
data['isPeak'] = False
data.loc[spikeLocations['index'], 'isPeak'] = True

### Get all spikes

In [None]:
spikes = data.loc[data['isPeak']==True, :]
spikes.insert(len(spikes.columns), 'class', spikeLocations['class'].values)
spikes.head()

#### Get training and validation spikes

In [None]:
spikes_training = spikes[:splitSpike]
spikes_validation = spikes[splitSpike:]

print("training size: {}\nvalidation size: {}".format(spikes_training.shape[0], spikes_validation.shape[0]))

### Get spike waveforms

In [None]:
z=100
spikes_training = getSpikeWaveform(spikes_training, data, window=z)
spikes_validation = getSpikeWaveform(spikes_validation, data, window=z)
spikes_training.head()

In [None]:
xRange = np.linspace(0,z, z+1)
sample = spikes_training.iloc[20:40, 4].tolist()

px.line(x=xRange, y=sample)

# Training

### Set up results dict

In [None]:
results = createResultsRepo(hiddenNodes=[200, 500, 700])
results

### Create and train NNs

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=101, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.1,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve, df1, df2 = batchTrain(data_training=spikes_training,
                                                                                    data_validation=spikes_validation,
                                                                                    nn=nn,
                                                                                    epochs=20,
                                                                                    plotCurves=False)
    results[hid]['nn'] = nn
    results[hid]['trainingCurve'] = trainingCurve
    results[hid]['validationCurve'] = validationCurve
    results[hid]['trainingData'] = df1
    results[hid]['validationData'] = df2

In [None]:
utilities.plotLearningCurves(results)

In [None]:
utilities.getConfusion(results['500']['validationData']['class'], results['500']['validationData']['classPrediction'])