In [2]:
import os

import torch
from pytorch_lightning import seed_everything

In [3]:
from bci_aic3.models.eegnet import EEGNet
from bci_aic3.paths import (
    MI_CONFIG_PATH,
    MI_RUNS_DIR,
    RAW_DATA_DIR,
    SSVEP_CONFIG_PATH,
    SSVEP_RUNS_DIR
)
from bci_aic3.train import train_model
from bci_aic3.util import save_model, load_model

In [4]:
MI_CONFIG_PATH

WindowsPath('P:/Programming/AIC3/repo/bci_aic3/configs/mi_config.yaml')

In [None]:

# Code necessary to create reproducible runs
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
seed_everything(42, workers=True)
torch.use_deterministic_algorithms(True, warn_only=True)


In [None]:
trainer, model = train_model(model=EEGNet, config_path=MI_CONFIG_PATH)

In [None]:
val_f1_score = trainer.callback_metrics.get("val_f1")

if val_f1_score is not None:
    val_f1_score = val_f1_score.

In [None]:
save_model(model, SSVEP_RUNS_DIR / "model_scripted.pt")

In [None]:
loaded_model = load_model(SSVEP_RUNS_DIR / "model_scripted.pt")  # Save

In [None]:
from bci_aic3.inference import (
    create_inference_data_loader,
    load_and_preprocess_for_inference,
    predict_batch,
)


In [None]:
test_dataset = load_and_preprocess_for_inference(
    "test.csv", base_path=RAW_DATA_DIR, task_type="SSVEP"
)

test_loader = create_inference_data_loader(test_dataset)

In [None]:
preds = predict_batch(model=loaded_model,
                      data_loader=test_loader,
                      device="cpu")
preds

In [None]:
from bci_aic3.inference import make_inference

ssvep_preds = make_inference(model=loaded_model,
                       csv_file="test.csv",
                       base_path=RAW_DATA_DIR,
                       task_type="SSVEP",
                       reverse_mapping=True)
ssvep_preds

In [None]:
loaded_mi_model = load_model(MI_RUNS_DIR / "model_scripted.pt")

mi_preds = make_inference(model=loaded_mi_model,
                       csv_file="test.csv",
                       base_path=RAW_DATA_DIR,
                       task_type="MI",
                       reverse_mapping=True)
mi_preds

In [None]:
import pandas as pd


df = pd.read_csv(RAW_DATA_DIR / "test.csv")

In [None]:

df.loc[df["task"] == "SSVEP", "labels"] = ssvep_preds
df.loc[df["task"] == "MI", "labels"] = mi_preds
df

In [None]:
from bci_aic3.paths import SUBMISSIONS_DIR


df[['id', 'labels']].to_csv(SUBMISSIONS_DIR / 'quick_test_1_epoch.csv', index=False)