# Reading Training Data

In [None]:
import pandas as pd
import numpy as np

#Load your data
eeg_data = pd.read_csv("F:/Graduation/Sessions/final/final_EEGdata.csv")  # replace with your file path
labels = pd.read_csv('F:/trial_labels.csv')  # replace with your file path

# Select only the columns for the 4 channels
channels = ['C3', 'C4', 'CZ', 'PZ']
eeg_data = eeg_data[channels]
print(eeg_data)

# Data Segmentation 

In [None]:
import numpy as np

def segment_data(data, labels, start_time, end_time, sampling_rate):
    segment_length = (end_time - start_time) * sampling_rate
    segments = []
    segment_labels = []

    for i, row in labels.iterrows():
        start_index = int(row['start_time'] * sampling_rate)
        segment = data.iloc[start_index : int(start_index + segment_length)]
        segments.append(segment)
        segment_labels.append(row['direction'])

    return segments, np.array(segment_labels)

# Define your parameters
START_TIME = 2.5  # in seconds
END_TIME = 7.5  # in seconds
SAMPLING_RATE = 250  # in Hz

# Segment the data
segments, segment_labels = segment_data(eeg_data, labels, START_TIME, END_TIME, SAMPLING_RATE)
print (segments)

# Feature Extraction

In [100]:
from scipy.signal import welch
import pywt
import scipy
import numpy as np

# Assuming the sampling rate is defined earlier in your script
SAMPLING_RATE = 250  # Modify as per your data's sampling rate

def calculate_psd(data):
    freqs, psd = welch(data, fs=SAMPLING_RATE)
    return psd

def calculate_hjorth(activity, mobility):
    complexity = np.sqrt(np.diff(mobility, axis=0)**2 + mobility[:-1]**2) / mobility[:-1]
    return complexity

def calculate_wavelet_transform(data):
    coeffs = pywt.wavedec(data, 'db4', level=5)
    features = [np.mean(coeff**2) for coeff in coeffs]
    return features

def extract_features(segments, channels):
    features = []
    for segment in segments:
        segment_features = []
        for channel in channels:  # Iterate through each of the 4 channels
            channel_data = segment[channel].values  # Extract data for the current channel
            # Calculate features for this channel
            psd = calculate_psd(channel_data)
            wavelet = calculate_wavelet_transform(channel_data)
            activity = np.var(channel_data)
            mobility = np.sqrt(np.var(np.diff(channel_data)) / activity)
            
            # Append the features from this channel to the segment's feature vector
            segment_features.extend([activity, mobility, *psd, *wavelet])
        
        # Append the feature vector for this segment to the overall feature list
        features.append(segment_features)
    return np.array(features)

# Extract features for each segment (assuming 'segments' is already defined and properly segmented)
features = extract_features(segments, channels)

# Model Training

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(features, segment_labels, test_size=0.2, random_state=42)

# Initialize and train the model
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Predict and evaluate the model
predictions = model.predict(X_test)
accuracy = accuracy_score(y_test, predictions)
print(f'Model accuracy: {accuracy}')

# GUI & new-input classification

In [172]:
import tkinter as tk
from tkinter import filedialog, Label
from PIL import Image, ImageTk
import pandas as pd

# Global variables
new_data = None
channels = ['C3', 'C4', 'CZ', 'PZ']  # Update with the actual channels you're using

def load_data():
    global new_data  # Declare new_data as a global variable
    file_path = filedialog.askopenfilename()
    if file_path:  # Only proceed if a file was selected
        new_data = pd.read_csv(file_path)
        status_label.config(text="File loaded successfully.")
        
def extract_new_features(channels):
    new_segment_features = []
    for channel in channels:  # Iterate through each of the 4 channels
        channel_data = new_data[channel].values  # Extract data for the current channel
         # Calculate features for this channel
        psd = calculate_psd(channel_data)
        wavelet = calculate_wavelet_transform(channel_data)
        activity = np.var(channel_data)
        mobility = np.sqrt(np.var(np.diff(channel_data)) / activity)
            
        # Append the features from this channel to the segment's feature vector
        new_segment_features.extend([activity, mobility, *psd, *wavelet])
        
    return np.array(new_segment_features)

def process_and_classify():
    global new_data  # Declare new_data as a global variable
    if new_data is not None:
        try:
            # Assuming the segment_data and extract_features functions are defined
            new_data = new_data.iloc[625 : 1875]
            new_data = new_data[channels]
            new_features = extract_new_features(channels)
            new_features = new_features.reshape(1, -1)
            prediction = model.predict([new_features[0]])

            # Update the label text and image based on the prediction
            result_label.config(text=str(prediction[0]))
            update_image(prediction[0])
        except Exception as e:
            status_label.config(text=f"Error during processing: {e}")
    else:
        status_label.config(text="No data loaded. Please load a CSV file first.")

def update_image(direction):
    if direction.lower() == 'left':
        image_path = "F:/Graduation/picture1.png"
    else:
        image_path = "F:/Graduation/picture4.png"
    image = Image.open(image_path)
    photo = ImageTk.PhotoImage(image)
    direction_label.config(image=photo)
    direction_label.image = photo  # Keep a reference!

# GUI setup
root = tk.Tk()
root.title("EEG Direction Classifier")

# Load Button
load_button = tk.Button(root, text="Load CSV", command=load_data)
load_button.pack()

# Process and Classify Button
process_button = tk.Button(root, text="Process and Classify", command=process_and_classify)
process_button.pack()

# Label to display classification result
result_label = Label(root, text="Result: None", font=('Helvetica', 14))
result_label.pack()

# Label to display direction image
direction_label = Label(root)
direction_label.pack()

# Status Label to display current status or errors
status_label = Label(root, text="Awaiting action", font=('Helvetica', 10))
status_label.pack()

# Run the application
root.mainloop()
