# CNN based EEG BCI prediction

In [1]:
import torch 
import torch.nn as nn
import torch.optim as optim 
from torch.utils.data import DataLoader, TensorDataset

In [2]:
import numpy as np

In [5]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder

In [4]:
import mne
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery

In [6]:
def get_bci_data(subject_id=1):
    print(f"Downloading/Loading data for Subject {subject_id}...")

    dataset = BNCI2014_001()
    dataset.subject_list = [subject_id]

    # define paradigm
    paradigm = MotorImagery(n_classes=2, fmin=8, fmax=32)

    # get the data 
    X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject_id])

    # encode labels
    encoder = LabelEncoder()
    y = encoder.fit_transform(y)

    print(f"Data Loaded: {X.shape}, Classes: {encoder.classes_}")
    return X, y

In [7]:
get_bci_data()

Choosing from all possible events


Downloading/Loading data for Subject 1...
Data Loaded: (576, 22, 1001), Classes: ['feet' 'left_hand' 'right_hand' 'tongue']


(array([[[ 5.52359238e+00,  6.05479173e+00,  5.55222300e+00, ...,
          -5.72856071e-01, -2.82827100e+00, -5.00875118e+00],
         [ 1.91658213e+00,  1.74039129e+00,  8.80343196e-01, ...,
           5.06184938e-01, -3.31466548e+00, -6.65938125e+00],
         [ 3.41951163e+00,  3.84870584e+00,  3.50987378e+00, ...,
          -1.52997242e+00, -5.09625352e+00, -7.88780877e+00],
         ...,
         [-1.18565960e+00, -1.51034610e+00, -2.41203985e+00, ...,
          -2.33472371e+00, -6.69460626e+00, -9.35521896e+00],
         [-2.35736743e+00, -3.05259983e+00, -3.87318513e+00, ...,
          -2.61970359e+00, -6.21065549e+00, -8.26673539e+00],
         [-1.09144736e+00, -1.01641231e+00, -1.73219282e+00, ...,
           6.21774064e-02, -3.33972413e+00, -5.84536201e+00]],
 
        [[-5.94328262e+00, -5.49236729e+00, -3.71652245e+00, ...,
           9.71111097e+00,  8.34966036e+00,  5.93386064e+00],
         [-6.39517030e+00, -6.41812583e+00, -4.88865165e+00, ...,
           7.27015452