In [1]:
### First, import necessary packages

# Some standard pythonic imports
import warnings
warnings.filterwarnings('ignore')
import os,numpy as np,pandas as pd
from collections import OrderedDict
import seaborn as sns
import matplotlib
from matplotlib import pyplot as plt
import pathlib
from os import listdir
from os.path import isfile, join
import pyxdf
import PyQt5
import time
import random
import pickle

from subfunctions import *

from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# Scikit-learn and Pyriemann ML functionalities
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
from sklearn.model_selection import cross_val_score, StratifiedShuffleSplit
from pyriemann.estimation import ERPCovariances, XdawnCovariances, Xdawn
from pyriemann.tangentspace import TangentSpace
from pyriemann.classification import MDM


# MNE functions
import mne
from mne import Epochs,find_events
from mne.datasets import sample
from mne.io import read_raw_fif
from mne.decoding import (SlidingEstimator, GeneralizingEstimator, Scaler,
                          cross_val_multiscore, LinearModel, get_coef,
                          Vectorizer, CSP)
# Real-time Functionalities
from mne_realtime import LSLClient

# Lab Streaming Layer
from pylsl import *


from joblib import dump, load

# For interactive plots
from IPython import get_ipython
get_ipython().run_line_magic('matplotlib', 'qt')

from easygui import *


### Load Prediction Model

In [13]:
current_path = pathlib.Path().absolute()  
##recording_path = (f"{current_path}\Data\Raw Data")

model_path = (f"{current_path}\Models")

model_files = [f for f in listdir(model_path) if isfile(join(model_path, f))]

# message to be displayed
text = "Choose Model File:"
  
# window title
title = "Model File"
  
# item choices
choices = model_files
  
# creating a button box
model_name = choicebox(text, title, choices)

print (f"Model file: {model_name}")
model_path = (f"{model_path}\{model_name}")

with open(model_path, "rb") as file:
    loaded_model = load(file)

Model file: sub-Synt_ses-Synt_task-3_Class_run-001_eeg_model


In [14]:
lowPass = loaded_model.bandwidthfilter.l_freq
highPass = loaded_model.bandwidthfilter.h_freq
filterMethod = loaded_model.bandwidthfilter.method
chNames = loaded_model.ch_names

In [4]:
def online_prediction(Data,input_shape,model):
    assert len(Data.shape) == 3, 'X is not 3D'
    assert Data.shape[1] == input_shape[1] , 'Number of electrodes differs'
    assert Data.shape[2] >= input_shape[2] , 'Data is too short'
    Data = Data[:,:,0:input_shape[2]]
    pred_y = model.predict(Data)
    return pred_y

In [5]:
## Assign channel names for online EEG stream
ch_names = ['C3','C4','Cz','FC1','FC2','FC5','FC6','CP1','CP2','CP5','CP6','O1','O2']

In [23]:
# First find our Unity P300 marker stream,
stream_name = "P300_Markers"
stream_info = pylsl.resolve_stream("name", stream_name)

if stream_info is not None:
    # create an inlet for the stream
    inlet = pylsl.StreamInlet(stream_info[0])
    print(f"Connected to {stream_name}")
else:
    print(f"Stream {stream_name} not found")
    
epoch_list=[]
realY=[]
i=0
flag = True
wait_max = 5
# main function is necessary here to enable script as own program
# in such way a child process can be started (primarily for Windows)
if __name__ == '__main__':
    with LSLClient(info=None, host='openbcigui', wait_max=wait_max) as client:
        client_info = client.get_measurement_info()
        sfreq = int(client_info['sfreq'])
        
        ## Create an outlet lsl marker stream for model decision
        info = StreamInfo('DecisionMarkerStream', 'Markers', 1, 0, 'string', 'myuidw43536')
        outlet = StreamOutlet(info)
        markernames = ['Y', 'N']
        
        ## Wait in 'while' loop during sampling
        while (flag):
            # get a new sample 
            sample, timestamp = inlet.pull_sample()
            
            if (sample[0] == 'Target Trial'):
                epoch = client.get_data_as_epoch(n_samples=151) # keep target trials data to epochs
                filt = epoch.filter(lowPass, highPass, method=filterMethod)
                epoch_list.append(epoch)
            if (sample[0] == 'Break'):
                flag = False
            print("got %s at time %s" % (sample[0], timestamp)) # Log of incoming samples
        
print('Streams closed')
concat_epochs = mne.concatenate_epochs(epoch_list)
time.sleep(0.2)


concat_epochs.info = mne.create_info(ch_names = ch_names, sfreq=sfreq)
concat_epochs.pick(chNames)

data = concat_epochs.get_data()

pred = online_prediction(data,loaded_model.input_shape,loaded_model.clsf)
sum(pred)
len(pred)

## Decide 
if (sum(pred)/len(pred) >= 0.5):
    #send Unity trigger to show 'Yes' answer
    outlet.push_sample(markernames[0])
else:
    #send Unity trigger to show 'No' answer
    outlet.push_sample(markernames[1])


Connected to P300_Markers
Client: Waiting for server to start
Looking for LSL stream openbcigui...
Found stream 'obci_eeg1' via openbcigui...
Client: Connected
got Distractor Trial at time 217902.7968646
got Distractor Trial at time 217903.5109507
got Distractor Trial at time 217904.2954575
got Distractor Trial at time 217905.1942058
got Distractor Trial at time 217905.9464431
got Distractor Trial at time 217906.7479855
got Non-Target Trial at time 217907.4133418
got Distractor Trial at time 217908.1820723
got Distractor Trial at time 217908.8440483
Not setting metadata
1 matching events found
No baseline correction applied
0 projection items activated
Setting up band-pass filter from 1 - 40 Hz

IIR filter parameters
---------------------
Butterworth bandpass zero-phase (two-pass forward and reverse) non-causal filter:
- Filter order 16 (effective, after forward-backward)
- Cutoffs at 1.00, 40.00 Hz: -6.02, -6.02 dB

got Target Trial at time 217909.6457147
got Distractor Trial at time 

### Save online data for future analysis

fif_export_path = (f"{current_path}\Data\Processed Data\{recording_file[:-4]}_Online.fif")
epochs.save(fif_export_path)