In [None]:
# imports
import numpy as np
from scipy.io import loadmat
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.metrics import f1_score, confusion_matrix
import pandas as pd

In [None]:
# load data
sample_recording = True
all_data = loadmat('data/sample_1.mat')
sample_data = np.array(all_data['data'])

# sample_data = np.array([np.load('data/recording.npy')])
# print(sample_data.shape)

# extract spike times for later comparison
spike_times = np.array(all_data['spike_times'])
print('# of actual spikes (ground truth):', len(spike_times[0][0][0]))

In [None]:
# parameters
seconds = 120
sr = 24000 # sample rate = 24kHz (For sample recordings) / 32051 for real recording

data = sample_data[0][:round(seconds*sr)]
ground_truth_spikes = spike_times[0][0][0][:round(seconds*sr)]

mean_data = np.mean(data)
std_data = np.std(data)

In [None]:
def detect_spikes(data, multiplier):
    spike_times_start_only = []
    spike_times_reconstructed = np.where(data >= np.mean(data) + multiplier * np.std(data))[0]
    
    if len(spike_times_reconstructed) > 0:
        spike_times_start_only.append(spike_times_reconstructed[0])
        spike_times_start_only.extend(spike_times_reconstructed[np.where(np.diff(spike_times_reconstructed) > 10)[0]+1])

    return np.array(spike_times_start_only)


def calculate_f1_score(detected_spikes, ground_truth_spikes, sr, tolerance=0.01):
    # convert tolerance (time duration in seconds: 0.001 = 1ms) to number of samples
    tolerance_in_samples = tolerance * sr
    detected_binary = np.zeros_like(data)

    # set values to 1 where spikes occur (detected_binary[start_index:end_index] = 1)
    for spike in (detected_spikes / sr):
        detected_binary[int(spike * sr - tolerance_in_samples): int(spike * sr + tolerance_in_samples)] = 1

    ground_truth_binary = np.zeros_like(data)
    for spike in (ground_truth_spikes / sr):
        ground_truth_binary[int(spike * sr - tolerance_in_samples): int(spike * sr + tolerance_in_samples)] = 1

    # calculate F1 score
    f1 = f1_score(ground_truth_binary, detected_binary)

    # confusion matrix
    conf_matrix = confusion_matrix(ground_truth_binary, detected_binary)
    
    return f1, conf_matrix


# test different multiplier values for threshold
multipliers = np.arange(0.0, 4.0, 0.1) 
f1_scores = []
best_f1, best_conf = 0, None

for multiplier in multipliers:
    detected_spikes = detect_spikes(data, multiplier)
    f1, conf = calculate_f1_score(detected_spikes, spike_times[0][0][0][:round(seconds*sr)], sr)
    f1_scores.append(f1)

    if f1 > best_f1:
        best_conf = conf


# plot F1 scores as a function of the multiplier
plt.figure(figsize=(10, 6))
plt.plot(multipliers, f1_scores, marker='o', linestyle='-', color='b')
plt.xlabel('Multiplier')
plt.ylabel('F1 Score')
plt.grid()
plt.show()

# get best multiplier based on F1 score
best_f1 = max(f1_scores)
best_multiplier = multipliers[np.argmax(f1_scores)]

# for better formatting
conf_df = pd.DataFrame(best_conf, index=["Actual No Spike", "Actual Spike"], columns=["Predicted No Spike", "Predicted Spike"])

print(f"Best multiplier: {best_multiplier}, F1 score: {best_f1}")
print(conf_df)

In [None]:
# apply threshold to raw data
spike_times_reconstructed = np.where(data >= mean_data + best_multiplier * std_data)[0]
print(mean_data + best_multiplier * std_data)
spike_times_start_only = []
spike_times_start_only.append(spike_times_reconstructed[0])
# if data is above threshold back to back, count all occurences as one single spike
spike_times_start_only.extend(spike_times_reconstructed[np.where(np.diff(spike_times_reconstructed) > 10)[0]+1])

In [None]:
# only runs this cell if seconds parameter is below 5, otherwise the plot gets too convoluted
if seconds < 5:
    # display data
    plt.figure(figsize=(25,6))
    plt.plot(np.linspace(0, seconds, round(seconds*sr)), data)
    plt.xlim(0, seconds)

    # plot where we think a spike is 
    for i, spike in enumerate(spike_times_start_only):
        # plt.axvline(spike/sr - .0008, c='red', alpha=0.5)
        plt.axvspan(spike/sr - .0008, spike/sr + .001, facecolor='r', alpha=0.2)
        plt.text(spike/sr - 0.006, 110 if i%2 == 0 else 100, i+1, c='r')

    # plot ground truth for reference
    for i, spike in enumerate(ground_truth_spikes / sr):
        if round(spike*24000) > round(seconds*24000): break
        plt.axvline(spike, c='k', alpha=1)
        plt.text(spike - 0.006, 85 if i%2 == 0 else 75, i+1, c='b')
        
    plt.show()

In [None]:
number_of_spikes = len(spike_times_start_only)

# get max & min voltage of all spikes (for plotting purposes)
ymax = 0
ymin = 0
spikes = []
detected_spike_times = []
for spike in spike_times_start_only:
    start = round(spike - .0008 * sr)
    end = round(spike + .001 * sr)
    if np.max(data[start:end]) > ymax: ymax = np.max(data[start:end])
    if np.min(data[start:end]) < ymin: ymin = np.min(data[start:end])
    # store detected spikes in list
    spikes.append(data[start:end])
    detected_spike_times.append(start)


if sample_recording:
    hit_or_miss = np.zeros(len(detected_spike_times))
    for i, spike in enumerate(detected_spike_times):
        matching = [x for x in spike_times[0][0][0] if spike - 10 <= x <= spike + 43]
        if matching:
            hit_or_miss[i] = 1
        else:
            hit_or_miss[i] = 0
    print(len(np.where(hit_or_miss == 1)[0])/len(spike_times[0][0][0]))
    
inertias = []
# determine best number of clusters via elbow plot
for i in range(1, 11):
    kmeans = KMeans(n_clusters=i, n_init='auto')
    kmeans.fit(spikes)
    inertias.append(kmeans.inertia_)

plt.plot(range(1, 11), inertias, marker='o')
plt.show()


In [None]:
from sklearn.metrics import silhouette_samples, silhouette_score
import matplotlib.cm as cm

range_n_clusters = [x for x in range(2, 11)]
for n_clusters in range_n_clusters:
    fig, (ax1) = plt.subplots(1, 1)
    fig.set_size_inches(5, 5)

    ax1.set_xlim([-0.1, 1])
    ax1.set_ylim([0, len(spikes) + (n_clusters + 1) * 10])

    clusterer = KMeans(n_clusters=n_clusters, n_init="auto")
    cluster_labels = clusterer.fit_predict(spikes)
    silhouette_avg = silhouette_score(spikes, cluster_labels)

    sample_silhouette_values = silhouette_samples(spikes, cluster_labels)

    y_lower = 10
    for i in range(n_clusters):
        ith_cluster_silhouette_values = sample_silhouette_values[cluster_labels == i]

        ith_cluster_silhouette_values.sort()

        size_cluster_i = ith_cluster_silhouette_values.shape[0]
        y_upper = y_lower + size_cluster_i

        color = cm.nipy_spectral(float(i) / n_clusters)
        ax1.fill_betweenx(
            np.arange(y_lower, y_upper),
            0,
            ith_cluster_silhouette_values,
            facecolor=color,
            edgecolor=color,
            alpha=0.7,
        )

        # Label the silhouette plots with their cluster numbers at the middle
        ax1.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))

        # Compute the new y_lower for next plot
        y_lower = y_upper + 10  # 10 for the 0 samples

    ax1.set_title("The silhouette plot for the various clusters.")
    ax1.set_xlabel("The silhouette coefficient values")
    ax1.set_ylabel("Cluster label")

    # The vertical line for average silhouette score of all the values
    ax1.axvline(x=silhouette_avg, color="red", linestyle="--")

    ax1.set_yticks([])  # Clear the yaxis labels / ticks
    ax1.set_xticks([-0.1, 0, 0.2, 0.4, 0.6, 0.8, 1])

plt.show()

In [None]:
# apply final clustering algorithm based on optimal number of clusters
n_clusters = 3
kmeans_final = KMeans(n_init='auto', n_clusters=n_clusters)
classes = kmeans_final.fit_predict(spikes)
colors = ['r', 'g', 'b', 'orange']

# plot first 20 spikes in different colors to see if clustering worked
plt.figure(figsize=(30,10))
for i, spike in enumerate(spikes[:20]):
    plt.subplot(5, 4, i+1)
    plt.grid()
    plt.ylim(ymin, ymax)
    plt.plot(spike, c=colors[classes[i]])
plt.show()

In [None]:
plt.figure(figsize=(5*n_clusters, 4))

print('# of spikes:', number_of_spikes)

# compute mean & std deviation of each spike class & plot them
for i in range(n_clusters):
    length = 43 if sample_recording else 58 # 43 for sr=24000, 58 for sr=32051
    x = np.arange(0, length, 1)
    classIndices = np.where(classes == i)[0]
    
    meanSpike = np.mean(np.array(spikes)[classIndices], axis=0)
    deviation = np.std(np.array(spikes)[classIndices])
    
    plt.subplot(1, n_clusters, i+1)
    plt.title(f'# of spikes in cluster: {len(classIndices)}')
    plt.plot(meanSpike)
    plt.fill_between(x, meanSpike - deviation, meanSpike + deviation, alpha=0.1, color='r')
    plt.grid()
plt.show()