# Prepare EEG dataset for REVE

## Resample EEG from 256 Hz to 200 Hz

In [5]:
import os
import numpy as np
from scipy.signal import resample

# ==========================
# Config
# ==========================
data_root = "/gpfs1/pi/djangraw/mindless_reading/data"
sfreq = 256
window_seconds = 2
window_size = int(sfreq * window_seconds)

# ==========================
# Collect subjects
# ==========================
all_subjects = sorted(
    d for d in os.listdir(data_root)
    if d.startswith("s") and os.path.isdir(os.path.join(data_root, d))
)

# ==========================
# Main loop over subjects
# ==========================
for subject_id in all_subjects:
    subject_dir = os.path.join(
        data_root,
        subject_id,
        "ml_data",
        f"{window_size}window_datasets",
    )

    data_file = os.path.join(subject_dir, f"{subject_id}_{window_size}windowed_data.npy")
    data = np.load(data_file)
    eeg_data = data[:, :64, :]
    eeg_data = resample(eeg_data, int(eeg_data.shape[-1] * 200 / 256), axis=-1)

    np.save(os.path.join(subject_dir, f"{subject_id}_{eeg_data.shape[-1]}windowed_eeg_data.npy"), eeg_data)

# REVE

## Load 200Hz EEG as dataset

In [10]:
import os
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

# ==========================
# Config
# ==========================
data_root = "/gpfs1/pi/djangraw/mindless_reading/data"
sfreq = 200
window_seconds = 2
window_size = int(sfreq * window_seconds)
X_all = []
y_all = []
chan_list = None

# ==========================
# Collect subjects
# ==========================
all_subjects = sorted(
    d for d in os.listdir(data_root)
    if d.startswith("s") and os.path.isdir(os.path.join(data_root, d))
)

# ==========================
# Main loop over subjects
# ==========================
for subject_id in all_subjects:
    subject_dir = os.path.join(
        data_root,
        subject_id,
        "ml_data",
        f"{256*window_seconds}window_datasets",
    )

    # Load channel list once
    if chan_list is None:
        col_file = os.path.join(subject_dir, f"{subject_id}_col_names.npy")
        col_names = np.load(col_file, allow_pickle=True)
        chan_list = col_names[:64].tolist()
        # Standardize channel names ("Afz" to "AFz")
        chan_list = ["AFz" if ch == "Afz" else ch for ch in chan_list]
    
    data_file = os.path.join(subject_dir, f"{subject_id}_{window_size}windowed_eeg_data.npy")
    labels_file = os.path.join(subject_dir, f"{subject_id}_{256*window_seconds}windowed_labels.npy")
    data = np.load(data_file)
    labels = np.load(labels_file)
    X_all.append(data)
    y_all.append(labels)

X_all = np.concatenate(X_all, axis=0)
y_all = np.concatenate(y_all, axis=0)

print("Final dataset shape:", X_all.shape, y_all.shape)

# split into train/val/test (80/10/10) with stratification
X_train, X_tmp, y_train, y_tmp = train_test_split(
    X_all, y_all, test_size=0.2, random_state=42, stratify=y_all
)
X_val, X_test, y_val, y_test = train_test_split(
    X_tmp, y_tmp, test_size=0.5, random_state=42, stratify=y_tmp
)

# Convert to Hugging Face Datasets
train_ds = Dataset.from_dict({
    "data": list(X_train),
    "labels": list(y_train),
})
val_ds = Dataset.from_dict({
    "data": list(X_val),
    "labels": list(y_val),
})
test_ds = Dataset.from_dict({
    "data": list(X_test),
    "labels": list(y_test),
})
dataset = DatasetDict({
    "train": train_ds,
    "val": val_ds,
    "test": test_ds,
})

dataset.set_format("torch", columns=["data", "labels"])
print(dataset)

Final dataset shape: (5604, 64, 400) (5604,)
DatasetDict({
    train: Dataset({
        features: ['data', 'labels'],
        num_rows: 4483
    })
    val: Dataset({
        features: ['data', 'labels'],
        num_rows: 560
    })
    test: Dataset({
        features: ['data', 'labels'],
        num_rows: 561
    })
})


## Load REVE model from HF

In [7]:
from transformers import AutoModel

pos_bank = AutoModel.from_pretrained("brain-bzh/reve-positions", trust_remote_code=True)
model = AutoModel.from_pretrained("brain-bzh/reve-large", trust_remote_code=True)

2026-02-09 10:35:02.761418: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


flash_attn not found, install it with `pip install flash_attn` if you want to use it


## Model and dataset setup

In [11]:
import torch
from transformers import set_seed
from functools import partial

ch_num = len(chan_list)
dim = ch_num * window_seconds * 512

model.final_layer = torch.nn.Sequential(
    torch.nn.Flatten(),
    torch.nn.RMSNorm(dim),
    torch.nn.Dropout(0.1),
    torch.nn.Linear(dim, 2),
)

# Training parameters
batch_size = 64
n_epochs = 20
lr = 1e-3
positions = pos_bank(chan_list)

set_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

def collate(batch, positions):
    x_data = torch.stack([x["data"] for x in batch])
    y_label = torch.tensor([x["labels"] for x in batch])
    positions = positions.repeat(len(batch), 1, 1)
    return {"sample": x_data,"label": y_label.long(),"pos": positions}
collate_fn = partial(collate, positions=positions)

train_loader = torch.utils.data.DataLoader(dataset["train"], batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
val_loader = torch.utils.data.DataLoader(dataset["val"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset["test"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn)