## Random Forest

In [10]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal
import mne
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import joblib
import torch, torch.nn as nn, torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ------------------- Paths -------------------
ROOT = os.getcwd()
EDF_DIR = '../data/synthetic_edf'
META_PATH = '../data/Metadata.csv'
OUT_DS = '../results/dataset_epoch_features.pkl'
MODEL_PATH = '../results/baseline_models_256.pkl'
DEEP_MODEL_PATH = '../results/deep_baseline_256.pkl'

os.makedirs('../results', exist_ok=True)

# ------------------- EDF & Feature Extraction Utilities -------------------
def load_raw_edf(path, preload=True, verbose=False):
    return mne.io.read_raw_edf(path, preload=preload, verbose=verbose)

def ensure_resample(raw, target_sfreq):
    if int(raw.info['sfreq']) != int(target_sfreq):
        raw = raw.copy().resample(int(target_sfreq))
    return raw

def pick_available_channels(raw):
    picks = {'eeg': mne.pick_types(raw.info, eeg=True),
             'eog': mne.pick_types(raw.info, eog=True),
             'emg': mne.pick_types(raw.info, emg=True),
             'ecg': mne.pick_types(raw.info, ecg=True)}
    chans = {}
    for k, v in picks.items():
        chans[k] = [raw.ch_names[i] for i in v] if len(v) > 0 else []
    return chans

def epoch_raw_nonoverlap(raw, epoch_s=30, picks='eeg'):
    if isinstance(picks, str) and picks == 'eeg':
        idx = mne.pick_types(raw.info, eeg=True)
    elif isinstance(picks, list):
        idx = [raw.ch_names.index(ch) for ch in picks if ch in raw.ch_names]
    else:
        idx = mne.pick_types(raw.info, eeg=True)
    data = raw.get_data(picks=idx)
    sf = int(raw.info['sfreq'])
    n_samples = sf * epoch_s
    n_epochs = data.shape[1] // n_samples
    epochs = []
    for e in range(n_epochs):
        start = e * n_samples; stop = start + n_samples
        epochs.append(data[:, start:stop])
    return np.array(epochs)

def bandpower_vector(epoch, sfreq):
    bands = {'delta': (0.5, 4), 'theta': (4, 8), 'alpha': (8, 12), 'beta': (12, 30), 'gamma': (30, 45)}
    feats = []
    for ch in range(epoch.shape[0]):
        f, Pxx = signal.welch(epoch[ch], fs=sfreq, nperseg=256)
        total = np.trapz(Pxx, f) + 1e-12
        for lo, hi in bands.values():
            idx = np.logical_and(f >= lo, f <= hi)
            feats.append(np.trapz(Pxx[idx], f[idx]) / total)
    return np.array(feats)

def hjorth_params(signal_epoch):
    diff1 = np.diff(signal_epoch)
    diff2 = np.diff(diff1)
    var0 = np.var(signal_epoch) + 1e-12
    var1 = np.var(diff1) + 1e-12
    var2 = np.var(diff2) + 1e-12
    mobility = np.sqrt(var1 / var0)
    complexity = np.sqrt(var2 / var1) / mobility if mobility > 0 else 0.0
    return mobility, complexity

def spectral_entropy(epoch_ch, sfreq):
    f, Pxx = signal.welch(epoch_ch, fs=sfreq, nperseg=256)
    P = Pxx / (Pxx.sum() + 1e-12)
    return -np.sum(P * np.log2(P + 1e-12))

def extract_epoch_features(epoch, sfreq):
    feats = bandpower_vector(epoch, sfreq).tolist()
    for ch in range(epoch.shape[0]):
        m, c = hjorth_params(epoch[ch])
        e = spectral_entropy(epoch[ch], sfreq)
        feats.extend([m, c, e])
    return np.array(feats)

# ------------------- Dataset Creation -------------------
if not os.path.exists(META_PATH):
    print('Metadata not found!')
    sys.exit()

meta = pd.read_csv(META_PATH)
all_X, all_y = [], []

for i, row in tqdm(meta.iterrows(), total=len(meta)):
    fname = row['Filename']
    edf_path = os.path.join(EDF_DIR, fname)
    if not os.path.exists(edf_path):
        print('Missing', edf_path)
        continue
    try:
        raw = load_raw_edf(edf_path, preload=True)
        raw = ensure_resample(raw, 256)
        chans = pick_available_channels(raw)
        picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
        if len(picks) == 0: continue
        epochs = epoch_raw_nonoverlap(raw, 30, picks)
        for ep in epochs:
            feat = extract_epoch_features(ep, 256)
            all_X.append(feat)
            all_y.append(row['Last sleep stage'])
    except Exception as e:
        print('Error processing', fname, e)

X = np.vstack(all_X)
y = np.array(all_y)
print('Dataset shape:', X.shape)

joblib.dump({'X': X, 'y': y}, OUT_DS)
print('Saved dataset to', OUT_DS)

# ------------------- Train Full Model -------------------
le = LabelEncoder()
y_enc = le.fit_transform(y)
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(Xs, y_enc, test_size=0.2, stratify=y_enc, random_state=42)

rf = RandomForestClassifier(n_estimators=300, random_state=42)
rf.fit(X_train, y_train)
print('RandomForest report:\n', classification_report(y_test, rf.predict(X_test), target_names=le.classes_))

joblib.dump({'rf': rf, 'scaler': scaler, 'le': le}, MODEL_PATH)
print('Saved model to', MODEL_PATH)

# ------------------- Prediction -------------------
def predict_from_edf_full(edf_path):
    saved = joblib.load(MODEL_PATH)
    scaler, le, rf = saved['scaler'], saved['le'], saved['rf']
    raw = load_raw_edf(edf_path, preload=True)
    chans = pick_available_channels(raw)
    picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
    epochs = epoch_raw_nonoverlap(raw, 30, picks)
    for i, ep in enumerate(epochs):
        feat = extract_epoch_features(ep, 256).reshape(1, -1)
        feat_scaled = scaler.transform(feat)
        pred = rf.predict(feat_scaled)
        print(f"Epoch {i+1} [Full 256] prediction:", le.inverse_transform(pred)[0])

  0%|          | 0/500 [00:00<?, ?it/s]

100%|██████████| 500/500 [00:24<00:00, 20.33it/s]


Dataset shape: (1000, 256)
Saved dataset to ../results/dataset_epoch_features.pkl
RandomForest report:
               precision    recall  f1-score   support

          N1       0.69      0.79      0.73        42
          N2       0.76      0.76      0.76        42
          N3       0.83      0.71      0.77        35
         REM       0.69      0.79      0.73        42
        Wake       0.88      0.72      0.79        39

    accuracy                           0.76       200
   macro avg       0.77      0.75      0.76       200
weighted avg       0.77      0.76      0.76       200

Saved model to ../results/baseline_models_256.pkl


In [11]:
import os, sys
import numpy as np
import pandas as pd
from scipy import signal
import mne
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
import joblib

# ------------------- Paths -------------------
ROOT = os.getcwd()
EDF_DIR = '../data/synthetic_edf'
META_PATH = '../data/Metadata.csv'
OUT_DS = '../results/dataset_epoch_features_8feat.pkl'
MODEL_PATH_8FEAT = '../results/baseline_models_8feat.pkl'

os.makedirs('../results', exist_ok=True)

# ------------------- EDF & Feature Extraction -------------------
def load_raw_edf(path, preload=True, verbose=False):
    return mne.io.read_raw_edf(path, preload=preload, verbose=verbose)

def pick_available_channels(raw):
    picks = {'eeg': mne.pick_types(raw.info, eeg=True),
             'eog': mne.pick_types(raw.info, eog=True),
             'emg': mne.pick_types(raw.info, emg=True),
             'ecg': mne.pick_types(raw.info, ecg=True)}
    chans = {}
    for k, v in picks.items():
        chans[k] = [raw.ch_names[i] for i in v] if len(v) > 0 else []
    return chans

def epoch_raw_nonoverlap(raw, epoch_s=30, picks='eeg'):
    if isinstance(picks, str) and picks == 'eeg':
        idx = mne.pick_types(raw.info, eeg=True)
    elif isinstance(picks, list):
        idx = [raw.ch_names.index(ch) for ch in picks if ch in raw.ch_names]
    else:
        idx = mne.pick_types(raw.info, eeg=True)
    data = raw.get_data(picks=idx)
    sf = int(raw.info['sfreq'])
    n_samples = sf * epoch_s
    n_epochs = data.shape[1] // n_samples
    epochs = []
    for e in range(n_epochs):
        start = e * n_samples; stop = start + n_samples
        epochs.append(data[:, start:stop])
    return np.array(epochs)

def bandpower_vector_8feat(epoch, sfreq):
    """Compute only 8 key features (e.g., from EEG delta/theta/alpha/beta bands)."""
    bands = [(0.5, 4), (4, 8), (8, 12), (12, 30)]
    feats = []
    for ch in range(min(2, epoch.shape[0])):  # only first 2 channels
        f, Pxx = signal.welch(epoch[ch], fs=sfreq, nperseg=256)
        total = np.trapz(Pxx, f) + 1e-12
        for lo, hi in bands:
            idx = np.logical_and(f >= lo, f <= hi)
            feats.append(np.trapz(Pxx[idx], f[idx]) / total)
    return np.array(feats)

# ------------------- Dataset Creation -------------------
if not os.path.exists(META_PATH):
    print('Metadata not found!')
    sys.exit()

meta = pd.read_csv(META_PATH)
all_X, all_y = [], []

for i, row in tqdm(meta.iterrows(), total=len(meta)):
    fname = row['Filename']
    edf_path = os.path.join(EDF_DIR, fname)
    if not os.path.exists(edf_path):
        print('Missing', edf_path)
        continue
    try:
        raw = load_raw_edf(edf_path, preload=True)
        chans = pick_available_channels(raw)
        picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
        if len(picks) == 0: continue
        epochs = epoch_raw_nonoverlap(raw, 30, picks)
        for ep in epochs:
            feat = bandpower_vector_8feat(ep, raw.info['sfreq'])
            all_X.append(feat)
            all_y.append(row['Last sleep stage'])
    except Exception as e:
        print('Error processing', fname, e)

X = np.vstack(all_X)
y = np.array(all_y)
print('8-feature dataset shape:', X.shape)

joblib.dump({'X': X, 'y': y}, OUT_DS)
print('Saved dataset to', OUT_DS)

# ------------------- Train 8-Feature Model -------------------
le = LabelEncoder()
y_enc = le.fit_transform(y)
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(Xs, y_enc, test_size=0.2, stratify=y_enc, random_state=42)

rf = RandomForestClassifier(n_estimators=150, random_state=42)
rf.fit(X_train, y_train)
print('8-feature RandomForest report:\n', classification_report(y_test, rf.predict(X_test), target_names=le.classes_))

joblib.dump({'rf': rf, 'scaler': scaler, 'le': le}, MODEL_PATH_8FEAT)
print('Saved 8-feature model to', MODEL_PATH_8FEAT)

# ------------------- Prediction -------------------
def predict_from_edf_8feat(edf_path):
    saved = joblib.load(MODEL_PATH_8FEAT)
    scaler, le, rf = saved['scaler'], saved['le'], saved['rf']
    raw = load_raw_edf(edf_path, preload=True)
    chans = pick_available_channels(raw)
    picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
    epochs = epoch_raw_nonoverlap(raw, 30, picks)
    for i, ep in enumerate(epochs):
        feat = bandpower_vector_8feat(ep, raw.info['sfreq']).reshape(1, -1)
        feat_scaled = scaler.transform(feat)
        pred = rf.predict(feat_scaled)
        print(f"Epoch {i+1} [8-Feature] prediction:", le.inverse_transform(pred)[0])



100%|██████████| 500/500 [00:06<00:00, 75.15it/s]


8-feature dataset shape: (1000, 8)
Saved dataset to ../results/dataset_epoch_features_8feat.pkl
8-feature RandomForest report:
               precision    recall  f1-score   support

          N1       0.24      0.29      0.26        42
          N2       0.25      0.29      0.27        42
          N3       0.28      0.23      0.25        35
         REM       0.31      0.31      0.31        42
        Wake       0.19      0.15      0.17        39

    accuracy                           0.26       200
   macro avg       0.25      0.25      0.25       200
weighted avg       0.25      0.26      0.25       200

Saved 8-feature model to ../results/baseline_models_8feat.pkl


## CNN

In [12]:
# ------------------- CNN Model -------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ----- Prepare dataset -----
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Reshape for CNN: (samples, channels=1, features)
X_train_cnn = X_train.reshape(X_train.shape[0], 1, X_train.shape[1])
X_test_cnn = X_test.reshape(X_test.shape[0], 1, X_test.shape[1])

train_ds = EEGDataset(X_train_cnn, y_train)
test_ds = EEGDataset(X_test_cnn, y_test)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# ----- Define CNN -----
class EEGCNN(nn.Module):
    def __init__(self, num_classes):
        super(EEGCNN, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool = nn.AdaptiveAvgPool1d(16)
        self.fc1 = nn.Linear(64 * 16, 128)
        self.fc2 = nn.Linear(128, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        
    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.relu(self.fc1(x)))
        return self.fc2(x)

# ----- Train CNN -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGCNN(num_classes=len(le.classes_)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 200
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model(Xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

# ----- Evaluate CNN -----
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for Xb, yb in test_loader:
        Xb = Xb.to(device)
        preds = model(Xb).argmax(1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(yb.numpy())

print('CNN Classification Report:\n', classification_report(y_true, y_pred, target_names=le.classes_))

# ----- Save CNN model -----
torch.save({'model_state_dict': model.state_dict(),
            'scaler': scaler,
            'label_encoder': le}, DEEP_MODEL_PATH)
print(f'Saved CNN model to {DEEP_MODEL_PATH}')


Epoch 1/200, Loss: 1.6478
Epoch 2/200, Loss: 1.5904
Epoch 3/200, Loss: 1.5689
Epoch 4/200, Loss: 1.5597
Epoch 5/200, Loss: 1.5413
Epoch 6/200, Loss: 1.5477
Epoch 7/200, Loss: 1.5320
Epoch 8/200, Loss: 1.5250
Epoch 9/200, Loss: 1.5271
Epoch 10/200, Loss: 1.5148
Epoch 11/200, Loss: 1.5101
Epoch 12/200, Loss: 1.5168
Epoch 13/200, Loss: 1.4995
Epoch 14/200, Loss: 1.5091
Epoch 15/200, Loss: 1.5104
Epoch 16/200, Loss: 1.5067
Epoch 17/200, Loss: 1.4934
Epoch 18/200, Loss: 1.4843
Epoch 19/200, Loss: 1.4865
Epoch 20/200, Loss: 1.4843
Epoch 21/200, Loss: 1.4854
Epoch 22/200, Loss: 1.4814
Epoch 23/200, Loss: 1.4836
Epoch 24/200, Loss: 1.4794
Epoch 25/200, Loss: 1.4757
Epoch 26/200, Loss: 1.4801
Epoch 27/200, Loss: 1.4620
Epoch 28/200, Loss: 1.4726
Epoch 29/200, Loss: 1.4622
Epoch 30/200, Loss: 1.4518
Epoch 31/200, Loss: 1.4689
Epoch 32/200, Loss: 1.4608
Epoch 33/200, Loss: 1.4656
Epoch 34/200, Loss: 1.4560
Epoch 35/200, Loss: 1.4395
Epoch 36/200, Loss: 1.4578
Epoch 37/200, Loss: 1.4465
Epoch 38/2

In [13]:
import os, sys
import numpy as np
import pandas as pd
from scipy import signal
import mne
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import joblib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ------------------- Paths -------------------
ROOT = os.getcwd()
EDF_DIR = '../data/synthetic_edf'
META_PATH = '../data/Metadata.csv'
OUT_DS = '../results/dataset_epoch_features_8feat.pkl'
DEEP_MODEL_PATH_8FEAT = '../results/cnn_model_8feat.pt'

os.makedirs('../results', exist_ok=True)

# ------------------- EDF & Feature Extraction -------------------
def load_raw_edf(path, preload=True, verbose=False):
    return mne.io.read_raw_edf(path, preload=preload, verbose=verbose)

def pick_available_channels(raw):
    picks = {'eeg': mne.pick_types(raw.info, eeg=True),
             'eog': mne.pick_types(raw.info, eog=True),
             'emg': mne.pick_types(raw.info, emg=True),
             'ecg': mne.pick_types(raw.info, ecg=True)}
    chans = {}
    for k, v in picks.items():
        chans[k] = [raw.ch_names[i] for i in v] if len(v) > 0 else []
    return chans

def epoch_raw_nonoverlap(raw, epoch_s=30, picks='eeg'):
    if isinstance(picks, str) and picks == 'eeg':
        idx = mne.pick_types(raw.info, eeg=True)
    elif isinstance(picks, list):
        idx = [raw.ch_names.index(ch) for ch in picks if ch in raw.ch_names]
    else:
        idx = mne.pick_types(raw.info, eeg=True)
    data = raw.get_data(picks=idx)
    sf = int(raw.info['sfreq'])
    n_samples = sf * epoch_s
    n_epochs = data.shape[1] // n_samples
    epochs = []
    for e in range(n_epochs):
        start = e * n_samples
        stop = start + n_samples
        epochs.append(data[:, start:stop])
    return np.array(epochs)

def bandpower_vector_8feat(epoch, sfreq):
    """Compute only 8 features using 4 EEG bands and first 2 channels."""
    bands = [(0.5, 4), (4, 8), (8, 12), (12, 30)]
    feats = []
    for ch in range(min(2, epoch.shape[0])):  # only first 2 channels
        f, Pxx = signal.welch(epoch[ch], fs=sfreq, nperseg=256)
        total = np.trapz(Pxx, f) + 1e-12
        for lo, hi in bands:
            idx = np.logical_and(f >= lo, f <= hi)
            feats.append(np.trapz(Pxx[idx], f[idx]) / total)
    return np.array(feats)

# ------------------- Dataset Creation -------------------
if not os.path.exists(META_PATH):
    print('Metadata not found!')
    sys.exit()

meta = pd.read_csv(META_PATH)
all_X, all_y = [], []

for i, row in tqdm(meta.iterrows(), total=len(meta)):
    fname = row['Filename']
    edf_path = os.path.join(EDF_DIR, fname)
    if not os.path.exists(edf_path):
        print('Missing', edf_path)
        continue
    try:
        raw = load_raw_edf(edf_path, preload=True)
        chans = pick_available_channels(raw)
        picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
        if len(picks) == 0:
            continue
        epochs = epoch_raw_nonoverlap(raw, 30, picks)
        for ep in epochs:
            feat = bandpower_vector_8feat(ep, raw.info['sfreq'])
            all_X.append(feat)
            all_y.append(row['Last sleep stage'])
    except Exception as e:
        print('Error processing', fname, e)

X = np.vstack(all_X)
y = np.array(all_y)
print('8-feature dataset shape:', X.shape)

joblib.dump({'X': X, 'y': y}, OUT_DS)
print('Saved dataset to', OUT_DS)

# ------------------- Train 8-Feature CNN Model -------------------
le = LabelEncoder()
y_enc = le.fit_transform(y)
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(Xs, y_enc, test_size=0.2, stratify=y_enc, random_state=42)

# Dataset for PyTorch
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)  # shape: (N, 1, 8)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = EEGDataset(X_train, y_train)
test_ds = EEGDataset(X_test, y_test)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

# CNN Model
class EEGCNN8(nn.Module):
    def __init__(self, num_classes):
        super(EEGCNN8, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=3)
        self.bn2 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.3)
        
        # dynamically determine flattened size
        self._to_linear = None
        self.convs(torch.randn(1, 1, 8))  # run dummy forward to compute _to_linear
        
        self.fc1 = nn.Linear(self._to_linear, 64)
        self.fc2 = nn.Linear(64, num_classes)

    def convs(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        if self._to_linear is None:
            self._to_linear = x.view(1, -1).shape[1]
        return x

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(self.relu(self.fc1(x)))
        return self.fc2(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGCNN8(num_classes=len(le.classes_)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 25
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model(Xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

# Evaluate CNN
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for Xb, yb in test_loader:
        Xb = Xb.to(device)
        preds = model(Xb).argmax(1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(yb.numpy())

print('8-feature CNN Report:\n', classification_report(y_true, y_pred, target_names=le.classes_))

# Save model
torch.save({'model_state_dict': model.state_dict(),
            'scaler': scaler,
            'label_encoder': le}, DEEP_MODEL_PATH_8FEAT)
print(f'Saved 8-feature CNN model to {DEEP_MODEL_PATH_8FEAT}')

# ------------------- Prediction -------------------
def predict_from_edf_8feat_cnn(edf_path):
    saved = torch.load(DEEP_MODEL_PATH_8FEAT, map_location='cpu')
    scaler, le = saved['scaler'], saved['label_encoder']

    model = EEGCNN8(num_classes=len(le.classes_))
    model.load_state_dict(saved['model_state_dict'])
    model.eval()

    raw = load_raw_edf(edf_path, preload=True)
    chans = pick_available_channels(raw)
    picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
    epochs = epoch_raw_nonoverlap(raw, 30, picks)
    for i, ep in enumerate(epochs):
        feat = bandpower_vector_8feat(ep, raw.info['sfreq']).reshape(1, -1)
        feat_scaled = scaler.transform(feat)
        feat_tensor = torch.tensor(feat_scaled, dtype=torch.float32).unsqueeze(1)
        with torch.no_grad():
            pred = model(feat_tensor).argmax(1).item()
        print(f"Epoch {i+1} [CNN 8-Feature] prediction:", le.inverse_transform([pred])[0])


100%|██████████| 500/500 [00:06<00:00, 72.70it/s]


8-feature dataset shape: (1000, 8)
Saved dataset to ../results/dataset_epoch_features_8feat.pkl
Epoch 1/25, Loss: 1.6270
Epoch 2/25, Loss: 1.5970
Epoch 3/25, Loss: 1.5866
Epoch 4/25, Loss: 1.5684
Epoch 5/25, Loss: 1.5644
Epoch 6/25, Loss: 1.5676
Epoch 7/25, Loss: 1.5626
Epoch 8/25, Loss: 1.5461
Epoch 9/25, Loss: 1.5454
Epoch 10/25, Loss: 1.5410
Epoch 11/25, Loss: 1.5298
Epoch 12/25, Loss: 1.5264
Epoch 13/25, Loss: 1.5312
Epoch 14/25, Loss: 1.5173
Epoch 15/25, Loss: 1.5117
Epoch 16/25, Loss: 1.5209
Epoch 17/25, Loss: 1.5151
Epoch 18/25, Loss: 1.5071
Epoch 19/25, Loss: 1.5048
Epoch 20/25, Loss: 1.5037
Epoch 21/25, Loss: 1.5016
Epoch 22/25, Loss: 1.5013
Epoch 23/25, Loss: 1.4949
Epoch 24/25, Loss: 1.4863
Epoch 25/25, Loss: 1.5000
8-feature CNN Report:
               precision    recall  f1-score   support

          N1       0.15      0.10      0.12        42
          N2       0.21      0.36      0.27        42
          N3       0.31      0.14      0.20        35
         REM       0.26

## LSTM

In [14]:
# ------------------- LSTM Model -------------------
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ----- Prepare dataset -----
class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

# Reshape for LSTM: (samples, timesteps, features)
X_train_lstm = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
X_test_lstm = X_test.reshape(X_test.shape[0], X_test.shape[1], 1)

train_ds = EEGDataset(X_train_lstm, y_train)
test_ds = EEGDataset(X_test_lstm, y_test)

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=32, shuffle=False)

# ----- Define LSTM -----
class EEGLSTM(nn.Module):
    def __init__(self, input_size=1, hidden_size=128, num_layers=2, num_classes=3, dropout=0.3):
        super(EEGLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc1 = nn.Linear(hidden_size, 64)
        self.fc2 = nn.Linear(64, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]  # last timestep
        out = self.dropout(self.relu(self.fc1(out)))
        out = self.fc2(out)
        return out

# ----- Train LSTM -----
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGLSTM(num_classes=len(le.classes_)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 200
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model(Xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

# ----- Evaluate LSTM -----
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for Xb, yb in test_loader:
        Xb = Xb.to(device)
        preds = model(Xb).argmax(1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(yb.numpy())

print('LSTM Classification Report:\n', classification_report(y_true, y_pred, target_names=le.classes_))

# ----- Save LSTM model -----
LSTM_MODEL_PATH = "../results/lstm_eeg_model.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'scaler': scaler,
    'label_encoder': le
}, LSTM_MODEL_PATH)

print(f'Saved LSTM model to {LSTM_MODEL_PATH}')


Epoch 1/200, Loss: 1.6102
Epoch 2/200, Loss: 1.6113
Epoch 3/200, Loss: 1.6086
Epoch 4/200, Loss: 1.6087
Epoch 5/200, Loss: 1.6084
Epoch 6/200, Loss: 1.6084
Epoch 7/200, Loss: 1.6086
Epoch 8/200, Loss: 1.6075
Epoch 9/200, Loss: 1.6087
Epoch 10/200, Loss: 1.6076
Epoch 11/200, Loss: 1.6085
Epoch 12/200, Loss: 1.6054
Epoch 13/200, Loss: 1.6067
Epoch 14/200, Loss: 1.6010
Epoch 15/200, Loss: 1.5993
Epoch 16/200, Loss: 1.5978
Epoch 17/200, Loss: 1.5998
Epoch 18/200, Loss: 1.5876
Epoch 19/200, Loss: 1.5967
Epoch 20/200, Loss: 1.5925
Epoch 21/200, Loss: 1.5896
Epoch 22/200, Loss: 1.5828
Epoch 23/200, Loss: 1.5901
Epoch 24/200, Loss: 1.5837
Epoch 25/200, Loss: 1.5856
Epoch 26/200, Loss: 1.5781
Epoch 27/200, Loss: 1.5784
Epoch 28/200, Loss: 1.5814
Epoch 29/200, Loss: 1.5747
Epoch 30/200, Loss: 1.5767
Epoch 31/200, Loss: 1.5733
Epoch 32/200, Loss: 1.5700
Epoch 33/200, Loss: 1.5602
Epoch 34/200, Loss: 1.5687
Epoch 35/200, Loss: 1.5699
Epoch 36/200, Loss: 1.5627
Epoch 37/200, Loss: 1.5565
Epoch 38/2

In [15]:
import os, sys
import numpy as np
import pandas as pd
from scipy import signal
import mne
from tqdm import tqdm
from sklearn.preprocessing import LabelEncoder, StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import joblib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ------------------- Paths -------------------
ROOT = os.getcwd()
EDF_DIR = '../data/synthetic_edf'
META_PATH = '../data/Metadata.csv'
OUT_DS = '../results/dataset_epoch_features_8feat.pkl'
DEEP_MODEL_PATH_8FEAT_LSTM = '../results/lstm_model_8feat.pt'

os.makedirs('../results', exist_ok=True)

# ------------------- EDF & Feature Extraction -------------------
def load_raw_edf(path, preload=True, verbose=False):
    return mne.io.read_raw_edf(path, preload=preload, verbose=verbose)

def pick_available_channels(raw):
    picks = {'eeg': mne.pick_types(raw.info, eeg=True),
             'eog': mne.pick_types(raw.info, eog=True),
             'emg': mne.pick_types(raw.info, emg=True),
             'ecg': mne.pick_types(raw.info, ecg=True)}
    chans = {}
    for k, v in picks.items():
        chans[k] = [raw.ch_names[i] for i in v] if len(v) > 0 else []
    return chans

def epoch_raw_nonoverlap(raw, epoch_s=30, picks='eeg'):
    if isinstance(picks, str) and picks == 'eeg':
        idx = mne.pick_types(raw.info, eeg=True)
    elif isinstance(picks, list):
        idx = [raw.ch_names.index(ch) for ch in picks if ch in raw.ch_names]
    else:
        idx = mne.pick_types(raw.info, eeg=True)
    data = raw.get_data(picks=idx)
    sf = int(raw.info['sfreq'])
    n_samples = sf * epoch_s
    n_epochs = data.shape[1] // n_samples
    epochs = []
    for e in range(n_epochs):
        start = e * n_samples
        stop = start + n_samples
        epochs.append(data[:, start:stop])
    return np.array(epochs)

def bandpower_vector_8feat(epoch, sfreq):
    """Compute only 8 features using 4 EEG bands and first 2 channels."""
    bands = [(0.5, 4), (4, 8), (8, 12), (12, 30)]
    feats = []
    for ch in range(min(2, epoch.shape[0])):  # only first 2 channels
        f, Pxx = signal.welch(epoch[ch], fs=sfreq, nperseg=256)
        total = np.trapz(Pxx, f) + 1e-12
        for lo, hi in bands:
            idx = np.logical_and(f >= lo, f <= hi)
            feats.append(np.trapz(Pxx[idx], f[idx]) / total)
    return np.array(feats)

# ------------------- Dataset Creation -------------------
if not os.path.exists(META_PATH):
    print('Metadata not found!')
    sys.exit()

meta = pd.read_csv(META_PATH)
all_X, all_y = [], []

for i, row in tqdm(meta.iterrows(), total=len(meta)):
    fname = row['Filename']
    edf_path = os.path.join(EDF_DIR, fname)
    if not os.path.exists(edf_path):
        print('Missing', edf_path)
        continue
    try:
        raw = load_raw_edf(edf_path, preload=True)
        chans = pick_available_channels(raw)
        picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
        if len(picks) == 0:
            continue
        epochs = epoch_raw_nonoverlap(raw, 30, picks)
        for ep in epochs:
            feat = bandpower_vector_8feat(ep, raw.info['sfreq'])
            all_X.append(feat)
            all_y.append(row['Last sleep stage'])
    except Exception as e:
        print('Error processing', fname, e)

X = np.vstack(all_X)
y = np.array(all_y)
print('8-feature dataset shape:', X.shape)

joblib.dump({'X': X, 'y': y}, OUT_DS)
print('Saved dataset to', OUT_DS)

# ------------------- Train 8-Feature LSTM Model -------------------
le = LabelEncoder()
y_enc = le.fit_transform(y)
scaler = StandardScaler()
Xs = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(Xs, y_enc, test_size=0.2, stratify=y_enc, random_state=42)

# Dataset for PyTorch
class EEGDataset(Dataset):
    def __init__(self, X, y):
        # Reshape for LSTM -> (samples, timesteps, features)
        self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(-1)  # shape: (N, 8, 1)
        self.y = torch.tensor(y, dtype=torch.long)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

train_ds = EEGDataset(X_train, y_train)
test_ds = EEGDataset(X_test, y_test)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=16, shuffle=False)

# LSTM Model
class EEGLSTM8(nn.Module):
    def __init__(self, input_size=1, hidden_size=64, num_layers=2, num_classes=3, dropout=0.3):
        super(EEGLSTM8, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout)
        self.fc1 = nn.Linear(hidden_size, 64)
        self.fc2 = nn.Linear(64, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        out, _ = self.lstm(x)   # (batch, seq_len, hidden)
        out = out[:, -1, :]     # last time step
        out = self.dropout(self.relu(self.fc1(out)))
        return self.fc2(out)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EEGLSTM8(num_classes=len(le.classes_)).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

EPOCHS = 25
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for Xb, yb in train_loader:
        Xb, yb = Xb.to(device), yb.to(device)
        optimizer.zero_grad()
        out = model(Xb)
        loss = criterion(out, yb)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/{EPOCHS}, Loss: {total_loss/len(train_loader):.4f}")

# Evaluate LSTM
model.eval()
y_true, y_pred = [], []
with torch.no_grad():
    for Xb, yb in test_loader:
        Xb = Xb.to(device)
        preds = model(Xb).argmax(1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(yb.numpy())

print('8-feature LSTM Report:\n', classification_report(y_true, y_pred, target_names=le.classes_))

# Save model
torch.save({'model_state_dict': model.state_dict(),
            'scaler': scaler,
            'label_encoder': le}, DEEP_MODEL_PATH_8FEAT_LSTM)
print(f'Saved 8-feature LSTM model to {DEEP_MODEL_PATH_8FEAT_LSTM}')

# ------------------- Prediction -------------------
def predict_from_edf_8feat_lstm(edf_path):
    saved = torch.load(DEEP_MODEL_PATH_8FEAT_LSTM, map_location='cpu')
    scaler, le = saved['scaler'], saved['label_encoder']

    model = EEGLSTM8(num_classes=len(le.classes_))
    model.load_state_dict(saved['model_state_dict'])
    model.eval()

    raw = load_raw_edf(edf_path, preload=True)
    chans = pick_available_channels(raw)
    picks = chans['eeg'] + chans['eog'] + chans['emg'] + chans['ecg']
    epochs = epoch_raw_nonoverlap(raw, 30, picks)
    for i, ep in enumerate(epochs):
        feat = bandpower_vector_8feat(ep, raw.info['sfreq']).reshape(1, -1)
        feat_scaled = scaler.transform(feat)
        feat_tensor = torch.tensor(feat_scaled, dtype=torch.float32).unsqueeze(-1)
        with torch.no_grad():
            pred = model(feat_tensor).argmax(1).item()
        print(f"Epoch {i+1} [LSTM 8-Feature] prediction:", le.inverse_transform([pred])[0])


100%|██████████| 500/500 [00:06<00:00, 74.22it/s]


8-feature dataset shape: (1000, 8)
Saved dataset to ../results/dataset_epoch_features_8feat.pkl
Epoch 1/25, Loss: 1.6109
Epoch 2/25, Loss: 1.6089
Epoch 3/25, Loss: 1.6081
Epoch 4/25, Loss: 1.6097
Epoch 5/25, Loss: 1.6082
Epoch 6/25, Loss: 1.6097
Epoch 7/25, Loss: 1.6094
Epoch 8/25, Loss: 1.6092
Epoch 9/25, Loss: 1.6066
Epoch 10/25, Loss: 1.6059
Epoch 11/25, Loss: 1.6000
Epoch 12/25, Loss: 1.6051
Epoch 13/25, Loss: 1.5996
Epoch 14/25, Loss: 1.5987
Epoch 15/25, Loss: 1.5987
Epoch 16/25, Loss: 1.6005
Epoch 17/25, Loss: 1.5905
Epoch 18/25, Loss: 1.5952
Epoch 19/25, Loss: 1.5929
Epoch 20/25, Loss: 1.5865
Epoch 21/25, Loss: 1.5802
Epoch 22/25, Loss: 1.5867
Epoch 23/25, Loss: 1.5858
Epoch 24/25, Loss: 1.5749
Epoch 25/25, Loss: 1.5815
8-feature LSTM Report:
               precision    recall  f1-score   support

          N1       0.19      0.12      0.15        42
          N2       0.22      0.45      0.29        42
          N3       0.00      0.00      0.00        35
         REM       0.2

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
