# BCI_EEG_Neural_Network

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [4]:
import numpy as np

In [7]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder

In [6]:
import mne
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery

In [8]:
def get_bci_data(subject_id=1):
    print("Downloading data")

    dataset = BNCI2014_001()
    dataset.subject_list = [subject_id]

    paradigm = MotorImagery(n_classes=4, fmin=8, fmax=32)

    X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject_id])

    le = LabelEncoder()
    y = le.fit_transform(y)

    print(f"Shape X {X.shape}, y {y.shape}, classes {le.classes_}")
    return X, y

In [44]:
# define model 

class BCIMLP(nn.Module):
    def __init__(self, input_features, hidden_sizes, output_size, dropout=0.3):
        super(BCIMLP, self).__init__()

        layers = []
        prev_size = input_features

        for hidden_size in hidden_sizes:
            layers.append(nn.Linear(prev_size, hidden_size))
            layers.append(nn.ReLU())
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.Dropout(dropout))
            prev_size = hidden_size

        layers.append(nn.Linear(prev_size, output_size))

        self.network = nn.Sequential(*layers)

    def forward(self, x):
        batch_size = x.size(0)
        x = x.view(batch_size, -1)
        return self.network(x)


In [45]:
# --- Configuration ---
BATCH_SIZE = 32
HIDDEN_SIZES = [512, 256]  # Two hidden layers (was HIDDEN_SIZE in RNN)
LEARNING_RATE = 0.001
EPOCHS = 50
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

In [46]:
DEVICE

device(type='mps')

In [47]:
X_raw, y_raw = get_bci_data()

Choosing from all possible events


Downloading data
Shape X (576, 22, 1001), y (576,), classes ['feet' 'left_hand' 'right_hand' 'tongue']


In [48]:
# Normalize features per channel
N, C, T = X_raw.shape
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_raw.reshape(-1, C)).reshape(N, C, T)

In [49]:
# convert to tensor
X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(DEVICE)
y_tensor = torch.tensor(y_raw, dtype=torch.long).to(DEVICE)

In [50]:
# split data 
X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

In [58]:
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE)

In [52]:
input_features = C * T
input_features

22022

In [53]:
model = BCIMLP(
    input_features=input_features,
    hidden_sizes=HIDDEN_SIZES,
    output_size=4,
    dropout=0.3
).to(DEVICE)

In [54]:
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [55]:
print("\nStarting Training on", DEVICE)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")


Starting Training on mps
Total parameters: 11,409,668


In [56]:
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
    if (epoch + 1) % 10 == 0:
        print(f"Epoch [{epoch+1}/{EPOCHS}], Loss: {running_loss/len(train_loader):.4f}")

Epoch [10/50], Loss: 0.0385
Epoch [20/50], Loss: 0.0157
Epoch [30/50], Loss: 0.1713
Epoch [40/50], Loss: 0.0076
Epoch [50/50], Loss: 0.0082


In [60]:
# --- Evaluation ---
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"\nFinal Test Accuracy: {100 * correct / total:.2f}%")


Final Test Accuracy: 30.17%
