# Clean Notebook for once hopefully

In [1]:
import os

import torch
from pytorch_lightning import seed_everything

from bci_aic3.preprocess import preprocess_and_save
from bci_aic3.data import load_raw_data
from bci_aic3.config import load_model_config, load_processing_config
from bci_aic3.paths import (
    LABEL_MAPPING_PATH,
    MI_CONFIG_PATH,
    MI_RUNS_DIR,
    RAW_DATA_DIR,
    SSVEP_CONFIG_PATH,
    SSVEP_RUNS_DIR,
)
from bci_aic3.util import read_json_to_dict
from bci_aic3.train import train_and_save

In [None]:
# To utilize cuda cores 
torch.set_float32_matmul_precision("medium")

# Code necessary to create reproducible runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed_everything(42, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)

Seed set to 42


In [3]:
task_type = "MI"

In [4]:
config_path = None
save_path = None
if task_type.upper() == "MI":
    config_path = MI_CONFIG_PATH
    save_path = MI_RUNS_DIR
elif task_type.upper() == "SSVEP":
    config_path = SSVEP_CONFIG_PATH
    save_path = SSVEP_RUNS_DIR
else:
    raise (
        ValueError(f"Invalid task_type: {task_type}.\nValid task_type (MI) or (SSVEP)")
    )


processing_config = load_processing_config(config_path)
label_mapping = read_json_to_dict(LABEL_MAPPING_PATH)

train, val, _ = load_raw_data(
    base_path=RAW_DATA_DIR,
    task_type=task_type,
    label_mapping=label_mapping,
)

  0%|          | 0/2400 [00:00<?, ?it/s]

100%|██████████| 2400/2400 [01:50<00:00, 21.70it/s]
100%|██████████| 50/50 [00:02<00:00, 21.57it/s]
100%|██████████| 50/50 [00:02<00:00, 21.35it/s]


In [5]:
preprocess_and_save(train_dataset=train,
                    val_dataset=val,
                    task_type=task_type,
                    processing_config=processing_config)


Training Pipeline steps:
1. notch_filter: MNENotchFilter
2. bandpass_filter: BandPassFilter
3. temporal_crop: TemporalCrop
4. artifact_removal: StatisticalArtifactRemoval
5. channel_normalizer: ChannelWiseNormalizer
6. reshaper: EEGReshaper
7. unsqueeze: FunctionTransformer

Testing Pipeline steps:
1. notch_filter: MNENotchFilter
2. bandpass_filter: BandPassFilter
3. temporal_crop: TemporalCrop
4. channel_normalizer: ChannelWiseNormalizer
5. reshaper: EEGReshaper
6. unsqueeze: FunctionTransformer

Original train data shape: (2400, 8, 2250)
Original validation data shape: (50, 8, 2250)

Transformed train data shape: (1826, 1, 8, 2250)
Transformed train labels shape: (1826,)

Transformed validation data shape: (20, 1, 8, 2250)
Transformed validation labels shape: (20,)

Train Channel Means (should be zero mean):
 [ 7.2979239e-10 -6.8336792e-10  6.6108419e-10  9.4241637e-10
 -1.0844752e-09  1.0287659e-09 -1.4312659e-09 -5.7937716e-10]

Train Channel Standard Deviations (should be unit st

In [None]:
train_and_save(task_type=task_type)

Created temporary run directory: /home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress


  return F.conv2d(
/home/Crim/AIC3/bci_aic3/.venv/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:209: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/Crim/AIC3/bci_aic3/.venv/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA GeF

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved. New best score: 0.678
Epoch 0, global step 58: 'val_f1' reached 0.46667 (best 0.46667), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.4667-epoch=00.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.672
Epoch 1, global step 116: 'val_f1' reached 0.37500 (best 0.46667), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.3750-epoch=01.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 2, global step 174: 'val_f1' reached 0.47917 (best 0.47917), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.4792-epoch=02.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Metric val_loss improved by 0.002 >= min_delta = 0.0. New best score: 0.670
Epoch 3, global step 232: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 4, global step 290: 'val_f1' reached 0.43574 (best 0.47917), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.4357-epoch=04.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 5, global step 348: 'val_f1' reached 0.53964 (best 0.53964), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.5396-epoch=05.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 6, global step 406: 'val_f1' reached 0.56044 (best 0.56044), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.5604-epoch=06.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 7, global step 464: 'val_f1' reached 0.56044 (best 0.56044), saving model to '/home/Crim/AIC3/bci_aic3/run/MI/ATCNet-20250628_214307-inprogress/checkpoints/atcnet-mi-best-f1-val_f1=0.5604-epoch=07.ckpt' as top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 8, global step 522: 'val_f1' was not in top 3


Validation: |          | 0/? [00:00<?, ?it/s]

Epoch 9, global step 580: 'val_f1' was not in top 3
