In [1]:
import matplotlib.pyplot as plt
import mne
import seaborn as sns
import torch
from braindecode import EEGClassifier
from sklearn.pipeline import make_pipeline
from skorch.callbacks import EarlyStopping, EpochScoring
from skorch.dataset import ValidSplit
import matplotlib.pyplot as plt
import pandas as pd
import os

from moabb.datasets import BNCI2014_001, BNCI2014_004
from moabb.evaluations import CrossSessionEvaluation
from moabb.paradigms import MotorImagery
from moabb.utils import setup_seed

from moabb.evaluations import AllRunsEvaluationModified, AllRunsEvaluationSubjectParam
from shallow import CollapsedShallowNet, SubjectOneHotNet, SubjectOneHotConvNet, SubjectDicionaryConvNet, SubjectOneHotConvNet2, SubjectAdvIndexFCNet,ShallowFBCSPNet
from shallowDict import ShallowPrivateTemporalDictNetSlow, ShallowPrivateSpatialDictNetSlow, ShallowPrivateCollapsedDictNetSlow, SubjectDicionaryFCNet

mne.set_log_level(False)

# Print Information PyTorch
print(f"Torch Version: {torch.__version__}")

# Set up GPU if it is there
cuda = torch.cuda.is_available()
device = "cuda" if cuda else "cpu"
print("GPU is", "AVAILABLE" if cuda else "NOT AVAILABLE")

seed = 42
setup_seed(seed)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Hyperparameter

# learning rate 1e-4

# batch = 2^7
LEARNING_RATE = 0.0001  # parameter taken from Braindecode
WEIGHT_DECAY = 0  # parameter taken from Braindecode
BATCH_SIZE = 128  # parameter taken from BrainDecode
EPOCH = 2
PATIENCE = 100
fmin = 4
fmax = 100
tmin = 0
tmax = None


dataset = BNCI2014_001()
paradigm = MotorImagery(
    fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax
)

X, _, _ = paradigm.get_data(dataset=dataset)

subjects = dataset.subject_list
        

def make_classifier(module):
    clf = EEGClassifier(
        module=module,  
        module__n_chans=X.shape[1],  # number of input channels
        module__n_outputs=len(dataset.event_id),  # number of output classes
        module__n_times=X.shape[2],  # length of the input signal in time points
        optimizer=torch.optim.Adam,
        optimizer__lr=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        max_epochs=EPOCH,
        warm_start=True, #To keep training the model further for each fit instead of re-initializing
        train_split=ValidSplit(0.2, random_state=seed),
        device=device,
        callbacks=[
            #EarlyStopping(monitor="valid_loss", patience=PATIENCE),
            EpochScoring(
                scoring="accuracy", on_train=True, name="train_acc", lower_is_better=False
            ),
            EpochScoring(
                scoring="accuracy", on_train=False, name="valid_acc", lower_is_better=False
            ),
        ],
        verbose=1,
    )
    
    return clf

clf = make_classifier(SubjectOneHotNet)

clf2 = make_classifier(SubjectDicionaryFCNet)

clf3 = make_classifier(CollapsedShallowNet)

clf4 = make_classifier(SubjectOneHotConvNet)

clf5 = make_classifier(SubjectOneHotConvNet2)

clf6 = make_classifier(SubjectDicionaryConvNet)

clf7 = make_classifier(SubjectAdvIndexFCNet)

clf8 = make_classifier(ShallowPrivateTemporalDictNetSlow)

clf9 = make_classifier(ShallowFBCSPNet)

clf10 = make_classifier(ShallowPrivateSpatialDictNetSlow)

clf11 = make_classifier(ShallowPrivateCollapsedDictNetSlow)

# Create a pipeline with the classifier

#pipes = {"ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf11)}
pipes = {"ShallowFBCSPNet": make_pipeline(clf9), "CollapsedShallowNet": make_pipeline(clf3)}
#pipes = {"CollapsedShallowNet": make_pipeline(clf3), "SubjectOneHotConvNet": make_pipeline(clf4), "SubjectOneHotConvNet2": make_pipeline(clf5), "SubjectDicionaryConvNet": make_pipeline(clf6),}
#pipes = {"CollapsedShallowNet": make_pipeline(clf3), "ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf11), "ShallowFBCSPNet": make_pipeline(clf9),"ShallowPrivateTemporalDictNetSlow": make_pipeline(clf8),"ShallowPrivateSpatialDictNetSlow": make_pipeline(clf10)}
#pipes = {"SubjectOneHotConvNet": make_pipeline(clf4), "CollapsedShallowNet": make_pipeline(clf3), "SubjectDicionaryFCNet": make_pipeline(clf2), "SubjectOneHotNet": make_pipeline(clf),}
#pipes = {"CollapsedShallowNet": make_pipeline(clf3), "ShallowFBCSPNet": make_pipeline(clf9),"ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf11) }
#one with them all
#pipes = {"SubjectOneHotNet": make_pipeline(clf), "SubjectDicionaryFCNet": make_pipeline(clf2), "CollapsedShallowNet": make_pipeline(clf3), "SubjectOneHotConvNet": make_pipeline(clf4), "SubjectOneHotConvNet2": make_pipeline(clf5), "SubjectDicionaryConvNet": make_pipeline(clf6), "SubjectAdvIndexFCNet": make_pipeline(clf7),}
# all the ones with dict
#pipes = { "ShallowPrivateSpatialDictNetSlow": make_pipeline(clf10), "SubjectDicionaryFCNet": make_pipeline(clf2), "ShallowPrivateTemporalDictNetSlow": make_pipeline(clf8), "ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf11)}

results_list = []
# Ensure the output directory exists
output_dir = f"./results_{seed}_{dataset.code}"
os.makedirs(output_dir, exist_ok=True)

# Modify plot and data saving within the loop
for pipe_name, pipe in pipes.items():
    unique_suffix = f"{pipe_name}_braindecode_example"
    
    evaluation = AllRunsEvaluationSubjectParam(
        paradigm=paradigm,
        datasets=[dataset],
        suffix=unique_suffix,
        overwrite=True,
        return_epochs=True,
        random_state=seed,
        n_jobs=1,
        hdf5_path=f"{output_dir}/{pipe_name}",
        save_model=True
    )
    
    # Run the evaluation process for this pipeline
    results = evaluation.process({pipe_name: pipe})

    # Save results to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv(f"{output_dir}/{pipe_name}_results.csv", index=False)

    # Save individual bar plot
    plt.figure(figsize=(10, 6))
    sns.barplot(data=results_df, y="score", x="subject", palette="viridis")
    plt.title(f"Model Performance by Subject - {pipe_name}")
    plt.ylabel("Score")
    plt.xlabel("Subject")
    plt.savefig(f"{output_dir}/{pipe_name}_performance.png")
    plt.close()
    
    results_list.append(results_df)

# Concatenate all results
results_all = pd.concat(results_list)
results_all.to_csv(f"{output_dir}/all_results.csv", index=False)

# Save combined bar plot
plt.figure(figsize=(12, 6))
sns.barplot(data=results_all, x="subject", y="score", hue="pipeline", palette="viridis")
plt.title("Scores per Subject for Each Pipeline")
plt.xlabel("Subject")
plt.ylabel("Score")
plt.legend(title="Pipeline")
plt.savefig(f"{output_dir}/combined_performance.png")
plt.close()

Choosing from all possible events


Torch Version: 2.2.2
GPU is NOT AVAILABLE
We try to set the tensorflow seeds, but it seems that tensorflow is not installed. Please refer to `https://www.tensorflow.org/` to install if you need to use this deep learning module.


 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2461[0m        [32m1.3928[0m       [35m0.2174[0m        [31m1.4110[0m  2.1765
      2       0.2344        1.3976       [35m0.2500[0m        [31m1.3954[0m  2.1316
      3       [36m0.2695[0m        [32m1.3892[0m       [35m0.2935[0m        [31m1.3794[0m  2.1285
      4       [36m0.2852[0m        [32m1.3821[0m       0.2500        1.3877  2.1296
      5       0.2422        1.3832       0.2500        1.3864  2.1973
      6       0.2188        1.4026       [35m0.3152[0m        1.3855  2.1656
      7       0.2773        1.3883       0.2935        [31m1.3762[0m  2.1581
      8       [36m0.3164[0m        [32m1.3782[0m       0.2609        1.3928  2.1486
      9       [36m0.3242[0m        [32m1.3645[0m       0.2500        1.3919  2.1781
     10       0.2773        1.3885       [35m0.3261[0m        1.38

BNCI2014-001-AllRuns: 100%|██████████| 1/1 [01:13<00:00, 73.50s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 '

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2109[0m        [32m1.3910[0m       [35m0.2935[0m        [31m1.4085[0m  0.3109
      2       0.2109        1.4158       0.1739        1.4367  0.3083
      3       [36m0.2344[0m        [32m1.3853[0m       0.2500        [31m1.3829[0m  0.2789
      4       [36m0.2695[0m        [32m1.3814[0m       0.2283        1.3972  0.2828
      5       0.2617        1.3844       0.2065        1.3959  0.2817
      6       0.2461        1.3995       0.2717        1.3941  0.2799
      7       [36m0.3008[0m        1.3845       0.2935        [31m1.3744[0m  0.2799
      8       0.2695        1.3849       0.2174        1.3847  0.2802
      9       0.2617        [32m1.3743[0m       0.2826        1.3857  0.2772
     10       0.2852        1.3753       0.2391        1.3809  0.2791
     11       0.2422        1.4056       [35m0.3261

BNCI2014-001-AllRuns: 100%|██████████| 1/1 [00:36<00:00, 36.25s/it]
