In [1]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CTNet-Mamba: A Convolution-Mamba Network for EEG-Based Motor Imagery Classification

author: zhaowei701@163.com
"""
import math

import os
import time
import random
from barebones_hymba.barebones_hymba_block import HymbaBlock
import flash_attn

import numpy as np
from zeta.nn import MultiQueryAttention
from vision_mamba import Vim

import torch
import torch.nn as nn
from torch.backends import cudnn
from torchsummary import summary

from einops import rearrange
from einops.layers.torch import Rearrange

from utils import load_data_evaluate, numberClassChannel, calMetrics

# Try to import real MambaBlock (fallback if missing)
try:
    from mamba_ssm import Mamba
except ImportError:
    class Mamba(nn.Module):
        def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.linear = nn.Linear(d_model, d_model)
        def forward(self, x):
            return self.linear(x)

from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention
from torchinfo import summary



# -----------------------------------------------------------------------------
# PatchEmbeddingCNN
# -----------------------------------------------------------------------------
class PatchEmbeddingCNN(nn.Module):
    def __init__(self, f1=16, kernel_size=64, D=2,
                 pooling_size1=8, pooling_size2=8,
                 dropout_rate=0.3, number_channel=22, emb_size=40):
        super().__init__()
        f2 = D * f1
        self.cnn_module = nn.Sequential(
            nn.Conv2d(1, f1, (1, kernel_size), padding='same', bias=False),
            nn.BatchNorm2d(f1),
            nn.Conv2d(f1, f2, (number_channel, 1), groups=f1, bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pooling_size1)),
            nn.Dropout(dropout_rate),
            nn.Conv2d(f2, f2, (1, 16), padding='same', bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pooling_size2)),
            nn.Dropout(dropout_rate),
        )
        self.projection = Rearrange('b e h w -> b (h w) e')

    def forward(self, x):
        x = self.cnn_module(x)        # [B, f2, 1, seq]
        return self.projection(x)     # [B, seq, f2]

# -----------------------------------------------------------------------------
# LinearAttention
# -----------------------------------------------------------------------------
def exists(val): return val is not None

class LinearAttention(nn.Module):
    def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0):
        super().__init__()
        inner = heads * dim_head
        self.heads, self.scale = heads, dim_head**-0.5
        self.to_qkv = nn.Linear(dim, inner*3, bias=False)
        self.to_out = nn.Sequential(nn.Linear(inner, dim), nn.Dropout(dropout))

    def forward(self, x, mask=None):
        h = self.heads
        q,k,v = self.to_qkv(x).chunk(3, dim=-1)
        q,k,v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q,k,v))
        q, k = q*self.scale, k
        q, k = q.softmax(dim=-1), k.softmax(dim=-2)
        if exists(mask):
            mask = rearrange(mask, 'b n -> (b h) n', h=h)
            k = k.masked_fill(~mask.unsqueeze(-1), 0.)
        ctx = torch.einsum('b n d, b n e -> b d e', q, k)
        out = torch.einsum('b d e, b n d -> b n e', ctx, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

# -----------------------------------------------------------------------------
# TransformerBlock
# -----------------------------------------------------------------------------
import torch
import torch.nn as nn

class MambaTransformerblock(nn.Module):
    def __init__(self, dim, heads, dim_head, dropout=0.1, 
                 ff_mult=4, d_state=16, depth=4):        
        super().__init__()
        self.layers = nn.ModuleList([
            self._create_layer(dim, heads, ff_mult, d_state, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)

    def _create_layer(self, dim, heads, ff_mult, d_state, dropout):
        return nn.ModuleDict({
            # 1) Mamba
            'mamba_norm':   nn.LayerNorm(dim),
            'mamba':        MambaBlock(dim, d_state=d_state),
            'mamba_dropout':nn.Dropout(dropout),

            # 2) PyTorch MultiheadAttention
            'attn_norm':    nn.LayerNorm(dim),
            'attn':         nn.MultiheadAttention(embed_dim=dim,
                                                  num_heads=heads,
                                                  dropout=dropout,
                                                  batch_first=True),
            'attn_dropout': nn.Dropout(dropout),

            # 3) FFN
            'ffn_norm':     nn.LayerNorm(dim),
            'feedforward':  nn.Sequential(
                                 nn.Linear(dim, ff_mult * dim),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(ff_mult * dim, dim)
                             ),
            'ffn_dropout':  nn.Dropout(dropout),
        })

    def forward(self, x):
        # x: [B, L, dim]
        for layer in self.layers:
            # (1) Mamba sublayer
            m = layer['mamba_norm'](x)
            x = x + layer['mamba_dropout'](layer['mamba'](m))

            # (2) Multi-head self-attention sublayer
            a = layer['attn_norm'](x)
            # because of the monkey-patch, this returns a Tensor directly
            attn_out = layer['attn'](a, a, a, need_weights=False)
            x = x + layer['attn_dropout'](attn_out)

            # (3) Feed-forward sublayer
            f = layer['ffn_norm'](x)
            x = x + layer['ffn_dropout'](layer['feedforward'](f))

        return self.norm(x)



class SummaryWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = x.unsqueeze(1)  # [B, C, T] ➜ [B, 1, C, T]
        _, out = self.model(x)
        return out


# -----------------------------------------------------------------------------
# EEGMambaTransformer
# -----------------------------------------------------------------------------
class EEGMambaTransformer(nn.Module):
    def __init__(self, num_classes=4, image_size=(32, 32)):
        super().__init__()
        self.num_classes = num_classes
        self.number_class, self.number_channel = numberClassChannel('A')

        self.model = Vim(
            dim=256,
            dt_rank=32,
            dim_inner=256,
            d_state=256,
            num_classes=num_classes,
            image_size=image_size,  # 2D image size
            patch_size=8,
            channels=3,
            dropout=0.1,
            depth=6,
        )

    def forward(self, x):
        # x: [B, 1, channels, time] -> need to reshape to [B, 3, H, W]
        B = x.size(0)
        x = x.squeeze(1)  # [B, channels, time]
        
        # Reshape to (B, 3, H, W)
        # Here, you can map EEG channels and time into a fake 2D layout
        # For simplicity, let's reshape (22,1000) into (32,32), fill zeros if needed
        x = torch.nn.functional.interpolate(x.unsqueeze(1), size=(32, 32), mode='bilinear', align_corners=False)  # [B, 1, 32, 32]
        x = x.repeat(1, 3, 1, 1)  # [B, 3, 32, 32]

        return None, self.model(x)

# -----------------------------------------------------------------------------
# Experiment wrapper
# -----------------------------------------------------------------------------
class ExP:
    def __init__(self, nsub, data_dir, result_name,
                 epochs=300, number_aug=3, number_seg=8,
                 evaluate_mode='subject-dependent',
                 heads=2, emb_size=16, depth=6,
                 d_state=16, transformer_depth=1, mamba_depth=3,
                 dataset_type='A', eeg1_f1=8, eeg1_kernel_size=64,
                 eeg1_D=2, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
                 eeg1_dropout_rate=0.5, flatten_eeg1=15*16,
                 validate_ratio=0.3, learning_rate=1e-3, batch_size=72,
                 early_stopping_patience=100 ):
        self.nSub = nsub
        self.dataset_type = dataset_type
        self.data_dir = data_dir
        self.result_name = result_name
        self.n_epochs = epochs
        self.patience = early_stopping_patience
        self.no_improve = 0
        self.best_loss = float('inf')
        self.batch_size = batch_size
        self.validate_ratio = validate_ratio
        self.number_class, self.number_channel = numberClassChannel('A')

        self.criterion = nn.CrossEntropyLoss().cuda()

        self.model = EEGMambaTransformer(
            num_classes=self.number_class,
            image_size=(32, 32)
        ).cuda()


        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        os.makedirs(self.result_name, exist_ok=True)
        self.model_filename = os.path.join(self.result_name, f"model_{self.nSub}.pth")

    def interaug(self, timg, label):
        aug_data, aug_label = [], []
        recs = 3 * (self.batch_size // self.model.number_class)
        segpts = 1000 // 8
        for cls in range(self.model.number_class):
            idx = np.where(label == cls + 1)
            data, lbl = timg[idx], label[idx]
            tmp = np.zeros((recs,1,self.model.number_channel,1000), dtype=np.float32)
            for i in range(recs):
                for j in range(8):
                    ridx = random.randrange(data.shape[0])
                    tmp[i,0,:,j*segpts:(j+1)*segpts] = data[ridx,0,:,j*segpts:(j+1)*segpts]
            aug_data.append(tmp); aug_label.append(lbl[:recs])
        aug_data = np.concatenate(aug_data).astype(np.float32)
        aug_label = np.concatenate(aug_label)
        perm = np.random.permutation(len(aug_data))
        return (
            torch.from_numpy(aug_data[perm]).cuda(),
            torch.from_numpy((aug_label[perm]-1)).long().cuda()
        )

    def get_source_data(self):
        tr_x, tr_y, te_x, te_y = load_data_evaluate(
            self.data_dir, self.dataset_type, self.nSub,
            mode_evaluate='subject-dependent'
        )
        tr_x = np.expand_dims(tr_x, axis=1).astype(np.float32)
        te_x = np.expand_dims(te_x, axis=1).astype(np.float32)
        tr_y = tr_y.reshape(-1)
        te_y = te_y.reshape(-1)
        m, s = tr_x.mean(), tr_x.std()
        tr_x = (tr_x - m) / s
        te_x = (te_x - m) / s
        return tr_x, tr_y, te_x, te_y

    def train(self):
        tr_x, tr_y, te_x, te_y = self.get_source_data()
    
        # Create validation split
        dataset_size = len(tr_x)
        val_size = int(self.validate_ratio * dataset_size)
        train_size = dataset_size - val_size
        train_ds, val_ds = torch.utils.data.random_split(
            torch.utils.data.TensorDataset(torch.from_numpy(tr_x), torch.from_numpy(tr_y-1)),
            [train_size, val_size]
        )
    
        test_ds = torch.utils.data.TensorDataset(torch.from_numpy(te_x), torch.from_numpy(te_y-1))
        self.test_loader = torch.utils.data.DataLoader(test_ds, batch_size=self.batch_size)
    
        best_loss = float('inf')
        for e in range(self.n_epochs):
            # Training
            self.model.train()
            train_loader = torch.utils.data.DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
            
            # Initialize training metrics
            epoch_train_loss = 0.0
            train_correct = 0
            total_samples = 0
            
            # Start timing and memory tracking
            torch.cuda.synchronize()
            start_time = time.time()
            torch.cuda.reset_peak_memory_stats()
            
            for xb, yb in train_loader:
                xb, yb = xb.cuda().float(), yb.cuda().long()
                aug_x, aug_y = self.interaug(tr_x, tr_y)
                xb = torch.cat([xb, aug_x])
                yb = torch.cat([yb, aug_y])
                
                _, out = self.model(xb)
                loss = self.criterion(out, yb)
                
                # Backprop
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Track metrics
                epoch_train_loss += loss.item()
                preds = out.argmax(dim=1)
                train_correct += (preds == yb).sum().item()
                total_samples += yb.size(0)
    
            # Calculate training metrics
            train_loss = epoch_train_loss / len(train_loader)
            train_acc = train_correct / total_samples
            
            # Calculate memory and speed
            torch.cuda.synchronize()
            elapsed = time.time() - start_time
            mem_used = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB
            speed = total_samples / elapsed if elapsed > 0 else 0
    
            # Validation
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            with torch.no_grad():
                val_loader = torch.utils.data.DataLoader(val_ds, batch_size=self.batch_size)
                for xb, yb in val_loader:
                    xb, yb = xb.cuda().float(), yb.cuda().long()
                    _, out = self.model(xb)
                    loss = self.criterion(out, yb)
                    val_loss += loss.item()
                    val_correct += (out.argmax(1) == yb).sum().item()
    
            val_loss = val_loss / len(val_loader)
            val_acc = val_correct / len(val_ds)
    
            print(f"Epoch {e+1}/{self.n_epochs} | "
                  f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | "
                  f"Mem: {mem_used:.2f}MB | Speed: {speed:.2f} samples/s")
    
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), self.model_filename)
            else:
                self.no_improve += 1
                if self.no_improve >= self.patience:
                    print(f"Stopping early at epoch {e+1} (no improvement in {self.patience} epochs).")
                    break

    # Rest of test evaluation...

    # Rest of test evaluation remains the same...

        self.model.load_state_dict(torch.load(self.model_filename))
        self.model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for xb, yb in self.test_loader:
                xb = xb.cuda()
                _, out = self.model(xb)
                preds.append(out.argmax(1).cpu().numpy())
                trues.append(yb.numpy())
        preds = np.concatenate(preds); trues = np.concatenate(trues)
        acc, *_ = calMetrics(trues, preds)
        print(f"Subject {self.nSub} final accuracy: {acc:.4f}")
        return acc

# -----------------------------------------------------------------------------
def main(result_dir, DATA_DIR, N_SUBJECT, **cfg):
    os.makedirs(result_dir, exist_ok=True)

    number_class, number_channel = numberClassChannel(cfg['dataset_type'])
    sModel = EEGMambaTransformer(
        num_classes=number_class,
        image_size=(32, 32)
    ).cuda()


    summary(sModel, (1, number_channel, 1000))


    print(time.asctime(time.localtime(time.time())))

    accs = []
    for sub in range(1, N_SUBJECT+1):
        seed_n = np.random.randint(2024)
        random.seed(seed_n); np.random.seed(seed_n)
        torch.manual_seed(seed_n); torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)

        print(f"seed is {seed_n}")
        print(f"Subject {sub}")

        exp = ExP(sub, DATA_DIR, result_dir, **cfg)
        accs.append(exp.train())

    print("Average accuracy:", np.mean(accs))
    return accs

if __name__ == "__main__":
    cudnn.benchmark = False
    cudnn.deterministic = True

    CONFIG = dict(
        emb_size=128, depth=1, heads=4,
        d_state=32, transformer_depth=1, mamba_depth=2,
        dataset_type='A',
        eeg1_f1=8, eeg1_kernel_size=64,
        eeg1_D=16, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
        eeg1_dropout_rate=0.5, flatten_eeg1=15*128,
        epochs=1000, number_aug=3, number_seg=8,
        validate_ratio=0.3, learning_rate=1e-3, batch_size=72
    )

    DATA_DIR   = "bci2a/"
    RESULT_DIR = f"CTNet_Mamba_{int(time.time())}"
    N_SUBJECT  = 9

    main(RESULT_DIR, DATA_DIR, N_SUBJECT, **CONFIG)
    print("Done.")


  from .autonotebook import tqdm as notebook_tqdm


Patch embedding: torch.Size([1, 16, 256])
Cls tokens: torch.Size([1, 1, 256])
torch.Size([1, 16, 256])
Conv1d: tensor([[[0.5589, 0.5844, 0.5376,  ..., 0.6169, 0.6532, 0.6044],
         [0.6766, 0.6771, 0.6502,  ..., 0.6393, 0.7033, 0.6912],
         [0.5860, 0.6117, 0.6059,  ..., 0.7396, 0.5281, 0.5852],
         ...,
         [0.8196, 0.9678, 0.9183,  ..., 0.7591, 1.0497, 0.8515],
         [0.6529, 0.7233, 0.7075,  ..., 0.6982, 0.7302, 0.7037],
         [0.6631, 0.6507, 0.7012,  ..., 0.7229, 0.6045, 0.6367]]],
       device='cuda:0')
Conv1d: tensor([[[0.7498, 0.8072, 0.7791,  ..., 0.6276, 0.8242, 0.7291],
         [0.8065, 0.7854, 0.8366,  ..., 0.8757, 0.9061, 0.7766],
         [0.7791, 0.7467, 0.7335,  ..., 0.7866, 0.6657, 0.7630],
         ...,
         [0.6508, 0.6218, 0.7448,  ..., 0.7284, 0.6653, 0.7428],
         [0.8520, 0.7383, 0.8200,  ..., 0.8150, 0.7878, 0.9361],
         [0.7867, 0.6099, 0.7925,  ..., 0.6493, 0.6867, 0.7017]]],
       device='cuda:0')
Layer: torch.Size([1,

OutOfMemoryError: CUDA out of memory. Tried to allocate 1.12 GiB. GPU 0 has a total capacty of 14.57 GiB of which 566.00 MiB is free. Process 180055 has 9.21 GiB memory in use. Including non-PyTorch memory, this process has 4.80 GiB memory in use. Of the allocated memory 4.63 GiB is allocated by PyTorch, and 41.88 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
!nvidia-smi

In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CTNet-Mamba: A Convolution-Mamba Network for EEG-Based Motor Imagery Classification

author: zhaowei701@163.com
"""
import math
import sys
import os
import time
import random
from barebones_hymba.barebones_hymba_block import HymbaBlock
import flash_attn

import numpy as np
from zeta.nn import MultiQueryAttention
from vision_mamba import Vim

import torch
import torch.nn as nn
from torch.backends import cudnn
from torchsummary import summary

from einops import rearrange
from einops.layers.torch import Rearrange
import torch.nn.functional as F

from utils import load_data_evaluate, numberClassChannel, calMetrics

# Try to import real MambaBlock (fallback if missing)
try:
    from mamba_ssm import Mamba
except ImportError:
    class Mamba(nn.Module):
        def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.linear = nn.Linear(d_model, d_model)
        def forward(self, x):
            return self.linear(x)

from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention
from torchinfo import summary




os.makedirs("logs", exist_ok=True)

# Log filename with timestamp
log_filename = f"vimambaprinted.log"

# Redirect stdout (print, summary, etc.)
sys.stdout = open(log_filename, "w")
sys.stderr = sys.stdout  # Optional: log errors too

# -----------------------------------------------------------------------------
# PatchEmbeddingCNN
# -----------------------------------------------------------------------------
class ConvFeatureExtractor(nn.Module):
    def __init__(self, in_channels=22, f1=16, D=2, kernel_size=64, pool1=8, pool2=8, dropout=0.3):
        super().__init__()
        f2 = D * f1
        self.features = nn.Sequential(
            nn.Conv2d(1, f1, (1, kernel_size), padding='same', bias=False),
            nn.BatchNorm2d(f1),
            nn.Conv2d(f1, f2, (in_channels, 1), groups=f1, bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pool1)),
            nn.Dropout(dropout),
            nn.Conv2d(f2, f2, (1, 16), padding='same', bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pool2)),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.features(x)  # [B, f2, 1, time]

def sliding_projected_windows(x, cnn_model, num_windows=5):
    B, C, T = x.shape
    x = x.unsqueeze(1)  # [B, 1, C, T]
    feats = cnn_model(x)  # [B, f2, 1, T']
    feats = feats.squeeze(2)  # [B, f2, T']
    T_proj = feats.size(-1)
    step = (T_proj - 32) // (num_windows - 1)
    windows = []
    for i in range(num_windows):
        start = i * step
        end = start + 32
        w = feats[:, :, start:end]  # [B, f2, 32]
        w = w.unsqueeze(1)  # [B, 1, f2, 32]
        w = F.interpolate(w, size=(32, 32), mode='bilinear', align_corners=False)  # [B, 1, 32, 32]
        windows.append(w.repeat(1, 3, 1, 1))  # [B, 3, 32, 32]
    return windows

class EEGMambaTransformer(nn.Module):
    def __init__(self, num_classes=4, image_size=(32, 32), num_windows=5):
        super().__init__()
        self.num_windows = num_windows
        self.cnn = ConvFeatureExtractor()
        self.number_class, self.number_channel = numberClassChannel('A')


        self.vim = Vim(
            dim=128, dt_rank=16, dim_inner=128, d_state=64,
            num_classes=num_classes, image_size=image_size,
            patch_size=8, channels=3, dropout=0.1, depth=3
        )

    def forward(self, x):
        x = x.squeeze(1)  # [B, C, T]
        windows = sliding_projected_windows(x, self.cnn, self.num_windows)
        logits = [self.vim(patch) for patch in windows]
        out = torch.stack(logits).mean(dim=0)
        return None, out

# -----------------------------------------------------------------------------
# Experiment wrapper
# -----------------------------------------------------------------------------
class ExP:
    def __init__(self, nsub, data_dir, result_name,
                 epochs=300, number_aug=3, number_seg=8,
                 evaluate_mode='subject-dependent',
                 heads=2, emb_size=16, depth=6,
                 d_state=16, transformer_depth=1, mamba_depth=3,
                 dataset_type='A', eeg1_f1=8, eeg1_kernel_size=64,
                 eeg1_D=2, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
                 eeg1_dropout_rate=0.5, flatten_eeg1=15*16,
                 validate_ratio=0.3, learning_rate=1e-3, batch_size=72,
                 early_stopping_patience=100 ):
        self.nSub = nsub
        self.dataset_type = dataset_type
        self.data_dir = data_dir
        self.result_name = result_name
        self.n_epochs = epochs
        self.patience = early_stopping_patience
        self.no_improve = 0
        self.best_loss = float('inf')
        self.batch_size = batch_size
        self.validate_ratio = validate_ratio
        self.number_class, self.number_channel = numberClassChannel('A')

        self.criterion = nn.CrossEntropyLoss().cuda()

        self.model = EEGMambaTransformer(
            num_classes=self.number_class,
            image_size=(32, 32)
        ).cuda()


        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        os.makedirs(self.result_name, exist_ok=True)
        self.model_filename = os.path.join(self.result_name, f"model_{self.nSub}.pth")

    def interaug(self, timg, label):
        aug_data, aug_label = [], []
        recs = 3 * (self.batch_size // self.model.number_class)
        segpts = 1000 // 8
        for cls in range(self.model.number_class):
            idx = np.where(label == cls + 1)
            data, lbl = timg[idx], label[idx]
            tmp = np.zeros((recs,1,self.model.number_channel,1000), dtype=np.float32)
            for i in range(recs):
                for j in range(8):
                    ridx = random.randrange(data.shape[0])
                    tmp[i,0,:,j*segpts:(j+1)*segpts] = data[ridx,0,:,j*segpts:(j+1)*segpts]
            aug_data.append(tmp); aug_label.append(lbl[:recs])
        aug_data = np.concatenate(aug_data).astype(np.float32)
        aug_label = np.concatenate(aug_label)
        perm = np.random.permutation(len(aug_data))
        return (
            torch.from_numpy(aug_data[perm]).cuda(),
            torch.from_numpy((aug_label[perm]-1)).long().cuda()
        )

    def get_source_data(self):
        tr_x, tr_y, te_x, te_y = load_data_evaluate(
            self.data_dir, self.dataset_type, self.nSub,
            mode_evaluate='subject-dependent'
        )
        tr_x = np.expand_dims(tr_x, axis=1).astype(np.float32)
        te_x = np.expand_dims(te_x, axis=1).astype(np.float32)
        tr_y = tr_y.reshape(-1)
        te_y = te_y.reshape(-1)
        m, s = tr_x.mean(), tr_x.std()
        tr_x = (tr_x - m) / s
        te_x = (te_x - m) / s
        return tr_x, tr_y, te_x, te_y

    def train(self):
        tr_x, tr_y, te_x, te_y = self.get_source_data()
    
        # Create validation split
        dataset_size = len(tr_x)
        val_size = int(self.validate_ratio * dataset_size)
        train_size = dataset_size - val_size
        train_ds, val_ds = torch.utils.data.random_split(
            torch.utils.data.TensorDataset(torch.from_numpy(tr_x), torch.from_numpy(tr_y-1)),
            [train_size, val_size]
        )
    
        test_ds = torch.utils.data.TensorDataset(torch.from_numpy(te_x), torch.from_numpy(te_y-1))
        self.test_loader = torch.utils.data.DataLoader(test_ds, batch_size=self.batch_size)
    
        best_loss = float('inf')
        for e in range(self.n_epochs):
            # Training
            self.model.train()
            train_loader = torch.utils.data.DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
            
            # Initialize training metrics
            epoch_train_loss = 0.0
            train_correct = 0
            total_samples = 0
            
            # Start timing and memory tracking
            torch.cuda.synchronize()
            start_time = time.time()
            torch.cuda.reset_peak_memory_stats()
            
            for xb, yb in train_loader:
                xb, yb = xb.cuda().float(), yb.cuda().long()
                aug_x, aug_y = self.interaug(tr_x, tr_y)
                xb = torch.cat([xb, aug_x])
                yb = torch.cat([yb, aug_y])
                
                _, out = self.model(xb)
                loss = self.criterion(out, yb)
                
                # Backprop
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Track metrics
                epoch_train_loss += loss.item()
                preds = out.argmax(dim=1)
                train_correct += (preds == yb).sum().item()
                total_samples += yb.size(0)
    
            # Calculate training metrics
            train_loss = epoch_train_loss / len(train_loader)
            train_acc = train_correct / total_samples
            
            # Calculate memory and speed
            torch.cuda.synchronize()
            elapsed = time.time() - start_time
            mem_used = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB
            speed = total_samples / elapsed if elapsed > 0 else 0
    
            # Validation
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            with torch.no_grad():
                val_loader = torch.utils.data.DataLoader(val_ds, batch_size=self.batch_size)
                for xb, yb in val_loader:
                    xb, yb = xb.cuda().float(), yb.cuda().long()
                    _, out = self.model(xb)
                    loss = self.criterion(out, yb)
                    val_loss += loss.item()
                    val_correct += (out.argmax(1) == yb).sum().item()
    
            val_loss = val_loss / len(val_loader)
            val_acc = val_correct / len(val_ds)
    
            print(f"Epoch {e+1}/{self.n_epochs} | "
                  f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | "
                  f"Mem: {mem_used:.2f}MB | Speed: {speed:.2f} samples/s")
    
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), self.model_filename)
            else:
                self.no_improve += 1
                if self.no_improve >= self.patience:
                    print(f"Stopping early at epoch {e+1} (no improvement in {self.patience} epochs).")
                    break

    # Rest of test evaluation...

    # Rest of test evaluation remains the same...

        self.model.load_state_dict(torch.load(self.model_filename))
        self.model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for xb, yb in self.test_loader:
                xb = xb.cuda()
                _, out = self.model(xb)
                preds.append(out.argmax(1).cpu().numpy())
                trues.append(yb.numpy())
        preds = np.concatenate(preds); trues = np.concatenate(trues)
        acc, *_ = calMetrics(trues, preds)
        print(f"Subject {self.nSub} final accuracy: {acc:.4f}")
        return acc

# -----------------------------------------------------------------------------
def main(result_dir, DATA_DIR, N_SUBJECT, **cfg):
    os.makedirs(result_dir, exist_ok=True)

    dataset_type = 'A'
    number_class, number_channel = numberClassChannel(dataset_type)
    sModel = EEGMambaTransformer(
        num_classes=number_class,
        image_size=(32, 32)
    ).cuda()




    print(time.asctime(time.localtime(time.time())))

    accs = []
    for sub in range(1, N_SUBJECT+1):
        seed_n = np.random.randint(2024)
        random.seed(seed_n); np.random.seed(seed_n)
        torch.manual_seed(seed_n); torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)

        print(f"seed is {seed_n}")
        print(f"Subject {sub}")

        exp = ExP(sub, DATA_DIR, result_dir, **cfg)
        accs.append(exp.train())

    print("Average accuracy:", np.mean(accs))
    return accs

if __name__ == "__main__":
    cudnn.benchmark = False
    cudnn.deterministic = True

    CONFIG = dict(
        emb_size=128, depth=1, heads=4,
        d_state=32, transformer_depth=1, mamba_depth=2,
        dataset_type='A',
        eeg1_f1=8, eeg1_kernel_size=64,
        eeg1_D=16, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
        eeg1_dropout_rate=0.5, flatten_eeg1=15*128,
        epochs=1000, number_aug=3, number_seg=8,
        validate_ratio=0.3, learning_rate=1e-3, batch_size=36
    )

    DATA_DIR   = "bci2a/"
    RESULT_DIR = f"CTNet_Mamba_{int(time.time())}"
    N_SUBJECT  = 9

    main(RESULT_DIR, DATA_DIR, N_SUBJECT, **CONFIG)
    print("Done.")


  from .autonotebook import tqdm as notebook_tqdm


In [1]:
!nvidia-smi

Wed Apr 30 08:02:42 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:8B:00.0 Off |                    0 |
| N/A   51C    P8              9W /   70W |       4MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                