# 描述
以下是一个简化版的混合专家模型（MoE）实现，仅保留核心功能（Top-K路由和专家计算）。

In [28]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    """基础专家模块：两层全连接网络"""
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, input_dim)
        )
    
    def forward(self, x):
        return self.net(x)

class SimpleMoE(nn.Module):
    """极简MoE实现（无负载均衡）"""
    def __init__(self, num_experts=4, input_dim=64, hidden_dim=128, top_k=2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # 定义专家和门控网络
        self.experts = nn.ModuleList([Expert(input_dim, hidden_dim) for _ in range(num_experts)])
        self.gate = nn.Linear(input_dim, num_experts)
    
    def forward(self, x):
        batch_size, input_dim = x.shape
        
        # Step 1: 计算门控分数 [batch_size, num_experts]
        gate_logits = self.gate(x)
        
        # Step 2: 选择Top-K专家并归一化权重
        weights, indices = torch.topk(gate_logits, self.top_k, dim=-1)
        weights = F.softmax(weights, dim=-1)  # [batch_size, top_k]
        
        # Step 3: 初始化结果张量
        results = torch.zeros_like(x)
        
        # Step 4: 稀疏计算（逐专家处理）
        for expert_id in range(self.num_experts):
            # 找到选择当前专家的样本
            mask = (indices == expert_id).any(dim=-1)  # [batch_size]
            
            if mask.sum() > 0:  # 若有样本选择该专家
                selected_x = x[mask]  # [selected_num, input_dim]
                expert_output = self.experts[expert_id](selected_x)
                
                # 获取对应权重（可能来自多个Top-K位置）
                expert_weights = weights[mask, (indices[mask] == expert_id).nonzero()[:, 1]]
                
                # 加权累加
                results[mask] += expert_output * expert_weights.unsqueeze(-1)
        
        return results

# 测试样例
if __name__ == "__main__":
    # 参数设置
    batch_size = 4
    input_dim = 64
    num_experts = 4
    top_k = 2
    
    # 初始化模型与输入
    moe = SimpleMoE(num_experts=num_experts, input_dim=input_dim, top_k=top_k)
    x = torch.randn(batch_size, input_dim)
    
    # 前向传播
    output = moe(x)
    
    # 验证输出
    print(f"输入形状: {x.shape}")      # 输出: torch.Size([4, 64])
    print(f"输出形状: {output.shape}")  # 输出: torch.Size([4, 64])
    print("示例输出:", output[0][:3])    # 示例: tensor([0.12, -0.34, 0.56])
    print("门控权重示例:", moe.gate(x)[0])  # 示例输出各专家得分

输入形状: torch.Size([4, 64])
输出形状: torch.Size([4, 64])
示例输出: tensor([-0.0109, -0.0823, -0.2853], grad_fn=<SliceBackward0>)
门控权重示例: tensor([ 0.8467,  0.7944, -0.4276, -0.3861], grad_fn=<SelectBackward0>)
