In [1]:
import os
import numpy as np
import mne
import pickle
from sklearn.decomposition import PCA
from sklearn.svm import SVC
from pykalman import KalmanFilter
from joblib import Parallel, delayed
from scipy.signal import welch
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

# test patients 
data_dir = "preprocessed_epochs"
test_patients_sd = ["52", "18", "29", "17", "34", "55","10", "22", "68", "19", "42", "63"]
test_patients_ns = ["01", "19", "30", "65", "10", "13", "25", "69", "24", "33", "38", "67"]


In [None]:
test_sessions = []
for patient in test_patients_ns:
    test_sessions.append((patient, "1"))
for patient in test_patients_sd:
    test_sessions.append((patient, "2"))

def process_epoch_em(epoch, n_components, n_iter):
    """
    Apply PCA to reduce dimensionality, then use Kalman Filter with EM to estimate parameters.
    """
    # Apply PCA (variance percentage?)
    pca = PCA(n_components=n_components)
    reduced_epoch = pca.fit_transform(epoch.T)  
    kf = KalmanFilter(
        transition_matrices=np.eye(n_components),  # Initial guess for transition matrix
        observation_matrices=np.eye(n_components), 
        transition_covariance=np.eye(n_components) * 1e-4,  # Small regularization, did not converge without this 
        observation_covariance=np.eye(n_components) * 1e-4 
    )

    kf = kf.em(reduced_epoch, n_iter=n_iter)

    smoothed_state_means, _ = kf.smooth(reduced_epoch)
    return smoothed_state_means.T  

smoothed_data_dict = {}

for file_name in os.listdir(data_dir):
    if not file_name.endswith(".fif"):
        continue

    patient_id = file_name.split("_")[0].split("-")[1]
    session = file_name.split("_")[1].split("-")[1]

    file_path = os.path.join(data_dir, file_name)
    epochs = mne.read_epochs(file_path, preload=True)

    # skip patients with less than 10 epochs, too little data, may skew results 
    if len(epochs) < 10:
        continue

    # parallel
    smoothed_epochs = Parallel(n_jobs=14)(
        delayed(process_epoch_em)(epoch, n_components = 25, n_iter=7) for epoch in epochs.get_data()
    )

    smoothed_data_dict[f"{patient_id}_ses-{session}"] = {
        "smoothed_data": np.array(smoothed_epochs), 
        "label": 0 if session == "1" else 1  # 0 = NS, 1 = SD
    }

    smoothed_data = np.array(smoothed_epochs)  
    print(f"Processed {file_name}: {smoothed_data.shape}")

output_file = "KF_smoothed_data.pkl"
with open(output_file, "wb") as f:
    pickle.dump(smoothed_data_dict, f)

print(f"Saved smoothed data to {output_file}.")

In [None]:
#EXTRACTING FEATURES
frequency_bands = {
    "theta": (4, 8),
    "alpha": (8, 13),
    "beta": (13, 30),
}

features_file = "KF_smoothed_data.pkl"

with open(features_file, 'rb') as f:
    smoothed_data_dict = pickle.load(f)

patient_features_dict = {}
sfreq = 500
# feature extraction is similar in each one, just change a little bit based on the input to ML 
for patient_session, data in smoothed_data_dict.items():
    smoothed_data = data["smoothed_data"]  
    label = data["label"]
    n_epochs, n_components, n_timepoints = smoothed_data.shape

    band_power_features = {band: [] for band in frequency_bands}
    temporal_features = []
    for epoch in smoothed_data:  
        freqs, psd = welch(epoch, sfreq, nperseg=sfreq * 2, axis=1)  

        for band_name, (fmin, fmax) in frequency_bands.items():
            band_mask = (freqs >= fmin) & (freqs <= fmax)
            band_power = psd[:, band_mask].mean(axis=1)  
            band_power_features[band_name].append(band_power)
        
        mean_features = np.mean(epoch, axis=1) 
        variance_features = np.var(epoch, axis=1)
        temporal_features.append(np.hstack([mean_features, variance_features]))
    for band_name in frequency_bands:
        band_power_features[band_name] = np.array(band_power_features[band_name]) 

    theta_power = band_power_features["theta"]
    alpha_power = band_power_features["alpha"]
    beta_power = band_power_features["beta"]

    theta_alpha_ratio = theta_power / (alpha_power + 1e-10)
    theta_beta_ratio = theta_power / (beta_power + 1e-10)
    alpha_beta_ratio = alpha_power / (beta_power + 1e-10)
   
    temporal_features = np.array(temporal_features)
    
    all_features = np.hstack([theta_power, alpha_power, beta_power, theta_beta_ratio, alpha_beta_ratio, theta_alpha_ratio, temporal_features])
    patient_id, session = patient_session.split("_")  
    unique_patient_id = f"{patient_id}_{session}"

    patient_features_dict[unique_patient_id] = {
    "features": all_features,
    "label": label,
    }   
output_file = "KF_extracted_features.pkl"
with open(output_file, 'wb') as f:
    pickle.dump(patient_features_dict, f)

print(f"Saved extracted features to {output_file}.")

In [None]:
# TRAIN MODEL AND LEAVE 2 PATIENT OUT FOLD 
from collections import Counter
features_file = "KF_extracted_features.pkl"

with open(features_file, 'rb') as f:
    patient_features_dict = pickle.load(f)

pairs = list(zip(test_patients_ns, test_patients_sd))

#loop through different svm configs, to see which one is the best 
svm_configs = [
    {"kernel": "linear", "C": 1.0, "class_weight": "balanced"},
    {"kernel": "rbf", "C": 1.0, "gamma": "scale", "class_weight": "balanced"},
    {"kernel": "poly", "C": 1.0, "degree": 3, "class_weight": "balanced"},
    {"kernel": "sigmoid", "C": 1.0, "class_weight": "balanced"}
]

results = {}

for config in svm_configs:
    print(f"\nEvaluating SVM with configuration: {config}")
    
    total_correct_sessions = 0
    total_correct_epochs = 0
    total_sessions = len(test_patients_ns) + len(test_patients_sd)
    total_epochs = 0
    for fold, (ns_patient, sd_patient) in enumerate(pairs):
        train_features = []
        train_labels = []
        test_features = []
        test_labels = []
        test_patient_data = []

        for unique_patient_session, data in patient_features_dict.items():
            patient_id, session = unique_patient_session.split("_")

            # Skip the test patients 
            if (patient_id == ns_patient and session == "ses-1") or (patient_id == sd_patient and session == "ses-2"):
                test_features.append(data["features"])  
                num_epochs = data["features"].shape[0]  
                test_labels.extend([data["label"]] * num_epochs) 
                test_patient_data.append(data)
                continue
            
            train_features.append(data["features"])
            num_epochs = data["features"].shape[0] 
            train_labels.extend([data["label"]] * num_epochs)

        X_train = np.vstack(train_features)
        y_train = np.array(train_labels)

        X_test = np.vstack(test_features)
        y_test = np.array(test_labels)

        scaler = StandardScaler()
        X_train_scaled = scaler.fit_transform(X_train)  
        X_test_scaled = scaler.transform(X_test)  
    
        pca = PCA(n_components=0.95)  # Retain 95% variance
        X_train_reduced = pca.fit_transform(X_train_scaled)
        X_test_reduced = pca.transform(X_test_scaled)

        # train svm 
        svm = SVC(**config, random_state=42)
        svm.fit(X_train_reduced, y_train)

        # evaluate 
        y_pred = svm.predict(X_test_reduced)
        
        epoch_accuracy = accuracy_score(y_test, y_pred)
        print(f"Fold {fold + 1} Epoch-Level Accuracy: {epoch_accuracy:.2f}")

        epoch_correct = sum(y_test == y_pred) 
        fold_total_epochs = len(y_test)       
        total_correct_epochs += epoch_correct
        total_epochs += fold_total_epochs
        # Majority voting for session-level prediction, predicts if each epoch is SD or NS 
        session_predictions = []
        session_labels = []
        start_idx = 0
        for data in test_patient_data:
            num_epochs = data["features"].shape[0]
            session_pred = y_pred[start_idx:start_idx + num_epochs]
            session_label = data["label"]
            majority_label = Counter(session_pred).most_common(1)[0][0]  
            session_predictions.append(majority_label)
            session_labels.append(session_label)
            start_idx += num_epochs

        session_accuracy = accuracy_score(session_labels, session_predictions)
        correct_sessions = sum(np.array(session_labels) == np.array(session_predictions))
        fold_total_sessions = len(session_labels)
        print(f"Fold {fold + 1} Session-Level Accuracy: {session_accuracy:.2f}")
        total_correct_sessions += correct_sessions

        print("Session-Level Classification Report:")
        print(classification_report(session_labels, session_predictions))

    overall_epoch_accuracy = total_correct_epochs / total_epochs
    overall_session_accuracy = total_correct_sessions / total_sessions

    print(f"\nOverall Epoch-Level Accuracy: {overall_epoch_accuracy:.2f}")
    print(f"Total Correct Predictions (Epoch-Level): {total_correct_epochs}/{total_epochs}")
    print(f"\nOverall Session-Level Accuracy: {overall_session_accuracy:.2f}")
    print(f"Total Correct Predictions (Session-Level): {total_correct_sessions}/{total_sessions}")

In [None]:
#get PSD for raw and smoothed signal
import matplotlib.pyplot as plt

psd_raw = []
fmin, fmax = 2, 40  # Hz

for i in range(1, 72):
    patient_id = f"sub-{i:02d}"
    for session, condition in zip(['1', '2'], ['normal sleep', 'sleep deprivation']):
        eeg_path = os.path.join("Sleep_dep_dataset", patient_id, "ses-" + session, "eeg", patient_id + "_ses-" + session + "_task-eyesopen_eeg.set")
        
        if os.path.exists(eeg_path):
            raw = mne.io.read_raw_eeglab(eeg_path, preload=True)
            data = raw.get_data()
            freqs, psd = welch(data, sfreq, nperseg=sfreq * 2, axis=1)
            freq_mask = (freqs >= fmin) & (freqs <= fmax)
            freqs_filtered = freqs[freq_mask]
            psd_filtered = psd[:, freq_mask]
            mean_psd = np.mean(psd_filtered, axis=0)  
            psd_raw.append(mean_psd)
           
psd_raw = np.array(psd_raw)
mean_raw = np.mean(psd_raw, axis=0)
sem_raw = np.std(psd_raw, axis=0) / np.sqrt(psd_raw.shape[0])


#SMOOTHED SIGNAL PSD
features_file = "KF_smoothed_data.pkl"
with open(features_file, 'rb') as f:
    smoothed_data_dict = pickle.load(f)
psd_smoothed = [] 
sfreq = 500
for patient_session, data in smoothed_data_dict.items():
    smoothed_data = data['smoothed_data'] 

    smoothed_data_continuous = np.concatenate(smoothed_data, axis=-1)  

    freqs, psd = welch(smoothed_data_continuous, sfreq, nperseg=sfreq * 2, axis=1) 
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    freqs_filtered = freqs[freq_mask]
    psd_filtered = psd[:, freq_mask]
    mean_psd = np.mean(psd_filtered, axis=0) 
    psd_smoothed.append(mean_psd)

#
psd_smoothed = np.array(psd_smoothed)
mean_smoothed = np.mean(psd_smoothed, axis=0)
sem_smoothed = np.std(psd_smoothed, axis=0) / np.sqrt(psd_smoothed.shape[0])

In [None]:
plt.figure(figsize=(10, 6))

# Convert PSD values to dB
mean_smoothed_db = 10 * np.log10(mean_smoothed)
sem_smoothed_db = 10 * np.log10(mean_smoothed + sem_smoothed) - mean_smoothed_db

mean_raw_db = 10 * np.log10(mean_raw)
sem_raw_db = 10 * np.log10(mean_raw + sem_raw) - mean_raw_db

# Plotting the smoothed signal
plt.plot(freqs_filtered, mean_smoothed_db, label='KF Smoothed Signal', color='blue')
plt.fill_between(
    freqs_filtered, 
    mean_smoothed_db - sem_smoothed_db, 
    mean_smoothed_db + sem_smoothed_db, 
    alpha=0.2, 
    color='blue'
)

# Plotting the original signal
plt.plot(freqs_filtered, mean_raw_db, label='Original Signal', color='orange')
plt.fill_between(
    freqs_filtered, 
    mean_raw_db - sem_raw_db, 
    mean_raw_db + sem_raw_db, 
    alpha=0.2, 
    color='orange'
)

# Labels and title
plt.xlabel('Frequency (Hz)')
plt.ylabel('PSD (dB)')  # Updated units
plt.title('KF Smoothed Signal vs Original Signal (PSD in dB)')
plt.legend()
plt.grid(True)
plt.show()


In [None]:
psd_raw_ns = []
psd_raw_sd = []
for i in range(1, 72):
    patient_id = f"sub-{i:02d}"
    for session, condition in zip(['1', '2'], ['normal sleep', 'sleep deprivation']):
        eeg_path = os.path.join("Sleep_dep_dataset", patient_id, "ses-" + session, "eeg", patient_id + "_ses-" + session + "_task-eyesopen_eeg.set")
        
        if os.path.exists(eeg_path):
            raw = mne.io.read_raw_eeglab(eeg_path, preload=True)
            data = raw.get_data()

            freqs, psd = welch(data, sfreq, nperseg=sfreq * 2, axis=1)
            freq_mask = (freqs >= fmin) & (freqs <= fmax)
            freqs_filtered = freqs[freq_mask]
            psd_filtered = psd[:, freq_mask]
            mean_psd = np.mean(psd_filtered, axis=0)
  
            if condition == 'normal sleep':
                psd_raw_ns.append(mean_psd)
            else:
                psd_raw_sd.append(mean_psd)

psd_raw_ns = np.array(psd_raw_ns)
psd_raw_sd = np.array(psd_raw_sd)

mean_raw_ns = np.mean(psd_raw_ns, axis=0)
sem_raw_ns = np.std(psd_raw_ns, axis=0) / np.sqrt(psd_raw_ns.shape[0])

mean_raw_sd = np.mean(psd_raw_sd, axis=0)
sem_raw_sd = np.std(psd_raw_sd, axis=0) / np.sqrt(psd_raw_sd.shape[0])

psd_smoothed_ns = []
psd_smoothed_sd = []

features_file = "KF_smoothed_data.pkl"
with open(features_file, 'rb') as f:
    smoothed_data_dict = pickle.load(f)

for patient_session, data in smoothed_data_dict.items():
    smoothed_data = data['smoothed_data'] 
    label = data['label']  


    smoothed_data_continuous = np.concatenate(smoothed_data, axis=-1)  

  
    freqs, psd = welch(smoothed_data_continuous, sfreq, nperseg=sfreq * 2, axis=1)  
    freq_mask = (freqs >= fmin) & (freqs <= fmax)
    freqs_filtered = freqs[freq_mask]
    psd_filtered = psd[:, freq_mask]
    mean_psd = np.mean(psd_filtered, axis=0)
    
    if label == 0:  
        psd_smoothed_ns.append(mean_psd)
    else:  
        psd_smoothed_sd.append(mean_psd)

psd_smoothed_ns = np.array(psd_smoothed_ns)
psd_smoothed_sd = np.array(psd_smoothed_sd)

mean_smoothed_ns = np.mean(psd_smoothed_ns, axis=0)
sem_smoothed_ns = np.std(psd_smoothed_ns, axis=0) / np.sqrt(psd_smoothed_ns.shape[0])

mean_smoothed_sd = np.mean(psd_smoothed_sd, axis=0)
sem_smoothed_sd = np.std(psd_smoothed_sd, axis=0) / np.sqrt(psd_smoothed_sd.shape[0])

In [None]:
import matplotlib.pyplot as plt

# Convert PSD values to dB for smoothed signals
mean_smoothed_ns_db = 10 * np.log10(mean_smoothed_ns)
sem_smoothed_ns_db = 10 * np.log10(mean_smoothed_ns + sem_smoothed_ns) - mean_smoothed_ns_db

mean_smoothed_sd_db = 10 * np.log10(mean_smoothed_sd)
sem_smoothed_sd_db = 10 * np.log10(mean_smoothed_sd + sem_smoothed_sd) - mean_smoothed_sd_db

# Convert PSD values to dB for raw signals
mean_raw_ns_db = 10 * np.log10(mean_raw_ns)
sem_raw_ns_db = 10 * np.log10(mean_raw_ns + sem_raw_ns) - mean_raw_ns_db

mean_raw_sd_db = 10 * np.log10(mean_raw_sd)
sem_raw_sd_db = 10 * np.log10(mean_raw_sd + sem_raw_sd) - mean_raw_sd_db

# Plotting
fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

# Smoothed signals plot
axes[0].plot(freqs_filtered, mean_smoothed_ns_db, label='NS (Smoothed)', color='blue')
axes[0].fill_between(
    freqs_filtered,
    mean_smoothed_ns_db - sem_smoothed_ns_db,
    mean_smoothed_ns_db + sem_smoothed_ns_db,
    alpha=0.2,
    color='blue'
)
axes[0].plot(freqs_filtered, mean_smoothed_sd_db, label='SD (Smoothed)', color='orange')
axes[0].fill_between(
    freqs_filtered,
    mean_smoothed_sd_db - sem_smoothed_sd_db,
    mean_smoothed_sd_db + sem_smoothed_sd_db,
    alpha=0.2,
    color='orange'
)
axes[0].set_title('Smoothed Signal: NS vs SD')
axes[0].set_xlabel('Frequency (Hz)')
axes[0].set_ylabel('PSD (dB)')  # Updated units
axes[0].legend()
axes[0].grid(True)

# Raw signals plot
axes[1].plot(freqs_filtered, mean_raw_ns_db, label='NS (Raw)', color='blue')
axes[1].fill_between(
    freqs_filtered,
    mean_raw_ns_db - sem_raw_ns_db,
    mean_raw_ns_db + sem_raw_ns_db,
    alpha=0.2,
    color='blue'
)
axes[1].plot(freqs_filtered, mean_raw_sd_db, label='SD (Raw)', color='orange')
axes[1].fill_between(
    freqs_filtered,
    mean_raw_sd_db - sem_raw_sd_db,
    mean_raw_sd_db + sem_raw_sd_db,
    alpha=0.2,
    color='orange'
)
axes[1].set_title('Raw Signal: NS vs SD')
axes[1].set_xlabel('Frequency (Hz)')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.show()
