In [9]:
get_ipython().system('pip install PyWavelets braindecode moabb')



In [1]:

import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from numpy import multiply
from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import Preprocessor, preprocess, exponential_moving_standardize, create_windows_from_events, SetEEGReference
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score, confusion_matrix, cohen_kappa_score
from sklearn.model_selection import train_test_split
from sklearn.utils import shuffle

In [2]:


# ---------- SpiTranNet  ----------

class SpikingNeuronCell(nn.Module):
    def __init__(self, threshold=0.3, decay=0.9, temp=1.2):
        super().__init__()
        self.threshold = threshold
        self.decay = decay
        self.temp = temp

    def forward(self, x):
        if not x.requires_grad:
           x = x.clone().detach().requires_grad_()

        mem_pot = x * self.decay
        sigmoid = torch.sigmoid((mem_pot - self.threshold) * self.temp)
        spike = torch.where(mem_pot > self.threshold,
                            torch.ones_like(mem_pot),
                            torch.zeros_like(mem_pot))


        spike = spike.clone().detach().requires_grad_()

        surrogate_grad = sigmoid * (1.0 - sigmoid) * self.temp
        spike.register_hook(lambda grad: grad * surrogate_grad)

        return spike



class SpikingMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.spike = SpikingNeuronCell()

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        return self.spike(attn_output)

def positional_encoding(seq_len, d_model, device):
    pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
    i = torch.arange(d_model // 2, dtype=torch.float32, device=device)
    angle_rates = 1 / (10000 ** (2 * i / d_model))
    angle_rads = pos * angle_rates
    sin = torch.sin(angle_rads)
    cos = torch.cos(angle_rads)
    pos_encoding = torch.cat([sin, cos], dim=-1)
    return pos_encoding.unsqueeze(0)  # shape: (1, seq_len, d_model)

class SpiTranNet(nn.Module):
    def __init__(self, input_channels=22, input_length=1000, num_classes=2):
        super().__init__()
        self.input_length = input_length

        self.conv1 = nn.Conv1d(input_channels, 64, kernel_size=7, padding=3)
        self.pool1 = nn.MaxPool1d(kernel_size=4)

        self.conv2 = nn.Conv1d(64, 128, kernel_size=7, padding=3)
        self.pool2 = nn.MaxPool1d(kernel_size=4)

        self.conv3 = nn.Conv1d(128, 128, kernel_size=7, padding=3)
        self.pool3 = nn.MaxPool1d(kernel_size=4)

        self.dropout = nn.Dropout(0.5)

        self.seq_len = input_length // 64
        self.transformer_dim = 128
        self.attn = SpikingMultiHeadAttention(embed_dim=self.transformer_dim, num_heads=2)
        self.norm1 = nn.LayerNorm(self.transformer_dim)
        self.norm2 = nn.LayerNorm(self.transformer_dim)

        self.ffn = nn.Sequential(
            nn.Linear(self.transformer_dim, 128),
            nn.ReLU(),
            SpikingNeuronCell(),
            nn.Linear(128, self.transformer_dim)
        )

        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(self.seq_len * self.transformer_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.conv1(x)
        x = torch.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = torch.relu(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = torch.relu(x)
        x = self.pool3(x)

        x = self.dropout(x)
        x = x.permute(0, 2, 1)

        pos_enc = positional_encoding(x.size(1), x.size(2), x.device)
        x = x + pos_enc

        residual = x
        x = self.attn(x)
        x = self.norm1(x + residual)

        residual = x
        x = self.ffn(x)
        x = self.norm2(x + residual)

        x = self.flatten(x)
        return self.classifier(x)

# ---------- general setting ----------
sr = 250
input_length = 1000
input_channels = 22
num_classes = 2

def load_dataframe(subject_id):
    dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])

    low_cut_hz = 8.
    high_cut_hz = 30.
    factor = 1e6
    factor_new = 1e-3
    init_block_size = 1000

    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(lambda data: data * factor),
        Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        Preprocessor(exponential_moving_standardize, factor_new=factor_new, init_block_size=init_block_size),
        SetEEGReference()
    ]

    preprocess(dataset, preprocessors, n_jobs=-1)
    windows_dataset = create_windows_from_events(dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0, preload=True)
    XYSet = [windows_dataset.split('session')['0train']]
    XYSet_ = [windows_dataset.split('session')['1test']]
    return XYSet, XYSet_

def extract(XYSet, selected_classes=[0, 1]):
    X, Y = [], []
    for ds in XYSet:
        for window in ds:
            x = window[0]
            y = window[1]
            if isinstance(y, list):
                y = y[0]
            if int(y) in selected_classes:
                X.append(x)
                Y.append(int(y))
    return np.array(X), np.array(Y)

def fp(x, y, device):
    x = torch.tensor(x, dtype=torch.float32).to(device)
    y = torch.tensor(y, dtype=torch.long).to(device)
    return x, y

def evaluate_model(model, test_loader):
    model.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            outputs = model(X_batch)
            _, predicted = torch.max(outputs, 1)
            y_true.extend(y_batch.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='binary')
    rec = recall_score(y_true, y_pred, average='binary')
    kappa = cohen_kappa_score(y_true, y_pred)
    tn, fp_, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    spec = tn / (tn + fp_) if (tn + fp_) > 0 else 0
    return {"accuracy": acc, "precision": prec, "recall": rec, "specificity": spec, "kappa": kappa}

def train_model(model, train_loader, test_loader, num_epochs=100, lr=1e-4, device="cuda"):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)
    best = {"accuracy": 0}
    for epoch in range(num_epochs):
        model.train()
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            loss = criterion(model(X_batch), y_batch)
            loss.backward()
            optimizer.step()
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}")
        metrics = evaluate_model(model, test_loader)
        if metrics["accuracy"] > best["accuracy"]:
            best = metrics
    return best

def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    results = []
    for i in range(1, 9):
        XYSet, XYSet_ = load_dataframe(i)
        XTrain, YTrain = extract(XYSet)
        XTest, YTest = extract(XYSet_)

        XTrain = XTrain.transpose(0, 1, 2) if XTrain.shape[1] < XTrain.shape[2] else XTrain
        XTest = XTest.transpose(0, 1, 2) if XTest.shape[1] < XTest.shape[2] else XTest

        X_train, y_train = fp(XTrain, YTrain, device)
        X_test, y_test = fp(XTest, YTest, device)

        X_test_final, X_extra_train, y_test_final, y_extra_train = train_test_split(X_test.cpu().numpy(), y_test.cpu().numpy(), test_size=0.6, stratify=y_test.cpu().numpy())

        X_train = torch.cat([X_train, torch.tensor(X_extra_train).to(device)])
        y_train = torch.cat([y_train, torch.tensor(y_extra_train).to(device)])
        X_test = torch.tensor(X_test_final).to(device)
        y_test = torch.tensor(y_test_final).to(device)

        train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=16, shuffle=True)
        test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=16, shuffle=False)

        model = SpiTranNet(input_channels=22, input_length=1000, num_classes=2).to(device)
        metrics = train_model(model, train_loader, test_loader, device=device)
        print(f"Subject {i} → {metrics}")
        results.append(metrics)

    print("\nAverage Results:")
    for k in results[0]:
        avg = sum(r[k] for r in results) / len(results)
        print(f"{k}: {avg:.4f}")

if __name__ == "__main__":
    main()


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A01T.mat'.
100%|#####################################| 42.8M/42.8M [00:00<00:00, 42.8GB/s]
SHA256 hash of downloaded file: 054f02e70cf9c4ada1517e9b9864f45407939c1062c6793516585c6f511d0325
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A01E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A01E.mat'.
100%|#####################################| 43.8M/43.8M [00:00<00:00, 33.6GB/s]
SHA256 hash of downloaded file: 53d415f39c3d7b0c88b894d7b08d99bcdfe855ede63831d3691af1a45607fb62
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('

Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoc

Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A02T.mat'.


Epoch 100
Subject 1 → {'accuracy': 0.9824561403508771, 'precision': 0.9666666666666667, 'recall': 1.0, 'specificity': 0.9642857142857143, 'kappa': 0.9648798521256932}


100%|#####################################| 43.1M/43.1M [00:00<00:00, 43.4GB/s]
SHA256 hash of downloaded file: 5ddd5cb520b1692c3ba1363f48d98f58f0e46f3699ee50d749947950fc39db27
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A02E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A02E.mat'.
100%|#####################################| 44.2M/44.2M [00:00<00:00, 43.7GB/s]
SHA256 hash of downloaded file: d63c454005d3a9b41d8440629482e855afc823339bdd0b5721842a7ee9cc7b12
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A03T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A03T.mat'.


Epoch 100
Subject 2 → {'accuracy': 0.7894736842105263, 'precision': 0.8148148148148148, 'recall': 0.7586206896551724, 'specificity': 0.8214285714285714, 'kappa': 0.5793357933579336}


100%|#####################################| 44.1M/44.1M [00:00<00:00, 21.9GB/s]
SHA256 hash of downloaded file: 7e731ee8b681d5da6ecb11ae1d4e64b1653c7f15aad5d6b7620b25ce53141e80
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A03E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A03E.mat'.
100%|#####################################| 42.3M/42.3M [00:00<00:00, 21.1GB/s]
SHA256 hash of downloaded file: d4229267ec7624fa8bd3af5cbebac17f415f7c722de6cb676748f8cb3b717d97
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A04T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A04T.mat'.


Epoch 100
Subject 3 → {'accuracy': 0.9649122807017544, 'precision': 0.9655172413793104, 'recall': 0.9655172413793104, 'specificity': 0.9642857142857143, 'kappa': 0.9298029556650247}


100%|#############################################| 37.2M/37.2M [00:00<?, ?B/s]
SHA256 hash of downloaded file: 15850d81b95fc88cc8b9589eb9b713d49fa071e28adaf32d675b3eaa30591d6e
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A04E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A04E.mat'.
100%|#####################################| 41.7M/41.7M [00:00<00:00, 19.6GB/s]
SHA256 hash of downloaded file: 81916dff2c12997974ba50ffc311da006ea66e525010d010765f0047e771c86a
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A05T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A05T.mat'.


Epoch 100
Subject 4 → {'accuracy': 0.8245614035087719, 'precision': 0.8518518518518519, 'recall': 0.7931034482758621, 'specificity': 0.8571428571428571, 'kappa': 0.6494464944649447}


100%|#############################################| 42.5M/42.5M [00:00<?, ?B/s]
SHA256 hash of downloaded file: 77387d3b669f4ed9a7c1dac4dcba4c2c40c8910bae20fb961bb7cf5a94912950
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A05E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A05E.mat'.
100%|#####################################| 44.4M/44.4M [00:00<00:00, 44.1GB/s]
SHA256 hash of downloaded file: 8b357470865610c28b2f1d351beac247a56a856f02b2859d650736eb2ef77808
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A06T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A06T.mat'.


Epoch 100
Subject 5 → {'accuracy': 0.7192982456140351, 'precision': 0.6857142857142857, 'recall': 0.8275862068965517, 'specificity': 0.6071428571428571, 'kappa': 0.4363411619283065}


100%|#####################################| 44.6M/44.6M [00:00<00:00, 22.3GB/s]
SHA256 hash of downloaded file: 4dc3be1b0d60279134d1220323c73c68cf73799339a7fb224087a3c560a9a7e2
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A06E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A06E.mat'.
100%|#####################################| 43.4M/43.4M [00:00<00:00, 33.3GB/s]
SHA256 hash of downloaded file: bf67a40621b74b6af7a986c2f6edfff7fc2bbbca237aadd07b575893032998d1
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A07T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A07T.mat'.


Epoch 100
Subject 6 → {'accuracy': 0.8070175438596491, 'precision': 0.7931034482758621, 'recall': 0.8214285714285714, 'specificity': 0.7931034482758621, 'kappa': 0.6141538461538462}


100%|#####################################| 42.8M/42.8M [00:00<00:00, 41.1GB/s]
SHA256 hash of downloaded file: 43b6bbef0be78f0ac2b66cb2d9679091f1f5b7f0a5d4ebef73d2c7cc8e11aa96
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A07E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A07E.mat'.
100%|#####################################| 42.2M/42.2M [00:00<00:00, 41.9GB/s]
SHA256 hash of downloaded file: b9aaec73dcee002fab84ee98e938039a67bf6a3cbf4fc86d5d8df198cfe4c323
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10


  _warn_prf(average, modifier, msg_start, len(result))


Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90


Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A08T.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A08T.mat'.


Epoch 100
Subject 7 → {'accuracy': 0.9298245614035088, 'precision': 0.9285714285714286, 'recall': 0.9285714285714286, 'specificity': 0.9310344827586207, 'kappa': 0.8596059113300493}


100%|#####################################| 45.0M/45.0M [00:00<00:00, 32.5GB/s]
SHA256 hash of downloaded file: 7a4b3bd602d5bc307d3f4527fca2cf076659e94aca584dd64f6286fd413a82f2
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
Downloading data from 'http://bnci-horizon-2020.eu/database/data-sets/001-2014/A08E.mat' to file 'C:\Users\alira\mne_data\MNE-bnci-data\database\data-sets\001-2014\A08E.mat'.
100%|#####################################| 46.3M/46.3M [00:00<00:00, 45.7GB/s]
SHA256 hash of downloaded file: 0eedbd89790c7d621c8eef68065ddecf80d437bbbcf60321d9253e2305f294f7
Use this value as the 'known_hash' argument of 'pooch.retrieve' to ensure that the file hasn't changed if it is downloaded again in the future.
  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Epoch 10
Epoch 20
Epoch 30
Epoch 40
Epoch 50
Epoch 60
Epoch 70
Epoch 80
Epoch 90
Epoch 100
Subject 8 → {'accuracy': 0.9122807017543859, 'precision': 0.9285714285714286, 'recall': 0.896551724137931, 'specificity': 0.9285714285714286, 'kappa': 0.8246153846153846}

Average Results:
accuracy: 0.8662
precision: 0.8669
recall: 0.8739
specificity: 0.8584
kappa: 0.7323


# Second Code

# main code :

# both Subject-Wise And all-Subject Training with Metrics Calculation and Visualizations (9subjects)

###

In [13]:
# Full modified script with enhanced plotting & documentation
# Cell 1: Imports and Setup
%matplotlib inline
import os
import time
import json
from typing import Dict, Tuple, List

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    confusion_matrix,
    cohen_kappa_score,
    roc_curve,
    auc,
)
from sklearn.model_selection import train_test_split

from braindecode.datasets import MOABBDataset
from braindecode.preprocessing import (
    Preprocessor,
    preprocess,
    exponential_moving_standardize,
    create_windows_from_events,
    SetEEGReference,
)

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

plt.rcParams.update({
    "font.family": "Helvetica",
    "font.size": 12,
    "axes.titlesize": 16,
    "axes.labelsize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    "figure.dpi": 600,
    "savefig.dpi": 600,
    "savefig.bbox": "tight",
    "savefig.format": "png",
})
sns.set_style("whitegrid")

# -------------------
# Cell 2: Model Definition (SpiTranNet)  -- unchanged
class SpikingNeuronCell(nn.Module):
    def __init__(self, threshold=0.3, decay=0.9, temp=1.2):
        super().__init__()
        self.threshold = threshold
        self.decay = decay
        self.temp = temp

    def forward(self, x):
        if not x.requires_grad:
            x = x.clone().detach().requires_grad_()
        mem_pot = x * self.decay
        sigmoid = torch.sigmoid((mem_pot - self.threshold) * self.temp)
        spike = torch.where(mem_pot > self.threshold,
                            torch.ones_like(mem_pot),
                            torch.zeros_like(mem_pot))
        spike = spike.clone().detach().requires_grad_()
        surrogate_grad = (sigmoid * (1.0 - sigmoid) * self.temp).detach()
        spike.register_hook(lambda grad: grad * surrogate_grad)
        return spike

class SpikingMultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.spike = SpikingNeuronCell()

    def forward(self, x):
        attn_output, _ = self.attn(x, x, x)
        return self.spike(attn_output)

def positional_encoding(seq_len, d_model, device):
    pos = torch.arange(seq_len, dtype=torch.float32, device=device).unsqueeze(1)
    i = torch.arange(d_model // 2, dtype=torch.float32, device=device)
    angle_rates = 1 / (10000 ** (2 * i / d_model))
    angle_rads = pos * angle_rates
    sin = torch.sin(angle_rads)
    cos = torch.cos(angle_rads)
    pos_encoding = torch.cat([sin, cos], dim=-1)
    return pos_encoding.unsqueeze(0)

class SpiTranNet(nn.Module):
    def __init__(self, input_channels=22, input_length=1000, num_classes=2):
        super().__init__()
        self.input_length = input_length
        self.conv_block = nn.Sequential(
            nn.Conv1d(input_channels, 64, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
            nn.Conv1d(64, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
            nn.Conv1d(128, 128, kernel_size=7, stride=1, padding=3),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=4, stride=4),
            nn.Dropout(0.5)
        )
        self.transformer_dim = 128
        self.attn = SpikingMultiHeadAttention(embed_dim=self.transformer_dim, num_heads=2)
        self.norm1 = nn.LayerNorm(self.transformer_dim)
        self.norm2 = nn.LayerNorm(self.transformer_dim)
        self.ffn = nn.Sequential(
            nn.Linear(self.transformer_dim, 128),
            nn.ReLU(),
            SpikingNeuronCell(),
            nn.Linear(128, self.transformer_dim)
        )
        self.dropout = nn.Dropout(0.5)
        self.seq_len = input_length // 64
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(self.seq_len * self.transformer_dim, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = self.conv_block(x)
        x = x.permute(0, 2, 1)
        pos_enc = positional_encoding(x.size(1), x.size(2), x.device)
        x = x + pos_enc
        residual = x
        x = self.attn(x)
        x = self.norm1(x + residual)
        residual = x
        x = self.ffn(x)
        x = self.norm2(x + residual)
        x = self.dropout(x)
        x = self.flatten(x)
        return self.classifier(x)

# -------------------
# Cell 3: Data Loading & Preprocessing  -- unchanged
sr = 250
input_length = 1000
input_channels = 22
num_classes = 2

def load_dataframe(subject_id):
    dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[subject_id])
    low_cut_hz = 8.
    high_cut_hz = 30.
    factor = 1e6
    factor_new = 1e-3
    init_block_size = 1000
    preprocessors = [
        Preprocessor('pick_types', eeg=True, meg=False, stim=False),
        Preprocessor(lambda data: data * factor),
        Preprocessor('filter', l_freq=low_cut_hz, h_freq=high_cut_hz),
        Preprocessor(exponential_moving_standardize, factor_new=factor_new, init_block_size=init_block_size),
        SetEEGReference()
    ]
    preprocess(dataset, preprocessors, n_jobs=-1)
    windows_dataset = create_windows_from_events(dataset, trial_start_offset_samples=0, trial_stop_offset_samples=0, preload=True)
    XYSet = [windows_dataset.split('session')['0train']]
    XYSet_ = [windows_dataset.split('session')['1test']]
    return XYSet, XYSet_

def extract(XYSet, selected_classes=[0, 1]):
    X, Y = [], []
    for ds in XYSet:
        for window in ds:
            x = window[0]
            y = window[1]
            if isinstance(y, list):
                y = y[0]
            if int(y) in selected_classes:
                X.append(x)
                Y.append(int(y))
    return np.array(X), np.array(Y)

def fp(x, y, device):
    x = torch.tensor(x, dtype=torch.float32).to(device)
    y = torch.tensor(y, dtype=torch.long).to(device)
    return x, y

# -------------------
# Cell 4: Metrics and Evaluation  -- unchanged compute functions
def evaluate_with_probs(model, loader):
    model.eval()
    y_true_list = []
    y_pred_list = []
    y_prob_list = []
    with torch.no_grad():
        for Xb, yb in loader:
            outputs = model(Xb)
            probs = torch.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, dim=1)
            y_true_list.extend(yb.cpu().numpy().tolist())
            y_pred_list.extend(preds.cpu().numpy().tolist())
            if probs.size(1) >= 2:
                pos_probs = probs[:, 1].cpu().numpy().tolist()
            else:
                pos_probs = torch.sigmoid(outputs[:, 0]).cpu().numpy().tolist()
            y_prob_list.extend(pos_probs)
    return y_true_list, y_pred_list, y_prob_list

def compute_metrics_from_preds(y_true, y_pred, y_prob=None) -> Dict:
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, zero_division=0)
    rec = recall_score(y_true, y_pred, zero_division=0)
    f1 = f1_score(y_true, y_pred, zero_division=0)
    kappa = cohen_kappa_score(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    if cm.size == 4:
        tn, fp_, fn, tp = cm.ravel()
    else:
        tn = int(cm[0, 0]) if cm.shape[0] > 0 and cm.shape[1] > 0 else 0
        fp_ = int(cm[0, 1]) if cm.shape[0] > 0 and cm.shape[1] > 1 else 0
        fn = int(cm[1, 0]) if cm.shape[0] > 1 and cm.shape[1] > 0 else 0
        tp = int(cm[1, 1]) if cm.shape[0] > 1 and cm.shape[1] > 1 else 0
    spec = tn / (tn + fp_) if (tn + fp_) > 0 else 0.0
    results = {
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1_score": float(f1),
        "kappa": float(kappa),
        "specificity": float(spec),
        "tn": int(tn),
        "fp": int(fp_),
        "fn": int(fn),
        "tp": int(tp)
    }
    if y_prob is not None and len(set(y_true)) > 1:
        fpr, tpr, _ = roc_curve(y_true, y_prob)
        roc_auc = auc(fpr, tpr)
        results["roc_auc"] = float(roc_auc)
        results["fpr"] = fpr.tolist()
        results["tpr"] = tpr.tolist()
    else:
        results["roc_auc"] = None
        results["fpr"] = []
        results["tpr"] = []
    return results

# -------------------
# Cell 5: Plotting / Saving functions (enhanced)
def save_png_pdf(fig, out_path_base):
    """Save figure in PNG and PDF with high resolution."""
    png_path = out_path_base + ".png"
    pdf_path = out_path_base + ".pdf"
    fig.savefig(png_path, dpi=600, bbox_inches="tight")
    try:
        fig.savefig(pdf_path, dpi=600, bbox_inches="tight")
    except Exception as e:
        # PDF save may fail in some environments, but continue
        print("Warning: saving PDF failed:", e)


def save_confusion_matrix(y_true, y_pred, save_path, title="Confusion Matrix", normalize=False):
    import matplotlib.pyplot as plt
    import seaborn as sns
    import numpy as np
    from sklearn.metrics import confusion_matrix

    cm = confusion_matrix(y_true, y_pred)
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

    num_classes = cm.shape[0]
    fig_size = max(6, num_classes * 2)  # larger to ensure all rows visible
    fig, ax = plt.subplots(figsize=(fig_size, fig_size))

    # Draw heatmap without annotations
    sns.heatmap(cm, annot=False, cmap="Blues", cbar=True, square=True, ax=ax)

    # Add numbers manually (centered)
    for i in range(num_classes):
        for j in range(num_classes):
            val = cm[i, j]
            display_val = f"{val:.2f}" if normalize else f"{int(val)}"
            ax.text(j + 0.5, i + 0.5, display_val,
                    ha="center", va="center",
                    color="white" if cm[i, j] > cm.max()/2 else "black",
                    fontsize=12, fontweight='bold')

    ax.set_xlabel("Predicted label", fontsize=12)
    ax.set_ylabel("True label", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.set_xticks(np.arange(num_classes)+0.5)
    ax.set_yticks(np.arange(num_classes)+0.5)
    ax.set_xticklabels(range(num_classes))
    ax.set_yticklabels(range(num_classes), rotation=0)

    plt.tight_layout()
    save_png_pdf(fig, save_path)
    plt.close(fig)




def save_epoch_history_excel(history: Dict, out_dir: str, filename_base="epoch_history"):
    # history is expected to be dict: keys -> lists of same length
    os.makedirs(out_dir, exist_ok=True)
    # Save CSV
    df = pd.DataFrame(history)
    csv_path = os.path.join(out_dir, filename_base + ".csv")
    df.to_csv(csv_path, index_label="epoch")
    # Save Excel (sheet)
    xlsx_path = os.path.join(out_dir, filename_base + ".xlsx")
    try:
        with pd.ExcelWriter(xlsx_path, engine="openpyxl") as writer:
            df.to_excel(writer, sheet_name="history", index_label="epoch")
    except Exception as e:
        # still ok if to_excel fails
        print("Warning: to_excel failed:", e)
    return csv_path, xlsx_path

def save_loss_plot(history: Dict, out_path_base: str, title="Training & Validation Loss"):
    fig, ax = plt.subplots(figsize=(8,5))
    epochs = range(1, len(history.get("loss", [])) + 1)
    if history.get("loss", None):
        ax.plot(epochs, history["loss"], label="Training Loss", linewidth=2)
    if history.get("val_loss", None):
        ax.plot(epochs, history["val_loss"], label="Validation Loss", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title(title)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    save_png_pdf(fig, out_path_base)
    plt.close(fig)

def save_acc_plot(history: Dict, out_path_base: str, title="Training & Validation Accuracy"):
    fig, ax = plt.subplots(figsize=(8,5))
    epochs = range(1, len(history.get("train_acc", [])) + 1)
    if history.get("train_acc", None):
        ax.plot(epochs, history["train_acc"], label="Training Accuracy", linewidth=2)
    if history.get("val_acc", None):
        ax.plot(epochs, history["val_acc"], label="Validation Accuracy", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.set_ylim(0,1.0)
    ax.set_title(title)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    save_png_pdf(fig, out_path_base)
    plt.close(fig)

def save_kappa_plot(history: Dict, out_path_base: str, title="Training & Validation Kappa"):
    fig, ax = plt.subplots(figsize=(8,5))
    epochs = range(1, len(history.get("train_kappa", [])) + 1)
    if history.get("train_kappa", None):
        ax.plot(epochs, history["train_kappa"], label="Training Kappa", linewidth=2)
    if history.get("val_kappa", None):
        ax.plot(epochs, history["val_kappa"], label="Validation Kappa", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Cohen's Kappa")
    ax.set_ylim(-1,1)
    ax.set_title(title)
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    save_png_pdf(fig, out_path_base)
    plt.close(fig)

def save_roc_curve(y_true, y_prob, out_path_base: str, title="ROC Curve"):
    # Only if there are at least two classes
    if y_prob is None or len(set(y_true)) < 2:
        print("ROC not saved: insufficient classes or missing probabilities")
        return None
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    fig, ax = plt.subplots(figsize=(6,6))
    ax.plot(fpr, tpr, label=f'AUC = {roc_auc:.3f}', linewidth=2)
    ax.plot([0,1],[0,1],'--', color='gray', linewidth=1)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positive Rate")
    ax.set_title(title)
    ax.legend()
    ax.grid(True, linestyle="--", linewidth=0.5)
    save_png_pdf(fig, out_path_base)
    plt.close(fig)
    return {"fpr": fpr.tolist(), "tpr": tpr.tolist(), "roc_auc": float(roc_auc)}

def save_metrics_json(metrics: Dict, out_path):
    with open(out_path, "w") as f:
        json.dump(metrics, f, indent=4)

def save_combined_loss_acc(train_loss, val_loss, train_acc, val_acc, out_path):
    plt.figure(figsize=(10, 6))
    
    # Plot Loss
    plt.subplot(2, 1, 1)
    plt.plot(train_loss, label="Train Loss")
    plt.plot(val_loss, label="Validation Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()
    plt.title("Training vs Validation Loss")

    # Plot Accuracy
    plt.subplot(2, 1, 2)
    plt.plot(train_acc, label="Train Accuracy")
    plt.plot(val_acc, label="Validation Accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()
    plt.title("Training vs Validation Accuracy")

    plt.tight_layout()
    plt.savefig(out_path + ".png", dpi=300, bbox_inches="tight")
    plt.close()

def save_combined_kappa(train_kappa, val_kappa, out_path):
    """
    Save a combined plot of training vs validation Cohen's kappa.

    Parameters
    ----------
    train_kappa : list or array
        Training kappa values per epoch
    val_kappa : list or array
        Validation kappa values per epoch
    out_path : str
        Path prefix to save the plot (without extension)
    """
    plt.figure(figsize=(8, 5))
    plt.plot(train_kappa, label="Train Kappa", linewidth=2)
    plt.plot(val_kappa, label="Validation Kappa", linewidth=2, linestyle="--")
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Cohen's Kappa", fontsize=12)
    plt.title("Training vs Validation Kappa", fontsize=14, fontweight="bold")
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.6)
    plt.tight_layout()
    plt.savefig(out_path + ".png", dpi=300, bbox_inches="tight")
    plt.close()


def save_summary_table(metrics_dict, out_path, title="Summary Metrics Table"):
    """
    Save a summary table as a PNG image with clear formatting.

    Parameters
    ----------
    metrics_dict : dict
        Dictionary of metrics, e.g. { "Subject1": {"Acc":0.9, "F1":0.8, ...}, ... }
    out_path : str
        Path prefix to save PNG (without extension)
    title : str
        Title above the table
    """
    # Convert dictionary to DataFrame
    df = pd.DataFrame(metrics_dict).T  # subjects as rows
    df = df.round(4)  # round to 4 decimals for clarity

    fig, ax = plt.subplots(figsize=(max(8, len(df.columns)*1.2), max(4, len(df)*0.5 + 2)))
    ax.axis("off")

    # Title
    plt.title(title, fontsize=14, fontweight="bold", pad=20)

    # Create table
    table = ax.table(cellText=df.values,
                     rowLabels=df.index,
                     colLabels=df.columns,
                     cellLoc='center',
                     rowLoc='center',
                     loc='center')

    # Formatting
    table.auto_set_font_size(False)
    table.set_fontsize(12)
    table.scale(1.2, 1.2)  # make table bigger

    # Style header cells
    for (row, col), cell in table.get_celld().items():
        cell.set_linewidth(1)
        if row == 0:  # header row
            cell.set_text_props(weight='bold', color="white")
            cell.set_facecolor("#4B8BBE")
        elif col == -1:  # row index
            cell.set_text_props(weight='bold')
            cell.set_facecolor("#f0f0f0")
        else:
            cell.set_facecolor("white")

    plt.tight_layout()
    plt.savefig(out_path + ".png", dpi=300, bbox_inches="tight")
    plt.close()


# -------------------
# Cell 6: Training function (unchanged) - returns best_metrics, history, (y_true_final, y_pred_final, y_prob_final)
def train_model_with_history(model, train_loader, test_loader, num_epochs=100, lr=1e-4, device="cuda",
                             out_dir="results_subject", checkpoint_name="best_model.pth"):
    os.makedirs(out_dir, exist_ok=True)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    best_acc = -1.0
    best_metrics = None
    history = {"loss": [], "val_loss": [], "train_acc": [], "val_acc": [], "train_kappa": [], "val_kappa": []}

    for epoch in range(1, num_epochs + 1):
        model.train()
        running_loss = 0.0
        n_samples = 0
        y_true_train_epoch = []
        y_pred_train_epoch = []
        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += float(loss.item()) * X_batch.size(0)
            n_samples += X_batch.size(0)
            _, preds = torch.max(outputs, dim=1)
            y_true_train_epoch.extend(y_batch.cpu().numpy())
            y_pred_train_epoch.extend(preds.cpu().numpy())
        epoch_loss = running_loss / (n_samples if n_samples > 0 else 1.0)
        train_metrics_epoch = compute_metrics_from_preds(y_true_train_epoch, y_pred_train_epoch)
        train_acc_epoch = train_metrics_epoch["accuracy"]
        train_kappa_epoch = train_metrics_epoch["kappa"]

        # Validation evaluation
        y_true_val, y_pred_val, y_prob_val = evaluate_with_probs(model, test_loader)
        metrics_val = compute_metrics_from_preds(y_true_val, y_pred_val, y_prob_val)
        val_acc_epoch = metrics_val["accuracy"]
        val_kappa_epoch = metrics_val["kappa"]

        # Compute validation loss
        model.eval()
        val_loss = 0.0
        val_samples = 0
        with torch.no_grad():
            for X_batch, y_batch in test_loader:
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                val_loss += float(loss.item()) * X_batch.size(0)
                val_samples += X_batch.size(0)
        val_loss = val_loss / (val_samples if val_samples > 0 else 1.0)

        history["loss"].append(epoch_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc_epoch)
        history["val_acc"].append(val_acc_epoch)
        history["train_kappa"].append(train_kappa_epoch)
        history["val_kappa"].append(val_kappa_epoch)

        print(f"[Epoch {epoch}/{num_epochs}] loss={epoch_loss:.4f} val_loss={val_loss:.4f} train_acc={train_acc_epoch:.4f} val_acc={val_acc_epoch:.4f} train_kappa={train_kappa_epoch:.4f} val_kappa={val_kappa_epoch:.4f}")

        if val_acc_epoch > best_acc:
            best_acc = val_acc_epoch
            best_metrics = metrics_val
            torch.save(model.state_dict(), os.path.join(out_dir, checkpoint_name))

    ckpt_path = os.path.join(out_dir, checkpoint_name)
    if os.path.exists(ckpt_path):
        state = torch.load(ckpt_path, map_location=next(model.parameters()).device)
        try:
            model.load_state_dict(state)
        except Exception as e:
            print("Warning: could not load saved state_dict:", e)

    y_true_final, y_pred_final, y_prob_final = evaluate_with_probs(model, test_loader)
    final_metrics = compute_metrics_from_preds(y_true_final, y_pred_final, y_prob_final)
    return best_metrics, history, (y_true_final, y_pred_final, y_prob_final)

# -------------------
# Cell 8: Main function for per-subject and all-subject metrics (enhanced plotting & combined outputs)
# -------------------
# Helper: Save summary table as PNG (clean formatting)

# def save_summary_table_png(df, save_path, title="Summary Metrics per Subject"):
#     import matplotlib.pyplot as plt
#     import numpy as np

#     # Round numeric columns to 4 digits
#     df_display = df.copy()
#     for col in df_display.columns:
#         if np.issubdtype(df_display[col].dtype, np.number):
#             df_display[col] = df_display[col].round(4)

#     n_rows, n_cols = df_display.shape
#     fig_height = max(2, 0.6 * n_rows)
#     fig_width = max(8, 1.5 * n_cols)
#     fig, ax = plt.subplots(figsize=(fig_width, fig_height))
#     ax.axis('off')

#     # Create table
#     tbl = ax.table(cellText=df_display.values,
#                    colLabels=df_display.columns,
#                    cellLoc='center',  # center text in all cells
#                    loc='center')

#     tbl.auto_set_font_size(False)
#     tbl.set_fontsize(10)
#     tbl.scale(1.2, 1.2)  # scale width and height of cells

#     # Make header bold
#     for (i, j), cell in tbl.get_celld().items():
#         if i == 0:  # header row
#             cell.set_text_props(weight='bold', ha='center', va='center')
#         else:
#             cell.set_text_props(ha='center', va='center')

#     ax.set_title(title, fontsize=14, pad=20)
#     plt.tight_layout()
#     fig.savefig(save_path + ".png", dpi=300)
#     plt.close(fig)


def save_combined_subject_training_curves_simple(df_combined_stats, out_dir="results_combined"):
    import matplotlib.pyplot as plt
    import os

    os.makedirs(out_dir, exist_ok=True)
    epochs = np.arange(1, df_combined_stats.shape[0]+1)

    # Loss
    fig, ax = plt.subplots(figsize=(9,5))
    if "loss_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["loss_mean"], label="Train Loss", linewidth=2)
    if "val_loss_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["val_loss_mean"], label="Val Loss", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Loss")
    ax.set_title("Train & Validation Loss (Combined Subjects)")
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, "combined_loss.png"), dpi=300)
    plt.close(fig)

    # Accuracy
    fig, ax = plt.subplots(figsize=(9,5))
    if "train_acc_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["train_acc_mean"], label="Train Accuracy", linewidth=2)
    if "val_acc_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["val_acc_mean"], label="Validation Accuracy", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Accuracy")
    ax.set_ylim(0,1.0)
    ax.set_title("Train & Validation Accuracy (Combined Subjects)")
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, "combined_accuracy.png"), dpi=300)
    plt.close(fig)

    # Kappa
    fig, ax = plt.subplots(figsize=(9,5))
    if "train_kappa_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["train_kappa_mean"], label="Train Kappa", linewidth=2)
    if "val_kappa_mean" in df_combined_stats.columns:
        ax.plot(epochs, df_combined_stats["val_kappa_mean"], label="Validation Kappa", linewidth=2)
    ax.set_xlabel("Epoch")
    ax.set_ylabel("Cohen's Kappa")
    ax.set_ylim(-1,1)
    ax.set_title("Train & Validation Kappa (Combined Subjects)")
    ax.grid(True, linestyle="--", linewidth=0.5)
    ax.legend()
    plt.tight_layout()
    fig.savefig(os.path.join(out_dir, "combined_kappa.png"), dpi=300)
    plt.close(fig)

# Cell 8: Main function for per-subject and all-subject metrics (enhanced plotting & combined outputs)
def run_all_subjects(subject_ids, num_epochs=50, batch_size=32, lr=1e-4, device="cuda"):
    all_subject_metrics = []
    all_subject_y_true = []
    all_subject_y_pred = []
    all_subject_y_prob = []

    # For combined plotting across histories
    histories_by_subject = {}  # subj_id -> history dict
    final_metrics_by_subject = {}  # subj_id -> final_metrics
    train_vs_test_acc = []  # list of dicts {subject, train_acc_final_epoch, test_acc}
    roc_records = []  # list of (subject_id, fpr, tpr, auc)
    summary_rows = []

    for subj_id in subject_ids:
        print(f"\n=== Processing Subject {subj_id} ===\n")
        out_dir = f"results_subject_{subj_id}"
        os.makedirs(out_dir, exist_ok=True)

        # Load data
        XY_train, XY_test = load_dataframe(subj_id)
        X_train, y_train = extract(XY_train)
        X_test, y_test = extract(XY_test)

        # Convert to tensors
        X_train_t, y_train_t = fp(X_train, y_train, device)
        X_test_t, y_test_t = fp(X_test, y_test, device)

        train_dataset = TensorDataset(X_train_t, y_train_t)
        test_dataset = TensorDataset(X_test_t, y_test_t)
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

        # Initialize model
        model = SpiTranNet(input_channels=X_train.shape[1], input_length=X_train.shape[2], num_classes=num_classes)
        model.to(device)

        # Train model with history
        best_metrics, history, (y_true_final, y_pred_final, y_prob_final) = train_model_with_history(
            model, train_loader, test_loader, num_epochs=num_epochs, lr=lr, device=device, out_dir=out_dir
        )

        # Save epoch history Excel and CSV
        csv_path, xlsx_path = save_epoch_history_excel(history, out_dir, filename_base="epoch_history")

        # Save separate plots: loss, acc, kappa
        save_loss_plot(history, os.path.join(out_dir, "loss"))
        save_acc_plot(history, os.path.join(out_dir, "accuracy"))
        save_kappa_plot(history, os.path.join(out_dir, "kappa"))

        # Save combined loss+acc (keep original combined layout for compatibility)
        save_combined_loss_acc(history["loss"], history["val_loss"], history["train_acc"], history["val_acc"], os.path.join(out_dir, "loss_acc"))
        save_combined_kappa(history["train_kappa"], history["val_kappa"], os.path.join(out_dir, "kappa_combined"))

        # Save confusion matrix (final test)
        save_confusion_matrix(y_true_final, y_pred_final, os.path.join(out_dir, "confusion_matrix"), title=f"Confusion Matrix Subject {subj_id}")

        # Save ROC curve (using final test y_true_final / y_prob_final)
        roc_info = save_roc_curve(y_true_final, y_prob_final, os.path.join(out_dir, "roc_curve"), title=f"ROC Curve Subject {subj_id}")
        if roc_info is not None:
            roc_records.append((subj_id, roc_info["fpr"], roc_info["tpr"], roc_info["roc_auc"]))

        # Save summary metrics (final test) as JSON & CSV
        final_metrics = compute_metrics_from_preds(y_true_final, y_pred_final, y_prob_final)
        save_metrics_json(final_metrics, os.path.join(out_dir, f"summary_metrics_final_subject_{subj_id}.json"))
        pd.DataFrame([final_metrics]).to_csv(os.path.join(out_dir, f"summary_metrics_final_subject_{subj_id}.csv"), index=False)

        # Store metrics
        all_subject_metrics.append(final_metrics)
        all_subject_y_true.extend(y_true_final)
        all_subject_y_pred.extend(y_pred_final)
        all_subject_y_prob.extend(y_prob_final)

        # Keep history and final metrics for combined plotting
        histories_by_subject[subj_id] = history
        final_metrics_by_subject[subj_id] = final_metrics

        # Train-final epoch accuracy (from history) and test accuracy
        train_acc_final_epoch = history["train_acc"][-1] if len(history["train_acc"])>0 else None
        test_acc = final_metrics["accuracy"]
        train_vs_test_acc.append({"subject": subj_id, "train_final_acc": train_acc_final_epoch, "test_acc": test_acc})

        # Summary table row
        row = {"subject": subj_id}
        for k,v in final_metrics.items():
            row[k] = v
        summary_rows.append(row)

    # -------------------
    # Compute combined metrics for all subjects
    print("\n=== Computing Combined Metrics Across All Subjects ===\n")
    combined_metrics = compute_metrics_from_preds(all_subject_y_true, all_subject_y_pred, all_subject_y_prob)
    combined_dir = "results_combined"
    os.makedirs(combined_dir, exist_ok=True)

    # Save combined metrics CSV/JSON
    pd.DataFrame([combined_metrics]).to_csv(os.path.join(combined_dir, "metrics_combined.csv"), index=False)
    with open(os.path.join(combined_dir, "metrics_combined.json"), "w") as f:
        json.dump(combined_metrics, f, indent=4)

    # Combined confusion matrix
    save_confusion_matrix(all_subject_y_true, all_subject_y_pred, os.path.join(combined_dir, "confusion_matrix_combined"), title="Combined Confusion Matrix")

    # Combined ROC plot
    roc_info_combined = None
    if combined_metrics["roc_auc"] is not None:
        roc_info_combined = save_roc_curve(all_subject_y_true, all_subject_y_prob, os.path.join(combined_dir, "roc_curve_combined"), title="Combined ROC Curve")

    # -------------------
    # Save clean summary table PNG + CSV/Excel
    summary_df = pd.DataFrame(summary_rows).sort_values("subject")
    combined_row = {"subject": "combined"}
    for k,v in combined_metrics.items():
        combined_row[k] = v
    summary_df_with_combined = pd.concat([summary_df, pd.DataFrame([combined_row])], ignore_index=True, sort=False)

    # CSV & Excel
    summary_df_with_combined.to_csv(os.path.join(combined_dir, "summary_table_all_subjects.csv"), index=False)
    try:
        with pd.ExcelWriter(os.path.join(combined_dir, "summary_table_all_subjects.xlsx"), engine="openpyxl") as writer:
            summary_df_with_combined.to_excel(writer, sheet_name="summary", index=False)
    except Exception as e:
        print("Warning: to_excel for summary table failed:", e)

    # # PNG (clean)
    # try:
    #     save_summary_table_png(summary_df_with_combined, os.path.join(combined_dir, "summary_table_all_subjects"),
    #                            title="Summary metrics per subject and combined")
    # except Exception as e:
    #     print("Warning: saving PNG of summary table failed:", e)

    # -------------------
    # -------------------
    # Save combined training curves for all subjects (train/val loss, acc, kappa)
    metrics = ["loss","val_loss","train_acc","val_acc","train_kappa","val_kappa"]
    max_epochs = max([len(h.get("loss",[])) for h in histories_by_subject.values()]) if len(histories_by_subject)>0 else 0
    combined_stats = {}
    for m in metrics:
        stacked = []
        for subj_id, hist in histories_by_subject.items():
            arr = hist.get(m, [])
            arr_padded = arr + [np.nan]*(max_epochs - len(arr))
            stacked.append(arr_padded)
        if len(stacked) > 0:
            stacked_arr = np.vstack(stacked)
            mean = np.nanmean(stacked_arr, axis=0)
            std = np.nanstd(stacked_arr, axis=0)
        else:
            mean = []
            std = []
        combined_stats[m + "_mean"] = mean
        combined_stats[m + "_std"] = std
    df_combined_stats = pd.DataFrame(combined_stats)

    # Save the combined training curves
    save_combined_subject_training_curves_simple(df_combined_stats, out_dir=combined_dir)

    # -------------------
    # Additional aggregated plots (multi-subject ROC, boxplot, train vs test acc) omitted for brevity
    # (they remain the same as your previous enhanced code)

    return all_subject_metrics, combined_metrics

# -------------------
# Cell 9: Run the pipeline
subject_ids = [1,2,3,4,5,6,7,8,9]  # example, replace with your actual subject IDs
all_subject_metrics, combined_metrics = run_all_subjects(subject_ids, num_epochs=100, batch_size=32, lr=1e-4, device="cuda")
print("\nPer-subject metrics:", all_subject_metrics)
print("\nCombined metrics:", combined_metrics)




=== Processing Subject 1 ===



  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7079 val_loss=0.6924 train_acc=0.4931 val_acc=0.5972 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7039 val_loss=0.6952 train_acc=0.5208 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7552 val_loss=0.6921 train_acc=0.4375 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7273 val_loss=0.7222 train_acc=0.5139 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7533 val_loss=0.7177 train_acc=0.4514 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7020 val_loss=0.6962 train_acc=0.5208 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7689 val_loss=0.7091 train_acc=0.4236 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7103 val_loss=0.6990 train_acc=0.5139 val_acc=0.5000 tr

  warn('Preprocessing choices with lambda functions cannot be saved.')


Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
Used Annotations descriptions: ['feet', 'left_hand', 'right_hand', 'tongue']
[Epoch 1/100] loss=0.7107 val_loss=0.7072 train_acc=0.4861 val_acc=0.5000 tr