In [None]:
from moabb.datasets import PhysionetMI, Cho2017, BNCI2014_001, Weibo2014
from moabb.paradigms import LeftRightImagery
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
import pandas as pd
from mne import create_info
from mne.io import RawArray
from moabb.datasets.base import BaseDataset
import mne
import os


# %%
data_path = "/home/zeyadcode/Workspace/ai_projects/eeg_detection/data/mtcaic3"


class CompetitionDataset(BaseDataset):
    def __init__(self, split="train"):
        super().__init__(
            subjects=list(range(1, 31)),  # List of subject IDs
            sessions_per_subject=1,  # Number of sessions per subject
            events={"left_hand": 1, "right_hand": 2},
            code="CompetitionDataset",
            interval=[0, 4],  # Time interval for trials
            paradigm="imagery",  # "ssvep" or "imagery" or "p300"
            doi=None,
        )
        self.split = split

    def _get_single_subject_data(self, subject):
        """
            Return data for one subject - THIS IS THE KEY METHOD
            tips for Motor Imagery:
                include C3, CZ, C4
                may include FZ, PZ
                don't include PO7, PO8, Oz

        """
        ch_names = ["FZ", "C3", "CZ", "C4", "PZ", "PO7", "OZ", "PO8"]
        moabb_channels = ["Fz", "C3", "Cz", "C4", "Pz", "PO7", "Oz", "PO8"]
        if self.paradigm == "ssvep":
            task = "SSVEP"
        elif self.paradigm == "imagery":
            task = "MI"
        else:
            raise ValueError(f"got unexpected paradigm {self.paradigm}")

        # they competition forced us to do this...
        subject_row = subject + 30 if self.split == "validation" else subject
        # Load labels for this subject
        labels_df = pd.read_csv(os.path.join(data_path, f"{self.split}.csv"), usecols=["subject_id", "trial_session", "trial", "task", "label"])
        task_df = labels_df.query(f"task=='{task}' and subject_id=='S{subject_row}'")

        if task_df.empty:
            print(f"\n\n\nWARNING TASK DF EMPTY {subject} AT ROW {subject_row} AT SPLIT {self.split}\n\n\n")
            return {"0": {"0": None}}  # No data for this subject

        sfreq = 250
        ch_types = ["eeg"] * len(ch_names)

        # Process each session
        sessions = {}
        for session_id in task_df["trial_session"].unique():
            session_trials = task_df[task_df["trial_session"] == session_id]

            # Load EEG data for this session
            fp = os.path.join(data_path, task, self.split, f"S{subject_row}", str(session_id), "EEGdata.csv")
            if not os.path.exists(fp):
                continue

            # Load the full session data
            eeg_data = pd.read_csv(fp, usecols=ch_names).values
            total_samples = eeg_data.shape[0]
            trial_length = total_samples // 10  # 10 trials per session

            # Create continuous data and events
            all_trial_data = []
            events_list = []
            current_sample = 0

            for _, trial_row in session_trials.iterrows():
                trial_num = int(trial_row["trial"])

                # Extract trial data (trial numbers are 1-indexed)
                start_idx = (trial_num - 1) * trial_length
                end_idx = trial_num * trial_length
                trial_data = eeg_data[start_idx:end_idx]

                all_trial_data.append(trial_data)

                # Create event at trial start
                label = "left_hand" if trial_row.label == "Left" else "right_hand"
                events_list.append([current_sample, 0, self.event_id[label]])
                current_sample += trial_data.shape[0]

            if not all_trial_data:
                continue

            # Concatenate all trials for this session
            continuous_data = np.vstack(all_trial_data).T  # Shape: (channels, samples)

            # Create MNE info object
            info = create_info(moabb_channels, sfreq, ch_types)

            # Create Raw object (convert to microvolts)
            raw = RawArray(continuous_data * 1e-6, info, verbose=False)

            # Add events as annotations
            if events_list:
                events_array = np.array(events_list)
                event_desc = {v: k for k, v in self.event_id.items()}
                annotations = mne.annotations_from_events(events_array, sfreq=sfreq, event_desc=event_desc)
                raw.set_annotations(annotations)

            sessions[str(session_id)] = {"0": raw}

        if sessions is None:
            print(f"\n\n\nWARNING TASK DF EMPTY {subject} AT ROW {subject_row} AT SPLIT {self.split}\n\n\n")
            return {"0": {"0": None}}
        else:
            return sessions
        # Return in required format: {"session_id": {"run_id": raw}}

    def data_path(self, subject, path=None, force_update=False, update_path=None, verbose=None):
        """Return file paths for the subject's data"""
        subject_paths = []

        # Get all session directories for this subject
        task = "mi" if self.paradigm == "imagery" else "ssvep"
        subject_dir = os.path.join(data_path, task, self.split, f"S{subject}")
        if os.path.exists(subject_dir):
            for session in os.listdir(subject_dir):
                session_path = os.path.join(subject_dir, session)
                eeg_file = os.path.join(session_path, "EEGdata.csv")
                if os.path.exists(eeg_file):
                    subject_paths.append(eeg_file)

        return subject_paths


def load_combined_moabb_data(datasets, paradigm_config=None, subjects_per_dataset=None):
    """
    Load and combine multiple MOABB datasets for DANN training.

    Args:
        datasets: List of MOABB dataset instances
        paradigm_config: Dict with paradigm parameters (channels, tmin, tmax, resample)
        subjects_per_dataset: Dict mapping dataset names to subject lists, or None for all

    Returns:
        X: Combined feature array
        class_labels: Binary class labels (0/1)
        domain_labels: Dataset-specific subject IDs (continuous across datasets)
        dataset_info: Metadata about each dataset
    """
    if paradigm_config is None:
        paradigm_config = {
            "channels": ["Cz", "C3", "C4"],
            "tmin": 0.0,
            "tmax": 4.0,
            "resample": 250,
        }

    paradigm = LeftRightImagery(**paradigm_config)

    all_X = []
    all_class_labels = []
    all_domain_labels = []
    dataset_info = {}

    current_subject_offset = 0

    for dataset_idx, dataset in enumerate(datasets):
        dataset_name = dataset.__class__.__name__
        print(f"\nProcessing dataset: {dataset_name}")

        # Get subjects for this dataset
        if subjects_per_dataset and dataset_name in subjects_per_dataset:
            subjects = subjects_per_dataset[dataset_name]
        else:
            subjects = dataset.subject_list

        print(f"Original subject range: {min(subjects)} to {max(subjects)}")

        # Load data for this dataset
        X, labels, metadata = paradigm.get_data(dataset, subjects=subjects)

        # Convert string labels to binary
        class_labels = []
        for label in labels:
            if label == "left_hand":
                class_labels.append(0)
            elif label == "right_hand":
                class_labels.append(1)
            else:
                raise ValueError(f"Unexpected label {label}")

        # Create domain labels with offset to avoid conflicts
        domain_labels = []
        for i in range(len(labels)):
            original_subject = metadata.iloc[i]["subject"]
            adjusted_subject = original_subject + current_subject_offset
            domain_labels.append(adjusted_subject)

        # Update offset for next dataset
        max_subject_in_dataset = max(metadata["subject"])
        next_offset = current_subject_offset + max_subject_in_dataset

        # Store dataset info
        dataset_info[dataset_name] = {
            "original_subject_range": (min(subjects), max(subjects)),
            "adjusted_subject_range": (current_subject_offset + min(subjects), current_subject_offset + max_subject_in_dataset),
            "n_trials": len(X),
            "n_subjects": len(set(metadata["subject"])),
            "subject_offset": current_subject_offset,
        }

        print(f"Adjusted subject range: {dataset_info[dataset_name]['adjusted_subject_range']}")
        print(f"Number of trials: {len(X)}")
        print(f"Number of subjects: {len(set(metadata['subject']))}")

        # Accumulate data
        all_X.append(X)
        all_class_labels.extend(class_labels)
        all_domain_labels.extend(domain_labels)

        current_subject_offset = next_offset

    # drop to match
    tmin = paradigm_config['tmin']
    tmax =paradigm_config['tmax']
    sfreq = paradigm_config['resample']
    max_possible_value = int((tmax - tmin) * sfreq)

    for i, x in enumerate(all_X):
        all_X[i] = x[:, :, :max_possible_value]

    # Combine all data
    combined_X = np.concatenate(all_X, axis=0)
    combined_class_labels = np.array(all_class_labels)
    combined_domain_labels = np.array(all_domain_labels)

    print(f"\n=== COMBINED DATASET SUMMARY ===")
    print(f"Total trials: {len(combined_X)}")
    print(f"Feature shape: {combined_X.shape}")
    print(f"Class distribution: {np.bincount(combined_class_labels)}")
    print(f"Subject range: {min(combined_domain_labels)} to {max(combined_domain_labels)}")
    print(f"Total unique subjects: {len(np.unique(combined_domain_labels))}")

    return combined_X, combined_class_labels, combined_domain_labels, dataset_info


In [None]:
%load_ext autoreload  
%autoreload 2  
  
import random  
import numpy as np  
import torch  
import torch.nn as nn  
from torch.autograd import Function  
import optuna  
from modules.utils import evaluate_model  
import matplotlib.pyplot as plt  
from braindecode.models import EEGInceptionMI, EEGSimpleConv, MSVTNet, FBCNet, ATCNet
from braindecode import EEGClassifier  
import torch.optim as optim  
  
# dataset related  
from modules import CompetitionDataset, load_combined_moabb_data  
from torch.utils.data import DataLoader, TensorDataset  
from moabb.datasets import BNCI2014_001, PhysionetMI, Weibo2014, Cho2017  # 250 hz  
from braindecode.datasets import MOABBDataset, create_from_X_y  
from braindecode.preprocessing import preprocess, Preprocessor, exponential_moving_standardize, filterbank  
from skorch.helper import predefined_split  
from braindecode.datasets.base import BaseConcatDataset  
import random
import mne
import numpy as np
  
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
device  

In [None]:
batch_size = 64
# Add this at the beginning of your notebook, after imports
def set_random_seeds(seed=42):
    """Set random seeds for reproducibility"""

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Call this function before creating datasets and models
set_random_seeds(42)

In [None]:
moabb_train_datasets = [
    PhysionetMI(imagined=True),  # 109 subjects
    Weibo2014(),  # 10 subjects, 64 channels
    CompetitionDataset(),
]
train_val = [CompetitionDataset(split="validation")]

# Load combined data # BEST tmin=1, tmax=4, resample=250
eeg_channels = ["Fz", "C3", "Cz", "C4", "Pz"]
X_train, y_train, domain_labels_train, info_train = load_combined_moabb_data(
    datasets=moabb_train_datasets,
    paradigm_config={
        "channels": eeg_channels,
        "tmin": 1.0,
        "tmax": 4.0,
        "resample": 250,
    }
    # subjects_per_dataset={
    #     "PhysionetMI": list(range(1, 50)),
    #     "Weibo2014": list(range(1, 11)),
    #     "CompetitionDataset": list(range(1, 31)),
    # },
)

# Load combined data
X_val, y_val, domain_labels_val, info_val = load_combined_moabb_data(
    datasets=train_val,
    paradigm_config={
        "channels": eeg_channels,
        "tmin": 1.0,
        "tmax": 4.0,
        "resample": 250,
    },
)

In [15]:
from Models import FilterBankRTSClassifier
from sklearn.metrics import classification_report, accuracy_score
n_bands = 4
filter_order = 3
fs = 100
n_estimators = 300
max_depth = None
min_samples_split = 7
min_samples_leaf = 3
max_features = 'sqrt'
# Create frequency bands

# Create FilterBank classifier with best parameters
clf = FilterBankRTSClassifier(
    fs=fs,
    order=filter_order,
    n_estimators=n_estimators,
    max_depth=max_depth,
    min_samples_split=min_samples_split,
    min_samples_leaf=min_samples_leaf# Trial 67 finished with value: 0.6263345734944465 and parameters: {'window_length': 250, 'stride': 250, 'tmin': 0, 'ch_FZ': 1, 'ch_C3': 0, 'ch_CZ': 1, 'ch_C4': 0, 'ch_PZ': 1, 'ch_PO7': 1, 'ch_OZ': 1, 'ch_PO8': 0, 'n_bands': 4, 'min_freq': 11, 'max_freq': 40, 'filter_order': 3, 'fs': 125, 'n_estimators': 200, 'max_depth': None, 'min_samples_split': 6, 'min_samples_leaf': 2, 'max_features': 'sqrt'}. Best is trial 67 with value: 0.6263345734944465.
)

# Fit on training data
clf.fit(X_train, y_train)

# Calculate accuracy
y_pred = clf.predict(X_val)
val_acc = accuracy_score(y_val, y_pred)

print(f"Validation accuracy: {val_acc:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(y_val, y_pred)) 

Validation accuracy: 0.4800

Classification Report:
              precision    recall  f1-score   support

           0       0.55      0.39      0.46        28
           1       0.43      0.59      0.50        22

    accuracy                           0.48        50
   macro avg       0.49      0.49      0.48        50
weighted avg       0.50      0.48      0.48        50



In [None]:
X_train = torch.from_numpy(X_train).float()  # FloatTensor of shape (N, C, T)
y_train = torch.from_numpy(y_train).long()  # LongTensor of shape (N, 2)

train_dataset = TensorDataset(X_train, y_train)

X_val_t = torch.from_numpy(X_val).float()
y_val_t = torch.from_numpy(y_val).long()

val_dataset = TensorDataset(X_val_t, y_val_t)


In [None]:
model = EEGSimpleConv(
    n_chans=5,
    n_outputs=2,  # Left/right hand motor imagery
    sfreq=250,  # Optimal resampling frequency for this model is 80hz
    feature_maps=96,  # Within recommended range [64-144]
    n_convs=2,  # For cross-subject: [2-4]
    kernel_size=8,  # For cross-subject: [5-8]
    resampling_freq=80,
)

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    optimizer__lr=0.0001,
    batch_size=64,
    max_epochs=100,
    train_split=predefined_split(val_dataset),
    device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=2,
    callbacks=["accuracy"],
)

clf.fit(X_train, y_train)

In [None]:
model = EEGInceptionMI(
    n_chans=5,  # Number of channels
    n_outputs=2,  # Number of classes (adjust based on your task)
    input_window_seconds=3.0,  # best is 3-second windows 1->4
    sfreq=250,
    n_filters=12,  # Optimized parameter
    n_convs=5,
    kernel_unit_s=0.1,
)

# Setup classifier with skorch wrapper
clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    optimizer__lr=0.001,
    batch_size=64,
    max_epochs=100,
    train_split=predefined_split(val_dataset), 
    device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=2,
    callbacks=['accuracy'],
)

clf.fit(X_train, y_train)

In [None]:
print(X_train.shape, y_train.shape)

model = ATCNet(
    n_chans=5,
    n_outputs=2,  # Left/right hand motor imagery
    input_window_seconds=3.0,
    sfreq=250,
    n_windows=5,
    att_head_dim=8,
    att_num_heads=2,
    tcn_depth=2,
    tcn_kernel_size=4,
    tcn_n_filters=32,
)

clf = EEGClassifier(
    model,
    criterion=torch.nn.CrossEntropyLoss,
    optimizer=torch.optim.Adam,
    optimizer__lr=0.0005,  # Lower learning rate for attention models
    batch_size=32,  # Smaller batch size for memory efficiency
    max_epochs=100,
    train_split=predefined_split(val_dataset),
    device="cuda" if torch.cuda.is_available() else "cpu",
    verbose=2,
    callbacks=["accuracy"],
)

clf.fit(X_train, y_train)