### Model Definition

In [26]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import mne
import glob
import os
from pathlib import Path
from preprocessing import preprocess_record, preprocess_all
from features import extract_features_all_epochs
from sklearn.model_selection import train_test_split

In [27]:
class SleepStageCNN(nn.Module):
    def __init__(self, input_length, n_channels=1, n_classes=5):
        super().__init__()

        self.conv1 = nn.Conv1d(
            in_channels=n_channels,
            out_channels=32,
            kernel_size=7,
            padding=3
        )
        self.pool1 = nn.MaxPool1d(kernel_size=2)

        self.conv2 = nn.Conv1d(
            in_channels=32,
            out_channels=64,
            kernel_size=5,
            padding=2
        )

        # Global average pooling will reduce temporal dimension
        self.gap = nn.AdaptiveAvgPool1d(1)

        self.fc = nn.Linear(64, n_classes)

    def forward(self, x):
        # x shape: (batch, channels, samples)
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.gap(x)          # → (batch, 64, 1)
        x = x.squeeze(-1)        # → (batch, 64)
        x = self.fc(x)
        return x


### Instantiate Model

In [28]:
input_length = 3000  # example for 30s @ 100Hz
model = SleepStageCNN(input_length=input_length, n_channels=1, n_classes=5)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


SleepStageCNN(
  (conv1): Conv1d(1, 32, kernel_size=(7,), stride=(1,), padding=(3,))
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (gap): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=64, out_features=5, bias=True)
)

### Optimizer, Loss, LR Schedule and Early Stopping

In [29]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

lr_scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=5,
    min_lr=1e-6
)

In [30]:
#early stopping helper
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = None
        self.counter = 0
        self.should_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.should_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0


### Data reading

In [31]:
def load_sleep_edf(psg_file, hyp_file, channel="EEG Fpz-Cz", epoch_length=30):
    """
    Returns:
        X: np array (n_epochs, samples)
        y: np array (n_epochs,)
    """

    # --- Load PSG ---
    raw = mne.io.read_raw_edf(psg_file, preload=True)
    raw.pick_channels([channel])
    raw.filter(0.3, 35)
    ann = mne.read_annotations(hyp_file)
    raw.set_annotations(ann, emit_warning=False)

    # Map annotation labels → integers
    # You may customize this mapping
    stage_map = {
        "Sleep stage W": 0,
        "Sleep stage 1": 1,
        "Sleep stage 2": 2,
        "Sleep stage ?": 2,
        "Sleep stage 3": 3,
        "Sleep stage 4": 3,# or 3/4 combined
        "Sleep stage R": 4,
    }

    # Create events from annotations
    events, event_ids = mne.events_from_annotations(raw, chunk_duration=epoch_length)

    # Epoch the PSG data
    epochs = mne.Epochs(raw, events, event_ids, tmin=0, tmax=epoch_length, baseline=None, preload=True)

    # Convert to (n_epochs, samples)
    X = epochs.get_data().squeeze(1)  # remove channel dimension if only one channel

    # Convert event IDs to labels
    y = np.array([stage_map.get(k, -1) for k in epochs.events[:, 2]])

    # Remove undefined stages (movement, unknown)
    mask = y != -1
    X = X[mask]
    y = y[mask]

    return X, y


In [32]:
from pathlib import Path
RAW_DIR = Path("sample_data")

def get_record_pairs(raw_dir):
    """Return list of (psg_file, hyp_file) pairs."""
    psg_files = {}
    hyp_files = {}

    for f in raw_dir.glob("*.edf"):
        name = f.name
        if "-PSG" in name:
            key = name.split("-")[0]   # e.g., "SC4001E0"
            psg_files[key] = f
        if "Hypnogram" in name:
            key = name.split("-")[0]   # e.g., "SC4001EC"
            # Normalize key to match PSG key
            key = key.replace("EC", "E0").replace("EH", "E0")
            hyp_files[key] = f

    pairs = []
    for key in psg_files:
        if key in hyp_files:
            pairs.append((psg_files[key], hyp_files[key]))
        else:
            print(f"⚠️ Missing hypnogram for {key}")

    return pairs

pairs = get_record_pairs(RAW_DIR)
print("Found pairs:", len(pairs))

Found pairs: 5


In [33]:
import os
X_all, y_all = [], []

for (psg, hyp) in pairs:
    psg_path = str(psg)
    hyp_path = str(hyp)

    print(f"Processing {os.path.basename(psg)}")
    X, y = preprocess_record(psg_path, hyp_path)
    X_all.append(X)
    y_all.append(y)

Processing SC4002E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/sample_data/SC4002E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8489999  =      0.000 ... 84899.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


Processing SC4011E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/sample_data/SC4011E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8405999  =      0.000 ... 84059.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


Processing SC4021E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/sample_data/SC4021E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8411999  =      0.000 ... 84119.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


Processing SC4012E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/sample_data/SC4012E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8549999  =      0.000 ... 85499.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


Processing SC4001E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/sample_data/SC4001E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 7949999  =      0.000 ... 79499.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


In [35]:
X_all2 = np.concatenate(X_all, axis=0)
y_all2 = np.concatenate(y_all, axis=0)
df_xall = pd.DataFrame(X_all2)
df_yall = pd.DataFrame(y_all2)

In [40]:
from sklearn.model_selection import train_test_split

X_train, X_val, y_train, y_val = train_test_split(
    X_all2, y_all2, test_size=0.2, random_state=42, stratify=y_all2
)

X_train = X_train[:, np.newaxis, :]  # add channel dim
X_val   = X_val[:, np.newaxis, :]

X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)

X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.long)


### Training loop

In [41]:
batch_size = 64

train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)

early_stopping = EarlyStopping(patience=10)

num_epochs = 100

for epoch in range(num_epochs):
    model.train()
    total_loss = 0

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)
            preds = model(X)
            loss = criterion(preds, y)
            val_loss += loss.item()

    val_loss /= len(val_loader)

    lr_scheduler.step(val_loss)

    print(f"Epoch {epoch+1}/{num_epochs}, Val Loss: {val_loss:.4f}")

    early_stopping(val_loss)
    if early_stopping.should_stop:
        print("Early stopping triggered.")
        break


Epoch 1/100, Val Loss: 1.0075
Epoch 2/100, Val Loss: 1.0073
Epoch 3/100, Val Loss: 1.0084
Epoch 4/100, Val Loss: 1.0123
Epoch 5/100, Val Loss: 1.0145
Epoch 6/100, Val Loss: 1.0099
Epoch 7/100, Val Loss: 1.0072
Epoch 8/100, Val Loss: 1.0083
Epoch 9/100, Val Loss: 1.0072
Epoch 10/100, Val Loss: 1.0084
Epoch 11/100, Val Loss: 1.0079
Epoch 12/100, Val Loss: 1.0074
Epoch 13/100, Val Loss: 1.0078
Epoch 14/100, Val Loss: 1.0076
Epoch 15/100, Val Loss: 1.0071
Epoch 16/100, Val Loss: 1.0072
Epoch 17/100, Val Loss: 1.0069
Epoch 18/100, Val Loss: 1.0071
Epoch 19/100, Val Loss: 1.0072
Epoch 20/100, Val Loss: 1.0075
Epoch 21/100, Val Loss: 1.0074
Epoch 22/100, Val Loss: 1.0076
Epoch 23/100, Val Loss: 1.0089
Epoch 24/100, Val Loss: 1.0072
Epoch 25/100, Val Loss: 1.0070
Epoch 26/100, Val Loss: 1.0070
Epoch 27/100, Val Loss: 1.0072
Early stopping triggered.


In [43]:
TEST_DIR = Path("testing_data")
pairs = get_record_pairs(TEST_DIR)
print("Found pairs:", len(pairs))

Found pairs: 1


In [51]:
X_test_list = []
y_test_list = []

for (psg, hyp) in pairs:
    psg_path = str(psg)
    hyp_path = str(hyp)

    print(f"Processing {os.path.basename(psg)}")
    test_X, test_y = preprocess_record(psg_path, hyp_path)

    X_test_list.append(test_X)
    y_test_list.append(test_y)

# Concatenate epochs
X_test = np.concatenate(X_test_list, axis=0)  # (N, 3000)
y_test = np.concatenate(y_test_list, axis=0)  # (N,)


Processing SC4061E0-PSG.edf
Extracting EDF parameters from /Users/onogantsog/Code/stageclassification/testing_data/SC4061E0-PSG.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 8309999  =      0.000 ... 83099.990 secs...


  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)
  raw = mne.io.read_raw_edf(psg_path, preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


  raw.set_annotations(ann)


In [52]:
# Add channel dimension
X_test = X_test[:, None, :]   # (N, 1, 3000)

# Convert to torch
X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.long).to(device)
lengths = {x.shape[1] for x in X_test_list}
print("Unique epoch lengths:", lengths)


Unique epoch lengths: {3000}


In [45]:
X_test = torch.tensor(X_test, dtype=torch.float32).to(device)
y_test = torch.tensor(y_test, dtype=torch.long).to(device)


ValueError: expected sequence of length 2770 at dim 0 (got 3000)

In [53]:
model.eval()


SleepStageCNN(
  (conv1): Conv1d(1, 32, kernel_size=(7,), stride=(1,), padding=(3,))
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv1d(32, 64, kernel_size=(5,), stride=(1,), padding=(2,))
  (gap): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=64, out_features=5, bias=True)
)

In [54]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import numpy as np

y_true = []
y_pred = []

with torch.no_grad():
    outputs = model(X_test)
    preds = torch.argmax(outputs, dim=1)

    y_true = y_test.cpu().numpy()
    y_pred = preds.cpu().numpy()


In [55]:
print("Test Accuracy:", accuracy_score(y_true, y_pred))
print(classification_report(
    y_true, y_pred,
    target_names=["W", "N1", "N2", "N3", "REM"]
))


Test Accuracy: 0.7469314079422382
              precision    recall  f1-score   support

           W       0.75      1.00      0.86      2069
          N1       0.00      0.00      0.00        56
          N2       0.00      0.00      0.00       407
          N3       0.00      0.00      0.00       136
         REM       0.00      0.00      0.00       102

    accuracy                           0.75      2770
   macro avg       0.15      0.20      0.17      2770
weighted avg       0.56      0.75      0.64      2770



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [56]:
cm = confusion_matrix(y_true, y_pred)
print(cm)


[[2069    0    0    0    0]
 [  56    0    0    0    0]
 [ 407    0    0    0    0]
 [ 136    0    0    0    0]
 [ 102    0    0    0    0]]
