In [26]:
from src.dataset.Moabb2BGenerator_One_Person import Moabb2BGenerator
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

from torch.utils.data import TensorDataset, DataLoader
from torch.utils.data import ConcatDataset


from config.default import cfg

from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from models.conditioned_eegnet import ConditionedEEGNet

from utils.eval import accuracy
from utils.model import print_parameters

%load_ext autoreload
%autoreload 2



The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
subjects = [1,2,3,4,5,6,7,8,9]
batch_size=64

In [3]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cpu')

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [34]:
train_datasets_2b = []

for subject in subjects:
    dataset = Moabb2BGenerator(subject, return_subject_id=True, device=device,runs=[0,1,2])
    train_datasets_2b.append(dataset)
    channels = dataset.channels
    time_steps = dataset.time_steps
train_dataset_2b = ConcatDataset(train_datasets_2b)

train_dataloader_2b = DataLoader(train_dataset_2b, batch_size=batch_size, shuffle=True)
print(f"Train dataset: {len(train_datasets_2b)} samples")

Train dataset: 9 samples


In [35]:
test_datasets_2b = []
for subject in subjects:
    test_datasets_2b.append(Moabb2BGenerator(subject, runs=[3,4],return_subject_id=True, device=device))
test_dataset_2b = ConcatDataset(test_datasets_2b)

test_dataloader_2b = DataLoader(test_dataset_2b, batch_size=batch_size, shuffle=False)
print(f"Test dataset: {len(test_datasets_2b)} samples")

Test dataset: 9 samples


In [37]:
print(f"Train dataset: {len(train_dataset_2b)} samples")
print(f"Test dataset: {len(test_dataset_2b)} samples")

for feature, label in train_dataloader_2b:
    print(feature[0].shape)
    break

Train dataset: 1296 samples
Test dataset: 864 samples
torch.Size([64, 3, 401])


In [38]:
channels =  channels
samples =  time_steps
num_classes_2b = 2

# Modell betöltése 3 csatornával, 2 kimenettel
new_eeg_model = ConditionedEEGNet( channels=channels, samples=samples, num_classes=num_classes_2b, num_subjects=len(subjects))

In [40]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(new_eeg_model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader_2b:
        optimizer.zero_grad()
        outputs = new_eeg_model(batch_features[0],batch_features[1])
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(new_eeg_model, train_dataloader_2b)
        test_accuracy = accuracy(new_eeg_model, test_dataloader_2b)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(new_eeg_model, train_dataloader_2b):.2f}%')
print(f'Final test accuracy: {accuracy(new_eeg_model, test_dataloader_2b):.2f}%')

Epoch 10/200, Loss: 11.29024949669838, Train accuracy: 72.53%, Test accuracy: 67.36%
Epoch 20/200, Loss: 10.957565128803253, Train accuracy: 73.30%, Test accuracy: 67.82%
Epoch 30/200, Loss: 11.083590507507324, Train accuracy: 73.69%, Test accuracy: 68.87%
Epoch 40/200, Loss: 11.032596319913864, Train accuracy: 73.77%, Test accuracy: 69.10%
Epoch 50/200, Loss: 11.10489371418953, Train accuracy: 71.68%, Test accuracy: 68.06%
Epoch 60/200, Loss: 11.454714268445969, Train accuracy: 72.38%, Test accuracy: 67.94%
Epoch 70/200, Loss: 11.383916020393372, Train accuracy: 70.52%, Test accuracy: 67.01%
Epoch 80/200, Loss: 11.300152122974396, Train accuracy: 69.75%, Test accuracy: 68.40%
Epoch 90/200, Loss: 11.205676168203354, Train accuracy: 70.91%, Test accuracy: 64.70%
Epoch 100/200, Loss: 11.36954739689827, Train accuracy: 71.37%, Test accuracy: 66.67%
Epoch 110/200, Loss: 11.340649604797363, Train accuracy: 70.83%, Test accuracy: 65.05%
Epoch 120/200, Loss: 11.232110887765884, Train accuracy

In [49]:
torch.save(new_eeg_model.state_dict(), 'model_state_2b_all_subjects.pth')

## Loading data for 2a dataset

In [41]:
subjects = [1,2,3,4,5,6,7,8,9]
train_runs = {
                1:[0, 1, 2, 3, 4],
                2:[0, 1, 2, 3, 4],
                3:[0, 1, 2, 3, 4],
                4:[0, 1],
                5:[0, 1, 2, 3, 4],
                6:[0, 1, 2, 3, 4],
                7:[0, 1, 2, 3, 4],
                8:[0, 1, 2, 3, 4],
                9:[0, 1, 2, 3, 4]
        }
test_runs = {
                1:[5],
                2:[5],
                3:[5],
                4:[2],
                5:[5],
                6:[5],
                7:[5],
                8:[5],
                9:[5]
}

batch_size = 64

In [42]:
test_datasets = []
for subject in subjects:
    test_datasets.append(MI_Dataset_single_subject(subject, test_runs[subject],return_subject_id=True, device=device, verbose=False))
test_dataset = ConcatDataset(test_datasets)

test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(f"Test dataset: {len(test_dataset)} samples")

25
25
25
25
25
25
25
25
25
Test dataset: 432 samples


In [43]:
train_datasets = []

for subject in subjects:
    dataset = MI_Dataset_single_subject(subject, train_runs[subject], return_subject_id=True, device=device, verbose=False)
    train_datasets.append(dataset)
    channels = dataset.channels
    time_steps = dataset.time_steps
train_dataset = ConcatDataset(train_datasets)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
print(f"Train dataset: {len(train_dataset)} samples")

25
25
25
25
25
25
25
25
25
Train dataset: 2016 samples


In [44]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for feature, label in train_dataloader:
    # print(feature[0].shape)
    # print(feature[1].shape)
    # print(label)
    print(feature[1])
    break

Train dataset: 2016 samples
Test dataset: 432 samples
tensor([6, 4, 8, 6, 4, 0, 6, 7, 6, 2, 0, 4, 8, 2, 8, 1, 5, 8, 1, 2, 1, 1, 1, 6,
        3, 4, 7, 4, 7, 0, 4, 0, 4, 0, 6, 8, 6, 2, 6, 8, 3, 0, 2, 2, 7, 5, 0, 5,
        1, 7, 4, 1, 4, 4, 6, 8, 0, 7, 6, 5, 6, 1, 6, 5])


In [45]:
def load_model(model_path: str, channels: int, samples: int, num_classes: int) -> torch.nn.Module:
    # ConditionedEEGNet példányosítása
    model = ConditionedEEGNet(channels=channels, samples=samples, num_classes=num_classes,num_subjects=len(subjects))

    # Modell súlyainak betöltése, kihagyva az fn2 réteget
    model_weights = torch.load(model_path, map_location=device)
    model_weights = {k: v for k, v in model_weights.items() if 'fn2' not in k}
    model.load_state_dict(model_weights, strict=False)

    # Az fn2 réteg cseréje,4 kimenetre
    in_features = model.fn2.in_features 
    model.fn2 = nn.Linear(in_features, 4)  

    model.to(device)
    return model

In [46]:
model = load_model(model_path="model_state_2b_all_subjects.pth", channels = channels, samples = samples, num_classes = 4 )

In [47]:
model.to(device)
print_parameters(model)

eeg_processor.conv1.weight.... --> 1024
eeg_processor.bn1.weight...... --> 16
eeg_processor.bn1.bias........ --> 16
eeg_processor.dw_conv1.weight. --> 96
eeg_processor.bn2.weight...... --> 32
eeg_processor.bn2.bias........ --> 32
eeg_processor.sep_conv1.weight --> 512
eeg_processor.conv2.weight.... --> 1024
eeg_processor.bn3.weight...... --> 32
eeg_processor.bn3.bias........ --> 32
subject_processor.fn1.weight.. --> 144
subject_processor.fn1.bias.... --> 16
query.weight.................. --> 12288
key.weight.................... --> 512
value.weight.................. --> 12288
fn1.weight.................... --> 4096
fn1.bias...................... --> 128
fn2.weight.................... --> 512
fn2.bias...................... --> 4
eeg_fn.weight................. --> 12288
eeg_fn.bias................... --> 32

Total Parameter Count:........ --> 45124


## Training 2a model with all subjects and 2b weigths

In [48]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features[0], batch_features[1])
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 10/200, Loss: 42.40646207332611, Train accuracy: 35.17%, Test accuracy: 33.33%
Epoch 20/200, Loss: 42.41975271701813, Train accuracy: 36.06%, Test accuracy: 32.18%
Epoch 30/200, Loss: 42.033458948135376, Train accuracy: 35.76%, Test accuracy: 35.42%
Epoch 40/200, Loss: 42.2812123298645, Train accuracy: 34.28%, Test accuracy: 33.56%
Epoch 50/200, Loss: 41.92963111400604, Train accuracy: 34.97%, Test accuracy: 33.33%
Epoch 60/200, Loss: 42.421858072280884, Train accuracy: 36.11%, Test accuracy: 33.56%
Epoch 70/200, Loss: 42.07186138629913, Train accuracy: 35.12%, Test accuracy: 32.18%
Epoch 80/200, Loss: 42.292325139045715, Train accuracy: 34.97%, Test accuracy: 33.56%
Epoch 90/200, Loss: 42.151368141174316, Train accuracy: 36.36%, Test accuracy: 35.42%
Epoch 100/200, Loss: 42.27779138088226, Train accuracy: 35.66%, Test accuracy: 34.49%
Epoch 110/200, Loss: 41.98583507537842, Train accuracy: 34.82%, Test accuracy: 35.65%
Epoch 120/200, Loss: 42.23501908779144, Train accuracy: 35.8

In [50]:
torch.save(model.state_dict(), 'model_state_2a_transfer_learning_all_subjects.pth')