In [1]:
import torch.nn as nn

# -------------------------------------------------------------------------
# Monkey-patch MultiheadAttention so forward() returns only `attn_output`
# -------------------------------------------------------------------------
_original_mha_forward = nn.MultiheadAttention.forward

def _mha_forward_no_weights(self, query, key, value, *args, **kwargs):
    # call the original, then discard the weights
    attn_output, _ = _original_mha_forward(self, query, key, value, *args, **kwargs)
    return attn_output

nn.MultiheadAttention.forward = _mha_forward_no_weights


In [2]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
CTNet-Mamba: A Convolution-Mamba Network for EEG-Based Motor Imagery Classification

author: zhaowei701@163.com
"""
import math

import os
import time
import random
from barebones_hymba.barebones_hymba_block import HymbaBlock
import flash_attn

import numpy as np
from zeta.nn import MultiQueryAttention

import torch
import torch.nn as nn
from torch.backends import cudnn
from torchsummary import summary

from einops import rearrange
from einops.layers.torch import Rearrange

from utils import load_data_evaluate, numberClassChannel, calMetrics

# Try to import real MambaBlock (fallback if missing)
try:
    from mamba_ssm import Mamba
except ImportError:
    class Mamba(nn.Module):
        def __init__(self, d_model, d_state=16, d_conv=4, expand=2):
            super().__init__()
            self.linear = nn.Linear(d_model, d_model)
        def forward(self, x):
            return self.linear(x)

from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention
from torchinfo import summary



# -----------------------------------------------------------------------------
# PatchEmbeddingCNN
# -----------------------------------------------------------------------------
class PatchEmbeddingCNN(nn.Module):
    def __init__(self, f1=16, kernel_size=64, D=2,
                 pooling_size1=8, pooling_size2=8,
                 dropout_rate=0.3, number_channel=22, emb_size=40):
        super().__init__()
        f2 = D * f1
        self.cnn_module = nn.Sequential(
            nn.Conv2d(1, f1, (1, kernel_size), padding='same', bias=False),
            nn.BatchNorm2d(f1),
            nn.Conv2d(f1, f2, (number_channel, 1), groups=f1, bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pooling_size1)),
            nn.Dropout(dropout_rate),
            nn.Conv2d(f2, f2, (1, 16), padding='same', bias=False),
            nn.BatchNorm2d(f2),
            nn.ELU(),
            nn.AvgPool2d((1, pooling_size2)),
            nn.Dropout(dropout_rate),
        )
        self.projection = Rearrange('b e h w -> b (h w) e')

    def forward(self, x):
        x = self.cnn_module(x)        # [B, f2, 1, seq]
        return self.projection(x)     # [B, seq, f2]

# -----------------------------------------------------------------------------
# LinearAttention
# -----------------------------------------------------------------------------
def exists(val): return val is not None

class LinearAttention(nn.Module):
    def __init__(self, dim, *, heads=4, dim_head=64, dropout=0.0):
        super().__init__()
        inner = heads * dim_head
        self.heads, self.scale = heads, dim_head**-0.5
        self.to_qkv = nn.Linear(dim, inner*3, bias=False)
        self.to_out = nn.Sequential(nn.Linear(inner, dim), nn.Dropout(dropout))

    def forward(self, x, mask=None):
        h = self.heads
        q,k,v = self.to_qkv(x).chunk(3, dim=-1)
        q,k,v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q,k,v))
        q, k = q*self.scale, k
        q, k = q.softmax(dim=-1), k.softmax(dim=-2)
        if exists(mask):
            mask = rearrange(mask, 'b n -> (b h) n', h=h)
            k = k.masked_fill(~mask.unsqueeze(-1), 0.)
        ctx = torch.einsum('b n d, b n e -> b d e', q, k)
        out = torch.einsum('b d e, b n d -> b n e', ctx, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        return self.to_out(out)

# -----------------------------------------------------------------------------
# TransformerBlock
# -----------------------------------------------------------------------------
import torch
import torch.nn as nn

class MambaTransformerblock(nn.Module):
    def __init__(self, dim, heads, dim_head, dropout=0.1, 
                 ff_mult=4, d_state=16, depth=4):        
        super().__init__()
        self.layers = nn.ModuleList([
            self._create_layer(dim, heads, ff_mult, d_state, dropout)
            for _ in range(depth)
        ])
        self.norm = nn.LayerNorm(dim)

    def _create_layer(self, dim, heads, ff_mult, d_state, dropout):
        return nn.ModuleDict({
            # 1) Mamba
            'mamba_norm':   nn.LayerNorm(dim),
            'mamba':        MambaBlock(dim, d_state=d_state),
            'mamba_dropout':nn.Dropout(dropout),

            # 2) PyTorch MultiheadAttention
            'attn_norm':    nn.LayerNorm(dim),
            'attn':         nn.MultiheadAttention(embed_dim=dim,
                                                  num_heads=heads,
                                                  dropout=dropout,
                                                  batch_first=True),
            'attn_dropout': nn.Dropout(dropout),

            # 3) FFN
            'ffn_norm':     nn.LayerNorm(dim),
            'feedforward':  nn.Sequential(
                                 nn.Linear(dim, ff_mult * dim),
                                 nn.GELU(),
                                 nn.Dropout(dropout),
                                 nn.Linear(ff_mult * dim, dim)
                             ),
            'ffn_dropout':  nn.Dropout(dropout),
        })

    def forward(self, x):
        # x: [B, L, dim]
        for layer in self.layers:
            # (1) Mamba sublayer
            m = layer['mamba_norm'](x)
            x = x + layer['mamba_dropout'](layer['mamba'](m))

            # (2) Multi-head self-attention sublayer
            a = layer['attn_norm'](x)
            # because of the monkey-patch, this returns a Tensor directly
            attn_out = layer['attn'](a, a, a, need_weights=False)
            x = x + layer['attn_dropout'](attn_out)

            # (3) Feed-forward sublayer
            f = layer['ffn_norm'](x)
            x = x + layer['ffn_dropout'](layer['feedforward'](f))

        return self.norm(x)



class SummaryWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        x = x.unsqueeze(1)  # [B, C, T] ➜ [B, 1, C, T]
        _, out = self.model(x)
        return out


# -----------------------------------------------------------------------------
# EEGMambaTransformer
# -----------------------------------------------------------------------------
class EEGMambaTransformer(nn.Module):
    def __init__(self, emb_size=40, depth=6, heads=4,
                 d_state=16, transformer_depth=1, mamba_depth=3,
                 database_type='A', eeg1_f1=8, eeg1_kernel_size=64,
                 eeg1_D=2, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
                 eeg1_dropout_rate=0.5, flatten_eeg1=15*16):
        super().__init__()
        self.number_class, self.number_channel = numberClassChannel(database_type)
        self.norm = nn.LayerNorm(emb_size)

        self.cnn = PatchEmbeddingCNN(
            f1=eeg1_f1, kernel_size=eeg1_kernel_size,
            D=eeg1_D, pooling_size1=eeg1_pooling_size1,
            pooling_size2=eeg1_pooling_size2,
            dropout_rate=eeg1_dropout_rate,
            number_channel=self.number_channel,
            emb_size=emb_size
        )

        dim_head = emb_size // heads
        self.hymba_layers = nn.ModuleList([
            HymbaBlock(
                mamba_expand=2,         # how much Mamba expands hidden dim
                hidden_size=emb_size,             # model hidden dimension
                num_attention_heads=heads,         # full attention heads
                num_key_value_heads=max(1, heads//2),
                conv_kernel_size=3,
                time_step_rank=d_state,            # SSM state dimension
                ssm_state_size=d_state,            # SSM state dimension
                attention_window_size=64,
                modify_attention_mask=False,
                num_meta_tokens=0,
                seq_length=15,
                use_positional_embedding=True,
                rope_base=10000,
            )
            for _ in range(depth)
        ])
        self.mamba_transformer = MambaTransformerblock(
            dim=emb_size, heads=heads,
            dim_head=dim_head,
            dropout=eeg1_dropout_rate, ff_mult=4,
            d_state=d_state,
            depth = depth,
        )
        self.drouput = nn.Dropout(0.25)
        self.flatten = nn.Flatten()
        self.classification = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(flatten_eeg1, self.number_class)
        )

    def forward(self, x):
        x = self.cnn(x)                       # [B, seq, emb]
        x = x * math.sqrt(x.shape[-1])

        for block in self.hymba_layers:
            x = block(x)                     # hybrid SSM + attention
        return x, self.classification(self.flatten(x))

# -----------------------------------------------------------------------------
# Experiment wrapper
# -----------------------------------------------------------------------------
class ExP:
    def __init__(self, nsub, data_dir, result_name,
                 epochs=300, number_aug=3, number_seg=8,
                 evaluate_mode='subject-dependent',
                 heads=2, emb_size=16, depth=6,
                 d_state=16, transformer_depth=1, mamba_depth=3,
                 dataset_type='A', eeg1_f1=8, eeg1_kernel_size=64,
                 eeg1_D=2, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
                 eeg1_dropout_rate=0.5, flatten_eeg1=15*16,
                 validate_ratio=0.3, learning_rate=1e-3, batch_size=72,
                 early_stopping_patience=100 ):
        self.nSub = nsub
        self.dataset_type = dataset_type
        self.data_dir = data_dir
        self.result_name = result_name
        self.n_epochs = epochs
        self.patience = early_stopping_patience
        self.no_improve = 0
        self.best_loss = float('inf')
        self.batch_size = batch_size
        self.validate_ratio = validate_ratio

        self.criterion = nn.CrossEntropyLoss().cuda()

        self.model = EEGMambaTransformer(
            emb_size=emb_size, depth=depth, heads=heads,
            d_state=d_state,
            transformer_depth=transformer_depth,
            mamba_depth=mamba_depth,
            database_type=dataset_type,
            eeg1_f1=eeg1_f1, eeg1_kernel_size=eeg1_kernel_size,
            eeg1_D=eeg1_D, eeg1_pooling_size1=eeg1_pooling_size1,
            eeg1_pooling_size2=eeg1_pooling_size2,
            eeg1_dropout_rate=eeg1_dropout_rate,
            flatten_eeg1=flatten_eeg1
        ).cuda()

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=learning_rate)
        os.makedirs(self.result_name, exist_ok=True)
        self.model_filename = os.path.join(self.result_name, f"model_{self.nSub}.pth")

    def interaug(self, timg, label):
        aug_data, aug_label = [], []
        recs = 3 * (self.batch_size // self.model.number_class)
        segpts = 1000 // 8
        for cls in range(self.model.number_class):
            idx = np.where(label == cls + 1)
            data, lbl = timg[idx], label[idx]
            tmp = np.zeros((recs,1,self.model.number_channel,1000), dtype=np.float32)
            for i in range(recs):
                for j in range(8):
                    ridx = random.randrange(data.shape[0])
                    tmp[i,0,:,j*segpts:(j+1)*segpts] = data[ridx,0,:,j*segpts:(j+1)*segpts]
            aug_data.append(tmp); aug_label.append(lbl[:recs])
        aug_data = np.concatenate(aug_data).astype(np.float32)
        aug_label = np.concatenate(aug_label)
        perm = np.random.permutation(len(aug_data))
        return (
            torch.from_numpy(aug_data[perm]).cuda(),
            torch.from_numpy((aug_label[perm]-1)).long().cuda()
        )

    def get_source_data(self):
        tr_x, tr_y, te_x, te_y = load_data_evaluate(
            self.data_dir, self.dataset_type, self.nSub,
            mode_evaluate='subject-dependent'
        )
        tr_x = np.expand_dims(tr_x, axis=1).astype(np.float32)
        te_x = np.expand_dims(te_x, axis=1).astype(np.float32)
        tr_y = tr_y.reshape(-1)
        te_y = te_y.reshape(-1)
        m, s = tr_x.mean(), tr_x.std()
        tr_x = (tr_x - m) / s
        te_x = (te_x - m) / s
        return tr_x, tr_y, te_x, te_y

    def train(self):
        tr_x, tr_y, te_x, te_y = self.get_source_data()
    
        # Create validation split
        dataset_size = len(tr_x)
        val_size = int(self.validate_ratio * dataset_size)
        train_size = dataset_size - val_size
        train_ds, val_ds = torch.utils.data.random_split(
            torch.utils.data.TensorDataset(torch.from_numpy(tr_x), torch.from_numpy(tr_y-1)),
            [train_size, val_size]
        )
    
        test_ds = torch.utils.data.TensorDataset(torch.from_numpy(te_x), torch.from_numpy(te_y-1))
        self.test_loader = torch.utils.data.DataLoader(test_ds, batch_size=self.batch_size)
    
        best_loss = float('inf')
        for e in range(self.n_epochs):
            # Training
            self.model.train()
            train_loader = torch.utils.data.DataLoader(train_ds, batch_size=self.batch_size, shuffle=True)
            
            # Initialize training metrics
            epoch_train_loss = 0.0
            train_correct = 0
            total_samples = 0
            
            # Start timing and memory tracking
            torch.cuda.synchronize()
            start_time = time.time()
            torch.cuda.reset_peak_memory_stats()
            
            for xb, yb in train_loader:
                xb, yb = xb.cuda().float(), yb.cuda().long()
                aug_x, aug_y = self.interaug(tr_x, tr_y)
                xb = torch.cat([xb, aug_x])
                yb = torch.cat([yb, aug_y])
                
                _, out = self.model(xb)
                loss = self.criterion(out, yb)
                
                # Backprop
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                
                # Track metrics
                epoch_train_loss += loss.item()
                preds = out.argmax(dim=1)
                train_correct += (preds == yb).sum().item()
                total_samples += yb.size(0)
    
            # Calculate training metrics
            train_loss = epoch_train_loss / len(train_loader)
            train_acc = train_correct / total_samples
            
            # Calculate memory and speed
            torch.cuda.synchronize()
            elapsed = time.time() - start_time
            mem_used = torch.cuda.max_memory_allocated() / (1024 ** 2)  # MB
            speed = total_samples / elapsed if elapsed > 0 else 0
    
            # Validation
            self.model.eval()
            val_loss = 0.0
            val_correct = 0
            with torch.no_grad():
                val_loader = torch.utils.data.DataLoader(val_ds, batch_size=self.batch_size)
                for xb, yb in val_loader:
                    xb, yb = xb.cuda().float(), yb.cuda().long()
                    _, out = self.model(xb)
                    loss = self.criterion(out, yb)
                    val_loss += loss.item()
                    val_correct += (out.argmax(1) == yb).sum().item()
    
            val_loss = val_loss / len(val_loader)
            val_acc = val_correct / len(val_ds)
    
            print(f"Epoch {e+1}/{self.n_epochs} | "
                  f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f} | "
                  f"Mem: {mem_used:.2f}MB | Speed: {speed:.2f} samples/s")
    
            if val_loss < best_loss:
                best_loss = val_loss
                torch.save(self.model.state_dict(), self.model_filename)


    # Rest of test evaluation...

    # Rest of test evaluation remains the same...

        self.model.load_state_dict(torch.load(self.model_filename))
        self.model.eval()
        preds, trues = [], []
        with torch.no_grad():
            for xb, yb in self.test_loader:
                xb = xb.cuda()
                _, out = self.model(xb)
                preds.append(out.argmax(1).cpu().numpy())
                trues.append(yb.numpy())
        preds = np.concatenate(preds); trues = np.concatenate(trues)
        acc, *_ = calMetrics(trues, preds)
        print(f"Subject {self.nSub} final accuracy: {acc:.4f}")
        return acc

# -----------------------------------------------------------------------------
def main(result_dir, DATA_DIR, N_SUBJECT, **cfg):
    os.makedirs(result_dir, exist_ok=True)

    number_class, number_channel = numberClassChannel(cfg['dataset_type'])
    model = EEGMambaTransformer(
        emb_size=cfg['emb_size'], depth=cfg['depth'], heads=cfg['heads'],
        d_state=cfg['d_state'],
        transformer_depth=cfg['transformer_depth'],
        mamba_depth=cfg['mamba_depth'],
        database_type=cfg['dataset_type'],
        eeg1_f1=cfg['eeg1_f1'], eeg1_kernel_size=cfg['eeg1_kernel_size'],
        eeg1_D=cfg['eeg1_D'], eeg1_pooling_size1=cfg['eeg1_pooling_size1'],
        eeg1_pooling_size2=cfg['eeg1_pooling_size2'],
        eeg1_dropout_rate=cfg['eeg1_dropout_rate'],
        flatten_eeg1=cfg['flatten_eeg1']
    ).cuda()

    summary(
        model, 
        input_size=(1, 1, number_channel, 1000),  # batch, channel=1, EEG channels, time
        col_names=["input_size", "output_size", "num_params"],
        depth=3,
        device="cuda"
    )
    print(time.asctime(time.localtime(time.time())))

    accs = []
    for sub in range(1, N_SUBJECT+1):
        seed_n = np.random.randint(2024)
        random.seed(seed_n); np.random.seed(seed_n)
        torch.manual_seed(seed_n); torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)

        print(f"seed is {seed_n}")
        print(f"Subject {sub}")

        exp = ExP(sub, DATA_DIR, result_dir, **cfg)
        accs.append(exp.train())

    print("Average accuracy:", np.mean(accs))
    return accs

if __name__ == "__main__":
    cudnn.benchmark = False
    cudnn.deterministic = True

    CONFIG = dict(
        emb_size=128, depth=1, heads=4,
        d_state=32, transformer_depth=1, mamba_depth=2,
        dataset_type='A',
        eeg1_f1=8, eeg1_kernel_size=64,
        eeg1_D=16, eeg1_pooling_size1=8, eeg1_pooling_size2=8,
        eeg1_dropout_rate=0.5, flatten_eeg1=15*128,
        epochs=1000, number_aug=3, number_seg=8,
        validate_ratio=0.3, learning_rate=1e-3, batch_size=72
    )

    DATA_DIR   = "bci2a/"
    RESULT_DIR = f"CTNet_Mamba_{int(time.time())}"
    N_SUBJECT  = 9

    main(RESULT_DIR, DATA_DIR, N_SUBJECT, **CONFIG)
    print("Done.")


  from .autonotebook import tqdm as notebook_tqdm


NameError: name 'sys' is not defined