### EEGFormer for SEED dataset
#### In this notebook I only demonstrate preprocessing + model forward.
##### EEGFormer on raw SEED data (9_1.cnt): Considering computer memory limitations, only a part of the original data was selected for processing. The first 16 channels + first 300 seconds. (regarding the seed introduction file, 24s->264s: label=1ï¼› 289s->526s: label=0)

In [21]:
import os
import sys
from pathlib import Path
from math import gcd

import numpy as np
import torch
import mne
from scipy import signal

In [22]:
root = None
for p in [Path.cwd()] + list(Path.cwd().parents):
    if (p / "pretrained_weights").exists() and (p / "data").exists():
        root = p
        break

print("Project root:", root)

CNT_PATH = root / "data" / "seed_eeg" / "9_1.cnt"
print("CNT:", CNT_PATH)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

Project root: f:\TUD\Phd_application\Oulu\Code
CNT: f:\TUD\Phd_application\Oulu\Code\data\seed_eeg\9_1.cnt
device: cuda


In [None]:
# Stream windows to avoid huge memory allocation
def iter_seed_cnt_windows(
    raw: mne.io.BaseRaw,
    t_max_sec: float = 300.0,
    win_sec: float = 4.0,
    stride_sec: float = 1.0,
    first_n_ch: int = 16,
    target_fs: float | None = 250.0,
    zscore_mode: str = "per_window",
):

    sfreq = float(raw.info["sfreq"])
    max_samp = int(t_max_sec * sfreq)
    win_samp = int(round(win_sec * sfreq))
    stride_samp = int(round(stride_sec * sfreq))

    if target_fs is not None:
        orig_fs_int = int(round(sfreq))
        tgt_fs_int  = int(round(target_fs))
        g = gcd(orig_fs_int, tgt_fs_int)
        up = tgt_fs_int // g
        down = orig_fs_int // g
    else:
        up = down = None

    for s in range(0, max_samp - win_samp + 1, stride_samp):
        seg = raw.get_data(start=s, stop=s + win_samp)
        seg = seg[:first_n_ch].astype(np.float32)

        if target_fs is not None:
            seg = signal.resample_poly(seg, up=up, down=down, axis=1).astype(np.float32)

        if zscore_mode == "per_window":
            mean = seg.mean(axis=1, keepdims=True)
            std  = seg.std(axis=1, keepdims=True) + 1e-8
            seg = (seg - mean) / std
        elif zscore_mode == "none":
            pass
        else:
            raise ValueError("zscore_mode must be 'per_window' or 'none'")

        yield s, seg.T

In [None]:
T_MAX_SEC   = 300.0  # the first 300 seconds
FIRST_N_CH  = 16     # the first 16 EEG channels
WIN_SEC     = 4.0    # window length
STRIDE_SEC  = 1.0    # stride

TARGET_FS   = 250.0  # resample to 250 Hz to reduce compute
ZSCORE_MODE = "per_window"

raw = mne.io.read_raw_cnt(str(CNT_PATH), preload=False, verbose="ERROR")
picks = mne.pick_types(raw.info, eeg=True, eog=False, emg=False, stim=False, exclude=[])
raw.pick(picks)

raw.pick_channels(raw.ch_names[:FIRST_N_CH])

sfreq = float(raw.info["sfreq"])
print("Original sfreq:", sfreq)
print("Selected channels:", len(raw.ch_names))
print("ch_names:", raw.ch_names)

NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).
Original sfreq: 1000.0
Selected channels: 16
ch_names: ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5']


  raw = mne.io.read_raw_cnt(str(CNT_PATH), preload=False, verbose="ERROR")
  raw = mne.io.read_raw_cnt(str(CNT_PATH), preload=False, verbose="ERROR")


In [25]:
all_w = []
all_start_samples = []

for s, w_lc in iter_seed_cnt_windows(
    raw,
    t_max_sec=T_MAX_SEC,
    win_sec=WIN_SEC,
    stride_sec=STRIDE_SEC,
    first_n_ch=FIRST_N_CH,
    target_fs=TARGET_FS,
    zscore_mode=ZSCORE_MODE,
):
    all_start_samples.append(s)
    all_w.append(w_lc)

windows = np.stack(all_w, axis=0)
windows_torch = torch.tensor(windows, dtype=torch.float32, device=device)

print("windows_torch:", windows_torch.shape, "(N, L, C)")
print("single sample:", windows_torch[0].shape, "(L, C)")
print("num windows:", len(all_start_samples))

windows_torch: torch.Size([297, 1000, 16]) (N, L, C)
single sample: torch.Size([1000, 16]) (L, C)
num windows: 297


In [None]:
if str(root) not in sys.path:
    sys.path.append(str(root))

from EEGformer.models import EEGformer

x0 = windows_torch[0]

model = EEGformer(
    input=x0,
    num_cls=2,
    input_channels=FIRST_N_CH,
    kernel_size=10,
    num_blocks=1,
    num_heads_RTM=1,
    num_heads_STM=1,
    num_heads_TTM=1,
    num_submatrices=1,
    CF_second=2,
    dtype=torch.float32
).to(device)

# Inference mode (no training here)
model.eval()

with torch.no_grad():
    y = model(x0)  # output probabilities

print("y:", y)
print("predicted class:", torch.argmax(y, dim=1).item())

y: tensor([[0.6274, 0.3726]], device='cuda:0')
predicted class: 0
