Mamba out!

In [3]:
## 从

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class MambaConfig:
    def __init__(
        self,
        d_model=256,    # 模型维度
        n_layer=4,      # Mamba层数
        vocab_size=30000, # 词表大小
        num_classes=14, # 分类类别数
        state_dim=16,   # 状态空间维度
        expand=2,       # 扩展因子
        dt_rank="auto", # Δ的秩
        conv_kernel=4,  # 卷积核大小
        # use_cuda=True,  # 是否使用CUDA加速
    ):
        self.d_model = d_model
        self.n_layer = n_layer
        self.vocab_size = vocab_size
        self.num_classes = num_classes
        self.state_dim = state_dim
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = dt_rank if dt_rank != "auto" else (self.d_model // 16)
        self.conv_kernel = conv_kernel
        # self.use_cuda = use_cuda

class MambaBlock(nn.Module):
    """Mamba 核心块 (基于选择性状态空间模型)"""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # 投影输入到内部维度
        self.in_proj = nn.Linear(config.d_model, config.d_inner * 2, bias=False)

        # 卷积分支
        self.conv1d = nn.Conv1d(
            in_channels=config.d_inner,
            out_channels=config.d_inner,
            kernel_size=config.conv_kernel,
            groups=config.d_inner,
            padding=config.conv_kernel - 1,
        )

        # 选择性SSM参数生成
        self.x_proj = nn.Linear(config.d_inner, config.dt_rank + config.state_dim * 2, bias=False)
        self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
        
        # 状态空间参数
        self.A = nn.Parameter(torch.arange(1, config.state_dim+1, dtype=torch.float32).repeat(config.d_inner, 1))
        self.D = nn.Parameter(torch.ones(config.d_inner))

        # 输出投影
        self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=False)

        # 层归一化
        self.norm = nn.LayerNorm(config.d_model)

    def forward(self, x):
        # x shape: [batch, seq_len, d_model]
        residual = x
        x = self.norm(x)
        
        # 投影到内部维度
        x = self.in_proj(x)  # [batch, seq_len, d_inner*2]
        x, z = x.chunk(2, dim=-1)  # 分割为两个分支
        
        # 1D卷积
        x = rearrange(x, 'b l d -> b d l')
        x = self.conv1d(x)[:, :, :-(self.config.conv_kernel - 1)]  # 因果卷积
        x = rearrange(x, 'b d l -> b l d')
        x = F.silu(x)
        
        # 生成选择性参数
        params = self.x_proj(x)  # [batch, seq_len, dt_rank + 2*state_dim]
        dt, B, C = torch.split(params, [self.config.dt_rank, self.config.state_dim, self.config.state_dim], dim=-1)
        dt = self.dt_proj(dt)  # [batch, seq_len, d_inner]
        
        # 离散化状态空间模型
        A = -torch.exp(self.A.float())  # [d_inner, state_dim]
        discrete_A = torch.exp(A[None, None, :, :] * dt[:, :, :, None])  # [batch, seq_len, d_inner, state_dim]
        discrete_B = dt[:, :, :, None] * B[:, :, None, :]  # [batch, seq_len, d_inner, state_dim]
        C = C[:, :, None, :]  # [batch, seq_len, 1, state_dim]
        
        # 扫描过程 (简化实现)
        state = torch.zeros(x.size(0), self.config.d_inner, self.config.state_dim, device=x.device)
        outputs = []
        for i in range(x.size(1)):
            state = discrete_A[:, i] * state + discrete_B[:, i] * x[:, i, :, None]
            y = (state @ C[:, i].transpose(-1, -2)).squeeze(-1) + self.D * x[:, i]
            outputs.append(y)
        x = torch.stack(outputs, dim=1)  # [batch, seq_len, d_inner]
        
        # 门控分支
        z = torch.sigmoid(z)
        x = x * z
        
        # 输出投影
        x = self.out_proj(x)  # [batch, seq_len, d_model]
        
        return x + residual

class MambaTextClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # 词嵌入层
        self.embedding = nn.Embedding(config.vocab_size, config.d_model)
        
        # Mamba层堆叠
        self.layers = nn.ModuleList([
            MambaBlock(config) for _ in range(config.n_layer)
        ])
        
        # 分类头
        self.classifier = nn.Linear(config.d_model, config.num_classes)
        
    def forward(self, input_ids):
        # input_ids: [batch, seq_len]
        x = self.embedding(input_ids)  # [batch, seq_len, d_model]
        
        # 通过Mamba层
        for layer in self.layers:
            x = layer(x)
        
        # 池化取平均
        pooled = x.mean(dim=1)  # [batch, d_model]
        
        # 分类
        logits = self.classifier(pooled)  # [batch, num_classes]
        return logits

# 示例用法
if __name__ == "__main__":
    config = MambaConfig(
        vocab_size=30000,
        num_classes=14,
        d_model=256,
        n_layer=4
    )
    
    model = MambaTextClassifier(config)
    model.to(torch.device("mps"))
    
    # 模拟输入
    batch_size = 32
    seq_len = 128
    input_ids = torch.randint(0, 30000, (batch_size, seq_len)).to(torch.device("mps"))
    
    # 前向传播
    logits = model(input_ids)
    print("输出logits形状:", logits.shape)  # 应为 [32, 14]

输出logits形状: torch.Size([32, 14])
