In [None]:
!pip install braindecode moabb ncps

In [1]:
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm.notebook import tqdm
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import create_windows_from_events

In [2]:
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids = [i for i in range(1, 10)])

from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

preprocessors = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Preprocess the data
preprocess(dataset, preprocessors, n_jobs=-1)
trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

  set_config(key, get_config("MNE_DATA"))
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01T.mat'.


MNE_DATA is not already configured. It will be set to default location in the home directory - /root/mne_data
All datasets will be downloaded to this location, if anything is already downloaded, please move manually to this location


100%|█████████████████████████████████████| 42.8M/42.8M [00:00<00:00, 39.7GB/s]
SHA256 hash of downloaded file: 054f02e70cf9c4ada1517e9b9864f45407939c1062c6793516585c6f511d0325
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01E.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A01E.mat'.
100%|█████████████████████████████████████| 43.8M/43.8M [00:00<00:00, 50.7GB/s]
SHA256 hash of downloaded file: 53d415f39c3d7b0c88b894d7b08d99bcdfe855ede63831d3691af1a45607fb62
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02T.mat' to file '/root/mne_data/MNE-bnci-data/database/data-sets/001-2014/A02T.mat'.
100%|█████████████████████

In [None]:
# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)
splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test']  # Session evaluation
batch_size = 16
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)
progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, (X, y, _) in progress_bar:
  print(X.shape, y.shape)
  print(y)
  break


In [4]:
class EEGWindowDataset(torch.utils.data.Dataset):
    def __init__(self, bd_dataset):
        self.bd_dataset = bd_dataset

    def __len__(self):
        return len(self.bd_dataset)

    def __getitem__(self, idx):
        X, y, _ = self.bd_dataset[idx]
        # X: (channels, time) -> (time, channels)
        X = torch.from_numpy(X).float().permute(1, 0)
        y = torch.tensor(y).long()
        return X, y

In [5]:
train_dataset = EEGWindowDataset(train_set)
test_dataset = EEGWindowDataset(test_set)

train_loader = DataLoader(train_dataset, batch_size=4096, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=4096)

In [6]:
from torch import nn
from ncps.torch import CfC

n_channels = train_set[0][0].shape[0]  # e.g., 22
n_classes = 4  # e.g., 4

class EEGCfCNet(nn.Module):
    def __init__(self, input_size, hidden_size, n_classes):
        super().__init__()
        self.rnn = CfC(input_size, hidden_size)
        self.hidden_size = hidden_size
        self.classifier = nn.Linear(hidden_size, n_classes)

    def forward(self, x):
        # x: (batch, time, features)
        h0 = torch.zeros(x.size(0), self.hidden_size, device=x.device)
        out, _ = self.rnn(x, h0)
        # Use last time step output
        out = out[:, -1, :]
        return self.classifier(out)

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = EEGCfCNet(input_size=n_channels, hidden_size=64, n_classes=n_classes).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            output = model(X)
            loss = criterion(output, y)
            total_loss += loss.item() * y.size(0)
            preds = output.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    avg_loss = total_loss / total
    accuracy = correct / total
    return avg_loss, accuracy

num_epochs = 100
for epoch in range(num_epochs):
    model.train()
    running_loss = 0
    running_correct = 0
    running_total = 0
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for X, y in loop:
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * y.size(0)
        preds = output.argmax(dim=1)
        running_correct += (preds == y).sum().item()
        running_total += y.size(0)

        loop.set_postfix(train_loss=running_loss/running_total,
                         train_acc=running_correct/running_total)

    train_loss = running_loss / running_total
    train_acc = running_correct / running_total

    val_loss, val_acc = evaluate(model, test_loader, criterion)

    print(f"Epoch {epoch+1} summary: "
          f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
          f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")

Epoch 1/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1 summary: Train Loss: 1.3952, Train Acc: 0.2326, Val Loss: 1.3954, Val Acc: 0.2384


Epoch 2/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2 summary: Train Loss: 1.3891, Train Acc: 0.2515, Val Loss: 1.3947, Val Acc: 0.2384


Epoch 3/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3 summary: Train Loss: 1.3855, Train Acc: 0.2670, Val Loss: 1.3922, Val Acc: 0.2369


Epoch 4/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4 summary: Train Loss: 1.3816, Train Acc: 0.2797, Val Loss: 1.3910, Val Acc: 0.2465


Epoch 5/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5 summary: Train Loss: 1.3791, Train Acc: 0.2870, Val Loss: 1.3911, Val Acc: 0.2542


Epoch 6/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6 summary: Train Loss: 1.3770, Train Acc: 0.2994, Val Loss: 1.3922, Val Acc: 0.2519


Epoch 7/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7 summary: Train Loss: 1.3751, Train Acc: 0.3044, Val Loss: 1.3938, Val Acc: 0.2481


Epoch 8/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8 summary: Train Loss: 1.3735, Train Acc: 0.3044, Val Loss: 1.3951, Val Acc: 0.2454


Epoch 9/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9 summary: Train Loss: 1.3715, Train Acc: 0.3098, Val Loss: 1.3959, Val Acc: 0.2423


Epoch 10/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 10 summary: Train Loss: 1.3692, Train Acc: 0.3140, Val Loss: 1.3968, Val Acc: 0.2469


Epoch 11/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 11 summary: Train Loss: 1.3671, Train Acc: 0.3214, Val Loss: 1.3980, Val Acc: 0.2454


Epoch 12/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 12 summary: Train Loss: 1.3654, Train Acc: 0.3264, Val Loss: 1.3994, Val Acc: 0.2423


Epoch 13/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 13 summary: Train Loss: 1.3634, Train Acc: 0.3291, Val Loss: 1.4007, Val Acc: 0.2485


Epoch 14/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 14 summary: Train Loss: 1.3609, Train Acc: 0.3318, Val Loss: 1.4021, Val Acc: 0.2465


Epoch 15/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 15 summary: Train Loss: 1.3583, Train Acc: 0.3410, Val Loss: 1.4033, Val Acc: 0.2434


Epoch 16/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 16 summary: Train Loss: 1.3559, Train Acc: 0.3395, Val Loss: 1.4041, Val Acc: 0.2419


Epoch 17/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 17 summary: Train Loss: 1.3533, Train Acc: 0.3422, Val Loss: 1.4047, Val Acc: 0.2400


Epoch 18/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 18 summary: Train Loss: 1.3505, Train Acc: 0.3484, Val Loss: 1.4057, Val Acc: 0.2415


Epoch 19/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 19 summary: Train Loss: 1.3474, Train Acc: 0.3526, Val Loss: 1.4079, Val Acc: 0.2446


Epoch 20/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 20 summary: Train Loss: 1.3439, Train Acc: 0.3600, Val Loss: 1.4109, Val Acc: 0.2415


Epoch 21/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 21 summary: Train Loss: 1.3402, Train Acc: 0.3623, Val Loss: 1.4135, Val Acc: 0.2488


Epoch 22/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 22 summary: Train Loss: 1.3365, Train Acc: 0.3603, Val Loss: 1.4151, Val Acc: 0.2485


Epoch 23/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 23 summary: Train Loss: 1.3322, Train Acc: 0.3634, Val Loss: 1.4166, Val Acc: 0.2461


Epoch 24/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 24 summary: Train Loss: 1.3278, Train Acc: 0.3711, Val Loss: 1.4204, Val Acc: 0.2469


Epoch 25/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 25 summary: Train Loss: 1.3229, Train Acc: 0.3750, Val Loss: 1.4253, Val Acc: 0.2450


Epoch 26/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 26 summary: Train Loss: 1.3175, Train Acc: 0.3800, Val Loss: 1.4294, Val Acc: 0.2411


Epoch 27/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 27 summary: Train Loss: 1.3117, Train Acc: 0.3897, Val Loss: 1.4345, Val Acc: 0.2438


Epoch 28/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 28 summary: Train Loss: 1.3057, Train Acc: 0.3920, Val Loss: 1.4398, Val Acc: 0.2481


Epoch 29/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 29 summary: Train Loss: 1.2997, Train Acc: 0.3978, Val Loss: 1.4537, Val Acc: 0.2434


Epoch 30/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 30 summary: Train Loss: 1.2994, Train Acc: 0.3862, Val Loss: 1.4507, Val Acc: 0.2438


Epoch 31/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 31 summary: Train Loss: 1.2904, Train Acc: 0.3974, Val Loss: 1.4564, Val Acc: 0.2419


Epoch 32/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 32 summary: Train Loss: 1.2831, Train Acc: 0.4012, Val Loss: 1.4654, Val Acc: 0.2361


Epoch 33/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 33 summary: Train Loss: 1.2770, Train Acc: 0.4136, Val Loss: 1.4682, Val Acc: 0.2404


Epoch 34/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 34 summary: Train Loss: 1.2699, Train Acc: 0.4186, Val Loss: 1.4695, Val Acc: 0.2512


Epoch 35/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 35 summary: Train Loss: 1.2627, Train Acc: 0.4225, Val Loss: 1.4780, Val Acc: 0.2500


Epoch 36/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 36 summary: Train Loss: 1.2551, Train Acc: 0.4329, Val Loss: 1.4810, Val Acc: 0.2481


Epoch 37/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 37 summary: Train Loss: 1.2478, Train Acc: 0.4333, Val Loss: 1.4843, Val Acc: 0.2442


Epoch 38/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 38 summary: Train Loss: 1.2392, Train Acc: 0.4379, Val Loss: 1.4922, Val Acc: 0.2496


Epoch 39/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 39 summary: Train Loss: 1.2318, Train Acc: 0.4572, Val Loss: 1.5010, Val Acc: 0.2442


Epoch 40/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 40 summary: Train Loss: 1.2222, Train Acc: 0.4576, Val Loss: 1.5083, Val Acc: 0.2419


Epoch 41/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 41 summary: Train Loss: 1.2154, Train Acc: 0.4599, Val Loss: 1.5136, Val Acc: 0.2446


Epoch 42/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 42 summary: Train Loss: 1.2067, Train Acc: 0.4626, Val Loss: 1.5237, Val Acc: 0.2438


Epoch 43/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 43 summary: Train Loss: 1.2064, Train Acc: 0.4738, Val Loss: 1.5290, Val Acc: 0.2461


Epoch 44/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 44 summary: Train Loss: 1.1925, Train Acc: 0.4734, Val Loss: 1.5355, Val Acc: 0.2431


Epoch 45/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 45 summary: Train Loss: 1.1834, Train Acc: 0.4823, Val Loss: 1.5453, Val Acc: 0.2488


Epoch 46/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 46 summary: Train Loss: 1.1793, Train Acc: 0.4861, Val Loss: 1.5466, Val Acc: 0.2450


Epoch 47/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 47 summary: Train Loss: 1.1679, Train Acc: 0.4931, Val Loss: 1.5525, Val Acc: 0.2438


Epoch 48/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 48 summary: Train Loss: 1.1637, Train Acc: 0.4923, Val Loss: 1.5685, Val Acc: 0.2388


Epoch 49/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 49 summary: Train Loss: 1.1525, Train Acc: 0.5008, Val Loss: 1.5714, Val Acc: 0.2465


Epoch 50/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 50 summary: Train Loss: 1.1497, Train Acc: 0.5069, Val Loss: 1.5708, Val Acc: 0.2450


Epoch 51/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 51 summary: Train Loss: 1.1392, Train Acc: 0.5150, Val Loss: 1.5775, Val Acc: 0.2442


Epoch 52/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 52 summary: Train Loss: 1.1285, Train Acc: 0.5120, Val Loss: 1.5877, Val Acc: 0.2465


Epoch 53/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 53 summary: Train Loss: 1.1277, Train Acc: 0.5170, Val Loss: 1.5972, Val Acc: 0.2500


Epoch 54/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 54 summary: Train Loss: 1.1250, Train Acc: 0.5228, Val Loss: 1.5946, Val Acc: 0.2500


Epoch 55/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 55 summary: Train Loss: 1.1155, Train Acc: 0.5255, Val Loss: 1.6005, Val Acc: 0.2562


Epoch 56/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 56 summary: Train Loss: 1.1113, Train Acc: 0.5216, Val Loss: 1.6185, Val Acc: 0.2461


Epoch 57/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 57 summary: Train Loss: 1.1020, Train Acc: 0.5309, Val Loss: 1.6272, Val Acc: 0.2573


Epoch 58/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 58 summary: Train Loss: 1.0987, Train Acc: 0.5285, Val Loss: 1.6347, Val Acc: 0.2492


Epoch 59/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 59 summary: Train Loss: 1.0849, Train Acc: 0.5428, Val Loss: 1.6448, Val Acc: 0.2612


Epoch 60/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 60 summary: Train Loss: 1.0863, Train Acc: 0.5370, Val Loss: 1.6546, Val Acc: 0.2519


Epoch 61/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 61 summary: Train Loss: 1.0704, Train Acc: 0.5486, Val Loss: 1.6553, Val Acc: 0.2539


Epoch 62/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 62 summary: Train Loss: 1.0771, Train Acc: 0.5386, Val Loss: 1.6582, Val Acc: 0.2450


Epoch 63/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 63 summary: Train Loss: 1.0662, Train Acc: 0.5540, Val Loss: 1.6605, Val Acc: 0.2527


Epoch 64/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 64 summary: Train Loss: 1.0520, Train Acc: 0.5586, Val Loss: 1.6804, Val Acc: 0.2554


Epoch 65/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 65 summary: Train Loss: 1.0463, Train Acc: 0.5671, Val Loss: 1.6962, Val Acc: 0.2562


Epoch 66/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 66 summary: Train Loss: 1.0405, Train Acc: 0.5733, Val Loss: 1.6943, Val Acc: 0.2558


Epoch 67/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 67 summary: Train Loss: 1.0291, Train Acc: 0.5679, Val Loss: 1.6988, Val Acc: 0.2523


Epoch 68/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 68 summary: Train Loss: 1.0296, Train Acc: 0.5741, Val Loss: 1.7070, Val Acc: 0.2515


Epoch 69/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 69 summary: Train Loss: 1.0192, Train Acc: 0.5779, Val Loss: 1.7178, Val Acc: 0.2500


Epoch 70/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 70 summary: Train Loss: 1.0039, Train Acc: 0.5853, Val Loss: 1.7359, Val Acc: 0.2531


Epoch 71/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 71 summary: Train Loss: 1.0047, Train Acc: 0.5872, Val Loss: 1.7317, Val Acc: 0.2612


Epoch 72/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 72 summary: Train Loss: 1.0253, Train Acc: 0.5737, Val Loss: 1.7251, Val Acc: 0.2523


Epoch 73/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 73 summary: Train Loss: 1.0238, Train Acc: 0.5733, Val Loss: 1.7329, Val Acc: 0.2585


Epoch 74/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 74 summary: Train Loss: 1.0009, Train Acc: 0.5818, Val Loss: 1.7557, Val Acc: 0.2477


Epoch 75/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 75 summary: Train Loss: 1.0454, Train Acc: 0.5498, Val Loss: 1.7563, Val Acc: 0.2558


Epoch 76/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 76 summary: Train Loss: 1.0239, Train Acc: 0.5675, Val Loss: 1.7566, Val Acc: 0.2577


Epoch 77/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 77 summary: Train Loss: 0.9983, Train Acc: 0.5930, Val Loss: 1.7591, Val Acc: 0.2523


Epoch 78/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 78 summary: Train Loss: 1.0113, Train Acc: 0.5687, Val Loss: 1.7647, Val Acc: 0.2519


Epoch 79/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 79 summary: Train Loss: 1.0706, Train Acc: 0.5405, Val Loss: 1.7636, Val Acc: 0.2562


Epoch 80/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 80 summary: Train Loss: 1.0774, Train Acc: 0.5451, Val Loss: 1.7693, Val Acc: 0.2527


Epoch 81/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 81 summary: Train Loss: 1.0862, Train Acc: 0.5394, Val Loss: 1.7856, Val Acc: 0.2442


Epoch 82/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 82 summary: Train Loss: 1.0694, Train Acc: 0.5459, Val Loss: 1.7885, Val Acc: 0.2446


Epoch 83/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 83 summary: Train Loss: 1.0356, Train Acc: 0.5590, Val Loss: 1.8053, Val Acc: 0.2438


Epoch 84/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 84 summary: Train Loss: 1.0394, Train Acc: 0.5567, Val Loss: 1.7978, Val Acc: 0.2481


Epoch 85/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 85 summary: Train Loss: 1.0321, Train Acc: 0.5648, Val Loss: 1.7852, Val Acc: 0.2612


Epoch 86/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 86 summary: Train Loss: 1.0048, Train Acc: 0.5810, Val Loss: 1.7736, Val Acc: 0.2650


Epoch 87/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 87 summary: Train Loss: 1.0227, Train Acc: 0.5656, Val Loss: 1.7644, Val Acc: 0.2616


Epoch 88/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 88 summary: Train Loss: 0.9953, Train Acc: 0.5799, Val Loss: 1.7897, Val Acc: 0.2531


Epoch 89/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 89 summary: Train Loss: 1.0326, Train Acc: 0.5633, Val Loss: 1.7952, Val Acc: 0.2473


Epoch 90/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 90 summary: Train Loss: 1.0229, Train Acc: 0.5702, Val Loss: 1.7923, Val Acc: 0.2566


Epoch 91/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 91 summary: Train Loss: 1.0110, Train Acc: 0.5791, Val Loss: 1.7968, Val Acc: 0.2485


Epoch 92/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 92 summary: Train Loss: 1.0213, Train Acc: 0.5606, Val Loss: 1.7973, Val Acc: 0.2446


Epoch 93/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 93 summary: Train Loss: 1.0167, Train Acc: 0.5613, Val Loss: 1.8082, Val Acc: 0.2369


Epoch 94/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 94 summary: Train Loss: 1.0316, Train Acc: 0.5556, Val Loss: 1.8022, Val Acc: 0.2400


Epoch 95/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 95 summary: Train Loss: 1.0033, Train Acc: 0.5745, Val Loss: 1.7930, Val Acc: 0.2527


Epoch 96/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 96 summary: Train Loss: 0.9832, Train Acc: 0.5880, Val Loss: 1.7856, Val Acc: 0.2477


Epoch 97/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 97 summary: Train Loss: 0.9881, Train Acc: 0.5903, Val Loss: 1.7931, Val Acc: 0.2504


Epoch 98/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 98 summary: Train Loss: 0.9895, Train Acc: 0.5860, Val Loss: 1.7961, Val Acc: 0.2519


Epoch 99/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 99 summary: Train Loss: 0.9820, Train Acc: 0.5976, Val Loss: 1.7927, Val Acc: 0.2488


Epoch 100/100:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 100 summary: Train Loss: 0.9628, Train Acc: 0.6053, Val Loss: 1.8035, Val Acc: 0.2531


In [8]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for X, y in tqdm(test_loader):
        X, y = X.to(device), y.to(device)
        out = model(X)
        pred = out.argmax(dim=1)
        correct += (pred == y).sum().item()
        total += y.size(0)
print(f"Test accuracy: {correct/total:.2%}")

  0%|          | 0/1 [00:00<?, ?it/s]

Test accuracy: 25.31%
