# Exploration

# Setup

In [96]:
import scipy.io as spio
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from scipy.signal import butter, lfilter
from nn_general import NeuralNetwork, batchTrain, test

In [5]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [6]:
data.head()

Unnamed: 0,time (s),signal
0,0.0,-0.091413
1,4e-05,0.371792
2,8e-05,0.093167
3,0.00012,-0.02453
4,0.00016,0.182676


# Plot Data

In [None]:
px.line(x=data['time (s)'], y=data['signal'])

# Function Definition

### Filter

In [91]:
def bandPassFilter(signal, lowCut=300.00, highCut=3000.00, sampleRate=25000, order=1):
    
    # TODO: Calculate something
    nyq = 0.5 * sampleRate
    low = lowCut / nyq
    high = highCut / nyq
    
    # Generate filter coefficients for butterworth filter
    b, a = butter(order, [low, high], btype='bandpass')

    signalFiltered = lfilter(b, a, signal)
    return signalFiltered

### Detect spikes

In [8]:
def detectPeaks(data, threshold=1.0):
    df = data.loc[data['signalFiltered'] > threshold]
    
    valleys = df[(df['signalFiltered'].shift(1) > df['signalFiltered']) &
                 (df['signalFiltered'].shift(-1) > df['signalFiltered'])]
    
    peaks = df[(df['signalFiltered'].shift(1) < df['signalFiltered']) &
               (df['signalFiltered'].shift(-1) < df['signalFiltered'])]
    
    return peaks.index

### Get putative spike waveforms

In [83]:
def getSpikeWaveform(spikes, data, window=100):
    
    if 'waveform' not in spikes.columns:
        spikes.insert(len(spikes.columns), 'waveform', None)
    
    for index in spikes.index:
        
        waveform = data.loc[index-int(window/4):index+int(3/4*window), 'signal'].tolist()
        waveformSmooth = bandPassFilter(waveform)
        spikes.at[index, 'waveform'] = waveformSmooth
        
    return spikes

### Get signal plot

In [10]:
def plotSignal(signal, peaks):
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        y=signal,
        mode='lines',
        name='Signal'
    ))

    fig.add_trace(go.Scatter(
        x=peaks,
        y=[signal[j] for j in peaks],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
            symbol='cross'
        ),
        name='Detected Peaks'
    ))

    fig.show()

---

Use labelled spikes to train the network, by first retrieving putative spike waveforms and passing it as input to NN. First we will split training data into training and validation.

In [28]:
splitSpike = int(len(spikeLocations)*3/4)
splitIndex = spikeLocations.iloc[splitSpike]['index']
splitIndex

1081118

In [34]:
data_training = data.iloc[:-splitIndex]
data_validation = data.iloc[-splitIndex:]
spikes_training = spikeLocations.iloc[:-splitSpike]
spikes_validation = spikeLocations.iloc[-splitSpike:]

In [None]:
plotSignal(data_validation['signal'], spikes_validation['index'])

In [62]:
data_tiny = data.iloc[:6500]
spikes_tiny = spikeLocations.iloc[:20]

In [None]:
plotSignal(data_tiny['signal'], spikes_tiny['index'])

### Get Spikes

In [92]:
data['isPeak'] = False
data.loc[spikeLocations['index'], 'isPeak'] = True

In [107]:
spikes = data.loc[data['isPeak']==True, :]
spikes.insert(len(spikes.columns), 'class', spikeLocations['class'].values)
spikes.head()

Unnamed: 0,time (s),signal,isPeak,class
731,0.02924,-0.121707,True,1
752,0.03008,1.123094,True,2
903,0.03612,0.006133,True,4
945,0.0378,0.777869,True,1
1110,0.0444,0.243796,True,2


In [108]:
z=100
spikes = getSpikeWaveform(spikes, data, window=z)
spikes.head()

Unnamed: 0,time (s),signal,isPeak,class,waveform
731,0.02924,-0.121707,True,1,"[0.04701036258506873, 0.08765796884474528, 0.0..."
752,0.03008,1.123094,True,2,"[0.009818887357405522, -0.05587195448249055, -..."
903,0.03612,0.006133,True,4,"[0.04635972436554307, 0.08411508952172055, 0.0..."
945,0.0378,0.777869,True,1,"[0.8601496474533747, 1.9916123234301182, 2.349..."
1110,0.0444,0.243796,True,2,"[0.057789209189221175, 0.16409504417357673, 0...."


In [110]:
xRange = np.linspace(0,z, z+1)
sample = spikes.iloc[-10:, 4].tolist()

px.line(x=xRange, y=sample)

# Training

In [97]:
results = {'900': {'training':[], 'validation':[], 'nn':None}, 
           '700': {'training':[], 'validation':[], 'nn':None}, 
           '200': {'training':[], 'validation':[], 'nn':None}}
results

{'900': {'training': [], 'validation': [], 'nn': None},
 '700': {'training': [], 'validation': [], 'nn': None},
 '200': {'training': [], 'validation': [], 'nn': None}}

In [100]:
epochs=55
batchSize=5000

In [None]:
for hid in results.keys():
    
    nn = NeuralNetwork(input_nodes=100, 
                       hidden_nodes=int(hid), 
                       output_nodes=4, 
                       lr=0.2,
                       error_function='difference-squared')

    nn, trainingCurve, validationCurve, = batchTrain(data_training=data_training,
                                                     data_validation=data_validation,
                                                     nn=nn,
                                                     batchSize=batchSize,
                                                     epochs=epochs)
    results[hid]['nn'] = nn
    results[hid]['training'] = trainingCurve
    results[hid]['validation'] = validationCurve