In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, cohen_kappa_score
import mne

In [2]:
# ==================== CONFIG ====================
DATA_DIR = 'BCICIV_2a/'          # Folder containing A01T.gdf ... A09E.gdf
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 32
EPOCHS = 50
LR = 0.001
T_MIN = 0.0      # Start of epoch relative to cue
T_MAX = 4.0      # 4-second trials (standard for this dataset)

In [3]:
class EEGDataset(Dataset):
    def __init__(self, features, labels):
        self.features = torch.tensor(features, dtype=torch.float32)   # [trials, channels, time]
        self.labels   = torch.tensor(labels,   dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]

In [4]:
class BiGRUModel(nn.Module):
    def __init__(self, num_channels=22, num_classes=4, hidden_dim=128, dropout=0.5):
        super().__init__()
        # Spatial convolution to reduce channel dimension (common in EEG DL)
        self.spatial = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=(num_channels, 1), stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Dropout2d(dropout)
        )

        self.bigru = nn.GRU(input_size=32,
                            hidden_size=hidden_dim,
                            num_layers=2,
                            batch_first=True,
                            bidirectional=True,
                            dropout=dropout)

        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        # x: [batch, channels, time]
        batch = x.size(0)
        x = x.unsqueeze(1)                     # [batch, 1, channels, time]
        x = self.spatial(x)                    # [batch, 32, 1, time]
        x = x.squeeze(2)                       # [batch, 32, time]
        x = x.permute(0, 2, 1)                 # [batch, time, 32] for GRU
        _, hn = self.bigru(x)                  # hn: [4, batch, hidden] (2 layers × bidirectional)
        hn = hn.view(2, 2, batch, -1)          # separate layers & directions
        hn = hn[-1]                            # take top layer
        hn = hn.transpose(0, 1).contiguous().view(batch, -1)  # [batch, hidden*2]
        out = self.classifier(hn)
        return out

In [5]:
def read_data(path):
    raw = mne.io.read_raw_gdf(path, eog=['EOG-left', 'EOG-central', 'EOG-right'], preload=True)
    raw.drop_channels(['EOG-left', 'EOG-central', 'EOG-right'])
    
    # Common preprocessing
    raw.filter(l_freq=8, h_freq=30)  # Bandpass mu/beta
    raw.set_eeg_reference('average')  # Or 'CAR'
    
    events, event_id = mne.events_from_annotations(raw)
    
    # Standard event ids for classes (769=left, 770=right, 771=foot, 772=tongue)
    picks = mne.pick_types(raw.info, eeg=True, eog=False)
    
    epochs = mne.Epochs(raw, events, event_id={'769':7, '770':8, '771':9, '772':10},
                        tmin=0, tmax=4,  # Full 4s trial
                        picks=picks, baseline=(0, 0),  # Or (None, 0) for no baseline
                        preload=True, reject=dict(eeg=100e-6))  # Drop bad trials
    
    labels = epochs.events[:, -1] - 6  # Make 1-4 → 0-3 for PyTorch
    features = epochs.get_data()
    
    return features, labels

In [6]:
from pathlib import Path
import numpy as np

def load_all_subjects(data_dir='BCICIV_2a', subject_ids=range(1, 10)):
    all_features = []
    all_labels = []
    subject_indices = []
    
    for subject_id in subject_ids:
        # Format filename (A01T, A02T, etc.)
        filename = f'A{subject_id:02d}T.gdf'
        filepath = Path(data_dir) / filename
        
        try:
            print(f"Loading {filename}...")
            features, labels = read_data(str(filepath))
            
            all_features.append(features)
            all_labels.append(labels)
            
            # Track which subject each trial belongs to
            subject_indices.extend([subject_id] * len(labels))
            
            print(f"  Loaded {len(labels)} trials from subject {subject_id}")
            print(f"  Shape: {features.shape}")
            
        except Exception as e:
            print(f"  Error loading {filename}: {e}")
            continue
    
    # Combine all data
    all_features = np.concatenate(all_features, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)
    subject_indices = np.array(subject_indices)
    
    print(f"\n{'='*50}")
    print(f"Total: {len(all_labels)} trials from {len(set(subject_indices))} subjects")
    print(f"Features shape: {all_features.shape}")
    print(f"Labels shape: {all_labels.shape}")
    print(f"Label distribution: {np.bincount(all_labels)}")
    print(f"{'='*50}")
    
    return all_features, all_labels, subject_indices

# Usage
features, labels, subjects = load_all_subjects()

Loading A01T.gdf...
Extracting GDF parameters from BCICIV_2a\A01T.gdf...
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG
Creating raw.info structure...
Reading 0 ... 672527  =      0.000 ...  2690.108 secs...


  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
  Error loading A04T.gdf: No matching events found for 771 (event id 9)
Loading A05T.gdf...
Extracting GDF parameters from BCICIV_2a\A05T.gdf...
Setting channel info

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

  next(self.gen)


Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 8 - 30 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 8.00
- Lower transition bandwidth: 2.00 Hz (-6 dB cutoff frequency: 7.00 Hz)
- Upper passband edge: 30.00 Hz
- Upper transition bandwidth: 7.50 Hz (-6 dB cutoff frequency: 33.75 Hz)
- Filter length: 413 samples (1.652 s)

EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Used Annotations descriptions: [np.str_('1023'), np.str_('1072'), np.str_('276'), np.str_('277'), np.str_('32766'), np.str_('768'), np.str_('769'), np.str_('770'), np.str_('771'), np.str_('772')]
Not setting metadata
288 matching events found
Applying baseline correction (mode: mean)
0 projection items activated
Using data from

In [7]:
valid_mask = labels > 0
features = features[valid_mask]
labels = labels[valid_mask]
labels = labels - 1  # Now labels are 0-3

if features.ndim == 4:
    print("Detected 4D array, squeezing...")
    features = features.squeeze()
    print(f"After squeeze: {features.shape}")

X_temp, X_test, y_temp, y_test = train_test_split(
    features, labels, 
    test_size=0.15, 
    random_state=42, 
    stratify=labels
)

# Second split: separate train and validation from temp
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, 
    test_size=0.176,  # 0.176 * 0.85 ≈ 0.15 of total
    random_state=42, 
    stratify=y_temp
)

In [None]:
# Create datasets and loaders
train_ds = EEGDataset(X_train, y_train)
val_ds   = EEGDataset(X_val,   y_val)
test_ds  = EEGDataset(X_test,  y_test)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE)

# Training loop
model = BiGRUModel().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LR)

best_val_acc = 0.0
for epoch in range(1, EPOCHS + 1):
    model.train()
    for data, target in train_loader:
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        print(target)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    val_preds, val_true = [], []
    with torch.no_grad():
        for data, target in val_loader:
            data = data.to(DEVICE)
            output = model(data)
            pred = output.argmax(dim=1)
            val_preds.extend(pred.cpu().numpy())
            val_true.extend(target.numpy())
    val_acc = accuracy_score(val_true, val_preds)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), 'EEG_Classification_Model.pth')

    if epoch % 10 == 0 or epoch == EPOCHS:
        print(f"Epoch {epoch:2d} | Val Acc: {val_acc:.4f}")

# Load best model and test
model.load_state_dict(torch.load('EEG_Classification_Model.pth'))
model.eval()
test_preds, test_true = [], []
with torch.no_grad():
    for data, target in test_loader:
        data = data.to(DEVICE)
        output = model(data)
        pred = output.argmax(dim=1)
        test_preds.extend(pred.cpu().numpy())
        test_true.extend(target.numpy())

acc = accuracy_score(test_true, test_preds)
kappa = cohen_kappa_score(test_true, test_preds)

print(f"\nFinal Test Results:")
print(f"Accuracy: {acc:.4f} | Kappa: {kappa:.4f}")

tensor([1, 0, 0, 3, 1, 0, 3, 2, 2, 2, 3, 2, 1, 3, 0, 1, 0, 3, 3, 1, 2, 0, 3, 3,
        0, 0, 0, 2, 2, 1, 2, 3])
tensor([0, 2, 1, 3, 0, 1, 0, 3, 0, 0, 2, 1, 2, 2, 3, 1, 1, 1, 0, 0, 1, 2, 2, 2,
        3, 3, 0, 1, 3, 3, 0, 3])
tensor([3, 0, 2, 2, 2, 3, 2, 0, 0, 0, 2, 1, 0, 1, 0, 2, 3, 0, 2, 2, 2, 1, 1, 2,
        1, 1, 1, 1, 2, 2, 0, 0])
tensor([3, 0, 1, 3, 0, 0, 0, 3, 2, 1, 1, 0, 1, 2, 3, 0, 0, 3, 0, 2, 0, 0, 3, 0,
        1, 2, 2, 1, 3, 2, 2, 1])
tensor([3, 0, 3, 3, 3, 2, 1, 0, 1, 3, 3, 1, 1, 0, 1, 3, 2, 1, 2, 0, 1, 2, 2, 1,
        2, 2, 2, 1, 3, 2, 0, 0])
tensor([1, 0, 3, 2, 2, 1, 3, 2, 2, 0, 0, 3, 2, 1, 1, 0, 3, 2, 0, 1, 3, 3, 3, 0,
        0, 0, 3, 2, 0, 1, 1, 1])
tensor([0, 2, 2, 2, 1, 2, 1, 3, 2, 3, 2, 0, 1, 2, 0, 0, 3, 3, 1, 0, 2, 1, 0, 0,
        2, 2, 3, 3, 2, 2, 0, 0])
tensor([3, 0, 0, 0, 0, 2, 2, 1, 0, 1, 3, 0, 1, 1, 1, 0, 2, 3, 2, 2, 3, 0, 1, 3,
        3, 2, 2, 0, 2, 2, 1, 3])
tensor([3, 1, 1, 1, 0, 2, 1, 2, 1, 0, 2, 2, 0, 3, 0, 1, 2, 2, 0, 2, 3, 1, 1, 3,
        3, 1, 0,