In [None]:
import os
import torch

from pytorch_lightning import Trainer, seed_everything

from bci_aic3.config import load_model_config, load_training_config
from bci_aic3.paths import (
    PROCESSED_DATA_DIR,
    MI_CONFIG_PATH,
    SSVEP_CONFIG_PATH,
)
from bci_aic3.train import BCILightningModule, setup_callbacks, create_processed_data_loaders
from bci_aic3.util import rec_cpu_count


In [2]:
# Code necessary to create reproducible runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed_everything(42, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)


Seed set to 42


In [8]:
from typing import Tuple


def train_model(config_path) -> Tuple[Trainer, BCILightningModule]:
    model_config = load_model_config(config_path)
    training_config = load_training_config(config_path)

    max_num_workers = rec_cpu_count()

    # Create data loaders
    train_loader, val_loader = create_processed_data_loaders(
        processed_data_dir=PROCESSED_DATA_DIR,
        task_type=model_config.task_type,
        batch_size=training_config.batch_size,
        num_workers=max_num_workers,
    )
    
    # Create Lightning module
    model = BCILightningModule(
        num_classes=model_config.num_classes,
        num_channels=model_config.num_channels,
        sequence_length=model_config.new_sequence_length,
        lr=training_config.learning_rate,
    )

    # Setup callbacks
    callbacks = setup_callbacks(model_config)
    
    # Create trainer
    trainer = Trainer(
        max_epochs=training_config.epochs,
        callbacks=callbacks,
        accelerator="auto",  # Automatically uses GPU if available
        devices="auto",  # Uses all available devices
        deterministic=True,  # For reproducibility
        log_every_n_steps=10,
    )
    
    trainer.fit(model, train_loader, val_loader)
    
    return trainer, model

In [None]:
trainer, model = train_model(config_path=MI_CONFIG_PATH)

In [6]:
# Train the model
trainer.fit(model, train_loader, val_loader)

p:\Programming\AIC3\repo\bci_aic3\.venv\Lib\site-packages\pytorch_lightning\callbacks\model_checkpoint.py:654: Checkpoint directory P:\Programming\AIC3\repo\bci_aic3\checkpoints\MI exists and is not empty.

  | Name           | Type               | Params | Mode 
--------------------------------------------------------------
0 | model          | EEGNet             | 36.2 K | train
1 | criterion      | CrossEntropyLoss   | 0      | train
2 | train_accuracy | MulticlassAccuracy | 0      | train
3 | val_accuracy   | MulticlassAccuracy | 0      | train
4 | train_f1       | MulticlassF1Score  | 0      | train
5 | val_f1         | MulticlassF1Score  | 0      | train
--------------------------------------------------------------
36.2 K    Trainable params
0         Non-trainable params
36.2 K    Total params
0.145     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  return F.conv2d(


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.862
Epoch 0, global step 19: 'val_f1' reached 0.41667 (best 0.41667), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4167-epoch=00.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.029 >= min_delta = 0.0. New best score: 0.832
Epoch 1, global step 38: 'val_f1' reached 0.47240 (best 0.47240), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4724-epoch=01.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 57: 'val_f1' reached 0.45455 (best 0.47240), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4545-epoch=02.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 3, global step 76: 'val_f1' reached 0.45455 (best 0.47240), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4545-epoch=03.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 95: 'val_f1' reached 0.46634 (best 0.47240), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4663-epoch=04.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 114: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.019 >= min_delta = 0.0. New best score: 0.813
Epoch 6, global step 133: 'val_f1' reached 0.47457 (best 0.47457), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4746-epoch=06.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.813
Epoch 7, global step 152: 'val_f1' reached 0.47457 (best 0.47457), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4746-epoch=07.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 171: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 190: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 10, global step 209: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 11, global step 228: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 12, global step 247: 'val_f1' reached 0.47457 (best 0.47457), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4746-epoch=12-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 13, global step 266: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 14, global step 285: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 15, global step 304: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 16, global step 323: 'val_f1' reached 0.48326 (best 0.48326), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4833-epoch=16-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.810
Epoch 17, global step 342: 'val_f1' reached 0.48326 (best 0.48326), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4833-epoch=17-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.806
Epoch 18, global step 361: 'val_f1' reached 0.49495 (best 0.49495), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.4949-epoch=18-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.058 >= min_delta = 0.0. New best score: 0.748
Epoch 19, global step 380: 'val_f1' reached 0.55929 (best 0.55929), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.5593-epoch=19-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 20, global step 399: 'val_f1' reached 0.52000 (best 0.55929), saving model to 'P:\\Programming\\AIC3\\repo\\bci_aic3\\checkpoints\\MI\\eegnet-mi-best-f1-val_f1=0.5200-epoch=20-v1.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 21, global step 418: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 22, global step 437: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 23, global step 456: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 24, global step 475: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 25, global step 494: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 26, global step 513: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 27, global step 532: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 28, global step 551: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Monitored metric val_loss did not improve in the last 10 records. Best score: 0.748. Signaling Trainer to stop.
Epoch 29, global step 570: 'val_f1' was not in top 3


In [23]:
from bci_aic3.data import BCIDataset
from bci_aic3.paths import LABEL_MAPPING_PATH, RAW_DATA_DIR
from bci_aic3.util import read_json_to_dict


test = BCIDataset(csv_file="test.csv",
                  base_path=RAW_DATA_DIR,
                  task_type="MI",
                  split="test",
                  label_mapping=read_json_to_dict(LABEL_MAPPING_PATH),
                  )

100%|██████████| 50/50 [00:02<00:00, 18.30it/s]


In [36]:
import numpy as np
from bci_aic3.config import load_processing_config
from bci_aic3.preprocess import apply_all_preprocessing_steps, preprocessing_pipeline
from torch.utils.data import DataLoader, TensorDataset

processing_config = load_processing_config(MI_CONFIG_PATH)

data_loader = DataLoader(test, batch_size=len(test), shuffle=False)
data_batch = next(iter(data_loader))

data = data_batch.numpy()

processed_data = apply_all_preprocessing_steps(
    data=data, settings=processing_config
)

processed_data_path = PROCESSED_DATA_DIR / "MI" / "test_data.npy"

np.save(processed_data_path, processed_data)
print(f"Processed data successfully saved at: {processed_data_path}")

Processed data successfully saved at: P:\Programming\AIC3\repo\bci_aic3\data\processed\MI\test_data.npy


In [37]:
test_data = np.load(processed_data_path)
test_data.shape

(50, 8, 1500)

In [38]:
from pathlib import Path
from typing import Dict
from bci_aic3.paths import TRAINING_STATS_PATH
# from bci_aic3.util import load_training_stats


def load_training_stats(load_path: Path) -> Dict[str, torch.Tensor]:
    """Load training statistics from disk."""
    return torch.load(load_path, weights_only=False)

training_stats = load_training_stats(TRAINING_STATS_PATH / "mi_stats.pt")
training_stats

{'mean': array([0.9257175 , 1.1769003 , 1.1218044 , 0.9505266 , 0.90054685,
        0.8140577 , 0.7452115 , 0.5988763 ], dtype=float32),
 'std': array([199.4371  , 106.681366, 352.71765 , 100.56845 , 187.39577 ,
         83.52705 ,  99.34637 ,  69.068184], dtype=float32)}

In [39]:

from bci_aic3.util import apply_normalization


normalized_test_data = apply_normalization(test_data, training_stats["mean"], training_stats["std"])
normalized_test_data.mean(), normalized_test_data.std()

(np.float32(-0.015454504), np.float32(1.8086668))

In [51]:
test_tensor = torch.from_numpy(normalized_test_data).float()
test_dataset = TensorDataset(test_tensor, torch.empty(len(normalized_test_data)))
test_loader = DataLoader(test_dataset,
                         batch_size=len(test_dataset),
                         shuffle=False)

model.eval()

with torch.no_grad():
    data_batch, labels = next(iter(test_loader))

    preds = model(data_batch)
    

In [53]:
preds.argmax(dim=1)

tensor([1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0,
        1, 1])

In [62]:
# from bci_aic3.paths import REVERSE_LABEL_MAPPING_PATH
from bci_aic3.paths import CONFIG_DIR


REVERSE_LABEL_MAPPING_PATH = CONFIG_DIR / "reverse_label_mapping.json"


reverse_label_mapping = read_json_to_dict(REVERSE_LABEL_MAPPING_PATH)
preds_labels = [reverse_label_mapping[str(p.item())] for p in preds.argmax(dim=1)]
preds_labels

['Right',
 'Right',
 'Right',
 'Left',
 'Right',
 'Right',
 'Right',
 'Right',
 'Left',
 'Left',
 'Right',
 'Right',
 'Left',
 'Left',
 'Left',
 'Left',
 'Right',
 'Right',
 'Right',
 'Right',
 'Right',
 'Left',
 'Left',
 'Left',
 'Right',
 'Right',
 'Right',
 'Right',
 'Right',
 'Right',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Left',
 'Right',
 'Left',
 'Left',
 'Right',
 'Left',
 'Left',
 'Right',
 'Right']

In [58]:
reverse_label_mapping

{'Left': 0, 'Right': 1, 'Forward': 2, 'Backward': 3}

In [None]:
from bci_aic3.paths import RUNS_DIR


scripted = model.to_torchscript()

torch.jit.save(scripted, RUNS_DIR / "mi_scripted.pt")