In [108]:
! pip install paho-mqtt scipy

import joblib
import numpy as np
import paho.mqtt.client as mqtt
from Functions import Data_path, Plot_data ,CCA, ROC
from Functions.Filtering import filtering
from Functions.Common_average_reference import car 
from Functions.CCA_Feature_Extraction import cca_feature_extraction
# from Functions.CCA_FE_Single_Trial import cca_feature_extraction
from Functions.Feature_selections import feature_selecions 

import sys
np.set_printoptions(threshold=sys.maxsize)


Defaulting to user installation because normal site-packages is not writeable


In [109]:
def count_values(arr):
    """
    Function to count occurrences of 0s, 1s, and 2s in a given array.
    
    Parameters:
    arr (list): The input array containing numbers.
    
    Returns:
    tuple: A tuple containing the count of 0s, 1s, and 2s in the array.
    """
    count_zeros = np.count_nonzero(arr == 0)
    count_ones = np.count_nonzero(arr == 1)
    count_twos = np.count_nonzero(arr == 2)
        
    return count_zeros, count_ones, count_twos


saved_model = joblib.load("trained_model.pkl")



# Step 1: Initialize parameters and system


In [110]:
order = 4
notch_freq = 50
quality_factor = 20
subbands = [[12, 16, 20], [14, 18, 22]]
f_low = np.min(subbands) - 1
f_high = np.max(subbands) + 1
notch_filter = "on"
filter_active = "on"
type_filter = "bandpass"
num_harmonic = 4
f_stim = [13, 21, 17]
num_channel = [0, 1]
window_size = 256  # Example window size
overlap = 128      # Example overlap size
fs = 256           # Sampling frequency



###############     CCA     ################
num_harmonic = 4          # Number of harmonic for each frequency stimulation
f_stim = [13, 21, 17]     # Frequencies stimulation
num_channel = [0, 1]      # Number of Channel     



############### feature selection ###########
num_features = 24                  # pick any random high number it's result won't go beyond its limits :)
type_feature_selection = "anova"   # var, anova, mi, ufs, rfe, rf, l1fs, tfs, fs, ffs, bfs


# Step 2: Define a function to process data chunks

In [111]:
def process_data_chunk(data_chunk):
    # Apply filtering
    filtered_chunk = filtering(data_chunk, f_low, f_high, order, fs, notch_freq, quality_factor, filter_active, notch_filter, type_filter)
    # Apply CAR
    car_chunk = car(filtered_chunk)
    # Extract features using CCA
    features_chunk = cca_feature_extraction(car_chunk, fs, f_stim, num_channel, num_harmonic)
    return features_chunk

# Example usage during inference
def normalize_inference(x_real_time, scaler_filename="normalize.joblib"):
    # Load the saved scaler
    norm = joblib.load(scaler_filename)
    
    if x_real_time.ndim == 1:
        x_real_time = x_real_time.reshape(-1, 1)
    
    x_real_time_normalized = norm.transform(x_real_time)
    
    return x_real_time_normalized

# Step 3: Simulate real-time data streaming

In [112]:
def real_time_processing(streaming_data):
    num_samples = streaming_data.shape[0]
    start_idx = 0
    while start_idx < num_samples:
        end_idx = min(start_idx + window_size, num_samples)
        data_chunk = streaming_data[start_idx:end_idx, :, :]
        
        if data_chunk.shape[0] == window_size:
            features_chunk = process_data_chunk(data_chunk)
            print("Processed chunk features:", features_chunk.shape)
            # Here you can do further processing or classification with features_chunk
        start_idx += (window_size - overlap)

# Step 4 : Data acquisition using MQTT

In [107]:
import json 
MQTT_BROKER = "localhost"
MQTT_PORT = 1883
MQTT_TOPIC = "DAQ"
client = mqtt.Client()
saved_model = joblib.load("trained_model.pkl")


def on_connect(client, userdata, flags, rc):
    print(f"Connected to MQTT broker with result code: {rc}")
    client.subscribe(MQTT_TOPIC)  # Subscribe to a topic if needed

def on_message(client, userdata, msg):
    try:
        message = json.loads(msg.payload.decode())

        frequency = message['frequency']
        
        eeg_data_list = message['eeg_data']
        eeg_data = np.array(eeg_data_list)

        print(f"Received EEG data for frequency {frequency}")
        print(f"EEG data shape: {eeg_data.shape}")
        # print(f"EEG data : {eeg_data}")

        filtered_data = filtering(eeg_data, f_low, f_high, order, fs, notch_freq, quality_factor, 
                                    filter_active, notch_filter, type_filter)
        
        data_car = car(filtered_data) 
        print('CAR data : ', data_car.shape)

# For total data (1280,8,480) -> features will be (480,6) for two channels
        features_extraction = cca_feature_extraction(data_car, fs, f_stim, num_channel, num_harmonic)
        print("Extracted feature : ",features_extraction.shape)

        #normalization
        norm_features = normalize_inference(features_extraction)

        # result
        feature_output = saved_model.predict(norm_features)
        z1,o1,t1 = count_values(feature_output)

        print(f"predicted result = {feature_output}")
        print(f"Result : Zeros: {z1}, Ones: {o1}, Twos: {t1}")

        # mapping with commands
        mapping = { 0 : "0", 1 : "1" , 2 : "2"}
        print(mapping[ 0 if z1 > o1 and z1 > t1 else 1 if o1 > t1 else 2])


    except Exception as e:
        print(f"Error processing message: {e}")



def connect_mqtt():
    client.on_connect = on_connect
    client.on_message = on_message
    client.connect(MQTT_BROKER, MQTT_PORT, 60)
    client.loop_start()  # Start the MQTT client loop in a separate thread

# Initialize MQTT client and connect
connect_mqtt()

try:
    while True:
        pass  # Keep the main thread alive to process MQTT messages
except KeyboardInterrupt:
    client.loop_stop()
    client.disconnect()


  client = mqtt.Client()


Connected to MQTT broker with result code: 0
Received EEG data for frequency 13Hz
EEG data shape: (1280, 8, 1)
CAR data :  (1280, 8, 1)
Extracted feature :  (1, 6)
Error processing message: Found array with dim 3. MinMaxScaler expected <= 2.
