# Imports

In [None]:
# Import data processing libraries
import numpy as np
import pandas as pd

# PCA for dimensionality reduction and KNN classifier
from sklearn.decomposition import PCA
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import KNeighborsClassifier

# Function Definition

In [None]:
def KNN_classifier(data_training, data_validation, labels_training):

    # Instantiate new PCA model with 4 components, and fit to the training data
    pca = PCA(n_components=4)
    pca.fit(data_training)

    # Print the total variance explained
    print("Total Variance Explained: ", np.sum(pca.explained_variance_ratio_))

    # Extract the principal components from the training data and transform the validation data using those components
    componentsTraining = pca.fit_transform(data_training)
    componentsValidation = pca.transform(data_validation)

    # Normalise the datasets
    min_max_scaler = MinMaxScaler()
    normalisedTraining = min_max_scaler.fit_transform(componentsTraining)
    normalisedValidation = min_max_scaler.fit_transform(componentsValidation)

    # Create a KNN classification system with k = 4, using the (p2) Euclidean norm and fit on the training data
    knn = KNeighborsClassifier(n_neighbors=4, p=2)
    knn.fit(normalisedTraining, labels_training)

    # Apply trained classifier to validation data
    predictions = knn.predict(normalisedValidation)

    return predictions, componentsTraining

# Setup

In [None]:
import plotly.express as px
import numpy as np
import pandas as pd
from nn_spikes import NeuralNetwork, batchTrain, test
import spike_tools, utilities

In [None]:
data = pd.read_csv('./datasources/spikes/training_data.csv')
spikeLocations = pd.read_csv('./datasources/spikes/training_spike_locations.csv', index_col=0)

In [None]:
data.head(3)

Use labelled spikes to train the network, by first retrieving putative spike waveforms and passing it as input to NN. First we will split training data into training and validation.

In [None]:
data = spike_tools.joinSpikes(data, spikeLocations)

In [None]:
data.head(3)

# Detect spikes

### Detect spikes yourself

Filter signal

In [None]:
data['signalFiltered'] = spike_tools.bandPassFilter(data['signal'], lowCut=300, highCut=3000,order=1)
data.head()

Predict peaks

In [None]:
data, predictedPeakIndexes = spike_tools.detectPeaks(data)
data.head(3)

Get spike waveforms for predicted spikes

In [None]:
data = spike_tools.getSpikeWaveforms(predictedPeakIndexes, data)
data.head()

Get spike waveforms for known spikes

In [None]:
knownSpikeIndexes = data[data['knownSpike']==True].index
data = spike_tools.getSpikeWaveforms(knownSpikeIndexes, data)

Plot spikes overlapped on original signal

In [None]:
sample = data.iloc[1152030-2500:1152030+2500, :]
spike_tools.plotSpikes([sample['signal'], sample['signalFiltered']], [sample['knownSpike'], sample['predictedSpike']])

In [None]:
# waves = sample[sample['predictedSpike']==True]['waveform'].tolist()

# px.line(x=np.linspace(0,100, 101), y=waves)

---

create datasets ready to pass to neural network

In [None]:
data_training, data_validation, spikeIndexes_training, spikeIndexes_validation = spike_tools.splitData(data, knownSpikeIndexes, trainingShare=0.8)

# Run KNN Classifier

In [None]:
predictions, componentsTraining = KNN_classifier(data_training.loc[spikeIndexes_training, 'waveform'].to_list(), 
                                                 data_validation.loc[spikeIndexes_validation, 'waveform'].to_list(), 
                                                 data_training.loc[spikeIndexes_training, 'knownClass'].to_list())

In [None]:
componentsTraining.T[0]

In [None]:
componentsTraining[:5]

In [None]:
px.scatter(x=componentsTraining.T[0], y=componentsTraining.T[1])

In [None]:
data.loc[spikeIndexes_validation, 'predictedClass'] = predictions

In [None]:
data.loc[spikeIndexes_validation]

In [None]:
matches = sum(data.loc[spikeIndexes_validation, 'knownClass'] == data.loc[spikeIndexes_validation, 'predictedClass'])

In [None]:
acc = matches/len(data.loc[spikeIndexes_validation])
acc