In [None]:
# Standard modules
import glob
import os
from joblib import dump
import time

# External modules
from catch22 import catch22_all
from matplotlib import pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn import svm

# Internal modules
from main import parse_snippet, get_snippet_event

In [None]:
def plot_snippet(time_slice: np.ndarray, signal_slice: np.ndarray, title: str = ""):
    """Plots a snippet.

    Parameters
    ----------
    time_slice : np.ndarray
        time component of slice.
    signal_slice : np.ndarray
        amplitude component of slice.
    title : str
        the title of the plot.
    """
    plt.figure()
    plt.plot(time_slice, signal_slice)
    plt.xlabel("Time (s)")
    plt.ylabel("Normalised amplitude")
    plt.ylim([-1.2, 1.2])
    plt.title(title)

In [None]:
# Load all snippet data into `snippets` dictionary
snippets = {}
snippet_folder = "Snippets"

# Find all snippets
for snippet_file in glob.glob(f"{snippet_folder}/*.npy"):
    # Load data
    snippet = np.load(snippet_file)
    signal_slice, time_slice = parse_snippet(snippet)
    event = get_snippet_event(snippet_file)
    # Looking at only left and right events
    if event != "left" and event != "right":
        continue
    # Add data to dictionary
    if event not in snippets:
        snippets[event] = [(signal_slice, time_slice, snippet_file)]
    else:
        snippets[event].append((signal_slice, time_slice, snippet_file))

In [None]:
# Compute catch22 data
data = []
# Create list of labels and names associated with catch22 data
labels = []
snippet_names = []
for event in snippets:
    for signal_slice, _, snippet_filename in snippets[event]:
        data.append(catch22_all(signal_slice)["values"])
        labels.append(event)
        snippet_names.append(snippet_filename)
print(f"There are {len(data)} samples to train/test on.")

In [None]:
# # Initialise KNN model
model = RandomForestClassifier()

In [None]:
# Split the training and testing data
training_data, test_data, training_labels, test_labels, training_snippet_names, test_snippet_names = train_test_split(data, labels, snippet_names, test_size=0.9)

In [None]:
# Fit the training data to the training labels
start_time = time.time()
model.fit(training_data, training_labels);
elapsed_time = time.time() - start_time
print(f"Fitting took {elapsed_time:.2f} seconds")

In [None]:
# Predict labels for test data
start_time = time.time()
predictions = model.predict(data)
elapsed_time = time.time() - start_time
print(f"Predictions took {elapsed_time:.2f} seconds")

In [None]:
# Evaluate accuracy and which snippets failed and succeeded
count = 0
failed = []
for idx, prediction in enumerate(predictions):
    _, tail = os.path.split(test_snippet_names[idx])
    tail = tail.rstrip(".npy")
    # Compare prediction to label
    print(f"{tail}\n Prediction: {prediction}\n Label:      {test_labels[idx]}\n")
    # Correct prediction
    if prediction == test_labels[idx]:
        count += 1
    # False prediction
    else:
        failed.append(test_snippet_names[idx])
# Print diagnostics
print(f"\nAccuracy: {100*count/len(predictions):.2f}%")
print("The failed snippets are:")
print(*failed, sep="\n")

In [None]:
# Plot the failed snippets
for snippet_filename in failed:
    snippet = np.load(snippet_filename)
    signal_slice, time_slice = parse_snippet(snippet)
    plot_snippet(time_slice, signal_slice, title=snippet_filename)

In [None]:
# Timing benchmarks
predictions = []
start_time = time.time()
for event in snippets:
    for signal_slice, _, snippet_filename in snippets[event]:
        data = catch22_all(signal_slice)["values"]
        predictions.append(model.predict([data]))
elapsed_time = time.time() - start_time
print(f"Takes {elapsed_time:2f} seconds to 'stream' or {elapsed_time/len(predictions):.5f} per second")

In [None]:
model.fit(data, labels)
predictions = model.predict(data)
# Evaluate accuracy and which snippets failed and succeeded
count = 0
failed = []
for idx, prediction in enumerate(predictions):
    _, tail = os.path.split(snippet_names[idx])
    tail = tail.rstrip(".npy")
    # Compare prediction to label
    print(f"{tail}\n Prediction: {prediction}\n Label:      {labels[idx]}\n")
    # Correct prediction
    if prediction == labels[idx]:
        count += 1
    # False prediction
    else:
        failed.append(snippet_names[idx])
# Print diagnostics
print(f"\nAccuracy: {100*count/len(predictions):.2f}%")
print("The failed snippets are:")
print(*failed, sep="\n")

In [None]:
dump(model, "RFC.joblib");
