In [None]:
import tkinter as tk
from tkinter import filedialog, messagebox
import joblib
import numpy as np
import mne
from sklearn.preprocessing import StandardScaler

# Load the pre-trained model
model = joblib.load('svm_model.pkl')

# Function to handle EEG file upload
def upload_eeg_file():
    file_path = filedialog.askopenfilename(filetypes=[("EEG files", "*.edf")])
    if file_path:
        file_label.config(text=f"Uploaded file: {file_path.split('/')[-1]}")
        cancel_button.config(state=tk.NORMAL)
        submit_button.config(state=tk.NORMAL)

        # Process EEG data to extract features and make prediction
        eeg_data = process_eeg_data(file_path)
        if eeg_data is not None:
            prediction = predict_risk(eeg_data)
            if prediction is not None:
                result = f"Schizophrenia Risk: {'High' if prediction == 1 else 'Low'}"
                result_label.config(text=result)
            else:
                result_label.config(text="Error in making prediction.")
        else:
            result_label.config(text="Error in processing EEG data.")

# Function to handle cancel action
def cancel_upload():
    file_label.config(text="No file uploaded.")
    result_label.config(text="Schizophrenia Risk: ")
    cancel_button.config(state=tk.DISABLED)
    submit_button.config(state=tk.DISABLED)

# Function to process EEG data (load, filter, and extract features)
def process_eeg_data(file_path):
    try:
        # Step 1: Load and preprocess the EEG signal (in EDF format)
        raw = mne.io.read_raw_edf(file_path, preload=True)
        raw.filter(1, 40, fir_design='firwin')  # Bandpass filter (1-40 Hz)

        # Step 2: Use the first 12 channels
        channels_to_use = raw.ch_names[:12]
        raw.pick_channels(channels_to_use)

        # Step 3: Extract features (18 features per channel)
        data = raw.get_data()  # Get the raw EEG data
        features = extract_features(data)

        # Step 4: Flatten features into a 1D array
        features_flat = features.flatten()

        # Step 5: Adjust to match model input size
        expected_feature_count = 209
        if len(features_flat) > expected_feature_count:
            features_flat = features_flat[:expected_feature_count]
        elif len(features_flat) < expected_feature_count:
            features_flat = np.pad(features_flat, (0, expected_feature_count - len(features_flat)))

        return features_flat
    except Exception as e:
        messagebox.showerror("Error", f"Error in processing EEG data: {e}")
        return None

# Function to extract features from EEG data
def extract_features(eeg_data):
    features = []
    for channel in eeg_data:  # Iterate over each channel
        # Extract basic statistical and time-domain features
        feature_vector = [
            np.mean(channel),
            np.std(channel),
            np.min(channel),
            np.max(channel),
            np.median(channel),
            np.var(channel),
            np.percentile(channel, 25),
            np.percentile(channel, 75),
            np.ptp(channel),
            np.sqrt(np.mean(np.square(channel))),  # Root mean square
            np.sum(np.abs(channel)),               # Sum of absolute values
            np.count_nonzero(channel) / len(channel),  # Sparsity
        ]
        
        # FFT feature extraction
        fft_result = np.fft.fft(channel)
        feature_vector.extend([
            np.mean(np.abs(fft_result)),  # Magnitude spectrum mean
            np.std(np.abs(fft_result)),   # Magnitude spectrum std
            np.max(np.abs(fft_result)),   # Max magnitude
            np.min(np.abs(fft_result)),   # Min magnitude
        ])

        # Trim or pad to ensure 18 features per channel
        if len(feature_vector) > 18:
            feature_vector = feature_vector[:18]
        elif len(feature_vector) < 18:
            feature_vector.extend([0] * (18 - len(feature_vector)))

        features.append(feature_vector)
    
    return np.array(features)

# Function to predict schizophrenia risk using the trained model
def predict_risk(eeg_data):
    if eeg_data is not None:
        try:
            # Reshape to match the model input format
            prediction = model.predict([eeg_data])  # Add extra brackets to create a 2D array
            return prediction[0]  # Return the first prediction
        except Exception as e:
            messagebox.showerror("Error", f"Prediction failed: {e}")
            return None
    return None

# Function to handle submit action
def submit():
    if file_label.cget("text") != "No file uploaded.":
        messagebox.showinfo("Prediction Result", result_label.cget("text"))
    else:
        messagebox.showwarning("Warning", "No EEG file uploaded!")

# Create the main window
root = tk.Tk()
root.title("Schizophrenia Risk Prediction")
root.geometry("500x500")
root.configure(bg="#f0f0f0")

# Create a frame for the buttons and center it
center_frame = tk.Frame(root, bg="#f0f0f0")
center_frame.place(relx=0.5, rely=0.5, anchor=tk.CENTER)

# Upload Button
upload_button = tk.Button(
    center_frame,
    text="Upload EEG File",
    command=upload_eeg_file,
    bg="white",
    fg="black",
    bd=0,
    relief=tk.FLAT,
    padx=20,
    pady=10,
    font=("Arial", 12),
    activebackground="#e0e0e0",
    activeforeground="black",
    cursor="hand2",
    highlightbackground="#f0f0f0",
    highlightthickness=0,
    borderwidth=0,
)
upload_button.pack(pady=10)

# Label to show the uploaded file name
file_label = tk.Label(center_frame, text="No file uploaded.", bg="#f0f0f0", font=("Arial", 10))
file_label.pack(pady=10)

# Frame for Cancel and Submit buttons
cancel_submit_frame = tk.Frame(center_frame, bg="#f0f0f0")
cancel_submit_frame.pack(pady=10)

# Cancel Button
cancel_button = tk.Button(
    cancel_submit_frame,
    text="Cancel",
    command=cancel_upload,
    bg="#f44336",
    fg="white",
    bd=0,
    relief=tk.FLAT,
    padx=20,
    pady=10,
    font=("Arial", 12),
    activebackground="#d32f2f",
    activeforeground="white",
    cursor="hand2",
    highlightbackground="#f0f0f0",
    highlightthickness=0,
    borderwidth=0,
    state=tk.DISABLED,
)
cancel_button.pack(side=tk.LEFT, padx=5)

# Submit Button
submit_button = tk.Button(
    cancel_submit_frame,
    text="Submit",
    command=submit,
    bg="#4CAF50",
    fg="white",
    bd=0,
    relief=tk.FLAT,
    padx=20,
    pady=10,
    font=("Arial", 12),
    activebackground="#45a049",
    activeforeground="white",
    cursor="hand2",
    highlightbackground="#f0f0f0",
    highlightthickness=0,
    borderwidth=0,
    state=tk.DISABLED,
)
submit_button.pack(side=tk.RIGHT, padx=5)

# Result Label to show the prediction
result_label = tk.Label(center_frame, text="Schizophrenia Risk: ", bg="#f0f0f0", font=("Arial", 12))
result_label.pack(pady=20)

# Run the application
root.mainloop()


Extracting EDF parameters from C:\Users\tanis\Downloads\dataverse_files\s12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 271749  =      0.000 ...  1086.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 825 samples (3.300 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.1s


Extracting EDF parameters from C:\Users\tanis\Downloads\dataverse_files\h07.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 227499  =      0.000 ...   909.996 secs...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 825 samples (3.300 s)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s
