**MiMo-V2-Flash核心设计**
1. 混合注意力: 128-token滑动窗口注意力(SWA) + 全局注意力(GA), ratio = 5:1
2. MoE 结构: 256个专家，每 token 激活 8 个
3. 轻量 MTP 模块: 稠密 FFN + SWA， 用于投机解码加速
4. 关键优化: RoPE位置编码、RMSNorm、FP16 混合精度

In [1]:
# 加载依赖库
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import sqrt

In [2]:
# RMSNorm
class RMSNorm(nn.Module):
    """RMSNorm 归一化"""
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

In [3]:
# RoPE位置编码
class RoPE(nn.Module):
    """旋转位置编码（仅应用前 64 维）"""
    def __init__(self, dim: int, max_seq_len: int = 256000):
        super().__init__()
        self.max_seq_len = max_seq_len
        # 仅对前 64 维应用 RoPE
        self.rope_dim = min(64, dim)
        theta = 1.0 / (10000 ** (torch.arange(0, self.rope_dim, 2) / self.rope_dim))
        self.register_buffer("theta", theta)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        seq_len = x.shape[1]
        device = x.device
        # 生成位置索引
        pos = torch.arange(seq_len, device=device).unsqueeze(1)
        # 计算旋转矩阵参数
        rope = pos * self.theta.unsqueeze(0)
        rope = torch.cat([rope, rope], dim=-1)
        # 仅对前 rope_dim 维应用旋转
        if self.rope_dim < self.dim:
            rope = torch.cat([rope, torch.zeros(seq_len, self.dim - self.rope_dim, device=device)], dim=-1)
        # 构建旋转矩阵并应用
        cos = rope.cos().unsqueeze(0)
        sin = rope.sin().unsqueeze(0)
        x_rope = x[..., :self.rope_dim]
        x_rope = torch.cat([x_rope[..., ::2] * cos[..., ::2] - x_rope[..., 1::2] * sin[..., 1::2],
                           x_rope[..., 1::2] * cos[..., ::2] + x_rope[..., ::2] * sin[..., 1::2]], dim=-1)
        if self.rope_dim < self.dim:
            x = torch.cat([x_rope, x[..., self.rope_dim:]], dim=-1)
        else:
            x = x_rope
        return x

In [4]:
# Expert
class Expert(nn.Module):
    """MoE 单个专家网络(稠密 FFN)"""
    def __init__(self, dim: int, hidden_dim: int = 2048):
        super().__init__()
        self.fc1 = nn.Linear(dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, dim)
        self.act = F.silu
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.fc2(self.act(self.fc1(x)))

In [5]:
# MoEFFN
class MoEFFN(nn.Module):
    """MoE 前馈网络(256专家，激活 8 个)"""
    def __init__(self, dim: int, num_experts: int = 256, top_k: int = 8):
        super().__init__()
        self.dim = dim
        self.num_experts = num_experts
        self.top_k = top_k
        # 专家网络
        self.experts = nn.ModuleList([Expert(dim) for _ in range(num_experts)])
        # 专家理由
        self.router = nn.Linear(dim, num_experts)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, dim = x.shape
        x_flat = x.reshape(-1, dim)  # [batch_size * seq_len, dim]

        # 路由选择专家
        router_logits = self.router(x_flat)  # [batch_size * seq_len, num_experts]
        top_k_logits, top_k_indices = torch.topk(router_logits, self.top_k, dim=-1)  # [batch_size * seq_len, top_k]
        top_k_weights = F.softmax(top_k_logits, dim=-1)  # [batch_size * seq_len, top_k]

        # 收集专家输出
        output = torch.zeros_like(x_flat)
        for i in range(self.top_k):
            expert_idx = top_k_indices[:, i]
            weight = top_k_weights[:, i].unsqueeze(-1)  # [batch_size * seq_len, 1]
            # 按专家分组计算
            for expert_id in range(self.num_experts):
                mask = (expert_idx == expert_id)
                if mask.any():
                    expert_output = self.experts[expert_id](x_flat[mask])
                    output[mask] += weight[mask] * expert_output

        return output.reshape(batch_size, seq_len, dim)

In [6]:
# SlidingWindowAttention
class SlidingWindowAttention(nn.Module):
    """滑动窗口注意力(SWA)"""
    def __init__(self, dim: int, num_heads: int = 64, num_kv_heads: int = 8, window_size: int = 128):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.head_dim = dim // num_heads
        self.window_size = window_size
        # 注意力 sink 偏置
        self.sink_bias = nn.Parameter(torch.tensor(0.0))

        # QKV 投影
        self.q_proj = nn.Linear(dim, num_heads * self.head_dim)
        self.k_proj = nn.Linear(dim, num_kv_heads * self.head_dim)
        self.v_proj = nn.Linear(dim, num_kv_heads * self.head_dim)
        self.o_proj = nn.Linear(num_heads * self.head_dim, dim)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        batch_size, seq_len, dim = x.shape

        # QKV 投影
        q = self.q_proj(x).reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k = self.k_proj(x).reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]
        v = self.v_proj(x).reshape(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2)  # [batch_size, num_kv_heads, seq_len, head_dim]

        # 重复 KV 头以匹配 Q 头数量（GQA 机制）
        if self.num_kv_heads != self.num_heads:
            k = k.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
            v = v.repeat_interleave(self.num_heads // self.num_kv_heads, dim=1)
        
        # 滑动窗口注意力计算
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / sqrt(self.head_dim)  # [batch_size, num_heads, seq_len, seq_len]

        # 应用滑动窗口注意力掩码
        if mask is None:
            mask = torch.ones(seq_len, seq_len, device=x.device, dtype=torch.bool)
            for i in range(seq_len):
                start = max(0, i - self.window_size + 1)
                mask[i, :start] = False
        attn_weights = attn_weights.masked_fill(~mask, -1e18)

        # 应用 attention sink 偏置
        m_i = torch.max(torch.max(attn_weights, dim=-1, keepdim=True)[0], self.sink_bias)
        attn_weights = attn_weights - m_i
        attn_weights = torch.exp(attn_weights)
        sink_term = torch.exp(self.sink_bias - m_i)
        attn_weights = attn_weights / (attn_weights.sum(dim=-1, keepdim=True) + sink_term)

        # 注意力输出
        out = torch.matmul(attn_weights, v)
        out = out.transpose(1, 2).reshape(batch_size, seq_len, dim)
        out = self.o_proj(out)

        return out

# GlobalAttention
class GlobalAttention(SlidingWindowAttention):
    """全局注意力(GA) - 滑动窗口设为序列长度"""
    def __init__(self, dim: int, num_heads: int = 64, num_kv_heads: int = 4):
        super().__init__(dim, num_heads, num_kv_heads, window_size=1000000)  # 超大窗口模拟全局注意力

In [7]:
# SWABlock
class SWABlock(nn.Module):
    """滑动窗口注意力块(SWA + MoE FFN)"""
    def __init__(self, dim: int):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = SlidingWindowAttention(dim)
        self.norm2 = RMSNorm(dim)
        self.ffn = MoEFFN(dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), mask)
        x = x + self.ffn(self.norm2(x))
        return x

In [8]:
# GABlock
class GABlock(nn.Module):
    """全局注意力块(GA + MoE FFN)"""
    def __init__(self, dim: int, use_dense_ffn: bool = False):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = GlobalAttention(dim)
        self.norm2 = RMSNorm(dim)
        # 第一个块使用稠密 FFN，其余使用 MoE FFN
        if use_dense_ffn:
            self.ffn = nn.Sequential(
                nn.Linear(dim, 16384),
                nn.SiLU(),
                nn.Linear(16384, dim)
            )
        else:
            self.ffn = MoEFFN(dim)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), mask)
        x = x + self.ffn(self.norm2(x))
        return x

In [9]:
# MTP
class MTPBlock(nn.Module):
    """多任务预测块(MTP)"""
    def __init__(self, dim: int):
        super().__init__()
        self.norm1 = RMSNorm(dim)
        self.attn = SlidingWindowAttention()
        self.norm2 = RMSNorm(dim)
        # 稠密 FFN（轻量设计)
        self.ffn = nn.Sequential(
            nn.Linear(dim, 1024),
            nn.SiLU(),
            nn.Linear(1024, dim)
        )
        # MTP 预测头
        self.predict_head = nn.Linear(dim, dim)
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        x = x + self.attn(self.norm1(x), mask)
        x = x + self.ffn(self.norm2(x))
        x = self.predict_head(x)
        return x

In [12]:
# MiMoV2Flash
class MiMoV2Flash(nn.Module):
    """MiMoV2Flash 模型"""
    def __init__(self, vocab_size: int = 32000, dim: int = 4096, num_hybrid_blocks: int = 8):
        super().__init__()
        self.dim = dim
        # 嵌入层
        self.embeddings = nn.Embedding(vocab_size, dim)
        # RoPE 位置编码
        self.rope = RoPE(dim)
        # 模型主结构: 1 个 GA 块(稠密 FFN) + 混合块(5 个 SWA + 1 个 GA)
        self.layers = nn.ModuleList()
        # 第一个块：GA + 稠密 FFN
        self.layers.append(GABlock(dim, use_dense_ffn=True))
        # 混合块：5 个 SWA + 1 个 GA
        for _ in range(num_hybrid_blocks):
            self.layers.append([SWABlock(dim) for _ in range(5)])
            self.layers.append(GABlock[dim])
        # 输出归一化
        self.norm = RMSNorm(dim)
        # 语言模型头
        self.lm_head = nn.Linear(dim, vocab_size)
        # 共享权重
        self.lm_head.weight = self.embeddings.weight
        # 3 层 MTP 模块
        self.mtp_layers = nn.ModuleList(MTPBlock(dim) for _ in range(3))

    def forward(self, input_ids: torch.Tensor, use_mtp: bool = False, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Args:
            input_ids: 输出 token ID，形状为 (batch_size, seq_len)
            use_mtp: 是否使用 MTP 模块
            mask: 掩码，形状为 (batch_size, seq_len)
        """
        x = self.embeddings(input_ids)
        x = self.rope(x)
        
        # 主模型前向传播
        for layer in self.layers:
            x = layer(x, mask)
        lm_logits = self.lm_head(self.norm(x))
        
        # MTP 模块生成 draft tokens
        mtp_logits = None
        if use_mtp:
            mtp_x = x
            for mtp_layer in self.mtp_layers:
                mtp_x = mtp_layer(mtp_x, mask)
            mtp_logits = self.lm_head(self.norm(mtp_x))

        return {"lm_logits": lm_logits, "mtp_logits": mtp_logits}

In [None]:
# ------------------------------
# 测试代码
# ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")

# 初始化模型（简化版，减少专家数量以降低显存占用）
# 注：实际部署时可恢复num_experts=256
model = MiMoV2Flash(vocab_size=32000, dim=4096, num_hybrid_blocks=8).to(device)
target_num_experts = 32
for layer in model.layers:
    # 判断当前层是否包含 MoEFFN（SWABlock 和 GABlock 都有 ffn 属性）
    if hasattr(layer, "ffn") and isinstance(layer.ffn, MoEFFN):
        layer.ffn.num_experts = target_num_experts
        # 重新初始化专家网络（可选，确保参数适配新的专家数量）
        layer.ffn.experts = nn.ModuleList([Expert(layer.ffn.dim) for _ in range(target_num_experts)])

# 生成测试
batch_size = 2
seq_len = 128
input_ids = torch.randint(0, 32000, (batch_size, seq_len), device=device)

# 前向传播(开启 FP16混合精度加速)
with torch.autocast(device_type=device.type, dtype=torch.float16):
    output = model(input_ids, use_mtp=True)

# 输出结果验证
print(f"输入形状: {input_ids.shape}")
print(f"LM输出形状: {output['lm_logits'].shape}")
print(f"MTP输出形状: {output['mtp_logits'].shape}")
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

# 计算损失（模拟训练）
loss_fn = nn.CrossEntropyLoss()
lm_loss = loss_fn(output["lm_logits"].reshape(-1, 32000), input_ids[:, 1:].reshape(-1))
mtp_loss = loss_fn(output["mtp_logits"].reshape(-1, 32000), input_ids[:, 1:].reshape(-1))
total_loss = lm_loss + 0.1 * mtp_loss  # MTP损失权重0.1
total_loss.backward()

print(f"LM损失: {lm_loss.item():.4f}")
print(f"MTP损失: {mtp_loss.item():.4f}")
print("模型运行成功！")

使用设备: cuda


TypeError: list is not a Module subclass