In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CTNet with Mamba SSM instead of Transformer:
Fully replicates your EEG-based motor imagery classification pipeline,
but replaces the Transformer encoder blocks with Mamba SSM blocks.

Author: zhaowei701@163.com
"""

import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'
import numpy as np
import pandas as pd
import random
import datetime
import time

from pandas import ExcelWriter
from torchsummary import summary
import torch
from torch.backends import cudnn
from utils import calMetrics
from utils import calculatePerClass
from utils import numberClassChannel
import math
import warnings
warnings.filterwarnings("ignore")
cudnn.benchmark = False
cudnn.deterministic = True

import torch
from torch import nn
from torch import Tensor
from einops.layers.torch import Rearrange, Reduce
from einops import rearrange, reduce, repeat
import torch.nn.functional as F

from utils import numberClassChannel
from utils import load_data_evaluate

import numpy as np
import pandas as pd
from torch.autograd import Variable
import sys

######################################
# Try importing Mamba from mamba_ssm
######################################
try:
    from mamba_ssm import Mamba
except ImportError:
    class Mamba(nn.Module):
        """
        Dummy fallback if mamba_ssm is not installed.
        This simply acts like a linear layer for demonstration.
        """
        def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.linear = nn.Linear(d_model, d_model)
        def forward(self, x):
            # x shape: (batch, seq_len, d_model)
            return self.linear(x)

######################################
# CNN Embedding (Patch Embedding)
######################################
class PatchEmbeddingCNN(nn.Module):
    """
    Same as your original PatchEmbeddingCNN,
    but no changes needed for Mamba-based architecture.
    """
    def __init__(self, f1=16, kernel_size=64, D=2, pooling_size1=8, pooling_size2=8, 
                 dropout_rate=0.3, number_channel=22, emb_size=40):
        super().__init__()
        f2 = D * f1
        self.cnn_module = nn.Sequential(
            # Temporal conv
            nn.Conv2d(1, f1, (1, kernel_size), (1, 1), padding='same', bias=False),
            nn.BatchNorm2d(f1),
            # Depthwise channel conv
            nn.Conv2d(f1, f2, (number_channel, 1), (1, 1), groups=f1, padding='valid', bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            # Average pooling 1
            nn.AvgPool2d((1, pooling_size1)),
            nn.Dropout(dropout_rate),
            # Additional spatial conv
            nn.Conv2d(f2, f2, (1, 16), padding='same', bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            # Average pooling 2
            nn.AvgPool2d((1, pooling_size2)),
            nn.Dropout(dropout_rate),
        )

        # Rearrange to (batch, seq_len, emb_size)
        self.projection = nn.Sequential(
            Rearrange('b e h w -> b (h w) e'),
        )

    def forward(self, x: Tensor) -> Tensor:
        x = self.cnn_module(x)           # -> (batch, f2, 1, new_time)
        x = self.projection(x)           # -> (batch, seq_len, f2)
        return x

######################################
# Simple Classification Head
######################################
class ClassificationHead(nn.Sequential):
    def __init__(self, flatten_number, n_classes):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(flatten_number, n_classes)
        )
    def forward(self, x):
        return self.fc(x)

######################################
# Positional Encoding (Optional)
######################################
class PositioinalEncoding(nn.Module):
    """
    Same as your original positional encoding, if you want to
    preserve the idea of 'time tokens' with added trainable positions.
    """
    def __init__(self, embedding, length=100, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        # trainable positional embeddings
        self.encoding = nn.Parameter(torch.randn(1, length, embedding))
    def forward(self, x):  # x -> [batch, seq_len, embedding]
        # Add position enc up to x.shape[1]
        x = x + self.encoding[:, : x.shape[1], :].cuda()
        return self.dropout(x)

######################################
# Mamba Encoder Blocks
######################################
class MambaEncoderBlock(nn.Module):
    """
    One Mamba block that mirrors the Transformer sub-block structure:
      1) LN -> Mamba -> Dropout -> residual
      2) LN -> FeedForward -> Dropout -> residual
    """
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dropout=0.5, ff_expansion=4):
        super().__init__()
        # Sub-block 1: Mamba
        self.mamba_norm = nn.LayerNorm(d_model)
        self.mamba = Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand)
        self.mamba_dropout = nn.Dropout(dropout)

        # Sub-block 2: Feedforward
        self.ffn_norm = nn.LayerNorm(d_model)
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, ff_expansion * d_model),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(ff_expansion * d_model, d_model),
        )
        self.ffn_dropout = nn.Dropout(dropout)

    def forward(self, x: Tensor) -> Tensor:
        # (1) LN -> Mamba -> Dropout -> residual
        y = self.mamba_norm(x)
        y = self.mamba(y)
        y = self.mamba_dropout(y)
        x = x + y

        # (2) LN -> Feedforward -> Dropout -> residual
        z = self.ffn_norm(x)
        z = self.feedforward(z)
        z = self.ffn_dropout(z)
        x = x + z
        return x

class MambaEncoder(nn.Module):
    """
    Stacks multiple MambaEncoderBlocks.
    """
    def __init__(self, depth=6, d_model=40, d_state=16, d_conv=4, expand=2, 
                 dropout=0.5, ff_expansion=4):
        super().__init__()
        self.layers = nn.ModuleList([
            MambaEncoderBlock(
                d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand,
                dropout=dropout, ff_expansion=ff_expansion
            )
            for _ in range(depth)
        ])

    def forward(self, x: Tensor) -> Tensor:
        # x shape: (batch, seq_len, d_model)
        for layer in self.layers:
            x = layer(x)
        return x

######################################
# "EEGMamba" Replacing the Transformer
######################################
class EEGMamba(nn.Module):
    """
    This replicates your EEGTransformer, but uses Mamba blocks instead of
    Transformer blocks (multi-head attention).
    """
    def __init__(self, 
                 emb_size=40,
                 depth=6,
                 database_type='A',
                 eeg1_f1=20,
                 eeg1_kernel_size=64,
                 eeg1_D=2,
                 eeg1_pooling_size1=8,
                 eeg1_pooling_size2=8,
                 eeg1_dropout_rate=0.3,
                 eeg1_number_channel=22,
                 flatten_eeg1=600,
                 d_state=16, 
                 d_conv=4,
                 expand=2,
                 ff_expansion=4,
                 dropout_mamba=0.5,
                 **kwargs):
        super().__init__()
        # number of classes
        self.number_class, _ = numberClassChannel(database_type)
        self.emb_size = emb_size
        self.flatten_eeg1 = flatten_eeg1

        # CNN embedding
        self.cnn = PatchEmbeddingCNN(
            f1=eeg1_f1,
            kernel_size=eeg1_kernel_size,
            D=eeg1_D,
            pooling_size1=eeg1_pooling_size1,
            pooling_size2=eeg1_pooling_size2,
            dropout_rate=eeg1_dropout_rate,
            number_channel=eeg1_number_channel,
            emb_size=emb_size
        )

        # Positional encoding
        self.position = PositioinalEncoding(emb_size, dropout=0.1)

        # Mamba Encoder
        self.mamba_encoder = MambaEncoder(
            depth=depth,
            d_model=emb_size,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
            dropout=dropout_mamba,
            ff_expansion=ff_expansion
        )

        # Flatten + Classification
        self.flatten = nn.Flatten()
        self.classification = ClassificationHead(self.flatten_eeg1, self.number_class)

    def forward(self, x):
        # CNN embedding -> (batch, seq_len, emb_size)
        cnn = self.cnn(x)

        # scale
        cnn = cnn * math.sqrt(self.emb_size)

        # add positional encoding
        cnn = self.position(cnn)

        # Mamba
        features = self.mamba_encoder(cnn)

        # final classification
        out = self.classification(self.flatten(features))
        return features, out

######################################
# The rest of your pipeline remains unchanged
######################################

class ExP():
    def __init__(self, nsub, data_dir, result_name, 
                 epochs=2000, 
                 number_aug=2,
                 number_seg=8, 
                 gpus=[0], 
                 evaluate_mode = 'subject-dependent',
                 heads=4, 
                 emb_size=40,
                 depth=6, 
                 dataset_type='A',
                 eeg1_f1 = 20,
                 eeg1_kernel_size = 64,
                 eeg1_D = 2,
                 eeg1_pooling_size1 = 8,
                 eeg1_pooling_size2 = 8,
                 eeg1_dropout_rate = 0.3,
                 flatten_eeg1 = 600, 
                 validate_ratio = 0.2,
                 learning_rate = 0.001,
                 batch_size = 72,  
                 ):
        
        super(ExP, self).__init__()
        self.dataset_type = dataset_type
        self.batch_size = batch_size
        self.lr = learning_rate
        self.b1 = 0.5
        self.b2 = 0.999
        self.n_epochs = epochs
        self.nSub = nsub
        self.number_augmentation = number_aug
        self.number_seg = number_seg
        self.root = data_dir
        self.heads = heads
        self.emb_size = emb_size
        self.depth = depth
        self.result_name = result_name
        self.evaluate_mode = evaluate_mode
        self.validate_ratio = validate_ratio

        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.number_class, self.number_channel = numberClassChannel(self.dataset_type)
        
        #  >>> REPLACE EEGTransformer with EEGMamba <<<
        self.model = EEGMamba(
            emb_size=self.emb_size,
            depth=self.depth, 
            database_type=self.dataset_type,
            eeg1_f1=eeg1_f1,
            eeg1_kernel_size=eeg1_kernel_size,
            eeg1_D=eeg1_D,
            eeg1_pooling_size1=eeg1_pooling_size1,
            eeg1_pooling_size2=eeg1_pooling_size2,
            eeg1_dropout_rate=eeg1_dropout_rate,
            eeg1_number_channel=self.number_channel,
            flatten_eeg1=flatten_eeg1,
        ).cuda()
        self.model_filename = self.result_name + '/model_{}.pth'.format(self.nSub)


    def interaug(self, timg, label):  
        aug_data = []
        aug_label = []
        number_records_by_augmentation = self.number_augmentation * int(self.batch_size / self.number_class)
        number_segmentation_points = 1000 // self.number_seg
        for clsAug in range(self.number_class):
            cls_idx = np.where(label == clsAug + 1)
            tmp_data = timg[cls_idx]
            tmp_label = label[cls_idx]
            
            tmp_aug_data = np.zeros((number_records_by_augmentation, 1, self.number_channel, 1000))
            for ri in range(number_records_by_augmentation):
                for rj in range(self.number_seg):
                    rand_idx = np.random.randint(0, tmp_data.shape[0], self.number_seg)
                    tmp_aug_data[ri, :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points] = \
                        tmp_data[rand_idx[rj], :, :, rj * number_segmentation_points:(rj + 1) * number_segmentation_points]

            aug_data.append(tmp_aug_data)
            aug_label.append(tmp_label[:number_records_by_augmentation])
        aug_data = np.concatenate(aug_data)
        aug_label = np.concatenate(aug_label)
        aug_shuffle = np.random.permutation(len(aug_data))
        aug_data = aug_data[aug_shuffle, :, :]
        aug_label = aug_label[aug_shuffle]

        aug_data = torch.from_numpy(aug_data).cuda().float()
        aug_label = torch.from_numpy(aug_label-1).cuda().long()
        return aug_data, aug_label

    def get_source_data(self):
        (self.train_data,
         self.train_label, 
         self.test_data, 
         self.test_label) = load_data_evaluate(self.root, self.dataset_type, self.nSub, mode_evaluate=self.evaluate_mode)

        self.train_data = np.expand_dims(self.train_data, axis=1)
        self.train_label = np.transpose(self.train_label)
        self.allData = self.train_data
        self.allLabel = self.train_label[0]

        shuffle_num = np.random.permutation(len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]
        self.allLabel = self.allLabel[shuffle_num]

        print('-'*20, "train size：", self.train_data.shape, "test size：", self.test_data.shape)
        self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label[0]

        # standardize
        target_mean = np.mean(self.allData)
        target_std = np.std(self.allData)
        self.allData = (self.allData - target_mean) / target_std
        self.testData = (self.testData - target_mean) / target_std
        
        return self.allData, self.allLabel, self.testData, self.testLabel

    def train(self):
        img, label, test_data, test_label = self.get_source_data()
        
        img = torch.from_numpy(img)
        label = torch.from_numpy(label - 1)
        dataset = torch.utils.data.TensorDataset(img, label)
        
        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label - 1)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=False)

        # Optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))
        best_epoch = 0
        num = 0
        min_loss = 100
        result_process = []
        for e in range(self.n_epochs):
            self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)
            epoch_process = {}
            epoch_process['epoch'] = e
            self.model.train()
            outputs_list = []
            label_list = []

            val_data_list = []
            val_label_list = []

            for i, (img_batch, label_batch) in enumerate(self.dataloader):
                number_sample = img_batch.shape[0]
                number_validate = int(self.validate_ratio * number_sample)

                # split real train / validate
                train_data = img_batch[:-number_validate]
                train_label = label_batch[:-number_validate]

                val_data_list.append(img_batch[number_validate:])
                val_label_list.append(label_batch[number_validate:])

                img_batch = Variable(train_data.type(self.Tensor))
                label_batch = Variable(train_label.type(self.LongTensor))

                # data augmentation
                aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # concat
                img_batch = torch.cat((img_batch, aug_data))
                label_batch = torch.cat((label_batch, aug_label))

                features, outputs = self.model(img_batch)
                outputs_list.append(outputs)
                label_list.append(label_batch)
                loss = self.criterion_cls(outputs, label_batch)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            # validate
            self.model.eval()
            val_data = torch.cat(val_data_list).cuda()
            val_label = torch.cat(val_label_list).cuda()
            val_data = val_data.type(self.Tensor)
            val_label = val_label.type(self.LongTensor)

            val_dataset = torch.utils.data.TensorDataset(val_data, val_label)
            val_dataloader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=self.batch_size, shuffle=False)
            outputs_list = []
            with torch.no_grad():
                for i, (img_v, _) in enumerate(val_dataloader):
                    img_v = img_v.type(self.Tensor).cuda()
                    _, Cls = self.model(img_v)
                    outputs_list.append(Cls)
            Cls = torch.cat(outputs_list)

            val_loss = self.criterion_cls(Cls, val_label)
            val_pred = torch.max(Cls, 1)[1]
            val_acc = float((val_pred == val_label).cpu().numpy().sum()) / float(val_label.size(0))

            epoch_process['val_acc'] = val_acc
            epoch_process['val_loss'] = val_loss.detach().cpu().numpy()

            # compute training acc from last batch
            train_pred = torch.max(outputs, 1)[1]
            train_acc = float((train_pred == label_batch).cpu().numpy().sum()) / float(label_batch.size(0))
            epoch_process['train_acc'] = train_acc
            epoch_process['train_loss'] = loss.detach().cpu().numpy()

            if val_loss < min_loss:
                min_loss = val_loss
                best_epoch = e
                torch.save(self.model, self.model_filename)
                print("{}_{} train_acc: {:.4f} train_loss: {:.6f}\tval_acc: {:.6f} val_loss: {:.7f}".format(
                    self.nSub, epoch_process['epoch'], epoch_process['train_acc'],
                    epoch_process['train_loss'], epoch_process['val_acc'], epoch_process['val_loss']
                ))

            result_process.append(epoch_process)

        # load best model for final test
        self.model.eval()
        from torch.serialization import safe_globals
        with safe_globals([EEGMamba, PatchEmbeddingCNN, PositioinalEncoding, ClassificationHead, 
                            MambaEncoder, MambaEncoderBlock, nn.Sequential]):
            self.model = torch.load(self.model_filename, weights_only=False)
        self.model = self.model.cuda()

        outputs_list = []
        with torch.no_grad():
            for i, (img_batch, label_batch) in enumerate(self.test_dataloader):
                img_test = Variable(img_batch.type(self.Tensor)).cuda()
                features, outputs = self.model(img_test)
                outputs_list.append(outputs)
        outputs = torch.cat(outputs_list)
        y_pred = torch.max(outputs, 1)[1]

        test_acc = float((y_pred == test_label).cpu().numpy().sum()) / float(test_label.size(0))
        print("epoch: ", best_epoch, '\tThe test accuracy is:', test_acc)

        df_process = pd.DataFrame(result_process)
        return test_acc, test_label, y_pred, df_process, best_epoch


def main(dirs,                
         evaluate_mode = 'subject-dependent',
         heads=8,           
         emb_size=48,       
         depth=3,           
         dataset_type='A',  
         eeg1_f1=20,
         eeg1_kernel_size=64,
         eeg1_D=2,
         eeg1_pooling_size1=8,
         eeg1_pooling_size2=8,
         eeg1_dropout_rate=0.3,
         flatten_eeg1=600,   
         validate_ratio=0.2
         ):

    if not os.path.exists(dirs):
        os.makedirs(dirs)

    result_write_metric = ExcelWriter(dirs+"/result_metric.xlsx")
    process_write = ExcelWriter(dirs+"/process_train.xlsx")
    pred_true_write = ExcelWriter(dirs+"/pred_true.xlsx")

    result_metric_dict = {}
    y_true_pred_dict = {}
    subjects_result = []
    best_epochs = []

    for i in range(N_SUBJECT):
        starttime = datetime.datetime.now()
        seed_n = np.random.randint(2024)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)

        print('Subject %d' % (i+1))
        exp = ExP(i + 1, DATA_DIR, dirs, EPOCHS, N_AUG, N_SEG, gpus,
                  evaluate_mode = evaluate_mode,
                  heads=heads,
                  emb_size=emb_size,
                  depth=depth,
                  dataset_type=dataset_type,
                  eeg1_f1=eeg1_f1,
                  eeg1_kernel_size=eeg1_kernel_size,
                  eeg1_D=eeg1_D,
                  eeg1_pooling_size1=eeg1_pooling_size1,
                  eeg1_pooling_size2=eeg1_pooling_size2,
                  eeg1_dropout_rate=eeg1_dropout_rate,
                  flatten_eeg1=flatten_eeg1,
                  validate_ratio=validate_ratio
                  )
        testAcc, Y_true, Y_pred, df_process, best_epoch = exp.train()
        true_cpu = Y_true.cpu().numpy().astype(int)
        pred_cpu = Y_pred.cpu().numpy().astype(int)
        df_pred_true = pd.DataFrame({'pred': pred_cpu, 'true': true_cpu})
        df_pred_true.to_excel(pred_true_write, sheet_name=str(i+1))
        y_true_pred_dict[i] = df_pred_true

        accuracy, precison, recall, f1, kappa = calMetrics(true_cpu, pred_cpu)
        subject_result = {
            'accuray': accuracy*100,
            'precision': precison*100,
            'recall': recall*100,
            'f1': f1*100,
            'kappa': kappa*100
        }
        subjects_result.append(subject_result)
        df_process.to_excel(process_write, sheet_name=str(i+1))
        best_epochs.append(best_epoch)
    
        print(' THE BEST ACCURACY IS ' + str(testAcc) + "\tkappa is " + str(kappa) )
        endtime = datetime.datetime.now()
        print('subject %d duration: '%(i+1) + str(endtime - starttime))

        if i == 0:
            yt = Y_true
            yp = Y_pred
        else:
            yt = torch.cat((yt, Y_true))
            yp = torch.cat((yp, Y_pred))

    df_result = pd.DataFrame(subjects_result)
    process_write.close()
    pred_true_write.close()

    print('**The average Best accuracy is: ' + str(df_result['accuray'].mean()) + 
          "kappa is: " + str(df_result['kappa'].mean()) + "\n" )
    print("best epochs: ", best_epochs)

    mean = df_result.mean(axis=0)
    mean.name = 'mean'
    std = df_result.std(axis=0)
    std.name = 'std'
    df_result = pd.concat([df_result, pd.DataFrame(mean).T, pd.DataFrame(std).T])
    df_result.to_excel(result_write_metric, index=False)

    print('-'*9, ' all result ', '-'*9)
    print(df_result)
    print("*"*40)
    result_write_metric.close()

    return df_result


if __name__ == "__main__":
    # Example usage
    DATA_DIR = r'bci2a/'
    EVALUATE_MODE = 'LOSO-No'
    N_SUBJECT = 9
    N_AUG = 3
    N_SEG = 8
    EPOCHS = 1000
    EMB_DIM = 16
    HEADS = 2
    DEPTH = 6
    TYPE = 'B'
    validate_ratio = 0.3

    EEGNet1_F1 = 8
    EEGNet1_KERNEL_SIZE=64
    EEGNet1_D=2
    EEGNet1_POOL_SIZE1 = 8
    EEGNet1_POOL_SIZE2 = 8
    FLATTEN_EEGNet1 = 240
    if EVALUATE_MODE!='LOSO':
        EEGNet1_DROPOUT_RATE = 0.5
    else:
        EEGNet1_DROPOUT_RATE = 0.25    

    parameters_list = ['A']
    for TYPE in parameters_list:
        number_class, number_channel = numberClassChannel(TYPE)
        RESULT_NAME = "CTNetMamba_{}_heads_{}_depth_{}_{}".format(TYPE, HEADS, DEPTH, int(time.time()))

        sModel = EEGMamba(
            emb_size=EMB_DIM,
            depth=DEPTH,
            database_type=TYPE,
            eeg1_f1=EEGNet1_F1,
            eeg1_D=EEGNet1_D,
            eeg1_kernel_size=EEGNet1_KERNEL_SIZE,
            eeg1_pooling_size1=EEGNet1_POOL_SIZE1,
            eeg1_pooling_size2=EEGNet1_POOL_SIZE2,
            eeg1_dropout_rate=EEGNet1_DROPOUT_RATE,
            eeg1_number_channel=number_channel,
            flatten_eeg1=FLATTEN_EEGNet1,
        ).cuda()
        summary(sModel, (1, number_channel, 1000))

        print(time.asctime(time.localtime(time.time())))
        result = main(RESULT_NAME,
                      evaluate_mode=EVALUATE_MODE,
                      heads=HEADS,
                      emb_size=EMB_DIM,
                      depth=DEPTH,
                      dataset_type=TYPE,
                      eeg1_f1=EEGNet1_F1,
                      eeg1_kernel_size=EEGNet1_KERNEL_SIZE,
                      eeg1_D=EEGNet1_D,
                      eeg1_pooling_size1=EEGNet1_POOL_SIZE1,
                      eeg1_pooling_size2=EEGNet1_POOL_SIZE2,
                      eeg1_dropout_rate=EEGNet1_DROPOUT_RATE,
                      flatten_eeg1=FLATTEN_EEGNet1,
                      validate_ratio=validate_ratio)
        print(time.asctime(time.localtime(time.time())))


----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1          [-1, 8, 22, 1000]             512
       BatchNorm2d-2          [-1, 8, 22, 1000]              16
            Conv2d-3          [-1, 16, 1, 1000]             352
       BatchNorm2d-4          [-1, 16, 1, 1000]              32
               ELU-5          [-1, 16, 1, 1000]               0
         AvgPool2d-6           [-1, 16, 1, 125]               0
           Dropout-7           [-1, 16, 1, 125]               0
            Conv2d-8           [-1, 16, 1, 125]           4,096
       BatchNorm2d-9           [-1, 16, 1, 125]              32
              ELU-10           [-1, 16, 1, 125]               0
        AvgPool2d-11            [-1, 16, 1, 15]               0
          Dropout-12            [-1, 16, 1, 15]               0
        Rearrange-13               [-1, 15, 16]               0
PatchEmbeddingCNN-14               [-1,

KeyboardInterrupt: 