In [3]:
import matplotlib.pyplot as plt
import mne
import seaborn as sns
import torch
from braindecode import EEGClassifier
from sklearn.pipeline import make_pipeline
from skorch.callbacks import EarlyStopping, EpochScoring
from skorch.dataset import ValidSplit
import matplotlib.pyplot as plt
import pandas as pd
import os

from moabb.datasets import BNCI2014_001, BNCI2014_004
from moabb.evaluations import CrossSessionEvaluation
from moabb.paradigms import MotorImagery
from moabb.utils import setup_seed

from moabb_Private_Encoder_Thesis.moabb.evaluations import SubjectParamEvaluation

#from moabb.evaluations import SubjectParamEvaluation
from shallow import CollapsedShallowNet ,ShallowFBCSPNet
from shallowDict import ShallowPrivateTemporalDictNetSlow, ShallowPrivateSpatialDictNetSlow, ShallowPrivateCollapsedDictNetSlow, SubjectDicionaryFCNet

mne.set_log_level(False)

# Print Information PyTorch
print(f"Torch Version: {torch.__version__}")

# Set up GPU if it is there
cuda = torch.cuda.is_available()
device = "cuda" if cuda else "cpu"
print("GPU is", "AVAILABLE" if cuda else "NOT AVAILABLE")

seed = 100
setup_seed(seed)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Hyperparameter

# learning rate 1e-4

# batch = 2^7
LEARNING_RATE = 0.0001  # parameter taken from Braindecode
WEIGHT_DECAY = 0  # parameter taken from Braindecode
BATCH_SIZE = 128  # parameter taken from BrainDecode
EPOCH = 2
PATIENCE = 100
fmin = 4
fmax = 100
tmin = 0
tmax = None


dataset = BNCI2014_001()
paradigm = MotorImagery(
    fmin=fmin, fmax=fmax, tmin=tmin, tmax=tmax
)

X, _, _ = paradigm.get_data(dataset=dataset)

subjects = dataset.subject_list
        

def make_classifier(module):
    clf = EEGClassifier(
        module=module,  
        module__n_chans=X.shape[1],  # number of input channels
        module__n_outputs=len(dataset.event_id),  # number of output classes
        module__n_times=X.shape[2],  # length of the input signal in time points
        optimizer=torch.optim.Adam,
        optimizer__lr=LEARNING_RATE,
        batch_size=BATCH_SIZE,
        max_epochs=EPOCH,
        warm_start=True, #To keep training the model further for each fit instead of re-initializing
        train_split=ValidSplit(0.2, random_state=seed),
        device=device,
        callbacks=[
            #EarlyStopping(monitor="valid_loss", patience=PATIENCE),
            EpochScoring(
                scoring="accuracy", on_train=True, name="train_acc", lower_is_better=False
            ),
            EpochScoring(
                scoring="accuracy", on_train=False, name="valid_acc", lower_is_better=False
            ),
        ],
        verbose=1,
    )
    
    return clf


clf1 = make_classifier(SubjectDicionaryFCNet)

clf2 = make_classifier(CollapsedShallowNet)

clf3 = make_classifier(ShallowPrivateTemporalDictNetSlow)

clf4 = make_classifier(ShallowFBCSPNet)

clf5 = make_classifier(ShallowPrivateSpatialDictNetSlow)

clf6 = make_classifier(ShallowPrivateCollapsedDictNetSlow)

# Create a pipeline with the classifier
#pipes = { "ShallowPrivateSpatialDictNetSlow": make_pipeline(clf5), "ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf6)}


pipes = { "SubjectDicionaryFCNet": make_pipeline(clf1), "CollapsedShallowNet": make_pipeline(clf2), "ShallowPrivateTemporalDictNetSlow": make_pipeline(clf3), "ShallowFBCSPNet": make_pipeline(clf4), "ShallowPrivateSpatialDictNetSlow": make_pipeline(clf5), "ShallowPrivateCollapsedDictNetSlow": make_pipeline(clf6)}


results_list = []
# Ensure the output directory exists
output_dir = f"./results_{seed}_{dataset.code}"
os.makedirs(output_dir, exist_ok=True)

# Modify plot and data saving within the loop
for pipe_name, pipe in pipes.items():
    unique_suffix = f"{pipe_name}_braindecode_example"
    
    evaluation = SubjectParamEvaluation(
        paradigm=paradigm,
        datasets=[dataset],
        suffix=unique_suffix,
        overwrite=True,
        return_epochs=True,
        random_state=seed,
        n_jobs=1,
        hdf5_path=f"{output_dir}/{pipe_name}",
        save_model=True
    )
    
    # Run the evaluation process for this pipeline
    results = evaluation.process({pipe_name: pipe})

    # Save results to CSV
    results_df = pd.DataFrame(results)
    results_df.to_csv(f"{output_dir}/{pipe_name}_results.csv", index=False)

    # Save individual bar plot
    plt.figure(figsize=(10, 6))
    sns.barplot(data=results_df, y="score", x="subject", palette="viridis")
    plt.title(f"Model Performance by Subject - {pipe_name}")
    plt.ylabel("Score")
    plt.xlabel("Subject")
    plt.savefig(f"{output_dir}/{pipe_name}_performance.png")
    plt.close()
    
    results_list.append(results_df)

# Concatenate all results
results_all = pd.concat(results_list)
results_all.to_csv(f"{output_dir}/all_results.csv", index=False)

# Save combined bar plot
plt.figure(figsize=(12, 6))
sns.barplot(data=results_all, x="subject", y="score", hue="pipeline", palette="viridis")
plt.title("Scores per Subject for Each Pipeline")
plt.xlabel("Subject")
plt.ylabel("Score")
plt.legend(title="Pipeline")
plt.savefig(f"{output_dir}/combined_performance.png")
plt.close()

Choosing from all possible events


Torch Version: 2.2.2
GPU is NOT AVAILABLE
We try to set the tensorflow seeds, but it seems that tensorflow is not installed. Please refer to `https://www.tensorflow.org/` to install if you need to use this deep learning module.


 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2617[0m        [32m1.3873[0m       [35m0.2826[0m        [31m1.4064[0m  0.3282
      2       0.2266        1.4298       0.2065        1.5215  0.3227
      3       0.2266        1.4043       0.2500        1.4441  0.3110
      4       0.2344        1.3974       [35m0.3804[0m        [31m1.3722[0m  0.3108
      5       0.2578        1.3958       0.1630        1.3963  0.2937
      6       0.2422        1.3894       0.2609        1.3964  0.3103
      7       0.2461        [32m1.3856[0m       0.2283        1.3905  0.2955
      8       0.2305        1.4098       0.2065        1.4027  0.3021
      9       [36m0.3438[0m        [32m1.3610[0m       0.2609        [31m1.3544[0m  0.2885
     10       0.3008        1.3828       0.3478        1.3760  0.3165
     11       0.2773        1.4012       0.2826        1.3931  0.2988


BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [00:39<00:00, 39.17s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2617[0m        [32m1.3864[0m       [35m0.1957[0m        [31m1.4277[0m  0.3644
      2       0.2617        1.4025       0.1848        [31m1.4189[0m  0.3241
      3       0.1914        1.3978       [35m0.2283[0m        [31m1.3944[0m  0.3291
      4       0.2266        1.3941       0.2283        1.3990  0.3219
      5       [36m0.2852[0m        1.3921       0.2065        [31m1.3905[0m  0.2947
      6       0.2227        1.4035       [35m0.2391[0m        1.4070  0.2899
      7       0.2656        1.3992       0.2391        [31m1.3897[0m  0.2989
      8       0.2695        [32m1.3846[0m       [35m0.2935[0m        [31m1.3777[0m  0.2920
      9       [36m0.3125[0m        [32m1.3633[0m       [35m0.3478[0m        [31m1.3481[0m  0.3199
     10       0.3008        1.3766       0.2717        1.3971  0.336

BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [00:41<00:00, 41.22s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2734[0m        [32m1.3932[0m       [35m0.2174[0m        [31m1.3977[0m  3.0458
      2       0.2305        1.4124       [35m0.3043[0m        [31m1.3838[0m  2.3102
      3       0.2500        1.3991       0.2174        1.4201  2.5299
      4       0.2266        1.3969       0.2065        1.3885  2.3478
      5       0.2148        1.4007       0.2065        1.4154  2.2621
      6       [36m0.2812[0m        [32m1.3786[0m       0.2391        1.4012  2.6497
      7       0.2500        1.3936       0.2826        1.3856  2.4751
      8       0.1836        1.3990       0.2174        1.3905  2.2542
      9       0.1758        1.4108       0.1957        1.4192  3.0387
     10       [36m0.2930[0m        1.3806       0.1630        1.3870  2.5044
     11       [36m0.3164[0m        1.3919       0.2500        1.3854  2.3140


BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [01:24<00:00, 84.06s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2227[0m        [32m1.3906[0m       [35m0.2174[0m        [31m1.5253[0m  3.1783
      2       [36m0.3281[0m        [32m1.3697[0m       [35m0.3043[0m        [31m1.4502[0m  2.8401
      3       0.2188        1.4006       0.1413        1.4796  2.2206
      4       0.2617        1.3873       [35m0.3370[0m        [31m1.3972[0m  2.3898
      5       0.2617        1.3824       0.2609        [31m1.3918[0m  2.6122
      6       0.2539        1.3892       0.1739        1.4492  2.3230
      7       0.2266        1.3982       0.2174        [31m1.3757[0m  2.2823
      8       0.2891        1.3743       [35m0.4130[0m        [31m1.3579[0m  2.5646
      9       0.2656        1.3957       0.2609        1.4163  2.2126
     10       0.2969        1.3801       0.1739        1.3906  2.2314
     11       0.3164        1.3731

BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [01:21<00:00, 81.58s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2617[0m        [32m1.3869[0m       [35m0.2065[0m        [31m1.4709[0m  2.7410
      2       0.2148        1.4295       [35m0.2609[0m        [31m1.4625[0m  2.3416
      3       [36m0.3164[0m        [32m1.3765[0m       0.2391        [31m1.3930[0m  2.2811
      4       0.2305        1.3947       [35m0.3152[0m        [31m1.3860[0m  2.5246
      5       0.2461        1.3847       0.2283        1.3948  2.5251
      6       0.2617        1.3824       0.3152        1.3877  2.3391
      7       0.2383        1.3915       0.3152        [31m1.3696[0m  2.4945
      8       0.2344        1.4040       0.2717        1.4103  2.4039
      9       0.2617        1.4078       0.1957        1.3767  2.2804
     10       0.2734        1.3843       0.2391        1.3843  2.5873
     11       0.2539        1.3995       0.2391     

BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [01:22<00:00, 82.11s/it]
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}")
 'left_hand': 12
 'right_hand': 12
 'feet': 12
 'tongue': 12>
  warn(f"warnEpochs {epochs}

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.2734[0m        [32m1.3939[0m       [35m0.2500[0m        [31m1.4686[0m  0.5284
      2       0.2422        1.4176       [35m0.2609[0m        [31m1.4217[0m  0.3172
      3       0.2617        1.3980       [35m0.2717[0m        [31m1.3945[0m  0.3055
      4       [36m0.2812[0m        [32m1.3822[0m       0.2283        [31m1.3869[0m  0.2995
      5       0.2266        1.3970       0.1957        1.4027  0.3166
      6       0.2656        1.3921       0.2717        1.4282  0.3439
      7       0.1875        1.4002       0.2609        1.3990  0.3355
      8       0.2344        1.3895       [35m0.2826[0m        1.3938  0.3488
      9       0.2383        1.3854       0.2174        1.4192  0.3682
     10       0.2461        [32m1.3814[0m       [35m0.2935[0m        1.3905  0.3005
     11       0.1484        1.4225

BNCI2014-001-SubjectParam: 100%|██████████| 1/1 [00:40<00:00, 40.82s/it]
