In [11]:
from __future__ import annotations

import re
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Any

import numpy as np
from scipy.io import loadmat

In [31]:
@dataclass
class SubjectData:
    subject_id: int
    sampling_rate: int
    n_trials: int
    n_samples: int
    X_left_raw: np.ndarray   # (n_trials, N_EEG_CHANNELS, n_samples)
    X_right_raw: np.ndarray  # (n_trials, N_EEG_CHANNELS, n_samples)


class Dataset:
    _FNAME_RE = re.compile(r"^s(\d+)\.mat$")

    # Paper window for classification: 0.5–2.5 seconds after onset
    T_START = 0.5
    T_END = 2.5

    # EEG channels 1–64 (65–68 are EMG)
    N_EEG_CHANNELS = 64

    def __init__(self, dataset_path):
        self.dataset_path = Path(dataset_path)
        self.subject_data = self._load_subject_data()

    def _load_subject_data(self):
        if not self.dataset_path.is_dir():
            raise ValueError(f"Could not find directory `{self.dataset_path}`")

        matches = []
        for p in self.dataset_path.iterdir():
            if p.is_file():
                m = self._FNAME_RE.match(p.name)
                if m:
                    matches.append((int(m.group(1)), p))
        if not matches:
            raise ValueError(f"No subject files like `s#.mat` found in `{self.dataset_path}`")

        subject_data = {}
        for subject_id, p in matches:
            mat = loadmat(p, simplify_cells=True)
            if "eeg" not in mat:
                raise ValueError(f"`{p.name}` missing key 'eeg'")
            eeg = mat["eeg"]
            srate = int(eeg["srate"])
            n_trials = int(eeg["n_imagery_trials"])
            onsets = np.where(eeg["imagery_event"] == 1)[0]
            if len(onsets) < n_trials:
                raise ValueError(f"`{p.name}`: found {len(onsets)} imagery onsets, expected at least {n_trials}")

            start = int(round(self.T_START * srate))
            end = int(round(self.T_END * srate))
            win = end - start
            if win <= 0:
                raise ValueError("T_END must be > T_START")

            left_stream = np.asarray(eeg["imagery_left"])[: self.N_EEG_CHANNELS, :]
            right_stream = np.asarray(eeg["imagery_right"])[: self.N_EEG_CHANNELS, :]
            X_left = np.empty(
                (n_trials, self.N_EEG_CHANNELS, win),
                dtype=left_stream.dtype
            )
            X_right = np.empty(
                (n_trials, self.N_EEG_CHANNELS, win),
                dtype=right_stream.dtype
            )
            stream_len = left_stream.shape[1]
            for i, onset in enumerate(onsets[:n_trials]):
                a = onset + start
                b = onset + end
                if a < 0 or b > stream_len:
                    raise ValueError(
                        f"`{p.name}`: trial {i} window [{a}:{b}] out of bounds "
                        f"for stream length {stream_len}"
                    )
                X_left[i] = left_stream[:, a:b]
                X_right[i] = right_stream[:, a:b]

            subject_data[subject_id] = SubjectData(
                subject_id=subject_id,
                sampling_rate=srate,
                n_trials=n_trials,
                n_samples=win,
                X_left_raw=X_left,
                X_right_raw=X_right,
            )

        return subject_data

    def print_info(self):
        total_bytes = sum(
            (s.X_left_raw.nbytes + s.X_right_raw.nbytes)
            for s in self.subject_data.values()
        )
        print(f"Total subjects: {len(self.subject_data)}")
        print(f"Epoch window: {self.T_START:.3f}–{self.T_END:.3f} s")
        print(f"Channels stored: {self.N_EEG_CHANNELS}")
        print(f"Total subject data (trial tensors only): {total_bytes / 1024**2:.2f} MB")


In [32]:
DATASET_PATH = os.path.expanduser('~/data/BCI/1_EEG_LRhandMI')
DATASET_PATH

'/Users/anna/data/BCI/1_EEG_LRhandMI'

In [33]:
ds = Dataset(DATASET_PATH)
ds.print_info()

Total subjects: 18
Epoch window: 0.500–2.500 s
Channels stored: 64
Total subject data (trial tensors only): 910.00 MB


Total EEG data: 4219.19 MB


In [3]:
s52_path = os.path.join(DATASET_PATH, 's52.mat')
mat = loadmat(s52_path, simplify_cells=True)
print([k for k in mat.keys() if not k.startswith("__")])

['eeg']


In [5]:
eeg = mat["eeg"]

print("eeg keys:", list(eeg.keys()))

for k, v in eeg.items():
    t = type(v).__name__
    shp = getattr(v, "shape", None)
    dt  = getattr(v, "dtype", None)
    # keep it readable
    if isinstance(v, (str, int, float, bool)):
        extra = f"value={v}"
    elif isinstance(v, (list, tuple)):
        extra = f"len={len(v)}"
    elif shp is not None:
        extra = f"shape={shp} dtype={dt}"
    else:
        extra = ""
    print(f"{k:20s} {t:15s} {extra}")


eeg keys: ['noise', 'rest', 'srate', 'movement_left', 'movement_right', 'movement_event', 'n_movement_trials', 'imagery_left', 'imagery_right', 'n_imagery_trials', 'frame', 'imagery_event', 'comment', 'subject', 'bad_trial_indices', 'psenloc', 'senloc']
noise                ndarray         shape=(5,) dtype=object
rest                 ndarray         shape=(68, 34048) dtype=float32
srate                int             value=512
movement_left        ndarray         shape=(68, 71680) dtype=float32
movement_right       ndarray         shape=(68, 71680) dtype=float32
movement_event       ndarray         shape=(71680,) dtype=uint8
n_movement_trials    int             value=20
imagery_left         ndarray         shape=(68, 358400) dtype=float32
imagery_right        ndarray         shape=(68, 358400) dtype=float32
n_imagery_trials     int             value=100
frame                ndarray         shape=(2,) dtype=int16
imagery_event        ndarray         shape=(358400,) dtype=uint8
comment  

In [8]:
events = eeg["imagery_event"]
onsets = np.where(events == 1)[0]
len(onsets) 

100

In [None]:
fs = eeg["srate"]

t_end

1536

In [10]:
X_left_raw = []
for idx in onsets[:eeg["n_imagery_trials"]]:
    trial = eeg["imagery_left"][:, idx + t_start : idx + t_end]
    X_left_raw.append(trial)

X_left_raw = np.stack(X_left_raw)
X_left_raw.shape

(100, 68, 1536)