In [None]:
import os
import glob
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import  DataLoader, TensorDataset,Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from sklearn.utils.class_weight import compute_class_weight
import gc
from tqdm.notebook import tqdm
from RESTCORE import REST
from RESTutils import create_sequences,get_oversampled_indices, FocalLoss

In [None]:
# This cell sets the parameters for the REST model and training process, change as needed, make sure to set the paths to your dataset and model correctly.
fs = 512  # Sampling frequency
epoch_length = 4  # Epoch length in seconds
window_size = 90 # Window size for sliding window
step = 60 # Step size for sliding window
nperseg = 256  # Segment length for PSD computation
batch_size = 128 # Batch size for training
n_epochs = 100  # Number of training epochs
f_bin=130 # Frequency bin for PSD computation
n_classes = 3   # Number of sleep stages (e.g., Wake, NREM, REM)
WeightedLoss = True # Use weighted loss function
OversampleRAM = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = REST(
    in_feat=f_bin,
    n_classes=3,
    win_len=window_size,
    d_model=256,
    nhead=8,
    nlayers_epoch=4,
    nlayers_seq=4,
    ff=512,
    fc_hidden1=128,
    fc_hidden2=64,
    dropout=0.1
).to(device)


In [None]:
# Look for a npz and pth files from the script directory. npz is the training dataset, pth is the model file, 
script_dir = os.getcwd()
# Look for a .npz file
npz_files = glob.glob(os.path.join(script_dir, "*.npz"))
if not npz_files:
    raise FileNotFoundError("No .npz training dataset found in the script directory.")
elif len(npz_files) > 1:
    print("Multiple .npz files found, using the first one.")

ds_path = npz_files[0]
print(f"Using dataset: {ds_path}")

Model_path = os.path.join(script_dir, "model_general.pth")
print(f"Model will be saved to: {Model_path}")

# use following two variables to set your own paths if needed
# ds_path = r"" # insert path to your dataset here
# Model_path=r"" # insert path to save the model here

In [None]:

arr = np.load(ds_path)
EEG = arr["EEG"]     # shape: [n_epochs,   256 * 4]  (down‑sampled to 64 Hz)
EMG = arr["EMG"]     # shape: [n_epochs, 1024 * 4]  (down‑sampled to 256 Hz)
score = arr["score"]
score=score-1 # convert to 0,1,2,3 (wake=0,NREM=1,REM=2, Artefact=3)
score[score >= 3] = -100 # set Artefact to -100 (ignore in loss function)
score[score < 0] = -100 # set Artefact to -100 (ignore in loss function)
score = score.astype(np.int64) 
del arr
gc.collect()

In [None]:
eeg_power = np.mean(EEG ** 2, axis=(1, 2))  # EEG power in each epoch
eeg_thresh = np.percentile(eeg_power, 99) # threshold for suspect epochs
suspect_epochs = np.where((eeg_power > eeg_thresh))[0] #exclude epoch with super high EEG power
score[suspect_epochs] = -100 # mark suspect epochs as invalid

In [None]:
unique_vals, counts = np.unique(score, return_counts=True)
for val, count in zip(unique_vals, counts):
    print(f"Label {val}: {count} samples")

In [None]:
# Concatenate along feature dimension  → [n_epochs, frames=5, feat=65*2]
epoch_tensor = np.concatenate([EEG, EMG], axis=-1).astype(np.float32)
labels = score # [n_epochs]

# Build sliding windows exactly like before
X, Y = create_sequences(window_size, step,epoch_tensor, labels)
del EEG, EMG ,epoch_tensor
# Split data into training and validation sets
X_train, X_val, Y_train, Y_val = train_test_split(
    X, Y, test_size=0.2, random_state=42
)
del X, Y

# Convert to PyTorch tensors
X_train = torch.tensor(X_train, dtype=torch.float32)
X_val = torch.tensor(X_val, dtype=torch.float32)
Y_train = torch.tensor(Y_train, dtype=torch.long)
Y_val = torch.tensor(Y_val, dtype=torch.long)

# Create DataLoader for training and validation
train_dataset = TensorDataset(X_train, Y_train)
oversampled_idx = get_oversampled_indices(Y_train.numpy(), repeat_factor=3)

if OversampleRAM:
    train_dataset = Subset(train_dataset, oversampled_idx)
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,  # still shuffle across the repeated indices
    )
else:
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
    )
val_dataset = TensorDataset(X_val, Y_val)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [None]:
train_labels_flat = Y_train.view(-1).cpu().numpy()
del X_train, X_val, Y_train, Y_val
train_labels_flat = train_labels_flat[train_labels_flat != -100]  # Fix: remove -100 before computing weights

# Compute class weights
class_weights = compute_class_weight('balanced', classes=np.unique(train_labels_flat), y=train_labels_flat)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)

# Initialize model, loss function, and optimizer

if WeightedLoss:
    # criterion = nn.CrossEntropyLoss(weight=class_weights)
    criterion = FocalLoss(alpha=class_weights, gamma=2, ignore_index=-100)
else:
    criterion = nn.CrossEntropyLoss(ignore_index=-100)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)

In [None]:
# Training loop
best_val_accuracy = 0.0
patientce = 0
counter = 0
for epoch in range(n_epochs):
    model.train()
    train_loss = 0.0
    for batch_X, batch_Y in tqdm(train_loader, desc=f"Epoch {epoch+1}/{n_epochs} - Training"):
        batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
        optimizer.zero_grad()
        output = model(batch_X)  # Shape: [batch_size, sequence_length, n_classes]
        loss = criterion(output.view(-1, n_classes), batch_Y.view(-1))  # Flatten for loss computation
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    # Validation
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_X, batch_Y in tqdm(val_loader, desc=f"Epoch {epoch+1}/{n_epochs} - Validation"):
            batch_X, batch_Y = batch_X.to(device), batch_Y.to(device)
            output = model(batch_X)  # Shape: [batch_size, sequence_length, n_classes]
            loss = criterion(output.view(-1, n_classes), batch_Y.view(-1))  # Flatten for loss computation
            val_loss += loss.item()

            # Compute accuracy
            _, predicted = torch.max(output.data, 2)  # Shape: [batch_size, sequence_length] 
            total += batch_Y.size(0) * batch_Y.size(1)  # Total number of predictions
            correct += (predicted == batch_Y).sum().item()  # Correct predictions

    # Print epoch results
    train_loss /= len(train_loader)
    val_loss /= len(val_loader)
    val_accuracy = 100 * correct / total
    print(f"Epoch {epoch+1}/{n_epochs}, "
          f"Train Loss: {train_loss:.4f}, "
          f"Val Loss: {val_loss:.4f}, "
          f"Val Accuracy: {val_accuracy:.2f}%")

    # Save the best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save(model.state_dict(), Model_path)
        print(f"New best model saved with accuracy {best_val_accuracy:.2f}%")
    else:
        patientce += 1
        if patientce >= 50:
            print("Early stopping triggered.")
            break
print(f"Training complete. Best validation accuracy: {best_val_accuracy:.2f}%")

In [None]:
model.load_state_dict(torch.load(Model_path,weights_only=True))
model.eval()

# Store all predictions and targets
all_preds = []
all_targets = []

with torch.no_grad():
    for batch_X, batch_Y in val_loader:
        batch_X = batch_X.to(device)
        output = model(batch_X)  # Shape: [B, W, C]
        preds = torch.argmax(output, dim=2)  # [B, W]

        all_preds.append(preds.cpu().view(-1))
        all_targets.append(batch_Y.view(-1))  # Already on CPU

# Concatenate all batches
all_preds = torch.cat(all_preds).numpy()
all_targets = torch.cat(all_targets).numpy()

# Filter out ignored labels (e.g., -100)
mask = all_targets != -100
all_preds = all_preds[mask]
all_targets = all_targets[mask]

# Print results
print("\nClassification Report (Validation Set):")
print(classification_report(all_targets, all_preds, target_names=["Wake", "NREM", "REM"]))

print("Confusion Matrix:")
print(confusion_matrix(all_targets, all_preds))