In [None]:
import os
import glob
import mne
import torch
import numpy as np
from scipy.io import savemat
from torch.utils.data import  DataLoader
from tqdm.notebook import tqdm
from RESTCORE import REST
from RESTutils import compute_powers, data_process,create_sequences

In [None]:
# Parameters, make sure they match those used in training
fs = 512  # Sampling frequency
epoch_length = 4  # Epoch length in seconds
window_size = 90 # Number of epochs in a sequence
step=60 # overlapping step size for sequences
batch_size = 128  # Batch size for training
n_classes = 3   # Number of sleep stages (e.g., Wake, NREM, REM)
f_bin=130 # Frequency bin for PSD computation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = REST(
    in_feat=f_bin,
    n_classes=3,
    win_len=window_size,
    d_model=256,
    nhead=8,
    nlayers_epoch=4,
    nlayers_seq=4,
    ff=512,
    fc_hidden1=128,
    fc_hidden2=64,
    dropout=0.1
).to(device)

"""
important! 
"""
edf_folder =  # insert Path to the folder containing EDF files, predicated score will be saved in the same folder

In [None]:
# Get the current script's directory
script_dir = os.getcwd()
# Search for .pth files in the same directory
pth_files = glob.glob(os.path.join(script_dir, "*.pth"))

if len(pth_files) == 0:
    raise FileNotFoundError("No .pth model file found in script directory.")
elif len(pth_files) > 1:
    print("Warning: multiple .pth files found. Using the first one.")
Model_path = pth_files[0]
print(f"Loading model from: {Model_path}")

model.load_state_dict(torch.load(Model_path, weights_only=True))  # Load the trained weights
model.to(device)  # Move the model to the GPU
model.eval()  # Set the model to evaluation mode

In [None]:
edf_files = glob.glob(os.path.join(edf_folder,"**", "*.edf"), recursive=True)
for fp_edf in tqdm(edf_files):
    file_name = os.path.splitext(os.path.basename(fp_edf))[0]
    save_folder = os.path.dirname(fp_edf)
    save_path = os.path.join(save_folder, file_name + "_REST.mat")
    
    raw = mne.io.read_raw_edf(fp_edf, preload=True)                            
    channel_name = raw.info.ch_names

    EEG_channel = [i for i, name in enumerate(channel_name) if 'RF' in name and 'LP' not in name]
    EMG_channel = [index for index, name in enumerate(channel_name) if 'EMG' in name]
    
    if not EEG_channel or not EMG_channel:
        print(f"Skipping {fp_edf}: Missing 'RF' or 'EMG' channel")
        continue

    EEG = raw.get_data(EEG_channel)
    EMG = raw.get_data(EMG_channel)

    power = compute_powers(EEG, EMG, sfreq=512)
    EEG_STFT, EMG_STFT = data_process(EEG, EMG)
    STFT = np.concatenate((EEG_STFT, EMG_STFT), axis=-1)
    X = create_sequences(window_size, step,STFT)
    sequences_tensor = torch.tensor(X, dtype=torch.float32).to(device)
    sequences_batch = DataLoader(sequences_tensor, batch_size=batch_size, shuffle=False)

    all_preds = []
    with torch.no_grad():
        for batch_X in sequences_batch:
            batch_X = batch_X.to(device)
            output = model(batch_X)
            predicted = torch.argmax(output.data, 2)
            first_epoch_preds = predicted[:, :step].cpu().numpy()
            all_preds.append(first_epoch_preds)

    predictions = np.concatenate(all_preds, axis=0).flatten() + 1
    score = np.array(predictions, dtype=np.int64)

    savemat(save_path, {'score': score, 'power': power})
    print(f"Saved: {save_path}")
