In [None]:
# %matplotlib widget
import os
import glob
import torch
import numpy as np
from torch.utils.data import  DataLoader
from sklearn.metrics import confusion_matrix,cohen_kappa_score,accuracy_score, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from RESTCORE import REST
from RESTutils import create_sequences

In [None]:
# Parameters
fs = 512  # Sampling frequency
epoch_length = 4  # Epoch length in seconds
nperseg = 256  # Segment length for PSD computation
sequence_length = 90 # Number of epochs in a sequence. 15 epochs/minute
window_size=sequence_length
step=60 # Step size for creating sequences
batch_size = 32  # Batch size for training
n_classes = 3   # Number of sleep stages (e.g., Wake, NREM, REM)
f_bin=130 # Frequency bin for PSD computation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
ds_path = r'' # Path to the test dataset file, it need to be different (unseen) from the training dataset file
arr = np.load(ds_path)
EEG = arr["EEG"]     # shape: [n_epochs,   256 * 4]  (down‑sampled to 64 Hz)
EMG = arr["EMG"]     # shape: [n_epochs, 1024 * 4]  (down‑sampled to 256 Hz)
score = arr["score"] - 1
score[score > 2] = 0             # collapse stage “?” to Wake
score = score.astype(np.int64) #wake=0,NREM=1,REM=2
del arr

In [None]:
# Concatenate along feature dimension  → [n_epochs, frames=5, feat=65*2]
epoch_tensor = np.concatenate([EEG, EMG], axis=-1).astype(np.float32)                   
X, Y = create_sequences(window_size, step,epoch_tensor, score)
sequences_tensor = torch.tensor(X, dtype=torch.float32).to(device)
sequences_batch=DataLoader(sequences_tensor, batch_size=batch_size, shuffle=False)

In [None]:
# Get the current script's directory
script_dir = os.getcwd()
# Search for .pth files in the same directory
pth_files = glob.glob(os.path.join(script_dir, "*.pth"))

if len(pth_files) == 0:
    raise FileNotFoundError("No .pth model file found in script directory.")
elif len(pth_files) > 1:
    print("Warning: multiple .pth files found. Using the first one.")
Model_path = pth_files[0]
print(f"Loading model from: {Model_path}")

In [None]:
model = REST(
    in_feat=f_bin,
    n_classes=3,
    win_len=window_size,
    d_model=256,
    nhead=8,
    nlayers_epoch=4,
    nlayers_seq=4,
    ff=512,
    fc_hidden1=128,
    fc_hidden2=64,
    dropout=0.1
).to(device)
model.load_state_dict(torch.load(Model_path, weights_only=True))  # Load the trained weights
model.to(device)  # Move the model to the GPU
model.eval()  # Set the model to evaluation mode
all_preds=[]
with torch.no_grad():
    for batch_X in sequences_batch:
        batch_X= batch_X.to(device)
        output = model(batch_X)  # Shape: [batch_size, sequence_length, n_classes]
        predicted = torch.argmax(output.data, 2)  # Shape: [batch_size, sequence_length] 
        first_epoch_preds = predicted[:,:step].cpu().numpy()
        all_preds.append(first_epoch_preds)
predictions = np.concatenate(all_preds, axis=0).flatten() + 1  # Adjust labels if needed

In [None]:
predicted_score = np.array(predictions)
Reference_score=score+1
Reference_score[Reference_score==4] = 1
reference_score = Reference_score[:len(predicted_score)]

In [None]:
labels = [1, 2, 3]
cm = confusion_matrix(reference_score, predicted_score, labels=labels)
kappa = cohen_kappa_score(reference_score, predicted_score)
accuracy = accuracy_score(reference_score, predicted_score)
print(f"Accuracy: {accuracy:.4f}")
print(f"\nCohen's Kappa: {kappa:.3f}")
# Detailed classification report
print("\nClassification Report:")
print(classification_report(reference_score, predicted_score, target_names=["Wake", "NREM", "REM"]))
# Compute the confusion matrix
# Optional: heatmap plot
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=["Wake", "NREM", "REM"],
            yticklabels=["Wake", "NREM", "REM"])
plt.xlabel("REST Prediction")
plt.ylabel("Manual Score")
plt.title("Confusion Matrix (Manual vs Transformer)")
plt.tight_layout()
plt.savefig("confusion_matrix_manual_VS_transformer.tiff", dpi=600, format="tiff")
plt.show()
plt.figure(figsize=(15, 5))
# plt.step(range(1800), predicted_score[:1800], where='mid', label='Predicted Labels')
plt.step(range(len(reference_score)), reference_score, where='mid', label='Reference Labels')
plt.xlabel('Epoch')
plt.ylabel('Reference Label')
plt.title('Reference Labels as Stairs Plot')
plt.legend()
plt.show()
plt.figure(figsize=(15, 5))
# plt.step(range(1800), predicted_score[:1800], where='mid', label='Predicted Labels')
plt.step(range(len(predicted_score)), predicted_score, where='mid', label='Predicted Labels')
plt.xlabel('Epoch')
plt.ylabel('Predicted Label')
plt.title('Predicted Labels as Stairs Plot')
plt.legend()
plt.show()