# CNN based EEG BCI prediction

In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader, TensorDataset

In [2]:
import numpy as np

In [3]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder

In [4]:
import mne
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery

In [5]:
def get_bci_data(subject_id=1):
    print(f"Downloading/Loading data for Subject {subject_id}...")

    dataset = BNCI2014_001()
    dataset.subject_list = [subject_id]

    # define paradigm
    paradigm = MotorImagery(n_classes=2, fmin=8, fmax=32)

    # get the data 
    X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject_id])

    # encode labels
    encoder = LabelEncoder()
    y = encoder.fit_transform(y)

    print(f"Data Loaded: {X.shape}, Classes: {encoder.classes_}")
    return X, y

In [7]:
# pytorch CNN model

class BCICNN(nn.Module):
    def __init__(self, input_channels, base_filters=64, num_classes=4, dropout=0.3):
        super(BCICNN, self).__init__()

        # block 1
        self.conv1 = nn.Conv1d(in_channels=input_channels, out_channels=base_filters, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm1d(base_filters)
        self.pool1 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(dropout)

        # block 2
        self.conv2 = nn.Conv1d(in_channels=base_filters, out_channels=base_filters*2, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm1d(base_filters*2)
        self.pool2 = nn.MaxPool1d(2)
        self.dropout = nn.Dropout(dropout)

        # block 3
        self.conv3 = nn.Conv1d(in_channels=base_filters*2, out_channels=base_filters*4, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(base_filters*4)
        self.adaptive_pool = nn.AdaptiveAvgPool1d(1)
        self.dropout = nn.Dropout(0.5)

        # classification head
        self.fc = nn.Linear(in_features=base_filters*4, out_features=num_classes)

    def forward(self, x):
        # x shape: (batch, channels, time)

        # block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        x = self.pool1(x)
        x = self.dropout(x)

        # block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)
        x = self.pool2(x)
        x = self.dropout(x)

        # block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = torch.relu(x)
        x = self.adaptive_pool(x)
        x = self.dropout(x)

        x = x.squeeze(-1)
        x = self.fc(x)
        return x

In [8]:
BATCH_SIZE = 32
NUM_FILTERS = 64
LEARNING_RATE = 0.001
EPOCHS = 50
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [9]:
X_raw, y_raw = get_bci_data(subject_id=1)

X_raw.shape, y_raw.shape

Choosing from all possible events


Downloading/Loading data for Subject 1...
Data Loaded: (576, 22, 1001), Classes: ['feet' 'left_hand' 'right_hand' 'tongue']


((576, 22, 1001), (576,))

In [10]:
N, C, T = X_raw.shape
scaler = StandardScaler()
X_raw = scaler.fit_transform(X_raw.reshape(-1, C)).reshape(N, C, T)

In [12]:
# convert to tensors
X_tensor = torch.tensor(X_raw, dtype=torch.float32).to(DEVICE)
y_tensor = torch.tensor(y_raw, dtype=torch.long).to(DEVICE)

In [15]:
# split data
X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

In [16]:
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE)

In [19]:
model = BCICNN(
    input_channels=C,
    base_filters=NUM_FILTERS,
    num_classes=4,
    dropout=0.3
).to(DEVICE)

In [20]:
model

BCICNN(
  (conv1): Conv1d(22, 64, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout): Dropout(p=0.5, inplace=False)
  (conv2): Conv1d(64, 128, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn2): BatchNorm1d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool2): MaxPool1d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv1d(128, 256, kernel_size=(3,), stride=(1,), padding=(1,))
  (bn3): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (adaptive_pool): AdaptiveAvgPool1d(output_size=1)
  (fc): Linear(in_features=256, out_features=4, bias=True)
)

In [21]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
# training loop

for epoch in range(EPOCHS):
    model.train()
    running_loss = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    if (epoch+1) % 10 == 0:
        print(f"Epoch: {epoch+1} | Loss: {running_loss/len(train_loader):.4f}")

Epoch: 10 | Loss: 0.3689


In [None]:
# Evaluate
