In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary


from torch.utils.data import DataLoader

from src.dataset.MI_dataset_all_subjects import MI_Dataset as MI_Dataset_all_subjects
from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from config.default import cfg


from models.eegnet import EEGNet

from utils.eval import accuracy

%load_ext autoreload
%autoreload 2


In [2]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [3]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
subject = 1
train_runs = [0,1,2,3,4]
test_runs = [5]


train_dataset = MI_Dataset_single_subject(subject, train_runs, device = device)
test_dataset = MI_Dataset_single_subject(subject, test_runs, device = device)

train_dataloader = DataLoader(train_dataset,  batch_size=cfg['train']['batch_size'], shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset,  batch_size=cfg['train']['batch_size'], shuffle=False, drop_last=True)

25
Before:  (5, 48, 3, 401)
After:  torch.Size([240, 3, 401])
25
Before:  (1, 48, 3, 401)
After:  torch.Size([48, 3, 401])


In [5]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for features, label in train_dataloader:
    print(features.shape)
    print(label.shape)
    break
    


Train dataset: 240 samples
Test dataset: 48 samples
torch.Size([48, 3, 401])
torch.Size([48])


In [6]:
train_dataset.time_steps

401

In [7]:
channels = train_dataset.channels
samples = train_dataset.time_steps
model = EEGNet(channels = channels, samples= samples, num_classes = 4)
model.to(device)
summary(model, input_size=(5, 10, *next(iter(train_dataloader))[0][0].shape));

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 16, 3, 402]           1,024
       BatchNorm2d-2           [-1, 16, 3, 402]              32
            Conv2d-3           [-1, 32, 1, 402]              96
       BatchNorm2d-4           [-1, 32, 1, 402]              64
               ELU-5           [-1, 32, 1, 402]               0
         AvgPool2d-6           [-1, 32, 1, 100]               0
           Dropout-7           [-1, 32, 1, 100]               0
            Conv2d-8           [-1, 32, 1, 101]             512
            Conv2d-9           [-1, 32, 1, 101]           1,024
      BatchNorm2d-10           [-1, 32, 1, 101]              64
              ELU-11           [-1, 32, 1, 101]               0
        AvgPool2d-12            [-1, 32, 1, 12]               0
          Dropout-13            [-1, 32, 1, 12]               0
          Flatten-14                  [

In [8]:
# Test forward pass
model(next(iter(train_dataloader))[0]);

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        print(batch_features.shape)
        optimizer.zero_grad()
        outputs = model(batch_features)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])
torch.Size([48, 3, 401])


In [10]:
torch.save(model.state_dict(), 'model_state.pth')