Notebook to train the model.

Dataset:
    lib/Datasets.py:
        EEGDataset

Models:
    lib/models/EEG_Net_CNN.py : EEG_Net_CNN - a simple convolutional neural network [conv -> conv -> conv -> dense] with ELU activation, pooling, and batch normalization

In [1]:
# Packages
from lib.models.EEG_Net_CNN import EEG_Net_CNN
from lib.utils import load_data
from lib.DataObject import DataObject
import lib.DataObjectUtils as util
import torch
import pickle
import torch.nn as nn
from lib.DataHandler import DataAcquisitionHandler
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import WeightedRandomSampler

pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [5]:
# Get Data

# Get list of handler objects
dir = ["C:/Users/c25th/code/P300_BCI_Speller/data/box_data_handler", "C:/Users/c25th/code/P300_BCI_Speller/data/keyboard_data_handler"]
handler_list = []
for d in dir:
    handler_list.extend(load_data(d))

# Make all Handler objects into one single DataObject
data = DataObject(handler_list.pop(0).get_data())
for handler in handler_list:
    new_data_obj = DataObject(handler.get_data())
    data.accept(util.AddDataVisitor(new_data_obj))

# Apply Filters
data.accept(util.BandpassFilterVisitor(low=0.1, high=10))
data.accept(util.BandstopFilterVisitor(low=49, high=51))

# Extract data
key_data, box_data = data.get_data(decorator=util.MakeTensorWindowsDataDecorator())

# Verify everything worked as expected
sample = box_data[0]
print("List of samples:" ,len(box_data))
print("Sample - (channels, label):", len(box_data[0]))
print("Channels:", len(box_data[0][0]))
max_channel_len = 0
min_channel_len = 10000000
for sample in box_data:
    for channel in sample[0]:
        if len(channel) > max_channel_len:
            max_channel_len = len(channel)
        if len(channel) < min_channel_len:
            min_channel_len = len(channel)
print("Max Channel Len - [reading_1, ...]:", max_channel_len)
print("Min Channel Len - [reading_1, ...]:", min_channel_len)
print("Sample example:", sample)

List of samples: 354
Sample - (channels, label): 2
Channels: 24
Max Channel Len - [reading_1, ...]: 250
Min Channel Len - [reading_1, ...]: 250
Sample example: (tensor([[ 2.4500e+02,  2.4600e+02,  2.4700e+02,  ...,  2.3600e+02,
          2.3700e+02,  2.3800e+02],
        [ 8.3703e+03,  8.4933e+03,  8.5348e+03,  ...,  8.4436e+03,
          8.5442e+03,  8.3925e+03],
        [-6.4812e+00, -6.1533e+00, -5.2448e+00,  ..., -3.8378e+00,
         -3.7535e+00, -3.5997e+00],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00],
        [ 1.6987e+09,  1.6987e+09,  1.6987e+09,  ...,  1.6987e+09,
          1.6987e+09,  1.6987e+09],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  0.0000e+00]]), tensor(1))
