In [None]:
from pyOpenBCI import OpenBCICyton
import numpy as np
import time
import matplotlib.pyplot as plt
import pyeeg
import threading
from neurodsp import filt
from scipy import signal
import os
import pickle
import pyxdf
from enum import Enum
from helperFunctions import epochByMarkIndex, loadData, getArticleSectionData, loadxdf, writeToPickle, loadPickle
from constants import StreamType
from dataAnalysisFunctions import getIntervals, getPowerRatio

In [None]:
from neurodsp import spectral, plts
from neurodsp.plts.time_series import plot_time_series


In [None]:
from neurodsp.plts.filt import plot_filter_properties
from neurodsp.filt.utils import compute_frequency_response
from neurodsp.plts.filt import plot_frequency_response

In [None]:
DEFAULT_SECONDS = 3
DEFAULT_FS = 250
DEFAULT_CHANNELS = [0, 1, 2, 3, 4, 5, 6, 7]
DEFAULT_BINNING=[4, 8, 12, 20, 30]
DEFAULT_POWER_UPDATE_SECONDS = 1
SCALE_FACTOR_EEG = (4500000)/24/(2**23-1) #uV/count
F_RANGE=(0.5, 50)

In [None]:
class SamePlotter:
    def __init__(self, fs=DEFAULT_FS, buffer_seconds=DEFAULT_SECONDS, num_lines=len(DEFAULT_CHANNELS), x_label="timepoints", y_label="output", x_lim=(0,DEFAULT_FS * DEFAULT_SECONDS), y_lim=(0,1)):
        self.fs = fs
        self.buffer_seconds = buffer_seconds
        self.num_lines = num_lines
        fig,ax = plt.subplots(1,1) # figsize=(20,10)
        ax.set_xlabel(x_label)
        ax.set_ylabel(y_label)
        ax.set_xlim(x_lim[0],x_lim[1])
        ax.set_ylim(y_lim[0],y_lim[1])
        self.x = np.linspace(x_lim[0], x_lim[1], self.fs * self.buffer_seconds)
        self.ax = ax
        self.fig = fig

    def __prepend_zeros(self, data, complete_length):
        # Prepends the data with arrays of zeros until it reaches the complete length needed for the buffer
        new_data = np.zeros((complete_length, self.num_lines))
        count = 0
        for i in range(complete_length - len(data), complete_length):
            new_data[i] = data[count]
            count += 1
        return new_data

    def plot(self, y_datas):
        complete_length = self.fs * self.buffer_seconds
        if len(y_datas) < complete_length:
            y_datas = self.__prepend_zeros(y_datas, complete_length)
        
        if self.ax.lines:
            for i, line in enumerate(self.ax.lines):
                line.set_ydata(y_datas[:, i])
        else :
            for data in range(self.num_lines): 
                self.ax.plot(self.x, y_datas[:,data])
        self.fig.canvas.draw()
        self.fig.canvas.flush_events()
            

In [None]:
class EEGSampler:
    """ Holds a buffer of EEG data and accepts a single data sample as input to append to the buffer
    """
    def __init__(self, fs=DEFAULT_FS, buffer_seconds=DEFAULT_SECONDS, power_update_seconds=DEFAULT_POWER_UPDATE_SECONDS, channels=DEFAULT_CHANNELS, binning=DEFAULT_BINNING, live=True, model=None):
        self.fs = fs
        self.buffer_seconds = buffer_seconds
        self.channels = channels # These are the channels we want
        
        self.live = live
        self.model = model
        
        self.buffer = np.zeros((fs * buffer_seconds, len(channels)))
        self.raw_buffer = np.zeros((fs * buffer_seconds, len(channels)))
        self.dc_removed_buffer = np.zeros((fs * buffer_seconds, len(channels)))
        
        self.binning = binning
        self.intervals = getIntervals(self.binning)
        self.power_bin_values = np.zeros((fs * buffer_seconds, len(self.intervals), len(self.channels)))
        self.focus = np.zeros((fs * buffer_seconds))
        self.focus_average = np.zeros((fs * buffer_seconds))
        self.power_update_seconds = power_update_seconds
        self.last_updated = time.time()
        self.mean = np.zeros(len(channels))
        self.count_samples = 0
        self.num_focused = 0 
        
        ## Realtime filtering 
        ## https://stackoverflow.com/questions/40483518/how-to-real-time-filter-with-scipy-and-lfilter
        ## This is how it works? 
        ## https://www.drdobbs.com/real-time-filters/184401931?pgno=1
        # Bandpass filter 
        self.bandpass_filter_b, self.bandpass_filter_a = filt.design_iir_filter(fs, pass_type='bandpass', f_range=F_RANGE, butterworth_order=2)
        #print("bandpass b", self.bandpass_filter_b)
        #print("bandpass a", self.bandpass_filter_a)
        self.bandpass_filter_z = signal.lfilter_zi(self.bandpass_filter_b, self.bandpass_filter_a)
        
#         # FIR NOTCH FILTER 
#         self.notch_filter_b  = filt.design_fir_filter(fs, pass_type='bandstop', f_range=(58, 62), n_seconds=0.1)#, butterworth_order=8)
#         self.notch_filter_z = signal.lfilter_zi(self.notch_filter_b, [1])
#         f_db, db = compute_frequency_response(self.notch_filter_b, 1, fs)
        #plot_frequency_response(f_db, db)#, fs, self.notch_filter_b)
        
        # IIR NOTCH FILTER 
        self.notch_filter_b, self.notch_filter_a = filt.design_iir_filter(fs, pass_type='lowpass', f_range=58, butterworth_order=8)
        self.notch_filter_z = signal.lfilter_zi(self.notch_filter_b, self.notch_filter_a)
#         self.notch_filter_b2, self.notch_filter_a2 = filt.design_iir_filter(fs, pass_type='lowpass', f_range=59, butterworth_order=8)
#         self.notch_filter_z2 = signal.lfilter_zi(self.notch_filter_b2, self.notch_filter_a2)
        # f_db, db = compute_frequency_response(self.notch_filter_b, self.notch_filter_a, fs)
#         plot_frequency_response(f_db, db)#, fs, self.notch_filter_b)

        ## IIR Low Pass
#         self.notch_filter_b, self.notch_filter_a  = filt.design_iir_filter(fs, pass_type='lowpass', f_range=50, butterworth_order=8)
#         self.notch_filter_z = signal.lfilter_zi(self.notch_filter_b, self.notch_filter_a)
#         f_db, db = compute_frequency_response(self.notch_filter_b, self.notch_filter_a, fs)
        
    def __update_focus(self):
        data = np.transpose(self.buffer[:546])
        self.focus[0] = -1 if self.model.predict_classes(np.array([data]))[0] == 0 else 1
        self.focus_average[0] = np.mean(self.focus[:self.fs])
    def __update_power_bin(self):
        for chan in range(len(self.channels)):
            eeg_data = self.buffer[:self.count_samples,chan]
            self.power_bin_values[0, :, chan] = getPowerRatio(eeg_data, self.binning, eeg_fs=self.fs)
        # print(self.power_bin_values[0, :, :])
    def push_data_sample(self, sample):
        if self.count_samples > 0: 
            self.mean = np.mean(self.raw_buffer[:self.count_samples], axis=0)
        # Count the number of samples up till the full buffer (this is for mean calculation)
        if self.count_samples < self.fs * self.buffer_seconds: 
            self.count_samples += 1
        
        # Get the scaled channel data
        if self.live: 
            raw_eeg_data = np.array(sample.channels_data) * SCALE_FACTOR_EEG
        else :
            raw_eeg_data = np.array(sample) 
        
        # Roll and prepend the buffer with the new data
        self.raw_buffer = np.roll(self.raw_buffer, 1, 0)
        self.buffer = np.roll(self.buffer, 1, 0)
        self.dc_removed_buffer = np.roll(self.dc_removed_buffer, 1, 0)
        self.power_bin_values = np.roll(self.power_bin_values, 1, 0)
        self.focus = np.roll(self.focus, 1, 0)
        self.focus_average = np.roll(self.focus_average, 1, 0)
        for i, chan in enumerate(self.channels):
            self.raw_buffer[0, i] = raw_eeg_data[chan]
            if self.raw_buffer[0, i] == 0 : 
                continue
            self.dc_removed_buffer[0, i] = self.raw_buffer[0,i] - self.mean[i]
            bandpassed_data, self.notch_filter_z = signal.lfilter(self.notch_filter_b, self.notch_filter_a, [self.dc_removed_buffer[0, i]], zi=self.notch_filter_z)
            self.buffer[0,i], self.bandpass_filter_z = signal.lfilter(self.bandpass_filter_b, self.bandpass_filter_a, bandpassed_data, zi=self.bandpass_filter_z)
            
            #self.power_bin_values[0,:, i] = self.power_bin_values[1,:,i]
        if (self.model is not None) and (self.count_samples > 546): 
            self.__update_focus()
        # Calculat the new mean if the update time has passed
        now = time.time()
        if (now - self.last_updated) > self.power_update_seconds :
            self.last_updated = now
            #self.__update_power_bin()
            

In [None]:
# Import model
import torch
model_t = torch.load("models/model_t_combined_filtered.pickle")

## Live data collection

In [None]:
# board.stop_stream()

In [None]:
# board = OpenBCICyton()
# sampler = EEGSampler(buffer_seconds=10)
# eeg_thread = threading.Thread(target=board.start_stream, args=(sampler.push_data_sample,))
# eeg_thread.start()
# print("thread started")

In [None]:
# %matplotlib notebook

# NUM_TIMES = 250
# sameplotter = SamePlotter(buffer_seconds=10, y_lim=(-400, 400))
# for i in range(NUM_TIMES) :
#     arr_to_plot = sampler.buffer
#     sameplotter.plot(arr_to_plot)
#     time.sleep(0.05)

# board.stop_stream()

In [None]:
# writeToPickle(sampler.raw_buffer, "sampler_raw_buffer_eye_close_2.pickle")
# writeToPickle(sampler.buffer, "sampler_buffer_small_eye_close_2.pickle")

### Visualize Live Data

In [None]:
# Filtered signal 

freqs, psds = spectral.compute_spectrum(sampler.buffer[:len(sampler.buffer)//3,0], DEFAULT_FS, method='welch', avg_type='mean', nperseg=DEFAULT_FS*2)
plt.plot(freqs[0:3*len(freqs)//5], np.log(psds[0:3*len(freqs)//5]))
plt.title("PSD after filtering")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power log(uV^2 / Hz)")
plt.ylim(-8, 12)
plt.show()


In [None]:
# Original (dc removed)

freqs, psds = spectral.compute_spectrum(sampler.dc_removed_buffer[:len(sampler.buffer)//3,0], DEFAULT_FS, method='welch', avg_type='mean', nperseg=DEFAULT_FS*2)
plt.plot(freqs[0:3*len(freqs)//5], np.log(psds[0:3*len(freqs)//5]))
plt.title("PSD before filter")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power log(uV^2 / Hz)")
plt.ylim(-8, 12)
plt.show()


In [None]:
# Filtered signal 

plt.plot(sampler.buffer[:len(sampler.buffer)//3,2])
plt.title("uV after filter")
plt.xlabel("Timepoints Before fs: " + str(DEFAULT_FS))
plt.ylabel("uV")
plt.ylim(-400, 400)
plt.show()

In [None]:
# Original (dc removed)

plt.plot(sampler.dc_removed_buffer[:len(sampler.buffer)//3,2])
plt.title("uV before filter")
plt.xlabel("Timepoints Before fs: " + str(DEFAULT_FS))
plt.ylabel("uV")
plt.ylim(-400, 400)
plt.show()

## Load + Analyze recorded live data for offline live testing

In [None]:
sampler_raw_buffer = loadPickle("realtime_data/sampler_raw_buffer_small_noise.pickle")
sampler_buffer = loadPickle("realtime_data/sampler_buffer_small_noise.pickle")

In [None]:
# Testing recorded samples with model
%matplotlib notebook
another_sampler = EEGSampler(live=False, model=model_t)
sameplotter = SamePlotter(y_lim=(-400, 400))
original_data = sampler_raw_buffer
for i in range(len(original_data) - 1, -1, -1): 
    sample = [original_data[i, 0], original_data[i, 1], 0, 0, 0, 0, original_data[i, 2], original_data[i, 3]]
    another_sampler.push_data_sample(sample)
    if i % 50 == 0: 
        arr_to_plot = another_sampler.buffer #np.transpose(np.array([another_sampler.focus_average]))
        #arr_to_plot = another_sampler.power_bin_values[:, :, 2]
        sameplotter.plot(arr_to_plot)
    
    

In [None]:
# (fs, pass_type='lowpass', f_range=58, butterworth_order=8)

freqs, psds = spectral.compute_spectrum(another_sampler.buffer[:len(another_sampler.buffer)//3,0], DEFAULT_FS, method='welch', avg_type='mean', nperseg=DEFAULT_FS*2)
plt.plot(freqs[0:3*len(freqs)//5], np.log(psds[0:3*len(freqs)//5]))
plt.title("PSD after filtering")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power log(uV^2 / Hz)")
plt.ylim(-8, 9)
plt.show()



In [None]:
# Original
freqs, psds = spectral.compute_spectrum(another_sampler.raw_buffer[:len(another_sampler.buffer)//3,0], DEFAULT_FS, method='welch', avg_type='mean', nperseg=DEFAULT_FS*2)
plt.plot(freqs[0:3*len(freqs)//5], np.log(psds[0:3*len(freqs)//5]))
plt.title("PSD before filtering")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power log(uV^2 / Hz)")
plt.ylim(-8, 9)
plt.show()




## Load pre-recorded data used for training

In [None]:
'''
# Files that work with this notebook: 
P001/part_P001_block_S004.xdf
P001/part_P001_block_S005.xdf

'''
foldername="P001"
filename="part_P001_block_S004"
path = "../data/filtered_data/"+foldername+"/"+filename+".pickle"
loaded_data = loadPickle(path)#loadxdf(path)

loaded_data = getArticleSectionData("response", 0, loaded_data)
loaded_eeg_data = loaded_data['eeg']['data']
loaded_eeg_time = loaded_data['eeg']['time']
df = loadData(datatype='dataframe_filtered', foldername=foldername, filename=filename)



In [None]:
# Testing filtered (no ICA cleaning) data compared with actual focus
%matplotlib notebook
loaded_sampler = EEGSampler(live=False, model=model_t)
sameplotter = SamePlotter(y_lim=(-400, 400))
original_data = loaded_eeg_data[::-1]
original_time = loaded_eeg_time[::-1]
actual_focus = np.zeros((len(loaded_sampler.focus)))
prev_df_time = 0
for i in range(len(original_data) - 1, -1, -1): 
    current_time = original_time[i]
    elem = df[(df['time'] <= current_time) & (df['time'] >= prev_df_time)]
    if len(elem) > 0: 
        #print(len(elem))
        elem = elem.iloc[-1]
        actual_focus = np.roll(actual_focus, 1, 0)
        actual_focus[0] = 1 if elem['trial_time'] <= 0.7 else 0 if elem['trial_time'] <= 0.9 else -1
        prev_df_time = elem['time']
    else :
        print("no df")
    sample = [original_data[i, 0], original_data[i, 1], 0, 0, 0, 0, original_data[i, 6], original_data[i, 7]]
    loaded_sampler.push_data_sample(sample)
    if i % 50 == 0: 
        arr_to_plot = loaded_sampler.buffer #np.transpose(np.append([loaded_sampler.focus_average], [actual_focus], axis=0))
        #arr_to_plot = loaded_sampler.power_bin_values[:, :, 2]
        sameplotter.plot(arr_to_plot)
    
    

In [None]:
# PSD of loaded data 

freqs, psds = spectral.compute_spectrum(loaded_sampler.dc_removed_buffer[:len(loaded_sampler.buffer)//3,0], DEFAULT_FS, method='welch', avg_type='mean', nperseg=DEFAULT_FS*2)
plt.plot(freqs[0//5:3*len(freqs)//5], np.log(psds[0//5:3*len(freqs)//5])) 
plt.title("PSD before filter")
plt.xlabel("Frequency (Hz)")
plt.ylabel("Power log(uV^2 / Hz)")
plt.ylim(-8, 9)
plt.show()