# MMEEG Tutorial

This tutorial shows how to use this repository to train your model. You can implement your model and register it in the models registry, similar to the current models - EEGConformer and EEGNet. The rest of the steps are the same. This allows for easy abstraction of training code and quick, config-driven experimentation.

Start off by importing the necessary dependencies.

In [1]:
from torch.optim import SGD
from torch.utils.data import DataLoader

from mmengine.evaluator import BaseMetric
from mmengine.runner import Runner
from mmengine.config import Config

from models.registry import MODELS
from datasets.registry import DATASETS

  from .autonotebook import tqdm as notebook_tqdm


Next, select the config for the model you wish to train. See the configs folder for examples.

In [4]:
# cfg = Config.fromfile('configs/eeg_conformer_config.py')
cfg = Config.fromfile('configs/eegnet_config.py')
model = MODELS.build(cfg.model)

First blocks out shape: torch.Size([1, 16, 1, 15])


In [5]:
model

MMEEGNet(
  (data_preprocessor): BaseDataPreprocessor()
  (model): eegNet(
    (firstBlocks): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 8, kernel_size=(1, 125), stride=(1, 1), padding=(0, 62), bias=False)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2dWithConstraint(8, 16, kernel_size=(64, 1), stride=(1, 1), groups=8, bias=False)
        (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ELU(alpha=1.0)
        (5): AvgPool2d(kernel_size=(1, 4), stride=4, padding=0)
        (6): Dropout(p=0.5, inplace=False)
      )
      (1): Sequential(
        (0): Conv2d(16, 16, kernel_size=(1, 22), stride=(1, 1), padding=(0, 11), groups=16, bias=False)
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ELU(alpha=1.0)
        (4): AvgPool2d(kernel

You can also define custom metrics for your experiment if required. Below is an example implementation of the Accuracy metric.

In [7]:
class Accuracy(BaseMetric):
    def process(self, data_batch, data_samples):
        score, gt = data_samples
        self.results.append({
            'batch_size': len(gt),
            'correct': (score.argmax(dim=1) == gt).sum().cpu(),
        })

    def compute_metrics(self, results):
        total_correct = sum(item['correct'] for item in results)
        total_size = sum(item['batch_size'] for item in results)
        return dict(accuracy=100 * total_correct / total_size)

The standard EEGDataset is provided with the repository and need not be reimplemented if your dataformat is the same. However, you can implement your own Dataset class and register it with the DATASETS registry to use it like below to build your training, validation and testing datasets. Please see the implementation of the EEGDataset under the datasets folder for an example. 

In [10]:
dataset = DATASETS.build(cfg.dataset)

In [11]:
val_dataset = DATASETS.build(cfg.val_dataset)

In [14]:
dataset[0][0].shape

torch.Size([1, 64, 480])

Create a dataloder next, which will be used by the mmengine runner. I'm using the default torch.utils.data Dataloader here.

In [15]:
train_dataloader = DataLoader(batch_size=32,
                              shuffle=True,
                              dataset=dataset)



In [16]:
val_dataloader = DataLoader(batch_size=32,
                            shuffle=False,
                            dataset=val_dataset)



And that's it! We're ready to train the model. You can set the optimizer to be used, the number of epochs, how often to perform validation, etc. More options for the mmengine runner can be found in their official [documentation](https://mmengine.readthedocs.io/en/latest/tutorials/runner.html).

In [17]:
runner = Runner(
    model=model,
    work_dir='./work_dir',
    train_dataloader=train_dataloader,
    optim_wrapper=dict(optimizer=dict(type=SGD, lr=0.001, momentum=0.9)),
    train_cfg=dict(by_epoch=True, max_epochs=5, val_interval=1),
    val_dataloader=val_dataloader,
    val_cfg=dict(),
    val_evaluator=dict(type=Accuracy),
)
runner.train()

04/03 02:40:31 - mmengine - [4m[97mINFO[0m - 
------------------------------------------------------------
System environment:
    sys.platform: linux
    Python: 3.10.9 (main, Jan 11 2023, 15:21:40) [GCC 11.2.0]
    CUDA available: True
    numpy_random_seed: 332588215
    GPU 0: NVIDIA A40
    CUDA_HOME: /usr/local/cuda
    NVCC: Cuda compilation tools, release 11.8, V11.8.89
    GCC: gcc (GCC) 11.2.0
    PyTorch: 1.13.1+cu117
    PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201402
  - Intel(R) Math Kernel Library Version 2020.0.0 Product Build 20191122 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=comput



04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Exp name: 20230403_024031
04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Saving checkpoint at 1 epochs
04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Epoch(val) [1][8/8]  accuracy: 22.4900
04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Exp name: 20230403_024031
04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Saving checkpoint at 2 epochs
04/03 02:40:34 - mmengine - [4m[97mINFO[0m - Epoch(val) [2][8/8]  accuracy: 24.4980
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Exp name: 20230403_024031
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Saving checkpoint at 3 epochs
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Epoch(val) [3][8/8]  accuracy: 26.9076
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Exp name: 20230403_024031
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Saving checkpoint at 4 epochs
04/03 02:40:35 - mmengine - [4m[97mINFO[0m - Epoch(val) [4][8/8]  accuracy: 32.1285
04/03 02:40:35 - mmengine - [4m[97mINFO[0

MMEEGNet(
  (data_preprocessor): BaseDataPreprocessor()
  (model): eegNet(
    (firstBlocks): Sequential(
      (0): Sequential(
        (0): Conv2d(1, 8, kernel_size=(1, 125), stride=(1, 1), padding=(0, 62), bias=False)
        (1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): Conv2dWithConstraint(8, 16, kernel_size=(64, 1), stride=(1, 1), groups=8, bias=False)
        (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (4): ELU(alpha=1.0)
        (5): AvgPool2d(kernel_size=(1, 4), stride=4, padding=0)
        (6): Dropout(p=0.5, inplace=False)
      )
      (1): Sequential(
        (0): Conv2d(16, 16, kernel_size=(1, 22), stride=(1, 1), padding=(0, 11), groups=16, bias=False)
        (1): Conv2d(16, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (3): ELU(alpha=1.0)
        (4): AvgPool2d(kernel