In [1]:
from tensorflow import keras
from pythonosc import dispatcher
from pythonosc import osc_server
import pandas as pd
from collections import defaultdict
import threading
import numpy as np

In [2]:
FEATURES = 20
INTERVAL_TIME = 3
SAMPLE_RATE = 256
STREAMING_SAMPLE_RATE = 10

In [3]:
#ALL GLOBAL VARIABLES:
hsi = [4,4,4,4]
hsi_string = ""

#Note: The names of the cols from muse device do not align with names in earlier files (Fp1 instead of AF7 and Fp2 instead of AF8)
#Note2: Very important that this list is in the same order as for the earlier datasets as sklearn ML models do not change order!!
cols = ['Delta_TP9','Delta_AF7','Delta_AF8','Delta_TP10',
    'Theta_TP9','Theta_AF7','Theta_AF8','Theta_TP10',
    'Alpha_TP9','Alpha_AF7','Alpha_AF8','Alpha_TP10',
    'Beta_TP9','Beta_AF7','Beta_AF8','Beta_TP10',
    'Gamma_TP9','Gamma_AF7','Gamma_AF8','Gamma_TP10']
Vals = defaultdict(list, { k:[] for k in cols})

In [4]:
datapoints = 0
place = 0

In [5]:
model = keras.models.load_model("C:\\Users\\Chanakya\\BDR 2022\\Saved_Models\\Model_1")

In [6]:
def hsi_handler(address: str,*args):
    global hsi, hsi_string
    hsi = args
    if ((args[0]+args[1]+args[2]+args[3])==4):
        hsi_string_new = "Muse Fit Good"
    else:
        hsi_string_new = "Muse Fit Bad on: "
        if args[0]!=1:
            hsi_string_new += "Left Ear. "
        if args[1]!=1:
            hsi_string_new += "Left Forehead. "
        if args[2]!=1:
            hsi_string_new += "Right Forehead. "
        if args[3]!=1:
            hsi_string_new += "Right Ear."        
    if hsi_string!=hsi_string_new:
        hsi_string = hsi_string_new
        print(hsi_string)  

In [7]:
def wave_handler(address: str,*args):
    global Vals, datapoints, model, cols
    wave = args[0][0]

    channels = ['TP9', 'AF7', 'AF8', 'TP10']

    for i in [0,1,2,3]: # for each of the 4 sensors update the specific brain wave data (delta, theta etc)
        key = wave + '_' + channels[i]
        Vals[key].append(args[i+1]) #add values to dict
        datapoints +=1

    # we have 20 features, and we want to have 3 seconds of data, data comes in 10Hz, so we first have to add 20x30=600 datapoints before moving on
    if datapoints == FEATURES * INTERVAL_TIME * STREAMING_SAMPLE_RATE:
        # step 1: create dataframe
        # we add datetime to the df as this makes it compatible with our earlier code
        df= pd.DataFrame.from_dict(Vals)
        print(df)
        
        limited_input = df.to_numpy()
        
        # print(limited_input)
        # print(limited_input.shape)
        
        input = np.zeros((1, SAMPLE_RATE * INTERVAL_TIME, FEATURES))
        for i in range(SAMPLE_RATE * INTERVAL_TIME):
            # print(i, min(len(limited_input)-1, int(i//(SAMPLE_RATE / STREAMING_SAMPLE_RATE))))
            input[0][i] = limited_input[min(len(limited_input)-1, int(i//(SAMPLE_RATE / STREAMING_SAMPLE_RATE)))]
        
        
        pred = model.predict(input)
        # proba = model.predict_proba(input)
        print(pred)
        print(pred.shape)
        # print(proba)
        
        datapoints = 0
        Vals = defaultdict(list, {k :[] for k in cols})
        
#         input = df.to_numpy()

#         # step 4: predict
#         pred = rf_model.predict(input)
#         proba = rf_model.predict_proba(input)
#         print(pred)
#         print(proba)

#         # step 5: update graph     
#         plot_update(pred)

#         #step 6: now, we reinit datapoints and the Vals dict and start again
#         datapoints = 0
#         Vals = defaultdict(list, { k:[] for k in cols})

In [8]:
def init_plot():
    ani = FuncAnimation(plt.gcf(), plot_update, interval=100) #update every 1 sec
    plt.show()

In [9]:
def plot_update(prediction):
    global place

    plt.cla()

    if prediction == 'label_left':
        place -= 1
    if prediction == 'label_right':
        place += 1 
    plt.plot(place,0,'ro')
    plt.xlim([-10,10])
    plt.xticks(np.arange(-10,10,1))
    plt.yticks([])

In [None]:
if __name__ == "__main__":
    
    # CHANGE THIS
    ip = "10.122.126.40"
    port = 5003
    
    thread = threading.Thread(target=init_plot)
    thread.daemon = True
    thread.start()

    #Init Muse Listeners    
    dispatcher = dispatcher.Dispatcher()
    dispatcher.map("/muse/elements/horseshoe", hsi_handler)
    
    dispatcher.map("/muse/elements/delta_absolute", wave_handler,'Delta')
    dispatcher.map("/muse/elements/theta_absolute", wave_handler,'Theta')
    dispatcher.map("/muse/elements/alpha_absolute", wave_handler,'Alpha')
    dispatcher.map("/muse/elements/beta_absolute", wave_handler,'Beta')
    dispatcher.map("/muse/elements/gamma_absolute", wave_handler,'Gamma')

    server = osc_server.ThreadingOSCUDPServer((ip, port), dispatcher)
    print("Listening on UDP port "+str(port))
    server.serve_forever()
    