In [None]:
from moabb.datasets.fake import FakeDataset
from moabb.datasets.base import BaseDataset
from moabb.paradigms import SSVEP  
from moabb.pipelines import SSVEP_CCA, SSVEP_MsetCCA
import os
import pandas as pd
import mne
import numpy as np

In [122]:
n_subjects = 30
n_sessions = 8
start_time = 1
end_time = 5
split = "train"

label_to_freq = {"Left": "10", "Right": "13", "Forward": "7", "Backward": "8"}
event_mapping = {"10": 1, "13": 2, "7": 3, "8": 4}
event_mapping_decoder = np.vectorize({0: "10", 1: "13", 2: "7", 3: "8"}.get)


class CompetitionDataset(BaseDataset):
    def __init__(self):
        super().__init__(
            subjects=list(range(1, n_subjects + 1)),
            sessions_per_subject=n_sessions,
            events=event_mapping,
            code="Competition",
            interval=[start_time, end_time],
            paradigm="ssvep",
        )

        self.base_path = "./data/mtcaic3/SSVEP"
        self.metadata_path = os.path.join('./data/mtcaic3', f"{split}.csv")

    def data_path(self, subject, path=None, force_update=False, update_path=None, verbose=None):  # type: ignore
        """Return list of CSV file paths for this subject."""
        subject_dir = os.path.join(self.base_path, split, f"S{int(subject)}")
        csv_files = []

        # Collect all 8 session files for this subject
        for session in range(1, n_sessions + 1):
            csv_file = os.path.join(subject_dir, str(session), "EEGdata.csv")
            if os.path.exists(csv_file):
                csv_files.append(csv_file)
            else:
                print(f"Warning: {csv_file} does not exist for subject {subject}, session {session}")

        return csv_files

    def _get_single_subject_data(self, subject):  # type: ignore
        """Load and process data for a single subject."""
        csv_files = self.data_path(subject)
        sessions = {}

        for session_idx, csv_file in enumerate(csv_files):
            # Load CSV data
            # todo neglect invalid cols
            eeg_columns = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
            df = pd.read_csv(csv_file, usecols=eeg_columns)

            eeg_data = df[eeg_columns].values.T  # Shape: (n_channels, n_timepoints)

            # Create channel info
            ch_names = eeg_columns + ["stim"]
            ch_types = ["eeg"] * len(eeg_columns) + ["stim"]
            sfreq = 250  # Your sampling frequency

            info = mne.create_info(ch_names=ch_names, ch_types=ch_types, sfreq=sfreq)  # type: ignore

            # Create stimulus channel from event labels
            # Assuming you have 10 trials of 1750 samples each
            stim_data = np.zeros(len(df))

            for trial in range(10):
                trial_start = trial * 1750
                if trial_start < len(df):
                    trial_label = self._get_trial_label(subject, session_idx + 1, trial + 1)
                    stim_data[trial_start] = event_mapping.get(trial_label)

            # Combine EEG and stimulus data
            full_data = np.vstack([eeg_data, stim_data[np.newaxis, :]])

            # Create Raw object
            raw = mne.io.RawArray(data=full_data, info=info, verbose=False)

            # Store in sessions dictionary
            session_name = str(session_idx)
            if session_name not in sessions:
                sessions[session_name] = {}
            sessions[session_name]["0"] = raw  # Single run per session

        return sessions

    def _get_trial_label(self, subject_id, session_id, trial_idx):
        """Extract the event label for a specific trial."""
        metadata_df = self._load_metadata()

        trial_number = trial_idx
        subject_str = f"S{subject_id}"

        # Filter the metadata for this specific trial
        trial_row = metadata_df[
            (metadata_df["subject_id"] == subject_str) & (metadata_df["trial_session"] == session_id) & (metadata_df["trial"] == trial_number) & (metadata_df["task"] == "SSVEP")  # Extra safety filter
        ]
        
        label = trial_row.iloc[0]["label"]
        return label_to_freq.get(label)

    def _load_metadata(self):
        """Load the metadata CSV file once and cache it."""
        if not hasattr(self, "_metadata_df"):
            self._metadata_df = pd.read_csv(self.metadata_path)
            self._metadata_df = self._metadata_df[self._metadata_df["task"] == "SSVEP"]

        return self._metadata_df


dataset = CompetitionDataset()
dataset._get_single_subject_data(1)

The dataset class name 'CompetitionDataset' must be an abbreviation of its code 'Competition'. See moabb.datasets.base.is_abbrev for more information.


{'0': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '1': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '2': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '3': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '4': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '5': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '6': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>},
 '7': {'0': <RawArray | 9 x 17500 (70.0 s), ~1.2 MiB, data loaded>}}

In [None]:
paradigm = SSVEP(n_classes=3, tmax=4)

# # Get the data
X, y, metadata = paradigm.get_data(dataset, subjects=[1, 2])
freqs = paradigm.used_events(dataset)

interval = [paradigm.tmin, paradigm.tmax]
cca_clf = SSVEP_CCA(interval=interval, freqs=freqs, n_harmonics=3)
cca_clf.fit(X, y)

y_pred = cca_clf.predict(X)

Choosing the first 3 classes from all possible events
 '10': 5
 '13': 0
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 4
 '13': 0
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 5
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 1
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 3
 '13': 3
 '7': 0>
  warn(f"warnEpochs {epochs}")
 '10': 1
 '13': 5
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 4
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 3
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 3
 '13': 3
 '7': 0>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 1
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 5
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 1
 '13': 3
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 4
 '13': 2
 '7': 1>
  warn(f"warnEpochs {epochs}")
 '10': 2
 '13': 2
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 3
 '13': 4
 '7': 2>
  warn(f"warnEpochs {epochs}")
 '10': 4
 '13': 2
 '7': 3>
  warn(f"warnEpochs {epochs}")
 '10': 5
 '13': 0


In [124]:
# Example with SSVEP_MsetCCA
msetcca_clf = SSVEP_MsetCCA(freqs=freqs, n_filters=2)
msetcca_clf.fit(X, y)
y_pred_mset = msetcca_clf.predict(X)

print(f"CCA predictions: {y_pred}")
print(f"MsetCCA predictions: {y_pred_mset}")

KeyboardInterrupt: 

In [None]:
decoded_y_pred = event_mapping_decoder(y_pred)
correct = (y == decoded_y_pred).sum()

print(correct / len(y))
# print((y == decoded_y_pred_mset).sum())

0.5839598997493735

