In [6]:
import mne
import numpy as np
import tensorflow as tf
import os
from mne.preprocessing import ICA
from lime import lime_tabular

import lime
from lime.lime_tabular import LimeTabularExplainer

# Function to rename channels and drop specified channels based on conditions
def process_channels(raw_data):
    """
    Process and standardize EEG channels to keep only the 17 most common channels.
    """
    print(f"Initial channels: {raw_data.ch_names}")

    # Initialize a list to hold channels to drop
    channels_to_drop = []

    # Create mapping for channel renaming
    rename_map = {}
    for name in raw_data.ch_names:
        if any(x in name for x in ['23A-23R', '24A-24R', 'A2-A1']):
            channels_to_drop.append(name)
        else:
            new_name = name.replace('EEG ', '').replace('-LE', '')
            rename_map[name] = new_name

    # Drop unwanted channels
    if channels_to_drop:
        print(f"Dropping channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Rename remaining channels
    raw_data.rename_channels(rename_map)

    print(f"Final channels: {raw_data.ch_names}")

    # Define the 17 most common channels
    expected_channels = [
        'Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'Fp2',
        'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Cz', 'Pz'
    ]

    # Keep only the expected channels
    channels_to_keep = set(expected_channels)
    channels_to_drop = [
        ch for ch in raw_data.ch_names if ch not in channels_to_keep]

    if channels_to_drop:
        print(
            f"Dropping channels to keep only the expected 17 channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Verify we have the expected number of channels (should be 17)
    if len(raw_data.ch_names) != len(expected_channels):
        print(
            f"Warning: Expected {len(expected_channels)} channels, got {len(raw_data.ch_names)}")
        print(f"Missing: {set(expected_channels) - set(raw_data.ch_names)}")

    return raw_data


def read_eeg_file(file_path):
    # Read the raw data
    raw_data = mne.io.read_raw_edf(file_path, preload=True)
    return raw_data


def bandpass_filter(data, l_freq, h_freq, notch_freq=None):
    filtered_data = data.copy()

    # Apply bandpass first
    filtered_data.filter(l_freq=l_freq, h_freq=h_freq,
                         method='fir', phase='zero')

    # If using notch, apply with wider bandwidth
    if notch_freq is not None:
        filtered_data.notch_filter(freqs=notch_freq, notch_widths=2.0)

    return filtered_data


ica_channels = ['Fp1', 'Fp2']


def preprocess_ICA(epochs, n_components):
    """
    Apply Independent Component Analysis (ICA) to the epochs data.
    
    Parameters:
    -----------
    epochs : mne.Epochs
        The epochs data to process.
    n_components : int
        The number of components to extract.
    
    Returns:
    --------
    ica : ICA
        The fitted ICA object.
    """
    print(
        f"Preprocessing ICA for {len(epochs)} epochs...")  # Print the number of epochs being processed

    ica = ICA(n_components=n_components, random_state=97, max_iter=800)
    # Use the epochs directly
    ica.fit(epochs.copy().pick_channels(ica_channels))
    return ica


def create_epochs(processed_data, duration=5.0, overlap=1.0):
    """
    Create epochs from continuous EEG data and format for CNN input
    
    Parameters:
    -----------
    processed_data : mne.io.Raw
        The raw EEG data
    duration : float
        Duration of each epoch in seconds
    overlap : float
        Overlap between epochs in seconds
    
    Returns:
    --------
    epochs_array : numpy.ndarray
        The epoched data formatted for CNN (samples, channels, timepoints, 1)
    """

    # Create epochs
    epochs = mne.make_fixed_length_epochs(
        processed_data,
        duration=duration,
        overlap=overlap,
        preload=True
    )

    # Drop bad epochs
    epochs.drop_bad()

    # Get data and reshape for CNN
    # Shape will be (n_epochs, n_channels, n_timepoints)
    data = epochs.get_data()

    # Add channel dimension for CNN: (n_epochs, n_channels, n_timepoints, 1)
    data = data[..., np.newaxis]

    return data


def preprocess_eeg(raw_data, l_freq=0.5, h_freq=60.0, notch_freq=50.0, n_components=2, epoch_duration=5.0, epoch_overlap=1.0):
    """
    Complete EEG preprocessing pipeline: filter -> bad channel removal -> epoching -> ICA -> baseline correction
    """
    try:
        # print(f"\nProcessing file: {filename}")

        # Make a copy of the raw data to prevent modifications to original
        processed_raw = raw_data.copy()

        # 1. Bandpass filtering
        print("1. Applying bandpass filter...")
        try:
            bandpass_filter(processed_raw, l_freq, h_freq, notch_freq)
            print("Bandpass filtering completed")
        except Exception as e:
            print(f"Error during bandpass filtering: {str(e)}")
            return None

        # 2. Bad channel removal
        print("2. Removing bad channels...")
        try:
            processed_raw = process_channels(raw_data=processed_raw)
            print("Bad channels removed")
        except Exception as e:
            print(f"Error during bad channel removal: {str(e)}")
            return None

        # 3. Epoching
        print("3. Creating epochs...")
        try:
            epochs = mne.make_fixed_length_epochs(
                processed_raw,
                duration=epoch_duration,
                overlap=epoch_overlap,
                preload=True
            )

            # Drop bad epochs
            epochs.drop_bad()

            # Get data and reshape for CNN
            data = epochs.get_data()
            data = data[..., np.newaxis]  # Add channel dimension for CNN

            print(f"Epoching completed. Final data shape: {data.shape}")
        except Exception as e:
            print(f"Error during epoching: {str(e)}")
            return None

        # 4. ICA
        print("4. Applying ICA...")
        try:
            # Pass the epochs object
            ica = preprocess_ICA(epochs, n_components)
            ica.apply(epochs)  # Apply ICA to the epochs
            print("ICA completed")
        except Exception as e:
            print(f"Error during ICA: {str(e)}")
            return None

        # 5. Baseline correction
        print("5. Applying baseline correction...")
        try:
            # Apply baseline correction over the entire epoch
            epochs.apply_baseline((None, None))
            print("Baseline correction completed")
        except Exception as e:
            print(f"Error during baseline correction: {str(e)}")
            return None

        return data  # Return the processed data

    except Exception as e:
        print(f"General preprocessing error: {str(e)}")
        return None


def process_eeg_data(file_path, model):
    try:
        # Read the EEG file
        raw_data = read_eeg_file(file_path)

        # Preprocess the EEG data using your existing preprocessing function
        processed_data = preprocess_eeg(
            raw_data,
            l_freq=0.5,
            h_freq=60.0,
            notch_freq=50.0,
            n_components=2,
            epoch_duration=5.0,
            epoch_overlap=1.0
        )

        # Ensure processed_data is in the correct shape
        processed_data = np.squeeze(processed_data)

        # Reshape data to match model's expected input
        reshaped_data = processed_data.reshape(
            processed_data.shape[0], processed_data.shape[1], -1)

        # Convert to tensor and add batch dimension
        input_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)

        # Ensure the input tensor is 3D: (n_epochs, n_channels, n_timepoints)
        if len(input_tensor.shape) == 2:  # If it's 2D, add a channel dimension
            input_tensor = tf.expand_dims(input_tensor, axis=-1)

        # Make prediction
        predict_fn = model.signatures['serving_default']
        predictions = predict_fn(input_tensor)
        output_key = list(predictions.keys())[0]
        preds = predictions[output_key].numpy()

        # Apply a threshold to determine class predictions
        threshold = 0.5
        class_predictions = (preds[0] > threshold).astype(int)

        print(preds)

        return [{
            "class": str(class_prediction),
            "probability": float(prob)
        } for class_prediction, prob in zip(class_predictions, preds[0])]

    except Exception as e:
        print(f"Error processing EEG data: {str(e)}")
        return [{"error": f"Failed to process EEG data: {str(e)}"}]

    

def load_model(model_path):
    try:
        # Print path to debug
        print(f"Looking for model at: {model_path}")
        print(f"Directory exists: {os.path.exists(model_path)}")

        # Load model
        model = tf.saved_model.load(model_path)
        print("Model loaded successfully")
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None
    

    # Specify the path to your EDF file and model
edf_file_path = './edf_dataset_2/MDD_S6_EC.edf'
model_path = './saved_model/1d_cnn/1st_model/'

# Load the model
model = load_model(model_path)

# Example usage
if model is not None:
    results = process_eeg_data(edf_file_path, model)
    print(results)
    
else:
    print("Model could not be loaded. Predictions cannot be made.")

# Define your actual channel names (replace with your specific channel labels)
channel_names = [
    'Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'Fp2',
    'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Cz', 'Pz'
]


def explain_prediction(file_path, model):
    # First get the processed data using existing function
    raw_data = read_eeg_file(file_path)
    processed_data = preprocess_eeg(
        raw_data,
        l_freq=0.5,
        h_freq=60.0,
        notch_freq=50.0,
        n_components=2,
        epoch_duration=5.0,
        epoch_overlap=1.0
    )

    # Reshape the processed data to 2D format for LIME
    flattened_data = processed_data.reshape(processed_data.shape[0], -1)

    # Create wrapper function for model predictions
    def model_predict(data):
        # Reshape back to model's expected format
        reshaped_data = data.reshape(-1, 17, 1280)
        data_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)
        predict_fn = model.signatures['serving_default']
        predictions = predict_fn(data_tensor)
        output_key = list(predictions.keys())[0]
        probs = predictions[output_key].numpy()

        # Convert single probability to two-class probability distribution
        two_class_probs = np.zeros((probs.shape[0], 2))
        two_class_probs[:, 1] = probs.flatten()
        two_class_probs[:, 0] = 1 - probs.flatten()
        return two_class_probs

    # Create LIME explainer
    explainer = LimeTabularExplainer(
        training_data=flattened_data,
        feature_names=[f"{channel}_{t}" for channel in channel_names
                       for t in range(processed_data.shape[2])],
        class_names=['Normal', 'MDD'],
        mode='classification'
    )

    # Get explanation for first instance
    instance_to_explain = flattened_data[0]
    explanation = explainer.explain_instance(
        instance_to_explain,
        model_predict,
        num_features=10
    )

    return explanation

# Usage
explanation = explain_prediction(edf_file_path, model)
# View results
print(explanation.as_list())

Looking for model at: ./saved_model/1d_cnn/1st_model/
Directory exists: True
Model loaded successfully
Extracting EDF parameters from /Users/hansandreanto/Development/capstone-project/edf_dataset_2/MDD_S6_EC.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 77311  =      0.000 ...   301.996 secs...
1. Applying bandpass filter...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 60 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 60.00 Hz
- Upper transition bandwidth: 15.00 Hz (-6 dB cutoff frequency: 67.50 Hz)
- Filter length: 1691 samples (6.605 sec)

Setting up band-stop filter from 48 - 52 Hz

FIR 

  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  l_freq = cast(l_freq)
  h_freq = cast(h_freq)
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  logger.info('- Lower passband edge: %0.2f' % (l_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  logger.info('- Upper passband edge: %0.2f Hz' % (h_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  float(min(h_check, l_check)),)
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,


ICA completed
5. Applying baseline correction...
Applying baseline correction (mode: mean)
Baseline correction completed
[[0.8810837 ]
 [0.87768567]
 [0.82781464]
 [0.92183346]
 [0.883632  ]
 [0.8369908 ]
 [0.8843969 ]
 [0.88429546]
 [0.81707436]
 [0.8676273 ]
 [0.9148272 ]
 [0.922066  ]
 [0.8101219 ]
 [0.7273692 ]
 [0.83795357]
 [0.8766037 ]
 [0.6507389 ]
 [0.7897131 ]
 [0.91054595]
 [0.9141953 ]
 [0.8906377 ]
 [0.9163796 ]
 [0.7924859 ]
 [0.90501493]
 [0.88552177]
 [0.85247445]
 [0.7833756 ]
 [0.7813268 ]
 [0.9116129 ]
 [0.9105847 ]
 [0.9056789 ]
 [0.85142154]
 [0.9209563 ]
 [0.9153217 ]
 [0.89807105]
 [0.87573   ]
 [0.8899173 ]
 [0.88219935]
 [0.89894015]
 [0.8908807 ]
 [0.82569623]
 [0.9207742 ]
 [0.9201482 ]
 [0.89967376]
 [0.841427  ]
 [0.8329116 ]
 [0.90828705]
 [0.89366376]
 [0.9021751 ]
 [0.8609036 ]
 [0.8522613 ]
 [0.8883957 ]
 [0.76264393]
 [0.8311075 ]
 [0.90432096]
 [0.8921838 ]
 [0.9181649 ]
 [0.9403085 ]
 [0.6598822 ]
 [0.9329643 ]
 [0.92724264]
 [0.9146163 ]
 [0.9036095

  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  l_freq = cast(l_freq)
  h_freq = cast(h_freq)
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  logger.info('- Lower passband edge: %0.2f' % (l_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  logger.info('- Upper passband edge: %0.2f Hz' % (h_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  float(min(h_check, l_check)),)
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,


Fitting ICA took 0.1s.
Applying ICA to Epochs instance
    Transforming to ICA space (2 components)
    Zeroing out 0 ICA components
    Projecting back using 2 PCA components
ICA completed
5. Applying baseline correction...
Applying baseline correction (mode: mean)
Baseline correction completed
[('P4_964 <= -0.00', -0.010750992255603732), ('P3_21 > 0.00', -0.010234267848386814), ('-0.00 < T6_507 <= 0.00', -0.009097230785845444), ('0.00 < C4_988 <= 0.00', -0.009006919096134611), ('Pz_985 > 0.00', 0.008721911964479094), ('O2_572 > 0.00', -0.008524995418641722), ('Fp2_990 > 0.00', -0.008196115947299043), ('Cz_761 <= -0.00', -0.008064980799966497), ('0.00 < C4_167 <= 0.00', 0.007947887588628652), ('-0.00 < F4_233 <= 0.00', -0.004187127098497771)]


In [None]:
# # Ensure processed_data is in the correct shape
# processed_data = np.squeeze(processed_data)

# # Reshape data to match model's expected input
# reshaped_data = processed_data.reshape(
#     processed_data.shape[0], processed_data.shape[1], -1)

# # Convert to tensor and add batch dimension
# input_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)

# # Ensure the input tensor is 3D: (n_epochs, n_channels, n_timepoints)
# if len(input_tensor.shape) == 2:  # If it's 2D, add a channel dimension
#     input_tensor = tf.expand_dims(input_tensor, axis=-1)

# # Make prediction
# predictions = predict_fn(input_tensor)
# output_key = list(predictions.keys())[0]
# preds = predictions[output_key].numpy()

# # Debugging: Print the raw predictions output
# print(f"Raw predictions output: {preds}")

# # Apply a threshold to determine class predictions
# threshold = 0.55
# class_predictions = (preds[0] > threshold).astype(int)

# return [{
#     "class": str(class_prediction),
#     "probability": float(prob)
# } for class_prediction, prob in zip(class_predictions, preds[0])]

In [None]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

# Assuming you have the following variables from the previous code
# X_data, y_labels, class_counts

# Concatenate the X_data list into a single numpy array
X = np.concatenate(X_data, axis=0)

# Convert y_labels to a numpy array
y = np.array(y_labels)

# Define the CNN model
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=X.shape[1:]))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(128, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# Compile the model
model.compile(optimizer='adam', loss='binary_crossentropy',
              metrics=['accuracy'])

# Train the model
model.fit(X, y, epochs=10, batch_size=32)

y_pred = (model.predict(X) > 0.5).astype(int)
acc = accuracy_score(y, y_pred)
f1 = f1_score(y, y_pred)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)

print(f"Accuracy: {acc:.2f}")
print(f"F1-score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

In [None]:
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import BatchNormalization
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
import matplotlib.pyplot as plt

# Assuming you have the following variables from the previous code
# X_data, y_labels, class_counts

# Concatenate the X_data list into a single numpy array
X = np.concatenate(X_data, axis=0)

# Convert y_labels to a numpy array
y = np.array(y_labels)

# Check the input shape
print(f"Input shape: {X.shape}")

# Define the CNN model
model = Sequential()
model.add(Conv2D(64, (3, 3), activation='relu', input_shape=X.shape[1:]))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(256, (3, 3), activation='relu'))
# Adjust the pooling size to match the input shape
model.add(MaxPooling2D((1, 2)))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1, activation='sigmoid'))
model.add(BatchNormalization())
model.add(Dropout(0.5))

# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001),
              loss='binary_crossentropy', metrics=['accuracy'])

# Train the model
history = model.fit(X, y, epochs=10, batch_size=32, validation_split=0.2)

# Evaluate the model
y_pred = (model.predict(X) > 0.5).astype(int)
acc = accuracy_score(y, y_pred)
f1 = f1_score(y, y_pred)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)

print(f"Accuracy: {acc:.2f}")
print(f"F1-score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

# Plot the training and validation accuracy
plt.figure(figsize=(8, 6))
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.title('Model Accuracy')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper left')
plt.show()

# Plot the training and validation loss
plt.figure(figsize=(8, 6))
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.title('Model Loss')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend(['Train', 'Validation'], loc='upper right')
plt.show()

In [None]:
# def preprocess_eeg(raw_data, l_freq=0.5, h_freq=50.0, n_components=5, epoch_duration=5.0, epoch_overlap=1.0):
#     """
#     Complete EEG preprocessing pipeline: bandpass -> ICA -> epoching

#     Parameters:
#     -----------
#     raw_data : mne.io.Raw
#         Raw EEG data
#     l_freq : float
#         Lower frequency bound for bandpass filter
#     h_freq : float
#         Higher frequency bound for bandpass filter
#     n_components : int
#         Number of ICA components
#     epoch_duration : float
#         Duration of each epoch in seconds
#     epoch_overlap : float
#         Overlap between epochs in seconds

#     Returns:
#     --------
#     epochs_array : numpy.ndarray
#         Preprocessed and epoched data ready for CNN (samples, channels, timepoints, 1)
#     """
#     try:
#         print(f"Processing file: {raw_data.filenames}")

#         processed_raw = process_channels(raw_data=raw_data)

#         # 1. Bandpass filtering
#         print("Applying bandpass filter...")
#         bandpass_filter(processed_raw, 0.5, 45)

#         # 2. ICA
#         print("Applying ICA...")
#         preprocess_ICA(processed_raw, 5)

#         # Find and remove EOG artifacts
#         # eog_indices, eog_scores = ica.find_bads_eog(raw_data)
#         # if eog_indices:
#         #     print(f'Found {len(eog_indices)} EOG components')
#         #     ica.exclude = eog_indices
#         # ica.apply(raw_data)

#         # 3. Epoching
#         print("Creating epochs...")
#         create_epochs(processed_raw)


#     except Exception as e:
#         print(f"Error during preprocessing: {str(e)}")
#         return None

In [1]:
# Import Required Libraries
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.models import load_model  # Save model in SavedModel format
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout
from tensorflow.keras.models import Sequential
import re  # Import the regular expressions module
import numpy as np
import matplotlib.pyplot as plt
import requests

# To unzip the edf_dataset
import zipfile
import os

# EDFlib and Data Preprocesing module
from mne.preprocessing import ICA, create_eog_epochs
import mne
from pyedflib import highlevel
import pyedflib as plib

import requests


def download_file(url, save_path):

    # Check if the file already exists
    if os.path.exists(save_path):
        print(f"File already exists at '{save_path}'. Skipping download.")
        return  # Exit the function if the file exists

    # Send a GET request to the URL
    response = requests.get(url)

    # Check if the request was successful
    if response.status_code == 200:
        # Open the file in binary write mode and save the content
        with open(save_path, 'wb') as file:
            file.write(response.content)
        print(f"File downloaded successfully and saved to '{save_path}'")
    else:
        print(f"Failed to download file. Status code: {response.status_code}")


# Specify the URL and the path where you want to save the file
url = 'https://figshare.com/ndownloader/articles/4244171/versions/2'
# Change this to your desired path
save_path = './edf_dataset.zip'

# Call the function to download the file
download_file(url, save_path)


def unzip_file(zip_file_path, extract_to_folder):

    # Check if the directory exist

    if os.path.exists(extract_to_folder):
        print(f"Directory '{extract_to_folder} already exists")
        return  # Exit the function if the directory

    # Create the directory if it doesn't exist
    os.makedirs(extract_to_folder, exist_ok=True)

    # Open the zip file
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        # Extract all the contents into the specified folder
        zip_ref.extractall(extract_to_folder)


# Specify the path to the zip file and the extraction folder
zip_file_path = './edf_dataset.zip'
# Change this if needed
extract_to_folder = './edf_dataset_2'

# Call the function to unzip
unzip_file(zip_file_path, extract_to_folder)


# edf_directory
edf_directory = "./edf_dataset_2"
# Loop through all files in the specified directory
for filename in os.listdir(edf_directory):
    # Check if the filename contains spaces
    if ' ' in filename:
        # Create a new filename by replacing multiple spaces with a single underscore
        # Replace one or more spaces with a single underscore
        new_filename = re.sub(r'\s+', '_', filename)

        # Get the full path for the old and new filenames
        old_file = os.path.join(edf_directory, filename)
        new_file = os.path.join(edf_directory, new_filename)

        # Rename the file
        os.rename(old_file, new_file)
        print(f'Renamed: "{filename}" to "{new_filename}"')
    elif '6931959' in filename:
        new_filename = filename.replace('6921959_', '')

        # Get the full path for the old and new filenames
        old_file = os.path.join(edf_directory, filename)
        new_file = os.path.join(edf_directory, new_filename)

        # Rename the file
        os.rename(old_file, new_file)
        print(f'Renamed: "{filename}" to "{new_filename}"')
    elif '6921143' in filename:
        new_filename = filename.replace('6921143_', '')
        # Get the full path for the old and new filenames
        old_file = os.path.join(edf_directory, filename)
        new_file = os.path.join(edf_directory, new_filename)

        # Rename the file
        os.rename(old_file, new_file)
        print(f'Renamed: "{filename}" to "{new_filename}"')
print("Renaming complete.")

# Function to rename channels and drop specified channels based on conditions


def process_channels(raw_data):
    """
    Process and standardize EEG channels to keep only the 17 most common channels.
    """
    print(f"Initial channels: {raw_data.ch_names}")

    # Initialize a list to hold channels to drop
    channels_to_drop = []

    # Create mapping for channel renaming
    rename_map = {}
    for name in raw_data.ch_names:
        if any(x in name for x in ['23A-23R', '24A-24R', 'A2-A1']):
            channels_to_drop.append(name)
        else:
            new_name = name.replace('EEG ', '').replace('-LE', '')
            rename_map[name] = new_name

    # Drop unwanted channels
    if channels_to_drop:
        print(f"Dropping channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Rename remaining channels
    raw_data.rename_channels(rename_map)

    print(f"Final channels: {raw_data.ch_names}")

    # Define the 17 most common channels
    expected_channels = [
        'Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'Fp2',
        'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Cz', 'Pz'
    ]

    # Keep only the expected channels
    channels_to_keep = set(expected_channels)
    channels_to_drop = [
        ch for ch in raw_data.ch_names if ch not in channels_to_keep]

    if channels_to_drop:
        print(
            f"Dropping channels to keep only the expected 17 channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Verify we have the expected number of channels (should be 17)
    if len(raw_data.ch_names) != len(expected_channels):
        print(
            f"Warning: Expected {len(expected_channels)} channels, got {len(raw_data.ch_names)}")
        print(f"Missing: {set(expected_channels) - set(raw_data.ch_names)}")

    return raw_data


all_edf_files = os.listdir(edf_directory)
ec_file_path = [i for i in all_edf_files if i.endswith('EC.edf')]
eo_file_path = [i for i in all_edf_files if i.endswith('EO.edf')]
task_file_path = [i for i in all_edf_files if i.endswith('TASK.edf')]

print(len(all_edf_files), len(ec_file_path),
      len(eo_file_path), len(task_file_path))





# Directory containing the EDF files
edf_directory = "./edf_dataset_2"  # Adjust this path to your dataset location

# Initialize lists
processed_raw_data = []
class_counts = {'Healthy': 0, 'MDD': 0}

# Read all EDF files
for filename in os.listdir(edf_directory):
    if filename.endswith('.edf'):
        file_path = os.path.join(edf_directory, filename)
        try:
            # Read the raw data
            raw_data = read_data(file_path)

            if raw_data is not None:
                processed_raw_data.append(raw_data)
                print(f"Successfully loaded: {filename}")
            else:
                print(f"Failed to load: {filename}")

        except Exception as e:
            print(f"Error loading {filename}: {str(e)}")
            continue

print(f"\nTotal files loaded: {len(processed_raw_data)}")


def read_data(file_path):
    data = mne.io.read_raw_edf(file_path, preload=True)
    data.set_eeg_reference()
    return data


def bandpass_filter(data, l_freq, h_freq):
    # Adjust the filter parameters as needed
    data.filter(l_freq=l_freq, h_freq=h_freq)


def preprocess_ICA(raw, n_components):
    print(f"Preprocessing ICA: {raw.filenames}")

    ica = ICA(n_components=n_components, random_state=97,
              max_iter=800)
    ica.fit(raw)
    return ica


def create_epochs(processed_data, duration=5.0, overlap=1.0):
    """
    Create epochs from continuous EEG data and format for CNN input
    
    Parameters:
    -----------
    processed_data : mne.io.Raw
        The raw EEG data
    duration : float
        Duration of each epoch in seconds
    overlap : float
        Overlap between epochs in seconds
    
    Returns:
    --------
    epochs_array : numpy.ndarray
        The epoched data formatted for CNN (samples, channels, timepoints, 1)
    """

    # Create epochs
    epochs = mne.make_fixed_length_epochs(
        processed_data,
        duration=duration,
        overlap=overlap,
        preload=True
    )

    # Drop bad epochs
    epochs.drop_bad()

    # Get data and reshape for CNN
    # Shape will be (n_epochs, n_channels, n_timepoints)
    data = epochs.get_data()

    # Add channel dimension for CNN: (n_epochs, n_channels, n_timepoints, 1)
    data = data[..., np.newaxis]

    return data
def preprocess_eeg(raw_data, l_freq=0.5, h_freq=50.0, n_components=5, epoch_duration=5.0, epoch_overlap=1.0):
    """
    Complete EEG preprocessing pipeline: bandpass -> ICA -> epoching
    """
    try:
        print(f"\nProcessing file: {raw_data.filenames}")

        # Make a copy of the raw data to prevent modifications to original

        processed_raw = process_channels(raw_data=raw_data)

        # 1. Bandpass filtering
        print("1. Applying bandpass filter...")
        try:
            bandpass_filter(processed_raw, l_freq, h_freq)
            print("Bandpass filtering completed")
        except Exception as e:
            print(f"Error during bandpass filtering: {str(e)}")
            return None

        # 2. ICA
        print("2. Applying ICA...")
        try:
            ica = preprocess_ICA(processed_raw, n_components)
            ica.apply(processed_raw)
            print("ICA completed")
        except Exception as e:
            print(f"Error during ICA: {str(e)}")
            return None

        # 3. Epoching
        print("3. Creating epochs...")
        try:
            epochs = mne.make_fixed_length_epochs(
                processed_raw,
                duration=epoch_duration,
                overlap=epoch_overlap,
                preload=True
            )

            # Drop bad epochs
            epochs.drop_bad()

            # Get data and reshape for CNN
            data = epochs.get_data()
            data = data[..., np.newaxis]  # Add channel dimension for CNN

            print(f"Epoching completed. Final data shape: {data.shape}")
            return data

        except Exception as e:
            print(f"Error during epoching: {str(e)}")
            return None

    except Exception as e:
        print(f"General preprocessing error: {str(e)}")
        return None


# Now process each raw data file
X_data = []
y_labels = []
class_counts = {'Healthy': 0, 'MDD': 0}

print("\nStarting preprocessing pipeline...")
for raw_data in processed_raw_data:
    filename = os.path.basename(raw_data.filenames[0])
    print(f"\n{'='*50}")
    print(f"Processing: {filename}")
    print(f"Initial data info:")
    print(f"Channels: {raw_data.ch_names}")
    print(f"Sample rate: {raw_data.info['sfreq']} Hz")
    print(f"Duration: {raw_data.n_times / raw_data.info['sfreq']:.2f} seconds")

    # Apply complete preprocessing pipeline
    try:
        # First, process the channels
        # Make a copy to prevent modifying original
        raw_data = process_channels(raw_data.copy())
        print(f"Channels after processing: {raw_data.ch_names}")

        # Apply complete preprocessing pipeline
        processed_data = preprocess_eeg(
            raw_data,
            l_freq=0.5,
            h_freq=50.0,
            n_components=5,
            epoch_duration=5.0,
            epoch_overlap=1.0
        )

        if processed_data is not None:
            print(f"Processed data shape: {processed_data.shape}")

            # Create label (1 for MDD, 0 for Healthy)
            label = 1 if filename.startswith('MDD') else 0

            # Update counts with number of epochs
            if label == 1:
                class_counts['MDD'] += processed_data.shape[0]
            else:
                class_counts['Healthy'] += processed_data.shape[0]

            # Append to lists
            X_data.append(processed_data)
            y_labels.extend([label] * processed_data.shape[0])
            print(f"Successfully processed {filename}")
        else:
            print(f"Failed to process {filename}")

    except Exception as e:
        print(f"Error processing {filename}: {str(e)}")
        continue

    print(f"{'='*50}\n")


# Print final summary
print("\nProcessing Summary:")
print(f"Total files processed: {len(processed_raw_data)}")
print(f"Total epochs: {len(y_labels)}")
print(f"Successfully processed files: {len(X_data)}")
print(f"Failed files: {len(processed_raw_data) - len(X_data)}")
print("\nClass distribution:")
print(f"MDD epochs: {class_counts['MDD']}")
print(f"Healthy epochs: {class_counts['Healthy']}")
# print(f"\nChannels used: {raw_data.ch_names}")

# # Optional: Print class balance percentage
# total_epochs = class_counts['MDD'] + class_counts['Healthy']
# print("\nClass balance:")
# print(f"MDD: {(class_counts['MDD']/total_epochs)*100:.2f}%")
# print(f"Healthy: {(class_counts['Healthy']/total_epochs)*100:.2f}%")

# Check if we have any processed data
if len(X_data) > 0:
    try:
        # Print shapes before concatenation
        print("\nArray shapes before concatenation:")
        for i, arr in enumerate(X_data):
            print(f"Array {i}: shape {arr.shape}")

        # Concatenate data
        X = np.concatenate(X_data, axis=0)
        y = np.array(y_labels)

        # Print final information
        print("\nFinal Dataset Information:")
        print(f"Total samples: {len(X)}")
        print(f"Healthy samples: {class_counts['Healthy']}")
        print(f"MDD samples: {class_counts['MDD']}")
        print(f"Input shape: {X.shape}")
        print(f"Labels shape: {y.shape}")

        # Print channel information
        print(f"\nChannels used: {raw_data.ch_names}")

        # Print class balance
        total_epochs = class_counts['MDD'] + class_counts['Healthy']
        print("\nClass balance:")
        print(f"MDD: {(class_counts['MDD']/total_epochs)*100:.2f}%")
        print(f"Healthy: {(class_counts['Healthy']/total_epochs)*100:.2f}%")

    except Exception as e:
        print(f"\nError during final processing: {str(e)}")
        print("Checking individual arrays for inconsistencies...")

        # Find arrays with different shapes
        base_shape = X_data[0].shape[1:]
        for i, arr in enumerate(X_data):
            if arr.shape[1:] != base_shape:
                print(
                    f"Mismatch at index {i}: expected {base_shape}, got {arr.shape[1:]}")
else:
    print("\nNo data was successfully processed!")


# Print original shape
print("Original shape:", X.shape)  # Should be (18312, 19, 1280, 1)

learning_rate = 0.001
# Remove the extra dimension
X = X.squeeze(-1)  # Shape becomes (18312, 19, 1280)
print("Reshaped data shape:", X.shape)

# Convert labels to numpy array
y = np.array(y_labels)

# Define the CNN model
model = Sequential([
    # Input layer - Conv1D
    Conv1D(
        filters=32,
        kernel_size=11,
        activation='relu',
        padding='same',  # Add padding to prevent size issues
        input_shape=(17, 1280)
    ),

    # Second Conv1D layer
    Conv1D(
        filters=64,
        kernel_size=11,
        activation='relu',
        padding='same'  # Add padding
    ),

    # Dropout layer
    Dropout(0.2),

    # MaxPooling layer
    MaxPooling1D(pool_size=4),

    # Flatten layer
    Flatten(),

    # Dense layer
    Dense(100, activation='relu'),

    # Output layer
    Dense(1, activation='sigmoid')
])

# Compile the model
model.compile(
    optimizer=Adam(learning_rate=learning_rate),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

# Print model summary
print("\nModel Architecture:")
model.summary()


# early_stopping = EarlyStopping(
#     monitor='val_loss', patience=5, restore_best_weights=True)

# Train the model with validation split
history = model.fit(
    X, y,
    epochs=25,
    batch_size=32,
    validation_split=0.2,  # Added validation split
    # callbacks=[early_stopping],
    verbose=1
)


# Evaluate the model
y_pred = (model.predict(X) > 0.5).astype(int)
acc = accuracy_score(y, y_pred)
f1 = f1_score(y, y_pred)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)

print("\nModel Performance:")
print(f"Accuracy: {acc:.2f}")
print(f"F1-score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

MODEL_PATH = './saved_model/1d_cnn/1st_model'
tf.keras.models.save_model(model, MODEL_PATH)
print(f"Model saved to: {MODEL_PATH}")
model = load_model(MODEL_PATH)


# Load the saved model
MODEL_PATH = './saved_model/1d_cnn/1st_model'
model = load_model(MODEL_PATH)

# Function to read the EEG data from an EDF file


def read_eeg_file(file_path):
    # Read the raw data
    raw_data = mne.io.read_raw_edf(file_path, preload=True)
    return raw_data


# Specify the path to your EDF file
edf_file_path = './edf_dataset_2/MDD_S1_EC.edf'

# Read the EEG file
raw_data = read_eeg_file(edf_file_path)

# Preprocess the EEG data using your existing preprocessing function
# Assuming your preprocessing function is defined as follows:
# def preprocess_eeg(raw_data, l_freq=0.5, h_freq=50.0, n_components=5, epoch_duration=5.0, epoch_overlap=1.0):
#     # Your existing preprocessing code here
processed_data = preprocess_eeg(
    raw_data,
    l_freq=0.5,
    h_freq=50.0,
    n_components=5,
    epoch_duration=5.0,
    epoch_overlap=1.0
)

# Make predictions
predictions = model.predict(processed_data)
# Convert probabilities to class labels
threshold = 0.5
predicted_classes = (predictions >= threshold).astype(int)

# Print the predicted classes
print("Predicted classes:", predicted_classes)

# Print the predictions
print("Model predictions:", predictions)

KeyboardInterrupt: 

In [None]:
import mne
import numpy as np
import tensorflow as tf


def read_eeg_file(file_path):
    # Read the raw data
    raw_data = mne.io.read_raw_edf(file_path, preload=True)
    return raw_data


def process_eeg_data(file_path):
    try:
        # Read the EEG file
        raw_data = read_eeg_file(file_path)

        # Preprocess the EEG data using your existing preprocessing function
        processed_data = preprocess_eeg(
            raw_data,
            l_freq=0.5,
            h_freq=50.0,
            n_components=5,
            epoch_duration=5.0,
            epoch_overlap=1.0
        )

        # Ensure processed_data is in the correct shape
        # Assuming processed_data is in the shape (n_epochs, n_channels, n_timepoints, 1)
        # Shape becomes (n_epochs, n_channels, n_timepoints)
        processed_data = np.squeeze(processed_data)

        # Reshape data to match model's expected input
        # Flatten the data to (n_epochs, n_timepoints) if necessary
        # Shape becomes (n_epochs, n_channels, n_timepoints)
        reshaped_data = processed_data.reshape(
            processed_data.shape[0], processed_data.shape[1], -1)

        # Convert to tensor and add batch dimension
        input_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)

        # Ensure the input tensor is 3D: (n_epochs, n_channels, n_timepoints)
        if len(input_tensor.shape) == 2:  # If it's 2D, add a channel dimension
            # Shape becomes (n_epochs, n_timepoints, 1)
            input_tensor = tf.expand_dims(input_tensor, axis=-1)

        # Make prediction
        predictions = predict_fn(input_tensor)
        output_key = list(predictions.keys())[0]
        preds = predictions[output_key].numpy()

        return [{
            "class": str(i),
            "probability": float(prob)
        } for i, prob in enumerate(preds[0])]

    except Exception as e:
        print(f"Error processing EEG data: {str(e)}")
        return [{"error": f"Failed to process EEG data: {str(e)}"}]


# Specify the path to your EDF file
edf_file_path = './edf_dataset_2/H_S9_TASK.edf'

# Example usage
results = process_eeg_data(edf_file_path)
print(results)

In [1]:
import mne
import numpy as np
import tensorflow as tf
import os

# Function to rename channels and drop specified channels based on conditions


def process_channels(raw_data):
    """
    Process and standardize EEG channels to keep only the 17 most common channels.
    """
    print(f"Initial channels: {raw_data.ch_names}")

    # Initialize a list to hold channels to drop
    channels_to_drop = []

    # Create mapping for channel renaming
    rename_map = {}
    for name in raw_data.ch_names:
        if any(x in name for x in ['23A-23R', '24A-24R', 'A2-A1']):
            channels_to_drop.append(name)
        else:
            new_name = name.replace('EEG ', '').replace('-LE', '')
            rename_map[name] = new_name

    # Drop unwanted channels
    if channels_to_drop:
        print(f"Dropping channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Rename remaining channels
    raw_data.rename_channels(rename_map)

    print(f"Final channels: {raw_data.ch_names}")

    # Define the 17 most common channels
    expected_channels = [
        'Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'Fp2',
        'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Cz', 'Pz'
    ]

    # Keep only the expected channels
    channels_to_keep = set(expected_channels)
    channels_to_drop = [
        ch for ch in raw_data.ch_names if ch not in channels_to_keep]

    if channels_to_drop:
        print(
            f"Dropping channels to keep only the expected 17 channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Verify we have the expected number of channels (should be 17)
    if len(raw_data.ch_names) != len(expected_channels):
        print(
            f"Warning: Expected {len(expected_channels)} channels, got {len(raw_data.ch_names)}")
        print(f"Missing: {set(expected_channels) - set(raw_data.ch_names)}")

    return raw_data


def read_data(file_path):
    data = mne.io.read_raw_edf(file_path, preload=True)
    data.set_eeg_reference()
    return data


def bandpass_filter(data, l_freq, h_freq, notch_freq=None):
    filtered_data = data.copy()

    # Apply bandpass first
    filtered_data.filter(l_freq=l_freq, h_freq=h_freq,
                         method='fir', phase='zero')

    # If using notch, apply with wider bandwidth
    if notch_freq is not None:
        filtered_data.notch_filter(freqs=notch_freq, notch_widths=2.0)

    return filtered_data


ica_channels = ['Fp1', 'Fp2']


def preprocess_ICA(epochs, n_components):
    """
    Apply Independent Component Analysis (ICA) to the epochs data.
    
    Parameters:
    -----------
    epochs : mne.Epochs
        The epochs data to process.
    n_components : int
        The number of components to extract.
    
    Returns:
    --------
    ica : ICA
        The fitted ICA object.
    """
    print(
        f"Preprocessing ICA for {len(epochs)} epochs...")  # Print the number of epochs being processed

    ica = ICA(n_components=n_components, random_state=97, max_iter=800)
    # Use the epochs directly
    ica.fit(epochs.copy().pick_channels(ica_channels))
    return ica


def create_epochs(processed_data, duration=5.0, overlap=1.0):
    """
    Create epochs from continuous EEG data and format for CNN input
    
    Parameters:
    -----------
    processed_data : mne.io.Raw
        The raw EEG data
    duration : float
        Duration of each epoch in seconds
    overlap : float
        Overlap between epochs in seconds
    
    Returns:
    --------
    epochs_array : numpy.ndarray
        The epoched data formatted for CNN (samples, channels, timepoints, 1)
    """

    # Create epochs
    epochs = mne.make_fixed_length_epochs(
        processed_data,
        duration=duration,
        overlap=overlap,
        preload=True
    )

    # Drop bad epochs
    epochs.drop_bad()

    # Get data and reshape for CNN
    # Shape will be (n_epochs, n_channels, n_timepoints)
    data = epochs.get_data()

    # Add channel dimension for CNN: (n_epochs, n_channels, n_timepoints, 1)
    data = data[..., np.newaxis]

    return data


def preprocess_eeg(raw_data, l_freq=0.5, h_freq=60.0, notch_freq=50.0, n_components=2, epoch_duration=5.0, epoch_overlap=1.0):
    """
    Complete EEG preprocessing pipeline: filter -> bad channel removal -> epoching -> ICA -> baseline correction
    """
    try:
        # print(f"\nProcessing file: {filename}")

        # Make a copy of the raw data to prevent modifications to original
        processed_raw = raw_data.copy()

        # 1. Bandpass filtering
        print("1. Applying bandpass filter...")
        try:
            bandpass_filter(processed_raw, l_freq, h_freq, notch_freq)
            print("Bandpass filtering completed")
        except Exception as e:
            print(f"Error during bandpass filtering: {str(e)}")
            return None

        # 2. Bad channel removal
        print("2. Removing bad channels...")
        try:
            processed_raw = process_channels(raw_data=processed_raw)
            print("Bad channels removed")
        except Exception as e:
            print(f"Error during bad channel removal: {str(e)}")
            return None

        # 3. Epoching
        print("3. Creating epochs...")
        try:
            epochs = mne.make_fixed_length_epochs(
                processed_raw,
                duration=epoch_duration,
                overlap=epoch_overlap,
                preload=True
            )

            # Drop bad epochs
            epochs.drop_bad()

            # Get data and reshape for CNN
            data = epochs.get_data()
            data = data[..., np.newaxis]  # Add channel dimension for CNN

            print(f"Epoching completed. Final data shape: {data.shape}")
        except Exception as e:
            print(f"Error during epoching: {str(e)}")
            return None

        # 4. ICA
        print("4. Applying ICA...")
        try:
            # Pass the epochs object
            ica = preprocess_ICA(epochs, n_components)
            ica.apply(epochs)  # Apply ICA to the epochs
            print("ICA completed")
        except Exception as e:
            print(f"Error during ICA: {str(e)}")
            return None

        # 5. Baseline correction
        print("5. Applying baseline correction...")
        try:
            # Apply baseline correction over the entire epoch
            epochs.apply_baseline((None, None))
            print("Baseline correction completed")
        except Exception as e:
            print(f"Error during baseline correction: {str(e)}")
            return None

        return data  # Return the processed data

    except Exception as e:
        print(f"General preprocessing error: {str(e)}")
        return None


def read_eeg_file(file_path):
    # Read the raw data
    raw_data = mne.io.read_raw_edf(file_path, preload=True)
    return raw_data


def load_model(model_path):
    try:
        # Print path to debug
        print(f"Looking for model at: {model_path}")
        print(f"Directory exists: {os.path.exists(model_path)}")

        # Load model
        model = tf.saved_model.load(model_path)
        print("Model loaded successfully")
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None


def process_eeg_data(file_path, model):
    try:
        # Read the EEG file
        raw_data = read_eeg_file(file_path)

        # Preprocess the EEG data using your existing preprocessing function
        processed_data = preprocess_eeg(
            raw_data,
            l_freq=0.5,
            h_freq=60.0,
            notch_freq=50.0,
            n_components=2,
            epoch_duration=5.0,
            epoch_overlap=1.0
        )

        # Ensure processed_data is in the correct shape
        processed_data = np.squeeze(processed_data)

        # Reshape data to match model's expected input
        reshaped_data = processed_data.reshape(
            processed_data.shape[0], processed_data.shape[1], -1)

        # Convert to tensor and add batch dimension
        input_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)

        # Ensure the input tensor is 3D: (n_epochs, n_channels, n_timepoints)
        if len(input_tensor.shape) == 2:  # If it's 2D, add a channel dimension
            input_tensor = tf.expand_dims(input_tensor, axis=-1)

        # Make prediction
        predict_fn = model.signatures['serving_default']
        predictions = predict_fn(input_tensor)
        output_key = list(predictions.keys())[0]
        preds = predictions[output_key].numpy()

        # Apply a threshold to determine class predictions
        threshold = 0.5
        class_predictions = (preds[0] > threshold).astype(int)

        print(preds)

        return [{
            "class": str(class_prediction),
            "probability": float(prob)
        } for class_prediction, prob in zip(class_predictions, preds[0])]

    except Exception as e:
        print(f"Error processing EEG data: {str(e)}")
        return [{"error": f"Failed to process EEG data: {str(e)}"}]


# Specify the path to your EDF file and model
edf_file_path = './edf_dataset_2/MDD_S6_EC.edf'
model_path = './saved_model/1d_cnn/1st_model/'

# Load the model
model = load_model(model_path)

# Example usage
if model is not None:
    results = process_eeg_data(edf_file_path, model)
    print(results)
else:
    print("Model could not be loaded. Predictions cannot be made.")

Looking for model at: ./saved_model/1d_cnn/1st_model/
Directory exists: True
Model loaded successfully
Extracting EDF parameters from /Users/hansandreanto/Development/capstone-project/edf_dataset_2/MDD_S6_EC.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 77311  =      0.000 ...   301.996 secs...
1. Applying bandpass filter...
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 0.5 - 60 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 0.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 0.25 Hz)
- Upper passband edge: 60.00 Hz
- Upper transition bandwidth: 15.00 Hz (-6 dB cutoff frequency: 67.50 Hz)
- Filter length: 1691 samples (6.605 sec)



  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  logger.info('Setting up band-pass filter from %0.2g - %0.2g Hz'
  l_freq = cast(l_freq)
  h_freq = cast(h_freq)


Setting up band-stop filter from 48 - 52 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandstop filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 48.50
- Lower transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 48.25 Hz)
- Upper passband edge: 51.50 Hz
- Upper transition bandwidth: 0.50 Hz (-6 dB cutoff frequency: 51.75 Hz)
- Filter length: 1691 samples (6.605 sec)

Bandpass filtering completed
2. Removing bad channels...
Initial channels: ['EEG Fp1-LE', 'EEG F3-LE', 'EEG C3-LE', 'EEG P3-LE', 'EEG O1-LE', 'EEG F7-LE', 'EEG T3-LE', 'EEG T5-LE', 'EEG Fz-LE', 'EEG Fp2-LE', 'EEG F4-LE', 'EEG C4-LE', 'EEG P4-LE', 'EEG O2-LE', 'EEG F8-LE', 'EEG T4-LE', 'EEG T6-LE', 'EEG Cz-LE', 'EEG Pz-LE', 'EEG A2-A1']
Dropping channels: ['EEG A2-A1']
Final channels: ['Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'T5', 'Fz', 'Fp2', 'F4', 'C4', 'P4', 'O2', 'F8

  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  msg += ' from %0.2g - %0.2g Hz' % (h_freq, l_freq)
  logger.info('- Lower passband edge: %0.2f' % (l_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  logger.info('- Upper passband edge: %0.2f Hz' % (h_freq,))
  msg += ' (%s cutoff frequency: %0.2f Hz)' % (
  float(min(h_check, l_check)),)
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,
  this_h = firwin(this_N, (prev_freq + this_freq) / 2.,


In [None]:
import mne
import numpy as np
import tensorflow as tf
import os

# Function to rename channels and drop specified channels based on conditions


def process_channels(raw_data):
    """
    Process and standardize EEG channels to keep only the 17 most common channels.
    """
    print(f"Initial channels: {raw_data.ch_names}")

    # Initialize a list to hold channels to drop
    channels_to_drop = []

    # Create mapping for channel renaming
    rename_map = {}
    for name in raw_data.ch_names:
        if any(x in name for x in ['23A-23R', '24A-24R', 'A2-A1']):
            channels_to_drop.append(name)
        else:
            new_name = name.replace('EEG ', '').replace('-LE', '')
            rename_map[name] = new_name

    # Drop unwanted channels
    if channels_to_drop:
        print(f"Dropping channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Rename remaining channels
    raw_data.rename_channels(rename_map)

    print(f"Final channels: {raw_data.ch_names}")

    # Define the 17 most common channels
    expected_channels = [
        'Fp1', 'F3', 'C3', 'P3', 'O1', 'F7', 'T3', 'Fp2',
        'F4', 'C4', 'P4', 'O2', 'F8', 'T4', 'T6', 'Cz', 'Pz'
    ]

    # Keep only the expected channels
    channels_to_keep = set(expected_channels)
    channels_to_drop = [
        ch for ch in raw_data.ch_names if ch not in channels_to_keep]

    if channels_to_drop:
        print(
            f"Dropping channels to keep only the expected 17 channels: {channels_to_drop}")
        raw_data.drop_channels(channels_to_drop)

    # Verify we have the expected number of channels (should be 17)
    if len(raw_data.ch_names) != len(expected_channels):
        print(
            f"Warning: Expected {len(expected_channels)} channels, got {len(raw_data.ch_names)}")
        print(f"Missing: {set(expected_channels) - set(raw_data.ch_names)}")

    return raw_data


def read_data(file_path):
    data = mne.io.read_raw_edf(file_path, preload=True)
    data.set_eeg_reference()
    return data


def bandpass_filter(data, l_freq, h_freq, notch_freq=None):
    filtered_data = data.copy()

    # Apply bandpass first
    filtered_data.filter(l_freq=l_freq, h_freq=h_freq,
                         method='fir', phase='zero')

    # If using notch, apply with wider bandwidth
    if notch_freq is not None:
        filtered_data.notch_filter(freqs=notch_freq, notch_widths=2.0)

    return filtered_data


ica_channels = ['Fp1', 'Fp2']


def preprocess_ICA(epochs, n_components):
    """
    Apply Independent Component Analysis (ICA) to the epochs data.
    
    Parameters:
    -----------
    epochs : mne.Epochs
        The epochs data to process.
    n_components : int
        The number of components to extract.
    
    Returns:
    --------
    ica : ICA
        The fitted ICA object.
    """
    print(
        f"Preprocessing ICA for {len(epochs)} epochs...")  # Print the number of epochs being processed

    ica = ICA(n_components=n_components, random_state=97, max_iter=800)
    # Use the epochs directly
    ica.fit(epochs.copy().pick_channels(ica_channels))
    return ica


def create_epochs(processed_data, duration=5.0, overlap=1.0):
    """
    Create epochs from continuous EEG data and format for CNN input
    
    Parameters:
    -----------
    processed_data : mne.io.Raw
        The raw EEG data
    duration : float
        Duration of each epoch in seconds
    overlap : float
        Overlap between epochs in seconds
    
    Returns:
    --------
    epochs_array : numpy.ndarray
        The epoched data formatted for CNN (samples, channels, timepoints, 1)
    """

    # Create epochs
    epochs = mne.make_fixed_length_epochs(
        processed_data,
        duration=duration,
        overlap=overlap,
        preload=True
    )

    # Drop bad epochs
    epochs.drop_bad()

    # Get data and reshape for CNN
    # Shape will be (n_epochs, n_channels, n_timepoints)
    data = epochs.get_data()

    # Add channel dimension for CNN: (n_epochs, n_channels, n_timepoints, 1)
    data = data[..., np.newaxis]

    return data


def preprocess_eeg(raw_data, l_freq=0.5, h_freq=60.0, notch_freq=50.0, n_components=2, epoch_duration=5.0, epoch_overlap=1.0):
    """
    Complete EEG preprocessing pipeline: filter -> bad channel removal -> epoching -> ICA -> baseline correction
    """
    try:
        # print(f"\nProcessing file: {filename}")

        # Make a copy of the raw data to prevent modifications to original
        processed_raw = raw_data.copy()

        # 1. Bandpass filtering
        print("1. Applying bandpass filter...")
        try:
            bandpass_filter(processed_raw, l_freq, h_freq, notch_freq)
            print("Bandpass filtering completed")
        except Exception as e:
            print(f"Error during bandpass filtering: {str(e)}")
            return None

        # 2. Bad channel removal
        print("2. Removing bad channels...")
        try:
            processed_raw = process_channels(raw_data=processed_raw)
            print("Bad channels removed")
        except Exception as e:
            print(f"Error during bad channel removal: {str(e)}")
            return None

        # 3. Epoching
        print("3. Creating epochs...")
        try:
            epochs = mne.make_fixed_length_epochs(
                processed_raw,
                duration=epoch_duration,
                overlap=epoch_overlap,
                preload=True
            )

            # Drop bad epochs
            epochs.drop_bad()

            # Get data and reshape for CNN
            data = epochs.get_data()
            data = data[..., np.newaxis]  # Add channel dimension for CNN

            print(f"Epoching completed. Final data shape: {data.shape}")
        except Exception as e:
            print(f"Error during epoching: {str(e)}")
            return None

        # 4. ICA
        print("4. Applying ICA...")
        try:
            # Pass the epochs object
            ica = preprocess_ICA(epochs, n_components)
            ica.apply(epochs)  # Apply ICA to the epochs
            print("ICA completed")
        except Exception as e:
            print(f"Error during ICA: {str(e)}")
            return None

        # 5. Baseline correction
        print("5. Applying baseline correction...")
        try:
            # Apply baseline correction over the entire epoch
            epochs.apply_baseline((None, None))
            print("Baseline correction completed")
        except Exception as e:
            print(f"Error during baseline correction: {str(e)}")
            return None

        return data  # Return the processed data

    except Exception as e:
        print(f"General preprocessing error: {str(e)}")
        return None


def read_eeg_file(file_path):
    # Read the raw data
    raw_data = mne.io.read_raw_edf(file_path, preload=True)
    return raw_data


def load_model(model_path):
    try:
        # Print path to debug
        print(f"Looking for model at: {model_path}")
        print(f"Directory exists: {os.path.exists(model_path)}")

        # Load model
        model = tf.saved_model.load(model_path)
        print("Model loaded successfully")
        return model
    except Exception as e:
        print(f"Error loading model: {str(e)}")
        return None


def process_eeg_data(file_path, model):
    try:
        # Read the EEG file
        raw_data = read_eeg_file(file_path)

        # Preprocess the EEG data using your existing preprocessing function
        processed_data = preprocess_eeg(
            raw_data,
            l_freq=0.5,
            h_freq=60.0,
            notch_freq=50.0,
            n_components=2,
            epoch_duration=5.0,
            epoch_overlap=1.0
        )

        # Ensure processed_data is in the correct shape
        processed_data = np.squeeze(processed_data)

        # Reshape data to match model's expected input
        reshaped_data = processed_data.reshape(
            processed_data.shape[0], processed_data.shape[1], -1)

        # Convert to tensor and add batch dimension
        input_tensor = tf.convert_to_tensor(reshaped_data, dtype=tf.float32)

        # Ensure the input tensor is 3D: (n_epochs, n_channels, n_timepoints)
        if len(input_tensor.shape) == 2:  # If it's 2D, add a channel dimension
            input_tensor = tf.expand_dims(input_tensor, axis=-1)

        # Make prediction
        predict_fn = model.signatures['serving_default']
        predictions = predict_fn(input_tensor)
        output_key = list(predictions.keys())[0]
        preds = predictions[output_key].numpy()

        # Apply a threshold to determine class predictions
        threshold = 0.5
        class_predictions = (preds[0] > threshold).astype(int)

        print(preds)

        return [{
            "class": str(class_prediction),
            "probability": float(prob)
        } for class_prediction, prob in zip(class_predictions, preds[0])]

    except Exception as e:
        print(f"Error processing EEG data: {str(e)}")
        return [{"error": f"Failed to process EEG data: {str(e)}"}]


# Specify the path to your EDF file and model
edf_file_path = './edf_dataset_2/MDD_S6_EC.edf'
model_path = './saved_model/1d_cnn/1st_model/'

# Load the model
model = load_model(model_path)

# Example usage
if model is not None:
    results = process_eeg_data(edf_file_path, model)
    print(results)
else:
    print("Model could not be loaded. Predictions cannot be made.")

In [None]:
# Parameter Tuning
from tensorflow.keras.models import load_model  # Save model in SavedModel format
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
# Define the model
model = Sequential([
    # First Conv Block
    Conv1D(
        filters=16,
        kernel_size=3,
        activation='relu',
        padding='same',
        input_shape=(17, 1280),
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    ),
    BatchNormalization(),
    MaxPooling1D(pool_size=2),
    Dropout(0.3),

    # Second Conv Block
    Conv1D(
        filters=32,
        kernel_size=3,
        activation='relu',
        padding='same',
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    ),
    BatchNormalization(),
    MaxPooling1D(pool_size=2),
    Dropout(0.4),

    # Third Conv Block
    Conv1D(
        filters=64,
        kernel_size=3,
        activation='relu',
        padding='same',
        kernel_regularizer=tf.keras.regularizers.l2(0.001)
    ),
    BatchNormalization(),
    Dropout(0.4),

    # Dense Layers
    Flatten(),
    Dense(128, activation='relu',
          kernel_regularizer=tf.keras.regularizers.l2(0.001)),
    BatchNormalization(),
    Dropout(0.5),
    Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.001)),
    Dense(1, activation='sigmoid')
])
# Compile with lower learning rate
optimizer = Adam(learning_rate=0.001)
model.compile(
    optimizer=optimizer,
    loss='binary_crossentropy',
    metrics=['accuracy', tf.keras.metrics.Precision(),
             tf.keras.metrics.Recall()]
)

callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        min_delta=0.0001
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-6
    ),
    ModelCheckpoint(
        'best_model.keras',  # Use the `.keras` extension
        save_best_only=True,
        monitor='val_loss'
    )
]

# Data preprocessing
X = (X - X.mean()) / (X.std() + 1e-7)
# Train the model
history = model.fit(
    X, y,
    epochs=30,
    batch_size=32,
    validation_split=0.2,
    callbacks=callbacks,
    class_weight={0: 1, 1: 1.2}
)
# Evaluate the model
y_pred = (model.predict(X) > 0.5).astype(int)
acc = accuracy_score(y, y_pred)
f1 = f1_score(y, y_pred)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)

print("\nModel Performance:")
print(f"Accuracy: {acc:.2f}")
print(f"F1-score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")

# Plot training history

plt.figure(figsize=(12, 4))

# Plot accuracy
plt.subplot(1, 2, 1)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Model Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

# Plot loss
plt.subplot(1, 2, 2)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.tight_layout()
plt.show()
MODEL_PATH = './saved_model/1d_cnn/2nd_model'
tf.saved_model.save(model, MODEL_PATH)
print(f"Model saved to: {MODEL_PATH}")