In [1]:
import serial
import numpy as np
import matplotlib.pyplot as plt
import time
import pandas as pd
from scipy.io import wavfile

%matplotlib notebook



def read_arduino(ser,inputBufferSize):
#     data = ser.readline((inputBufferSize+1)*2)
    data = ser.read((inputBufferSize+1)*2)
    out =[(int(data[i])) for i in range(0,len(data))]
    return out

def process_data(data):
    data_in = np.array(data)
    result = []
    i = 1
    while i < len(data_in)-1:
        if data_in[i] > 127:
            # Found beginning of frame
            # Extract one sample from 2 bytes
            intout = (np.bitwise_and(data_in[i],127))*128
            i = i + 1
            intout = intout + data_in[i]
            result = np.append(result,intout)
        i=i+1
    return np.flip(np.array(result)-512)



# ser.read works by waiting for <inputBufferSize> bytes from the port

def read_arduinbro(wav_array, inputBufferSize, k):
#    data = ser.readline(inputBufferSize)
    if inputBufferSize*(k+1) < len(wav_array):
        data = wav_array[(inputBufferSize*(k)):(inputBufferSize*(k+1))]
    else:
        data = wav_array[(inputBufferSize*(k))::]
    return np.flip(data)





In [2]:
# # use this to find ports
# from serial.tools import list_ports

# ports = list_ports.comports()
# for port in ports:
#     print(port)
    



# # Read example data
# baudrate = 230400
# cport = '/dev/cu.usbmodem142301'  # set the correct port before you run it
# ser = serial.Serial(port=cport, baudrate=baudrate)   

# inputBufferSize = 10000 # 20000 = 1 second

 

In [3]:
# Load training set

import os

path = "/Users/billydodds/Documents/Uni/DATA3888/Aqua10/Datasets/Eye Data/"

files = os.listdir(path)


waves = {}
labels = {}

for file in files:
    if file[-4::] == ".wav":
        samprate, wav_array = wavfile.read(path+file)
        print(samprate)
        waves[file[:-4]] = wav_array
    elif file[-4::] == ".txt":
        labels_dat = pd.read_csv(path+file, sep=",\t", skiprows=1)
        labels_dat.columns = ["label", "time"]
        # Change depending on whether L is coded as 1 or as 2
        labels_dat.label = ["L" if label == 1 else "R" for label in labels_dat.label]
        
        labels[file[:-4].replace(".", "")] = labels_dat
        

print(waves.keys(), labels.keys())


assert set(waves.keys()).difference(set(labels.keys())) == set()




10000
10000
10000
dict_keys(['left-middle-right', 'left-middle-right2', 'left-right']) dict_keys(['left-right', 'left-middle-right', 'left-middle-right2'])


  labels_dat = pd.read_csv(path+file, sep=",\t", skiprows=1)


In [5]:
# Simulate stream

# with wave.open('/Users/billydodds/Documents/Uni/DATA3888/Aqua10/Spiker_box_Louis/Short/LLL_L2.wav','r') as wav:
#     framerate = wav.getframerate()
#     nframes = wav.getnframes()
#     samprate = 10000.0
#     wav_array = np.array(struct.unpack("<" + "h"*nframes, wav.readframes(nframes)))
    

In [6]:
from scipy import signal

def classify_event(arr, downsample_rate=10, window_size_seconds=0.3):
    arr_ds = arr[0::downsample_rate]
    
    # Smooth wave
    window_length = int(window_size_seconds*samprate/downsample_rate + 1)
    filtered_arr = signal.savgol_filter(arr_ds, window_length, 1)
    
    # Get max min
    max_locs = np.array(signal.argrelextrema(filtered_arr, np.greater)[0])
    min_locs = np.array(signal.argrelextrema(filtered_arr, np.less)[0])
    
    
#     mins = arr_ds[min_locs]
#     mins = arr_ds[min_locs]
    
    
    max_min_locs = np.append(max_locs, min_locs)
    max_min = arr_ds[max_min_locs]

    # Sort vals    
    top_3 = sorted(max_min, key=abs, reverse=True)[0:3]

    if np.sum(np.sign(top_3)) == 1:
        return "L"
    elif np.sum(np.sign(top_3)) == -1:
        return "R"
    else:
        return "_"



In [7]:

# SINCE THE STREAMING DATA COMES IN FROM THE RIGHT TO THE LEFT ON THE PLOT, 
# THE RESULTANT PLOT IS ACTUALLY REFLECTED ABOUT THE Y AXIS



def streaming_classifier(
    wav_array,
    window_size = 1.5, # time plotted in window [s]
    N_loops_over_window = 15, # implicitly defines buffer to be 1/x of the window
    samprate = 10000,
    total_time = None,  # max time
    hyp_detection_buffer_end = 0.3, # seconds - how much time to shave off either end of the window in order to define the middle portion
    hyp_detection_buffer_start = 0.7,
    hyp_event_threshold = 1800, # crossings per second
    hyp_event_history = 5,
    hyp_consecutive_triggers = 3,
    hyp_consecutive_reset = 1,
    plot = False,
    store_events = False, 
    verbose=False
):
    
    
    if total_time is None:
        total_time = len(wav_array)/samprate
    if store_events:
        predictions_storage = []
    
    predictions = ""
    predictions_timestamps = []

    
    # Initialise variables
    inputBufferSize = int(window_size/N_loops_over_window * samprate)
    N_loops =(total_time*samprate)//inputBufferSize  # len(wav_array)//inputBufferSize 
    T_acquire = inputBufferSize/samprate    # length of time that data is acquired for 
    N_loops_over_window = window_size/T_acquire    # total number of loops to cover desire time window


    # Initialise plot
    if plot:
        min_y = -2000 #np.min(wav_array)
        max_y = 2000 #np.max(wav_array)
        fig = plt.figure()
        ax1 = fig.add_subplot(1,1,1)
        plt.ion()
        fig.show()
        fig.canvas.draw()


    # Hyperparameter conversions
    hyp_detection_buffer_start_ind = int(round(hyp_detection_buffer_start * samprate))
    hyp_detection_buffer_end_ind = int(round(hyp_detection_buffer_end * samprate))


    event_history = np.array([False]*hyp_event_history)
    primed = True

    for k in range(0,int(N_loops)):
        
        # Simulate stream
        data_temp = read_arduinbro(wav_array, inputBufferSize, k)

        # Stream
    #     data = read_arduino(ser,inputBufferSize)
    #     data_temp = process_data(data)


        if k < N_loops_over_window:
            if k==0:
                data_plot = data_temp
            else:
                data_plot = np.append(data_temp,data_plot)

            continue
        else:
            data_plot = np.roll(data_plot,len(data_temp))
            data_plot[0:len(data_temp)] = data_temp


        ### CLASSIFIER ###
        
        ## EVENT DETECTION ##

        interval = data_plot[hyp_detection_buffer_start_ind:-hyp_detection_buffer_end_ind] # Take middle part of window


    #     test_stat = np.sum(interval[0:-1] * interval[1::] <= 0) # Calculate test stat (zero crossings) 
        test_stat = np.max(interval) - np.min(interval) # Calculate test stat (range) 
        test_stat = test_stat/(len(interval)/samprate) # convert to crossings per second


        is_event = (test_stat > hyp_event_threshold) # Test threshold

        ## KEEP HISTORY ##

        event_history[1::] = event_history[0:-1]
        event_history[0] = is_event


        ## Classification

        if np.all(event_history[0:hyp_consecutive_triggers]) and primed:
            prediction = classify_event(data_plot)
            print(f"CONGRATULATIONS, ITS AN {prediction}!") if verbose else None

            if store_events:
                predictions_storage.append(data_plot)
                
            predictions += prediction
            
            
            end_time = round(k*inputBufferSize/samprate, 2)
            start_time = round(end_time - window_size, 2)
            predictions_timestamps.append((start_time, end_time))
            
            timer = 20

            primed = False
        elif np.all(~event_history[0:hyp_consecutive_reset]):
            primed = True



        ## PLOT ###

        if plot:
            t = (min(k+1,N_loops_over_window))*inputBufferSize/samprate*np.linspace(0,1,(data_plot).size)
            ax1.clear()
            # Debugging Annotations
            if np.all(event_history[0:hyp_consecutive_triggers]) and timer >0:
                ax1.annotate(f"ITS AN {prediction}!!!", (window_size/2, max_y-50))
                timer -= 1
            
            ax1.annotate(f"{event_history}", (window_size/2, max_y-70))
            ax1.set_xlim(0, window_size)
            ax1.set_ylim(min_y, max_y)
            plt.xlabel('time [s]')
            ax1.plot(t,data_plot)
            fig.canvas.draw()    
            plt.show()
    
    if store_events:
        return predictions, predictions_timestamps, predictions_storage
    else:
        return predictions, predictions_timestamps
    
    

In [8]:
def plot_predictions(wav_array, labels_dat, predictions, predictions_timestamps, ax, i,
                     title="", before_buffer = 0.7, after_buffer = 1, actual_alpha=0.2,
                     wave_alpha=1, pred_alpha = 0.5):
    
    time_seq = np.linspace(1, len(wav_array), len(wav_array))/samprate

    

    left_events_bool = np.array([False]*len(time_seq))
    for time in labels_dat.time[labels_dat.label == "L"]:
        left_events_bool = ( (time_seq > time - before_buffer) & (time_seq < time+after_buffer) ) | left_events_bool

    right_events_bool = np.array([False]*len(time_seq))
    for time in labels_dat.time[labels_dat.label == "R"]:
        right_events_bool = ( (time_seq > time - before_buffer) & (time_seq < time+after_buffer) ) | right_events_bool
        
    left_preds_bool = np.array([False]*len(time_seq))
    right_preds_bool = np.array([False]*len(time_seq))
    idk_preds_bool = np.array([False]*len(time_seq))
    for pred, times in zip(predictions, predictions_timestamps):
        if pred == "L":
            left_preds_bool = ( (time_seq > times[0]) & (time_seq < times[1]) ) | left_preds_bool
        elif pred == "R":
            right_preds_bool = ( (time_seq > times[0]) & (time_seq < times[1]) ) | right_preds_bool
        else:
            idk_preds_bool = ( (time_seq > times[0]) & (time_seq < times[1]) ) | idk_preds_bool
            

    ax[i].plot(time_seq, wav_array, alpha=wave_alpha)

    
    # Plot actuals
    ax[i].fill_between(time_seq, 2500, -2500,
                     where = left_events_bool,
                     color = 'g',
                     label = "L",
                     alpha=actual_alpha)

    ax[i].fill_between(time_seq, 2500, -2500,
                     where = right_events_bool,
                     color = 'r',
                     label = "R",
                     alpha=actual_alpha)
    
    # Plot predictions
    ax[i].fill_between(time_seq, 2500, -2500,
                     where = left_preds_bool,
                     color = 'g',
                     label = "Pred L",
                     alpha=pred_alpha)

    ax[i].fill_between(time_seq, 2500, -2500,
                     where = right_preds_bool,
                     color = 'r',
                     label = "Pred R",
                     alpha=pred_alpha)
    
    ax[i].fill_between(time_seq, 2500, -2500,
                     where = idk_preds_bool,
                     color = 'y',
                     label = "Pred idk",
                     alpha=pred_alpha)
    
    ax[i].set_title(key)
    ax[i].legend()

In [10]:
from copy import deepcopy


before_buffer = 0.7
after_buffer = 1


fig, ax = plt.subplots(len(waves), 1)
fig.set_size_inches(10, 7)





i=0
for key in waves.keys():
    
    predictions, predictions_timestamps, predictions_storage = streaming_classifier(waves[key],
                                                                                    window_size=2,
                                                                                    N_loops_over_window = 20,
                                                                                    hyp_detection_buffer_end = 0.1, # seconds - how much time to shave off either end of the window in order to define the middle portion
                                                                                    hyp_detection_buffer_start = 1,
                                                                                    hyp_event_history = 10,
                                                                                    hyp_consecutive_triggers = 6,
                                                                                    store_events=True)
    
    plot_predictions(waves[key], labels[key], predictions, predictions_timestamps, ax, i, title=key, actual_alpha=0)
    ax[i].set_xlim(-13)
    i+=1
    
    
    actual_times = [(time-before_buffer, time+after_buffer) for time in labels[key].time]
    actual_leftovers = deepcopy(actual_times)
    pred_leftovers = deepcopy(predictions_timestamps)
    
    print(pred_leftovers)
    print(predictions_timestamps)
    
    for act_times in actual_times:
        for pred_times in predictions_timestamps:
            if act_times[0] < pred_times[1] and act_times[1] > pred_times[0] and pred_times in pred_leftovers:
                actual_leftovers.remove(act_times)
                pred_leftovers.remove(pred_times)
                
                
    print("File: ", key)
    print(f"False negative triggers: {len(actual_leftovers)}", actual_leftovers)
    print(f"False positive triggers: {len(pred_leftovers)}",  pred_leftovers)
    
    
        


<IPython.core.display.Javascript object>

[(1.9, 3.9), (5.5, 7.5), (9.4, 11.4), (12.4, 14.4), (15.7, 17.7), (18.5, 20.5), (21.4, 23.4), (25.6, 27.6), (29.9, 31.9), (33.5, 35.5), (36.4, 38.4)]
[(1.9, 3.9), (5.5, 7.5), (9.4, 11.4), (12.4, 14.4), (15.7, 17.7), (18.5, 20.5), (21.4, 23.4), (25.6, 27.6), (29.9, 31.9), (33.5, 35.5), (36.4, 38.4)]
File:  left-middle-right
False negative triggers: 0 []
False positive triggers: 0 []
[(4.1, 6.1), (7.4, 9.4), (10.7, 12.7), (13.8, 15.8), (16.6, 18.6), (19.2, 21.2), (21.8, 23.8), (24.1, 26.1), (26.7, 28.7), (29.3, 31.3), (31.8, 33.8), (33.9, 35.9)]
[(4.1, 6.1), (7.4, 9.4), (10.7, 12.7), (13.8, 15.8), (16.6, 18.6), (19.2, 21.2), (21.8, 23.8), (24.1, 26.1), (26.7, 28.7), (29.3, 31.3), (31.8, 33.8), (33.9, 35.9)]
File:  left-middle-right2
False negative triggers: 0 []
False positive triggers: 0 []
[(2.4, 4.4), (5.1, 7.1), (7.6, 9.6), (9.5, 11.5), (12.0, 14.0), (13.8, 15.8), (16.3, 18.3), (18.1, 20.1), (20.7, 22.7), (23.1, 25.1), (25.6, 27.6), (27.5, 29.5)]
[(2.4, 4.4), (5.1, 7.1), (7.6, 9.6), 

In [88]:
print(len(predictions_storage))


fig, ax = plt.subplots(12, 3)

fig.set_size_inches(5, 15)    
    
    
i=0
for key in waves.keys():
    
    window_size = 2
    predictions, predictions_timestamps, predictions_storage = streaming_classifier(waves[key],
                                                                                    window_size=2,
                                                                                    N_loops_over_window = 20,
                                                                                    hyp_detection_buffer_end = 0.1, # seconds - how much time to shave off either end of the window in order to define the middle portion
                                                                                    hyp_detection_buffer_start = 1,
                                                                                    hyp_event_history = 7,
                                                                                    hyp_consecutive_triggers = 6,
                                                                                    store_events=True)
    
    
    t = np.linspace(0, window_size, window_size*samprate)
    for arr, label in zip(predictions_storage, labels[key].label):
        t_ds, smooth_arr = smooth_wave(arr,t, downsample_rate=10, window_size_seconds=0.3)
        
        ax[i//3, i%3].plot(t, np.flip(arr))
        ax[i//3, i%3].plot(t_ds, np.flip(smooth_arr))
        
        ax[i//3, i%3].set_ylim(-2500, 2500)
        
        ax[i//3, i%3].set_title(label)
        
        
        
        i+=1
    


12


<IPython.core.display.Javascript object>

In [86]:
def smooth_wave(arr,t, downsample_rate=10, window_size_seconds=0.3):
    arr_ds = arr[0::downsample_rate]
    t_ds = t[0::downsample_rate]
    
    # Smooth wave
    window_length = int(window_size_seconds*samprate/downsample_rate + 1)
    filtered_arr = signal.savgol_filter(arr_ds, window_length, 1)
    
    return t_ds, filtered_arr