In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary


from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset

from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from config.default import cfg


from models.conditioned_eegnet import ConditionedEEGNet

from utils.eval import accuracy
from utils.model import print_parameters

%load_ext autoreload
%autoreload 2


In [16]:
train_subjects = [1,2,3,4,5,6]
test_subjects = [7,8,9]
train_runs = {
                1:[0, 1, 2, 3, 4,5],
                2:[0, 1, 2, 3, 4,5],
                3:[0, 1, 2, 3, 4,5],
                4:[0, 1,2],
                5:[0, 1, 2, 3, 4,5],
                6:[0, 1, 2, 3, 4,5]                
        }
test_runs = {
                7:[0, 1, 2, 3, 4,5],
                8:[1,2,3,4,5],
                9:[1,2,3,4,5]
}

batch_size = 64

In [17]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [18]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [19]:
train_datasets = []

for subject in train_subjects:
    dataset = MI_Dataset_single_subject(subject, train_runs[subject], return_subject_id=True, device=device, verbose=False)
    train_datasets.append(dataset)
    channels = dataset.channels
    time_steps = dataset.time_steps
train_dataset = ConcatDataset(train_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f"Train dataset: {len(train_dataset)} samples")

25
25
25
25
25
25
Train dataset: 1584 samples


In [20]:
test_datasets = []
for subject in test_subjects:
    test_datasets.append(MI_Dataset_single_subject(subject, test_runs[subject],return_subject_id=True, device=device, verbose=False))
test_dataset = ConcatDataset(test_datasets)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"Test dataset: {len(test_dataset)} samples")

25
25
25
Test dataset: 768 samples


In [29]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for feature, label in train_dataloader:
    print(feature[0].shape)
    break

Train dataset: 1584 samples
Test dataset: 768 samples
torch.Size([64, 3, 401])


In [25]:
model = ConditionedEEGNet( channels = channels, samples= time_steps, num_classes = 4)
model.to(device)
print_parameters(model)

eeg_processor.conv1.weight.... --> 1024
eeg_processor.bn1.weight...... --> 16
eeg_processor.bn1.bias........ --> 16
eeg_processor.dw_conv1.weight. --> 96
eeg_processor.bn2.weight...... --> 32
eeg_processor.bn2.bias........ --> 32
eeg_processor.sep_conv1.weight --> 512
eeg_processor.conv2.weight.... --> 1024
eeg_processor.bn3.weight...... --> 32
eeg_processor.bn3.bias........ --> 32
subject_processor.fn1.weight.. --> 96
subject_processor.fn1.bias.... --> 16
query.weight.................. --> 12288
key.weight.................... --> 512
value.weight.................. --> 12288
fn1.weight.................... --> 4096
fn1.bias...................... --> 128
fn2.weight.................... --> 512
fn2.bias...................... --> 4
eeg_fn.weight................. --> 12288
eeg_fn.bias................... --> 32

Total Parameter Count:........ --> 45076


In [26]:
print(model)

ConditionedEEGNet(
  (eeg_processor): EEGNet(
    (conv1): Conv2d(1, 16, kernel_size=(1, 64), stride=(1, 1), padding=(0, 32), bias=False)
    (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (dw_conv1): Conv2d(16, 32, kernel_size=(3, 1), stride=(1, 1), groups=16, bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (activation): ELU(alpha=1.0)
    (avg_pool1): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (dropout1): Dropout(p=0.5, inplace=False)
    (sep_conv1): Conv2d(32, 32, kernel_size=(1, 16), stride=(1, 1), padding=(0, 8), groups=32, bias=False)
    (conv2): Conv2d(32, 32, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (avg_pool2): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
    (flatten): Flatten(start_dim=1, end_dim=-1)
  )
  (subject_processor): FeedForward(
    (

In [32]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features[0], batch_features[1])
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        #test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
#print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 10/200, Loss: 31.80426251888275, Train accuracy: 36.11%
Epoch 20/200, Loss: 32.11807405948639, Train accuracy: 33.14%
Epoch 30/200, Loss: 31.69123125076294, Train accuracy: 36.36%
Epoch 40/200, Loss: 32.04901099205017, Train accuracy: 34.72%
Epoch 50/200, Loss: 31.984686851501465, Train accuracy: 35.98%
Epoch 60/200, Loss: 31.990182399749756, Train accuracy: 34.79%
Epoch 70/200, Loss: 31.802629351615906, Train accuracy: 33.90%
Epoch 80/200, Loss: 31.799349188804626, Train accuracy: 35.61%
Epoch 90/200, Loss: 32.042041301727295, Train accuracy: 34.47%
Epoch 100/200, Loss: 32.09911668300629, Train accuracy: 34.34%
Epoch 110/200, Loss: 31.986289381980896, Train accuracy: 31.88%
Epoch 120/200, Loss: 32.12982439994812, Train accuracy: 34.03%
Epoch 130/200, Loss: 31.931111216545105, Train accuracy: 33.84%
Epoch 140/200, Loss: 31.95639967918396, Train accuracy: 33.21%
Epoch 150/200, Loss: 31.923577547073364, Train accuracy: 36.11%
Epoch 160/200, Loss: 31.925588250160217, Train accuracy:

In [None]:
torch.save(model.state_dict(), 'model_state_all_subjectsv2.pth')