Notebook to train the model.
Uses attention.

In [1]:
# Packages
from lib.models.EEG_Net_CNN import EEG_Net_CNN
import matplotlib.pyplot as plt
from lib.utils import load_data, train, test
from lib.Datasets import EEGDataset
from lib.DataObject import DataObject
import lib.DataObjectUtils as util
import torch
import pickle
import torch.nn as nn
from lib.DataHandler import DataAcquisitionHandler
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data import WeightedRandomSampler
import datetime

pygame 2.5.1 (SDL 2.28.2, Python 3.11.5)
Hello from the pygame community. https://www.pygame.org/contribute.html


In [2]:
# load files
filepath = "C:/Users/c25th/code/P300_BCI_Speller/"
date = "2023-12-04"

train_filename = "data/dataloaders/train_loader_"+str(date)+".txt"
val_filename = "data/dataloaders/val_loader_"+str(date)+".txt"
test_filename = "data/dataloaders/test_loader_"+str(date)+".txt"

train_loader = torch.load(filepath + train_filename)
val_loader = torch.load(filepath + val_filename)
test_loader = torch.load(filepath + test_filename)

# Print dimentions of data
for s, l in train_loader:
    print(s.size())
    print(l.size())
    break

torch.Size([32, 8, 250])
torch.Size([32])


In [84]:
# create p300Model
class EEG_Net_Attention(torch.nn.Module):
    """
    Expecting input of shape (batch_size, channels, readings)
    input = [32, 8, 250] = [batch_size, channels, readings]
    batch_size: number of samples in a batch
    channels: number of channels in a sample (8)
    readings: number of readings in a channel (len())
    """
    
    def __init__(self, num_channels=8, num_classes=1, input_length=250):
        super(EEG_Net_Attention, self).__init__()
        total_len = num_channels*input_length

        # Attention
        self.attention1 = nn.MultiheadAttention(embed_dim=input_length, num_heads=1)
        self.attention2 = nn.MultiheadAttention(embed_dim=input_length, num_heads=1)
        self.attention3 = nn.MultiheadAttention(embed_dim=input_length, num_heads=1)

        # Flatten
        self.flatten = nn.Flatten()

        # Linear
        self.linear1 = nn.Sequential(
            nn.Linear(in_features=total_len, out_features=int(total_len*2)),
            nn.ReLU()
        )
        self.linear2 = nn.Sequential(
            nn.Linear(in_features=int(total_len*2), out_features=int(total_len/2)),
            nn.ReLU()
        )
        self.linear3 = nn.Sequential(
            nn.Linear(in_features=int(total_len/2), out_features=int(total_len/4)),
            nn.ReLU()
        )

        # Out
        self.out = nn.Sequential(
            nn.Linear(in_features=int(total_len/4), out_features=num_classes),
            nn.Softmax(dim=1)
        )

        

    def forward(self, x):
        # Attention
        x, _ = self.attention1(x, x, x)
        x = nn.ReLU()(x)
        x, _ = self.attention2(x, x, x)
        x = nn.ReLU()(x)
        x, _ = self.attention3(x, x, x)
        x = nn.ReLU()(x)

        # Flatten
        x = self.flatten(x)

        # Linear
        x = self.linear1(x)
        x = self.linear2(x)
        x = self.linear3(x)

        # Out
        x = self.out(x)

        return x

In [97]:
# initialize model
model = EEG_Net_Attention()

# loss function
loss_fn = torch.nn.MSELoss()

# optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [104]:
def train(train_dataloader, val_dataloader, model, loss_fn, optimizer, num_epochs, print_every=100):
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for batch, (X, y) in enumerate(train_dataloader):
            optimizer.zero_grad()
            pred = model(X)
            loss = loss_fn(pred.squeeze(), y.float())
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            correct += (pred.argmax(0) == y).type(torch.float).sum().item()
            total += y.size(0)

            if batch % print_every == print_every-1 and print_every > 0:
                print(f'Epoch {epoch+1} - Batch {batch+1}/{len(train_dataloader)} - Loss: {loss.item():.4f}')

        avg_train_loss = running_loss / len(train_dataloader)
        avg_train_acc = correct / total
        
        train_losses.append(avg_train_loss)
        train_accuracies.append(avg_train_acc)
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for X, y in val_dataloader:
                pred = model(X)
                loss = loss_fn(pred.squeeze(), y.float())
                
                val_loss += loss.item()
                correct += (pred.argmax(1) == y).type(torch.float).sum().item()
                total += y.size(0)
        
        avg_val_loss = val_loss / len(val_dataloader)
        avg_val_acc = correct / total
        
        val_losses.append(avg_val_loss)
        val_accuracies.append(avg_val_acc)
        
        print(f'Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f} - Train Accuracy: {avg_train_acc:.4f} - Val Loss: {avg_val_loss:.4f} - Val Accuracy: {avg_val_acc:.4f}')

    # Plotting
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.title("Loss")
    plt.plot(train_losses, label='Train')
    plt.plot(val_losses, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.title("Accuracy")
    plt.plot(train_accuracies, label='Train')
    plt.plot(val_accuracies, label='Validation')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

    return train_losses, val_losses, train_accuracies, val_accuracies

In [105]:
# Train & test model
train_losses, val_losses, train_accuracies, val_accuracies = train(
    train_loader, val_loader, model, loss_fn, optimizer, 100)
print("Maximum Train Accuracy: ", max(train_accuracies))
test(test_loader, model, loss_fn)

  return F.mse_loss(input, target, reduction=self.reduction)


Epoch 1 - Train Loss: 0.5196 - Train Accuracy: 0.5129 - Val Loss: 0.4556 - Val Accuracy: 0.4493
Epoch 2 - Train Loss: 0.5058 - Train Accuracy: 0.5129 - Val Loss: 0.4451 - Val Accuracy: 0.4493
Epoch 3 - Train Loss: 0.5058 - Train Accuracy: 0.5129 - Val Loss: 0.4451 - Val Accuracy: 0.4493
Epoch 4 - Train Loss: 0.5196 - Train Accuracy: 0.5129 - Val Loss: 0.4493 - Val Accuracy: 0.4493


KeyboardInterrupt: 

In [7]:
print("Maximum Train Accuracy: ", max(train_accuracies))
test(test_loader, model, loss_fn)

Maximum Train Accuracy:  0.536441828881847
Test Error: 
 Accuracy: 46.6%, Avg loss: 0.701727 

