# Exploration

# Setup

In [None]:
import scipy.io as spio
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import pandas as pd
from scipy.signal import butter, lfilter

In [None]:
data = pd.read_csv('./datasources/spikes/training_data_50k.csv', index_col=0)

In [None]:
locations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data.head()

In [None]:
locations.head()

# Plot Data

In [None]:
px.line(x=data['time (s)'], y=data['signal'])

Just checking data distribution

In [None]:
classCount = {'1': locations[locations['class'] == 1].shape[0], 
              '2': locations[locations['class'] == 2].shape[0], 
              '3': locations[locations['class'] == 3].shape[0], 
              '4': locations[locations['class'] == 4].shape[0]}

In [None]:
classCount

# Function Definition

### Filter

In [None]:
def bandPassFilter(signal, lowCut=300.00, highCut=3000.00, sampleRate=25000, order=1):
    
    # TODO: Calculate something
    nyq = 0.5 * sampleRate
    low = lowCut / nyq
    high = highCut / nyq
    
    # Generate filter coefficients for butterworth filter
    b, a = butter(order, [low, high], btype='bandpass')

    signalFiltered = lfilter(b, a, signal)
    return signalFiltered

### Detect spikes

In [None]:
def detectPeaks(data, threshold=1.0):
    df = data.loc[data['signalFiltered'] > threshold]
    
    valleys = df[(df['signalFiltered'].shift(1) > df['signalFiltered']) &
                 (df['signalFiltered'].shift(-1) > df['signalFiltered'])]
    
    peaks = df[(df['signalFiltered'].shift(1) < df['signalFiltered']) &
               (df['signalFiltered'].shift(-1) < df['signalFiltered'])]
    
    return peaks.index

### Get putative spike waveforms

In [None]:
def getSpikeWaveform(signal, spikeIndex, preceding=15, succeeding=15):
    return signal[spikeIndex-preceding:spikeIndex+succeeding]

### Get signal plot

In [None]:
def plotSignal(signal, peaks):
    
    fig = go.Figure()
    
    fig.add_trace(go.Scatter(
        y=signal,
        mode='lines',
        name='Signal'
    ))

    fig.add_trace(go.Scatter(
        x=peaks,
        y=[signal[j] for j in peaks],
        mode='markers',
        marker=dict(
            size=8,
            color='red',
            symbol='cross'
        ),
        name='Detected Peaks'
    ))

    fig.show()

---

In [None]:
data['signalFiltered'] = bandPassFilter(data['signal'])
data.head()

In [None]:
peakIndexes = detectPeaks(data)
peakIndexes

In [None]:
plotSignal(data['signal'], peakIndexes)

In [None]:
data['isPeak'] = False
data.loc[peakIndexes, 'isPeak'] = True

In [None]:
data.to_csv('./datasources/spikes/dev/training_data_50k.csv')