In [None]:
from pytorch_lightning import Trainer

from bci_aic3.config import load_model_config, load_training_config
from bci_aic3.data import BCIDataset, load_data
from bci_aic3.inference import load_models, make_inference, predict_batch
from bci_aic3.models.eegnet import EEGNet
from bci_aic3.paths import (
    CONFIG_DIR,
    LABEL_MAPPING_PATH,
    MODELS_DIR,
    RAW_DATA_DIR,
    REVERSE_LABEL_MAPPING_PATH,
)
from bci_aic3.train import BCILightningModule, create_data_loaders, setup_callbacks
from bci_aic3.util import load_model, read_json_to_dict, rec_cpu_count, save_model

In [None]:
config_file = 'mi_config.yaml'

model_config = load_model_config(CONFIG_DIR / config_file)
training_config = load_training_config(CONFIG_DIR / config_file)

max_num_workers = rec_cpu_count()

# Create data loaders
train_loader, val_loader, test_loader = create_data_loaders(
    base_path=RAW_DATA_DIR,
    task_type=model_config.task_type,
    batch_size=training_config.batch_size,
    num_workers=max_num_workers,
)
print("Loaded the data...")

In [None]:
# Create Lightning module
model = BCILightningModule(
    model_config=model_config,
    training_config=training_config,
)

# Setup callbacks
callbacks = setup_callbacks(model_config)

In [None]:
# Create trainer
trainer = Trainer(
    max_epochs=training_config.epochs,
    callbacks=callbacks,
    accelerator="auto",  # Automatically uses GPU if available
    devices="auto",  # Uses all available devices
    deterministic=True,  # For reproducibility
    log_every_n_steps=10,
)

In [None]:
# Train the model
trainer.fit(model, train_loader, val_loader)

In [None]:
from bci_aic3.paths import CHECKPOINTS_DIR

loaded_model = BCILightningModule.load_from_checkpoint("../checkpoints/MI/eegnet-mi-best-f1-epoch=00-val_f1=0.5192-v1.ckpt", strict=False)

In [None]:
from torcheval.metrics.functional import multiclass_f1_score

import torch

loaded_model.eval()

with torch.no_grad():
    for data, labels in val_loader:
        outputs = loaded_model(data)
        preds = torch.argmax(outputs, dim=1)
        
        print(f"outputs = {outputs}")
        print(f"preds = {preds}")
                
        f1_score = multiclass_f1_score(preds, labels, num_classes=2, average="macro")
        print(f"{f1_score=}")

In [None]:
test_loader.dataset[:].shape