# Exploration

# Setup

In [None]:
import scipy.io as spio
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from scipy.signal import butter, lfilter, savgol_filter
from scipy.interpolate import CubicSpline, interp1d

In [None]:
data = pd.read_csv('./datasources/spikes/dev/training_data_50k.csv', index_col=0)

In [None]:
data.head()

# Plot Data

In [None]:
px.line(x=data['time (s)'], y=data['signal'])

# Function Definition

### Filter

In [None]:
def bandPassFilter(signal, lowCut=300.00, highCut=3000.00, sampleRate=25000, order=1):
    
    # TODO: Calculate something
    nyq = 0.5 * sampleRate
    low = lowCut / nyq
    high = highCut / nyq
    
    # Generate filter coefficients for butterworth filter
    b, a = butter(order, [low, high], btype='bandpass')

    signalFiltered = lfilter(b, a, signal)
    return signalFiltered

### Detect spikes

In [None]:
def detectPeaks(data, threshold=1.0):
    df = data.loc[data['signalFiltered'] > threshold]
    
    valleys = df[(df['signalFiltered'].shift(1) > df['signalFiltered']) &
                 (df['signalFiltered'].shift(-1) > df['signalFiltered'])]
    
    peaks = df[(df['signalFiltered'].shift(1) < df['signalFiltered']) &
               (df['signalFiltered'].shift(-1) < df['signalFiltered'])]
    
    return peaks.index

### Get putative spike waveforms

In [None]:
def getSpikeWaveform(spikes, data, window=200):
    
    if 'waveform' not in spikes.columns:
        spikes.insert(len(spikes.columns), 'waveform', None)
    
    for index in spikes.index:
        
        spikes.at[index, 'waveform'] = data.loc[index-int(window/4):index+int(3/4*window), 'signalFiltered'].tolist()
        
    return spikes

### Get signal plot

In [None]:
def plotSignal(signal, peaks):
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        y=signal,
        mode='lines',
        name='Signal'
    ))

    fig.add_trace(go.Scatter(
        x=peaks,
        y=[signal[j] for j in peaks],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
            symbol='cross'
        ),
        name='Detected Peaks'
    ))

    fig.show()

---

Add column for waveform

In [None]:
spikes = data.loc[data['isPeak']==True, :]
spikes.head()

In [None]:
spikes = getSpikeWaveform(spikes, data)
spikes.head()

In [None]:
z = 4*50
xRange = np.linspace(0,z, z+1)
sample = spikes.iloc[-10:, 4].tolist()

px.line(x=xRange, y=sample)

Need to smooth the waveforms but for now will continue to develop NN.