In [1]:
import numpy as np
from braindecode.datasets import MOABBDataset

subject_id = [1,2,3,4]
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[1,2,3,4,5,6,7,8,9])



from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

transforms = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, transforms, n_jobs=-1)


  warn('Preprocessing choices with lambda functions cannot be saved.')


<braindecode.datasets.moabb.MOABBDataset at 0x1dd7c4f11f0>

In [2]:
from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']

In [3]:
import torch
from braindecode.models import ShallowFBCSPNet
from braindecode.util import set_random_seeds


cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200222
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = windows_dataset[0][0].shape[0]
input_window_samples = windows_dataset[0][0].shape[1]

print("n_classes: ", n_classes)
print("n_channels:", n_channels)
print("input_window_samples size:", input_window_samples)

n_classes:  4
n_channels: 22
input_window_samples size: 1125


In [4]:
# The ShallowFBCSPNet is a `nn.Sequential` model

model = ShallowFBCSPNet(
    n_channels,
    n_classes,
    input_window_samples=input_window_samples,
    final_conv_length="auto",
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()



Layer (type (var_name):depth-idx)        Input Shape               Output Shape              Param #                   Kernel Shape
ShallowFBCSPNet (ShallowFBCSPNet)        [1, 22, 1125]             [1, 4]                    --                        --
├─Ensure4d (ensuredims): 1-1             [1, 22, 1125]             [1, 22, 1125, 1]          --                        --
├─Rearrange (dimshuffle): 1-2            [1, 22, 1125, 1]          [1, 1, 1125, 22]          --                        --
├─CombinedConv (conv_time_spat): 1-3     [1, 1, 1125, 22]          [1, 40, 1101, 1]          36,240                    --
├─BatchNorm2d (bnorm): 1-4               [1, 40, 1101, 1]          [1, 40, 1101, 1]          80                        --
├─Expression (conv_nonlin_exp): 1-5      [1, 40, 1101, 1]          [1, 40, 1101, 1]          --                        --
├─AvgPool2d (pool): 1-6                  [1, 40, 1101, 1]          [1, 40, 69, 1]            --                        [75, 1]
├─Express



In [5]:
splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test']  # Session evaluation

from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

lr = 0.0625 * 0.01
weight_decay = 0
batch_size = 64
n_epochs = 100


In [6]:

from tqdm import tqdm
# Define a method for training one epoch


def train_one_epoch(
        dataloader: DataLoader, model: Module, loss_fn, optimizer,
        scheduler: LRScheduler, epoch: int, device, print_batch_stats=True
):
    model.train()  # Set the model to training mode
    train_loss, correct = 0, 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                        disable=not print_batch_stats)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        print(X.shape)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()  # update the model weights
        optimizer.zero_grad()

        train_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Epoch {epoch}/{n_epochs}, "
                f"Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {loss.item():.6f}"
            )

    # Update the learning rate
    scheduler.step()

    correct /= len(dataloader.dataset)
    return train_loss / len(dataloader), correct


In [7]:

@torch.no_grad()
def test_model(
    dataloader: DataLoader, model: Module, loss_fn, print_batch_stats=True
):
    size = len(dataloader.dataset)
    n_batches = len(dataloader)
    model.eval()  # Switch to evaluation mode
    test_loss, correct = 0, 0

    if print_batch_stats:
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        progress_bar = enumerate(dataloader)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        batch_loss = loss_fn(pred, y).item()

        test_loss += batch_loss
        correct += (pred.argmax(1) == y).type(torch.float).sum().item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Batch {batch_idx + 1}/{len(dataloader)}, "
                f"Loss: {batch_loss:.6f}"
            )

    test_loss /= n_batches
    correct /= size

    print(
        f"Test Accuracy: {100 * correct:.1f}%, Test Loss: {test_loss:.6f}\n"
    )
    return test_loss, correct

In [None]:
# Define the optimization
optimizer = torch.optim.AdamW(model.parameters(),
                              lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                       T_max=n_epochs - 1)
# Define the loss function
# We used the NNLoss function, which expects log probabilities as input
# (which is the case for our model output)
loss_fn = torch.nn.NLLLoss()

# train_set and test_set are instances of torch Datasets, and can seamlessly be
# wrapped in data loaders.
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)

for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_accuracy = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device,
    )

    test_loss, test_accuracy = test_model(test_loader, model, loss_fn)

    print(
        f"Train Accuracy: {100 * train_accuracy:.2f}%, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Test Accuracy: {100 * test_accuracy:.1f}%, "
        f"Average Test Loss: {test_loss:.6f}\n"
    )

Epoch 1/100: 

  0%|          | 0/41 [00:00<?, ?it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 1/41, Loss: 1.721020:   2%|▏         | 1/41 [00:00<00:30,  1.30it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 2/41, Loss: 1.782541:   5%|▍         | 2/41 [00:01<00:26,  1.50it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 3/41, Loss: 1.776343:   7%|▋         | 3/41 [00:01<00:23,  1.63it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 4/41, Loss: 1.736924:  10%|▉         | 4/41 [00:02<00:21,  1.73it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 5/41, Loss: 1.613262:  12%|█▏        | 5/41 [00:02<00:20,  1.79it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 6/41, Loss: 1.463803:  15%|█▍        | 6/41 [00:03<00:19,  1.81it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 7/41, Loss: 1.521159:  17%|█▋        | 7/41 [00:04<00:18,  1.80it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 8/41, Loss: 1.701540:  20%|█▉        | 8/41 [00:04<00:18,  1.81it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 9/41, Loss: 1.642441:  22%|██▏       | 9/41 [00:05<00:17,  1.80it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 10/41, Loss: 1.683417:  24%|██▍       | 10/41 [00:05<00:17,  1.81it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 11/41, Loss: 1.805015:  27%|██▋       | 11/41 [00:06<00:16,  1.84it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 12/41, Loss: 1.559020:  29%|██▉       | 12/41 [00:06<00:15,  1.86it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 13/41, Loss: 1.488359:  32%|███▏      | 13/41 [00:07<00:14,  1.88it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 14/41, Loss: 1.450093:  34%|███▍      | 14/41 [00:07<00:14,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 15/41, Loss: 1.733809:  37%|███▋      | 15/41 [00:08<00:13,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 16/41, Loss: 1.560249:  39%|███▉      | 16/41 [00:08<00:13,  1.91it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 17/41, Loss: 1.636397:  41%|████▏     | 17/41 [00:09<00:12,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 18/41, Loss: 1.474145:  44%|████▍     | 18/41 [00:09<00:12,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 19/41, Loss: 1.543110:  46%|████▋     | 19/41 [00:10<00:11,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 20/41, Loss: 1.568367:  49%|████▉     | 20/41 [00:10<00:10,  1.92it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 21/41, Loss: 1.671234:  51%|█████     | 21/41 [00:11<00:10,  1.93it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 22/41, Loss: 1.548132:  54%|█████▎    | 22/41 [00:11<00:09,  1.93it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 23/41, Loss: 1.507244:  56%|█████▌    | 23/41 [00:12<00:09,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 24/41, Loss: 1.543331:  59%|█████▊    | 24/41 [00:13<00:08,  1.93it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 25/41, Loss: 1.529670:  61%|██████    | 25/41 [00:13<00:08,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 26/41, Loss: 1.457732:  63%|██████▎   | 26/41 [00:14<00:07,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 27/41, Loss: 1.463516:  66%|██████▌   | 27/41 [00:14<00:07,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 28/41, Loss: 1.423893:  68%|██████▊   | 28/41 [00:15<00:06,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 29/41, Loss: 1.379427:  71%|███████   | 29/41 [00:15<00:06,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 30/41, Loss: 1.524713:  73%|███████▎  | 30/41 [00:16<00:05,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 31/41, Loss: 1.354858:  76%|███████▌  | 31/41 [00:16<00:05,  1.92it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 32/41, Loss: 1.610585:  78%|███████▊  | 32/41 [00:17<00:04,  1.93it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 33/41, Loss: 1.506715:  80%|████████  | 33/41 [00:17<00:04,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 34/41, Loss: 1.439613:  83%|████████▎ | 34/41 [00:18<00:03,  1.94it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 35/41, Loss: 1.536733:  85%|████████▌ | 35/41 [00:18<00:03,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 36/41, Loss: 1.495076:  88%|████████▊ | 36/41 [00:19<00:02,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 37/41, Loss: 1.516355:  90%|█████████ | 37/41 [00:19<00:02,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 38/41, Loss: 1.342138:  93%|█████████▎| 38/41 [00:20<00:01,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 39/41, Loss: 1.480828:  95%|█████████▌| 39/41 [00:20<00:01,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 1/100, Batch 40/41, Loss: 1.434785:  98%|█████████▊| 40/41 [00:21<00:00,  1.95it/s]

torch.Size([32, 22, 1125])


Epoch 1/100, Batch 41/41, Loss: 1.597588: 100%|██████████| 41/41 [00:21<00:00,  1.91it/s]
Batch 41/41, Loss: 1.236457: 100%|██████████| 41/41 [00:04<00:00,  8.84it/s]


Test Accuracy: 32.6%, Test Loss: 1.538076

Train Accuracy: 30.21%, Average Train Loss: 1.556712, Test Accuracy: 32.6%, Average Test Loss: 1.538076

Epoch 2/100: 

  0%|          | 0/41 [00:00<?, ?it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 1/41, Loss: 1.290418:   2%|▏         | 1/41 [00:00<00:20,  1.97it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 2/41, Loss: 1.252190:   5%|▍         | 2/41 [00:01<00:19,  1.96it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 3/41, Loss: 1.319565:   7%|▋         | 3/41 [00:01<00:19,  1.96it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 4/41, Loss: 1.381554:  10%|▉         | 4/41 [00:02<00:18,  1.95it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 5/41, Loss: 1.398466:  12%|█▏        | 5/41 [00:02<00:19,  1.86it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 6/41, Loss: 1.225463:  15%|█▍        | 6/41 [00:03<00:20,  1.74it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 7/41, Loss: 1.310601:  17%|█▋        | 7/41 [00:03<00:18,  1.79it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 8/41, Loss: 1.399083:  20%|█▉        | 8/41 [00:04<00:18,  1.79it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 9/41, Loss: 1.413269:  22%|██▏       | 9/41 [00:04<00:18,  1.77it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 10/41, Loss: 1.331901:  24%|██▍       | 10/41 [00:05<00:17,  1.80it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 11/41, Loss: 1.497215:  27%|██▋       | 11/41 [00:05<00:16,  1.83it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 12/41, Loss: 1.225773:  29%|██▉       | 12/41 [00:06<00:15,  1.85it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 13/41, Loss: 1.455521:  32%|███▏      | 13/41 [00:07<00:15,  1.85it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 14/41, Loss: 1.645365:  34%|███▍      | 14/41 [00:07<00:14,  1.87it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 15/41, Loss: 1.593895:  37%|███▋      | 15/41 [00:08<00:13,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 16/41, Loss: 1.318624:  39%|███▉      | 16/41 [00:08<00:13,  1.87it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 17/41, Loss: 1.492740:  41%|████▏     | 17/41 [00:09<00:13,  1.83it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 18/41, Loss: 1.437199:  44%|████▍     | 18/41 [00:09<00:12,  1.79it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 19/41, Loss: 1.556988:  46%|████▋     | 19/41 [00:10<00:11,  1.84it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 20/41, Loss: 1.288996:  49%|████▉     | 20/41 [00:10<00:11,  1.86it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 21/41, Loss: 1.266029:  51%|█████     | 21/41 [00:11<00:10,  1.87it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 22/41, Loss: 1.467161:  54%|█████▎    | 22/41 [00:11<00:10,  1.85it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 23/41, Loss: 1.412782:  56%|█████▌    | 23/41 [00:12<00:09,  1.86it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 24/41, Loss: 1.516493:  59%|█████▊    | 24/41 [00:12<00:09,  1.88it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 25/41, Loss: 1.433221:  61%|██████    | 25/41 [00:13<00:08,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 26/41, Loss: 1.321258:  63%|██████▎   | 26/41 [00:14<00:07,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 27/41, Loss: 1.469710:  66%|██████▌   | 27/41 [00:14<00:07,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 28/41, Loss: 1.394316:  68%|██████▊   | 28/41 [00:15<00:06,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 29/41, Loss: 1.248919:  71%|███████   | 29/41 [00:15<00:06,  1.91it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 30/41, Loss: 1.392508:  73%|███████▎  | 30/41 [00:16<00:05,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 31/41, Loss: 1.524311:  76%|███████▌  | 31/41 [00:16<00:05,  1.88it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 32/41, Loss: 1.512575:  78%|███████▊  | 32/41 [00:17<00:04,  1.89it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 33/41, Loss: 1.472494:  80%|████████  | 33/41 [00:17<00:04,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 34/41, Loss: 1.392101:  83%|████████▎ | 34/41 [00:18<00:03,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 35/41, Loss: 1.221487:  85%|████████▌ | 35/41 [00:18<00:03,  1.91it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 36/41, Loss: 1.430100:  88%|████████▊ | 36/41 [00:19<00:02,  1.91it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 37/41, Loss: 1.424350:  90%|█████████ | 37/41 [00:19<00:02,  1.92it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 38/41, Loss: 1.323293:  93%|█████████▎| 38/41 [00:20<00:01,  1.91it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 39/41, Loss: 1.414932:  95%|█████████▌| 39/41 [00:20<00:01,  1.90it/s]

torch.Size([64, 22, 1125])


Epoch 2/100, Batch 40/41, Loss: 1.365500:  98%|█████████▊| 40/41 [00:21<00:00,  1.86it/s]

torch.Size([32, 22, 1125])


Epoch 2/100, Batch 41/41, Loss: 1.248881: 100%|██████████| 41/41 [00:21<00:00,  1.89it/s]
Batch 41/41, Loss: 1.279742: 100%|██████████| 41/41 [00:04<00:00,  8.28it/s]


Test Accuracy: 39.1%, Test Loss: 1.307987

Train Accuracy: 37.31%, Average Train Loss: 1.392372, Test Accuracy: 39.1%, Average Test Loss: 1.307987

Epoch 3/100: 

  0%|          | 0/41 [00:00<?, ?it/s]

torch.Size([64, 22, 1125])


Epoch 3/100, Batch 1/41, Loss: 1.239675:   2%|▏         | 1/41 [00:00<00:22,  1.77it/s]

torch.Size([64, 22, 1125])


In [None]:
# Assuming 'model' is your trained Braindecode model
torch.save(model, "braindecode_model.pth")
torch.save(model.state_dict(), "braindecode_model_state.pth")
