# RNN (LSTM) based EEG Data Prediction

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

In [2]:
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import LabelEncoder

In [3]:
import mne
from moabb.datasets import BNCI2014_001
from moabb.paradigms import MotorImagery

In [4]:
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print("Using device: ", device)

Using device:  mps


In [5]:
# Load the data and preprocess BCI data

def get_bci_data(subject_id=1):
    print(f"Downloading/Loading data for Subject {subject_id}...")

    # we use the famous BNCI 2014 dataset 
    dataset = BNCI2014_001()
    dataset.subject_list = [subject_id]

    # define paradigm (Left vs Right Hand)
    # filter frequencies between 8-32 Hz
    paradigm = MotorImagery(n_classes=2, fmin=8, fmax=32)

    # Get the data X and label y
    # X shape: (n_trials, n_channels, n_timepoints)
    X, y, metadata = paradigm.get_data(dataset=dataset, subjects=[subject_id])

    # Encode labels (Left_hand -> 0, Right_hand -> 1)
    le = LabelEncoder()
    y = le.fit_transform(y)

    print(f"Data loaded: {X.shape}, Classes: {le.classes_}")
    return X, y


In [6]:
get_bci_data()

Choosing from all possible events


Downloading/Loading data for Subject 1...
Data loaded: (576, 22, 1001), Classes: ['feet' 'left_hand' 'right_hand' 'tongue']


(array([[[ 5.52359238e+00,  6.05479173e+00,  5.55222300e+00, ...,
          -5.72856071e-01, -2.82827100e+00, -5.00875118e+00],
         [ 1.91658213e+00,  1.74039129e+00,  8.80343196e-01, ...,
           5.06184938e-01, -3.31466548e+00, -6.65938125e+00],
         [ 3.41951163e+00,  3.84870584e+00,  3.50987378e+00, ...,
          -1.52997242e+00, -5.09625352e+00, -7.88780877e+00],
         ...,
         [-1.18565960e+00, -1.51034610e+00, -2.41203985e+00, ...,
          -2.33472371e+00, -6.69460626e+00, -9.35521896e+00],
         [-2.35736743e+00, -3.05259983e+00, -3.87318513e+00, ...,
          -2.61970359e+00, -6.21065549e+00, -8.26673539e+00],
         [-1.09144736e+00, -1.01641231e+00, -1.73219282e+00, ...,
           6.21774064e-02, -3.33972413e+00, -5.84536201e+00]],
 
        [[-5.94328262e+00, -5.49236729e+00, -3.71652245e+00, ...,
           9.71111097e+00,  8.34966036e+00,  5.93386064e+00],
         [-6.39517030e+00, -6.41812583e+00, -4.88865165e+00, ...,
           7.27015452

In [7]:
# Pytorch RNN model
class BCIRNN(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size):
        super(BCIRNN, self).__init__()

        # we use LSTM instead of vanilla RNN because EEG sequences are long
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            output_size=output_size,
            batch_first=True,
            dropout=0.3
        )

        self.fc = nn.Linear(in_features=hidden_size, out_features=output_size)

    def forward(self, x):
        #. x shape: (batch, time_stamps, channels)

        # pass through LSTM
        # out shape: (batch, time_stamps, hidden_size)
        out, (hn, cn) = self.lstm(x)

        # we take the output of the LAST time step to make the prediction
        last_time_Step = out[:, -1, :]

        out = self.fc(last_time_Step)
        return out

In [9]:
# main execution

# configuration
BATCH_SIZE = 32
HIDDEN_SIZE = 64
NUM_LAYERS = 2
LEARNING_RATE = 0.001
EPOCHS = 50
DEVICE = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')

print(DEVICE)

mps


In [10]:
# get data
# X_raw shape is usually (Trials, Channels, Time) -> (576, 22, 1001)
X_raw, y_raw = get_bci_data(subject_id=1)

X_raw, y_raw

Choosing from all possible events


Downloading/Loading data for Subject 1...
Data loaded: (576, 22, 1001), Classes: ['feet' 'left_hand' 'right_hand' 'tongue']


(array([[[ 5.52359238e+00,  6.05479173e+00,  5.55222300e+00, ...,
          -5.72856071e-01, -2.82827100e+00, -5.00875118e+00],
         [ 1.91658213e+00,  1.74039129e+00,  8.80343196e-01, ...,
           5.06184938e-01, -3.31466548e+00, -6.65938125e+00],
         [ 3.41951163e+00,  3.84870584e+00,  3.50987378e+00, ...,
          -1.52997242e+00, -5.09625352e+00, -7.88780877e+00],
         ...,
         [-1.18565960e+00, -1.51034610e+00, -2.41203985e+00, ...,
          -2.33472371e+00, -6.69460626e+00, -9.35521896e+00],
         [-2.35736743e+00, -3.05259983e+00, -3.87318513e+00, ...,
          -2.61970359e+00, -6.21065549e+00, -8.26673539e+00],
         [-1.09144736e+00, -1.01641231e+00, -1.73219282e+00, ...,
           6.21774064e-02, -3.33972413e+00, -5.84536201e+00]],
 
        [[-5.94328262e+00, -5.49236729e+00, -3.71652245e+00, ...,
           9.71111097e+00,  8.34966036e+00,  5.93386064e+00],
         [-6.39517030e+00, -6.41812583e+00, -4.88865165e+00, ...,
           7.27015452

In [14]:
# format data for pytorch
# RNN expects (Batch, time, channels). MNE gives (Batch, Channel, Time)
X_raw = np.transpose(X_raw, (0, 2, 1))
X_raw.shape

(576, 1001, 22)

In [23]:
# Normalization
# we reshape to (Trials*Time, Channels) to fit scaler, then reshape back
N, T, C = X_raw.shape
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_raw.reshape(-1, C)).reshape(N, T, C)

X_scaled.shape

(576, 1001, 22)

In [24]:
# Convert to tensor 
X_tensor = torch.tensor(X_scaled, dtype=torch.float32).to(DEVICE)
y_tensor = torch.tensor(y_raw, dtype=torch.float32).to(DEVICE)

In [29]:
X_tensor.to(DEVICE), y_tensor.to(DEVICE)

(tensor([[[ 1.1110e+00,  4.0182e-01,  6.6990e-01,  ..., -2.2847e-01,
           -4.4973e-01, -2.0024e-01],
          [ 1.2178e+00,  3.6491e-01,  7.5395e-01,  ..., -2.9111e-01,
           -5.8245e-01, -1.8645e-01],
          [ 1.1167e+00,  1.8469e-01,  6.8759e-01,  ..., -4.6508e-01,
           -7.3909e-01, -3.1797e-01],
          ...,
          [-1.1504e-01,  1.0629e-01, -2.9939e-01,  ..., -4.5017e-01,
           -4.9981e-01,  1.1736e-02],
          [-5.6860e-01, -6.9433e-01, -9.9780e-01,  ..., -1.2914e+00,
           -1.1853e+00, -6.1334e-01],
          [-1.0071e+00, -1.3952e+00, -1.5445e+00,  ..., -1.8047e+00,
           -1.5778e+00, -1.0737e+00]],
 
         [[-1.1950e+00, -1.3398e+00, -1.3647e+00,  ..., -6.3509e-01,
           -2.3351e-01, -1.5581e-02],
          [-1.1044e+00, -1.3446e+00, -1.4517e+00,  ..., -9.0114e-01,
           -4.1786e-01, -2.8145e-01],
          [-7.4723e-01, -1.0241e+00, -1.1936e+00,  ..., -1.0655e+00,
           -5.9936e-01, -5.8158e-01],
          ...,
    

In [33]:
# Split data
X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)

In [34]:
X_train.shape, X_test.shape

(torch.Size([460, 1001, 22]), torch.Size([116, 1001, 22]))

In [35]:
# Data Loaders
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=BATCH_SIZE)

In [49]:
for X_train, y_train in train_loader:
    print(X_train.to(DEVICE).shape, y_train.to(DEVICE).shape)

torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([32, 1001, 22]) torch.Size([32])
torch.Size([12, 1001, 22]) torch.Size([12])
