EEG Signals Control
-

In [None]:
import tkinter as tk
from PIL import Image, ImageTk
import numpy as np
import matplotlib.pyplot as plt
import pywt
import pandas as pd
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from scipy.signal import butter, filtfilt, savgol_filter
from scipy.stats import skew, kurtosis
import random
import joblib

## For Communication with Arduino
import serial
import time

## To open connection with Arduino
arduino = serial.Serial('COM7', 9600, timeout=1)
time.sleep(2)  # Waiting for the connection to initialize


## Load the model and scale
svm_model = joblib.load('svm_model.pkl')
scaler = joblib.load('scaler.pkl')

## Load data from NPZ file
data = np.load('eeg_signal.npz')

## Extract signals from the file
signals_dict = {
    "forward": data['forward'],
    "back": data['back'],
    "right": data['right'],
    "left": data['left'],
    "stop": data['stop']
}

colors = {"forward": 'blue', "back": 'green', "right": 'red', "left": 'purple', "stop": 'orange'}

## Set up signal filters
def butter_bandpass_filter(data, lowcut, highcut, fs, order=4):
    nyquist = 0.5 * fs
    low = lowcut / nyquist
    high = highcut / nyquist
    b, a = butter(order, [low, high], btype='band')
    return filtfilt(b, a, data)

def remove_outliers(signal, threshold=3):
    mean, std = np.mean(signal), np.std(signal)
    return np.where(np.abs(signal - mean) > threshold * std, mean, signal)

def normalize_signal(signal):
    return (signal - np.min(signal)) / (np.max(signal) - np.min(signal))

def smooth_signal(data, window_length=51, polyorder=3):
    return savgol_filter(data, window_length, polyorder)

def preprocess_signal(signal, fs, lowcut, highcut):
    signal_filtered = butter_bandpass_filter(signal, lowcut, highcut, fs)
    signal_smoothed = smooth_signal(signal_filtered)
    signal_cleaned = remove_outliers(signal_smoothed)
    return normalize_signal(signal_cleaned)

def extract_features(signal, fs):
    features = [
        np.mean(signal), np.std(signal), np.var(signal), np.max(signal),
        np.sum(np.abs(np.fft.fft(signal))**2),  # Total energy
        np.sum(np.abs(np.fft.fft(signal)[:50])**2),  # Energy 0-50 Hz
        skew(signal), kurtosis(signal)
    ]
    coeffs = pywt.wavedec(signal, 'db4', level=5)
    for c in coeffs[:3]:
        features.append(np.mean(c))
        features.append(np.std(c))
    derivative_1 = np.diff(signal)
    derivative_2 = np.diff(derivative_1)
    features.append(np.mean(derivative_1))
    features.append(np.mean(derivative_2))
    return features

fs, lowcut, highcut = 500, 0.5, 50
feature_names = [
    "mean", "std", "variance", "max", "total_energy", "energy_0_50",
    "skewness", "kurtosis", "wavelet_mean_1", "wavelet_std_1",
    "wavelet_mean_2", "wavelet_std_2", "wavelet_mean_3", "wavelet_std_3",
    "derivative_1_mean", "derivative_2_mean"
]

def show_random_signal(signal_name):
    global canvas_frame, result_label
    if signal_name in signals_dict:
        signal_list = signals_dict[signal_name]
        random_index = random.randint(0, len(signal_list) - 1)
        signal = signal_list[random_index]
        filtered_signal = preprocess_signal(signal, fs, lowcut, highcut)
        extracted_features = extract_features(filtered_signal, fs)
        feature_df = pd.DataFrame([extracted_features], columns=feature_names)
        scaled_features = scaler.transform(feature_df)
        prediction = svm_model.predict(scaled_features)[0]
        
        ## Update label to display prediction
        result_label.config(text=f"Predicted Label: {prediction}")

        ## send prediction to Arduino
        arduino.write(prediction.encode())
        
        fig, ax = plt.subplots(figsize=(10, 5))
        fig.patch.set_facecolor('#d6eaff')
        ax.set_facecolor('#edf6ff')
        ax.plot(signal, color=colors[signal_name])
        ax.set_title(f'{signal_name.capitalize()} Signal')
        ax.set_xlabel('Samples')
        ax.set_ylabel('Amplitude')
        ax.grid(True)

        ## Remove any previous graphics in the interface
        for widget in canvas_frame.winfo_children():
            widget.destroy()

        ## View chart in Tkinter
        canvas = FigureCanvasTkAgg(fig, master=canvas_frame)
        canvas.draw()
        canvas.get_tk_widget().pack()

        ## Close the drawing to prevent memory consumption
        plt.close(fig)

# ====== GUI ====== #

root = tk.Tk()
root.title("Nerve Signals")
root.geometry("1920x1080")

## Add background image
bg_image = Image.open("Nerve Signal Image.jpg")
bg_image = bg_image.resize((1920, 1080))
bg_photo = ImageTk.PhotoImage(bg_image)
bg_label = tk.Label(root, image=bg_photo)
bg_label.place(x=0, y=0, relwidth=1, relheight=1)

## Place the rest of the elements over the background
title_label = tk.Label(root, text="Human Decision", font=("Arial", 18), bg="#99ccff", fg="#000033")
title_label.pack(pady=20)

button_frame = tk.Frame(root, bg="#395d7f")
button_frame.pack(pady=20)

for i, signal_name in enumerate(signals_dict.keys()):
    btn = tk.Button(button_frame, text=signal_name.capitalize(), width=20,height=2,
                    bg=colors[signal_name], fg="white",
                    command=lambda name=signal_name: show_random_signal(name))
    btn.grid(row=0, column=i, padx=20)

canvas_frame = tk.Frame(root, bg="#ffffff")
canvas_frame.pack(pady=20)

result_label = tk.Label(root, text="Predicted Label will appear here", font=("Arial", 14), bg="#99ccff", fg="#000033")
result_label.pack(pady=10)

root.mainloop()


In [None]:
## Close the connection when the program ends
arduino.close()