In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import mne
import tensorflow as tf
import gc
import pickle
from scipy.signal import welch
from tensorflow.keras.models import Sequential
from sklearn.utils.class_weight import compute_class_weight
from tensorflow.keras.layers import GRU, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers import Masking
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import f1_score
from collections import Counter
import optuna
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns

data_dir = "preprocessed_epochs"
frequency_bands = {
    "theta": (4, 8),
    "alpha": (8, 13),
    "beta": (13, 30),
}

# region_channels = {
#     "frontal": ["Fp1", "Fp2", "Fz", "F3", "F4"],
#     "central": ["Cz", "C3", "C4"],
#     "parietal": ["Pz", "P3", "P4"],
#     "occipital": ["O1", "O2"],
# }

In [None]:
X_data = [] 
y_data = [] 
max_epochs = 75  

for file_name in os.listdir(data_dir):
    if not file_name.endswith(".fif"):
        continue  

    file_path = os.path.join(data_dir, file_name)
    epochs = mne.read_epochs(file_path, preload=True)

    if len(epochs) < 10:
        print(f"Skipping patient {file_name}: less than 10 epochs.")
        continue

    patient_data = epochs.get_data()

    scaler = StandardScaler()
    n_epochs, n_channels, n_times = patient_data.shape

    patient_data = patient_data.reshape(n_epochs, -1) 
    normalized_data = scaler.fit_transform(patient_data) 
    normalized_data = normalized_data.reshape(n_epochs, n_channels, n_times)  

    X_data.append(normalized_data)
    y_data.append(1 if "ses-2" in file_name else 0)  

# Padding
X_padded = []
for patient_data in X_data:
    num_epochs = len(patient_data)
    if num_epochs > max_epochs:
        # Truncate
        padded_data = patient_data[:max_epochs]
    else:
        padded_data = np.pad(
            patient_data,
            ((0, max_epochs - num_epochs), (0, 0), (0, 0)),  
            mode="constant",
            constant_values=0
        )
    X_padded.append(padded_data)

X_padded = np.array(X_padded) 
y_data = np.array(y_data)  
save_path = "RNN_padded_data"
with open(save_path, "wb") as file:
    pickle.dump((X_padded, y_data), file)


In [None]:
save_path = 'RNN_padded_data'
with open(save_path, "rb") as file:
    X_padded, y_data = pickle.load(file)

test_patients_sd = ["52", "18", "29", "17", "34"]
test_patients_ns = ["01", "19", "30", "65", "10"]
validate_patients_sd = ["55", "10", "22", "68", "19", "42", "63", "14"]
validate_patients_ns = ["13", "25", "69", "24", "33", "38", "67", "34"]

test_sessions = [(patient, "2") for patient in test_patients_sd] + [(patient, "1") for patient in test_patients_ns]
validate_sessions = [(patient, "2") for patient in validate_patients_sd] + [(patient, "1") for patient in validate_patients_ns]

X_train, y_train = [], []
X_val, y_val = [], []
X_test, y_test = [], []


X_padded_reshaped = X_padded.reshape(-1, X_padded.shape[3], X_padded.shape[2])  
y_data_repeated = np.repeat(y_data, X_padded.shape[1])  


all_file_names = [
    file_name for file_name in os.listdir(data_dir) if file_name.endswith(".fif")
]

X_padded_reshaped = X_padded.reshape(-1, X_padded.shape[3], X_padded.shape[2])  
y_data_repeated = np.repeat(y_data, X_padded.shape[1])  

test_patient_data = []
validate_patient_data = []
train_patient_data = []

for file_name, (epoch_data, label) in zip(all_file_names, zip(X_padded_reshaped, y_data_repeated)):
    patient_id, session_info = file_name.split("_")[0].split("-")[1], file_name.split("_")[1].split("-")[1]

    if (patient_id, session_info) in test_sessions:
        X_test.extend(epoch_data) 
        y_test.extend([1 if session_info == '2' else 0] * len(epoch_data)) 
        test_patient_data.append({
            "features": epoch_data,  
            "label": 1 if session_info == '2' else 0  
        })
    elif (patient_id, session_info) in validate_sessions:
        X_val.extend(epoch_data)  
        y_val.extend([1 if session_info == '2' else 0] * len(epoch_data))  
        validate_patient_data.append({
            "features": epoch_data,  
            "label": 1 if session_info == '2' else 0  
        })
    else:
        X_train.extend(epoch_data)  
        y_train.extend([1 if session_info == '2' else 0] * len(epoch_data))
        train_patient_data.append({
            "features": epoch_data,  
            "label": 1 if session_info == '2' else 0  
        })

X_train, y_train = np.array(X_train), np.array(y_train)
X_val, y_val = np.array(X_val), np.array(y_val)
X_test, y_test = np.array(X_test), np.array(y_test)

X_train = X_train.reshape((X_train.shape[0], 1, X_train.shape[1]))
X_val = X_val.reshape((X_val.shape[0], 1, X_val.shape[1]))
X_test = X_test.reshape((X_test.shape[0], 1, X_test.shape[1]))

print(f"X_train shape: {X_train.shape}, y_train shape: {y_train.shape}")
print(f"X_val shape: {X_val.shape}, y_val shape: {y_val.shape}")
print(f"X_test shape: {X_test.shape}, y_test shape: {y_test.shape}")

class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
class_weight_dict = dict(enumerate(class_weights))

epochs = 40  
input_shape = (X_train.shape[1], X_train.shape[2]) 

early_stopping = EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)

lr_scheduler = ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,
    patience=3,
    verbose=0,
    min_lr=1e-6
)
def build_gru_model(input_shape, gru_units, dropouts):
    model = Sequential([
        Masking(mask_value=0.0, input_shape=input_shape),
        GRU(units=gru_units, return_sequences=True, dropout=dropouts, use_cudnn=False),
        GRU(units=gru_units // 2, return_sequences=True, dropout=dropouts, use_cudnn=False),
        GRU(units=gru_units // 4, return_sequences=False, dropout=dropouts, use_cudnn=False),
        Dense(units=128, activation='relu'),
        Dropout(0.2),
        Dense(units=64, activation='relu'),
        Dropout(0.2),
        Dense(units=32, activation='relu'),
        Dropout(0.2),
        Dense(units=1, activation='sigmoid'),
    ])
    return model

def objective(trial):
    gru_units = trial.suggest_int('gru_units', 64, 512, step=64)
    learning_rate = trial.suggest_float('learning_rate', 1e-5, 1e-3, log=True)
    dropouts = trial.suggest_float('dropouts', 0.1, 0.5)
    batch_size = trial.suggest_categorical('batch_size', [16, 32, 64])

    model = build_gru_model(input_shape, gru_units, dropouts)
    model.compile(
        optimizer=Adam(learning_rate=learning_rate),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    model.fit(
        X_train, y_train,
        validation_data=(X_val, y_val),
        epochs=40,
        batch_size=batch_size,
        class_weight=class_weight_dict,
        callbacks=[early_stopping, lr_scheduler],
        verbose=0
    )

    val_loss, val_accuracy = model.evaluate(X_val, y_val, verbose=0)
    return val_accuracy  

study = optuna.create_study(direction='maximize')
study.optimize(objective, n_trials=20)  
study.trials_dataframe().to_csv("optuna_study_results.csv")
print("Best Hyperparameters:")
print(study.best_params)

best_params = study.best_params
gru_units = best_params['gru_units']
learning_rate = best_params['learning_rate']
dropouts = best_params['dropouts']
batch_size = best_params['batch_size']

model = build_gru_model(input_shape, gru_units, dropouts)
model.compile(
    optimizer=Adam(learning_rate=learning_rate),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=30,
    batch_size=batch_size,
    class_weight=class_weight_dict,
    callbacks=[early_stopping, lr_scheduler],
    verbose=1
)

test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")



In [None]:
loss, accuracy = model.evaluate(X_val, y_val, verbose=0)
print(f"Validation Loss: {loss:.4f}")
print(f"Validation Accuracy: {accuracy:.4f}")

plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Loss Over Epochs')

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.legend()
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.title('Accuracy Over Epochs')

plt.show()

In [None]:
test_loss, test_accuracy = model.evaluate(X_test, y_test, verbose=0)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy (Epoch-Level): {test_accuracy:.4f}")

y_test_pred = model.predict(X_test)
y_test_pred_binary = (y_test_pred > 0.5).astype(int).flatten()

print("\nEpoch-Level Classification Report:")
print(classification_report(y_test, y_test_pred_binary, target_names=["Normal Sleep", "Sleep Deprivation"]))

conf_matrix_epoch = confusion_matrix(y_test, y_test_pred_binary)
print("\nEpoch-Level Confusion Matrix:")
print(conf_matrix_epoch)

plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix_epoch, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Normal Sleep", "Sleep Deprivation"],
            yticklabels=["Normal Sleep", "Sleep Deprivation"])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Epoch-Level Confusion Matrix")
plt.show()

session_predictions = []
session_labels = []
start_idx = 0

for data in test_patient_data:
    num_epochs = data["features"].shape[0]
    session_pred = y_test_pred_binary[start_idx:start_idx + num_epochs]  
    session_label = data["label"]
    
    majority_label = Counter(session_pred).most_common(1)[0][0]
    session_predictions.append(majority_label)
    session_labels.append(session_label)
    
    start_idx += num_epochs

session_accuracy = accuracy_score(session_labels, session_predictions)
print(f"\nSession-Level Accuracy: {session_accuracy:.4f}")

print("\nSession-Level Classification Report:")
print(classification_report(session_labels, session_predictions, target_names=["Normal Sleep", "Sleep Deprivation"]))

conf_matrix_session = confusion_matrix(session_labels, session_predictions)
print("\nSession-Level Confusion Matrix:")
print(conf_matrix_session)

plt.figure(figsize=(6, 5))
sns.heatmap(conf_matrix_session, annot=True, fmt="d", cmap="Blues",
            xticklabels=["Normal Sleep", "Sleep Deprivation"],
            yticklabels=["Normal Sleep", "Sleep Deprivation"])
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Session-Level Confusion Matrix")
plt.show()


In [None]:
print(X_padded.shape)