# LLM的MoE-混合专家不同版本

## MoE（Mixture of Experts）设计想解决什么？  
1. 模型太大，容量越大效果越好，但是训练推理的成本线性增长。
2. MoE 把“一个大网络”拆成“多个小网络（专家）”，每次只激活其中 极少 几个专家，于是：
 * 参数总量继续膨胀 → 表达能力↑
 * 计算量几乎不变（只激活 k 个专家）→ 成本≈常数


## 核心模块：
### 1. 专家网络（Expert Network）

通常是普通 FFN（两个线性层 + 非线性），也可以是更复杂的子模型。

### 2. 路由/门控函数（Router / Gate）

决定每条输入 token 该去哪些专家。
数学形式：
$$ g(x) = softmax(W_r \times x)  \in \mathbb{R}^E$$
其中 $E$ 是专家数，$W_r$是路由权重矩阵。

### 3. Top-k 路由策略

Switch-Transformer 用 k=1（Top-1 路由）
GShard 用 k=2（Top-2 路由）
只保留最大的 k 个概率，其余置零 $\rightarrow$ 稀疏。

### 4. 负载均衡损失（Load-Balancing Loss）

防止所有 token 都涌向少数“学霸”专家。
经典实现：
$$ L_{aux} = \alpha \times \sum_{i=1}^E {f_i \times P_i}$$
$f_i$ ：当前 batch 里分配给专家 i 的 token 比例
$P_i$：路由给专家 i 的平均门控概率
让 $f_i$ ≈ 均匀分布，$P_i$ 也相应均匀，于是 $L_{aux}$ 最小。

# 1. 基础版本MoE
![basic MOE](./img/basic-moe-model.png)

In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F

这里简单使用一个Linear层来作为这个专家网络，实际上需要换成FFN。

In [15]:
class BasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.fc(x)

In [16]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in, num_experts)
        # output shape is (b, num_experts)
        self.experts = nn.ModuleList(
            BasicExpert(feature_in, feature_out) for _ in range(num_experts)
        )
    
    def forward(self, x):
        # x shape is (b, f_in)  feature_in 也可叫做hidden_size
        expert_weights = self.gate(x)
        expert_out_list = [expert(x) for expert in self.experts]
        # 每个expert输出一个（b, f_out）

        #expert_out unsqueeze to (b, 1, feature_out)
        expert_outputs= [expert_out.unsqueeze(1) for expert_out in expert_out_list]


        expert_output = torch.concat(expert_outputs, dim=1,)
        #expert_output shape is (b, num_experts, feature_out)

        expert_weights = F.softmax(expert_weights, dim=1) # shape is (b, num_expert)

        # 我们希望输出为（batch_size, feature_out）
        expert_weights = expert_weights.unsqueeze(1)  # (b, 1, num_expert)
        output = expert_weights @ expert_output
        return output.squeeze(1)

In [17]:
# test
def test_BasicMOE():
    x = torch.rand(4, 512)
    basic_moe = BasicMOE(512, 128, 4)
    output = basic_moe(x)
    print(output.shape)

test_BasicMOE()

torch.Size([4, 128])


# 2.SparseMoE(现代大模型)
和basic的区别是，MoE选择topK个专家，并根据输出进行加权求和，并把输入样本变成了大模型的真实输入Shape，（batch_size, seq_len, hidden_dim）

### topk()



* 选出对应维度（默认dim=-1）的topK（前K最大值），返回值和索引

In [18]:
a = torch.randn(11)
top2_value, indices = a.topk(2, dim=-1)
print(a, "\n", top2_value, indices)

tensor([ 1.5430,  0.7799,  1.8441, -0.3242, -0.7400, -1.2461, -0.0387,  1.0213,
         1.3348, -0.4689, -0.0103]) 
 tensor([1.8441, 1.5430]) tensor([2, 0])


In [19]:
import torch
a = torch.arange(45).reshape(3, 3, 5).float()
print('a shape:', a.shape)        # torch.Size([2, 2, 5])
print('a:\n', a)

values, indices = a.topk(2, dim=0)   # k=2，dim=0
print('values shape:', values.shape) # torch.Size([2, 2, 5])
print('indices shape:', indices.shape) # torch.Size([2, 2, 5])

print('values:\n', values)
print('indices:\n', indices)

a shape: torch.Size([3, 3, 5])
a:
 tensor([[[ 0.,  1.,  2.,  3.,  4.],
         [ 5.,  6.,  7.,  8.,  9.],
         [10., 11., 12., 13., 14.]],

        [[15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.]],

        [[30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44.]]])
values shape: torch.Size([2, 3, 5])
indices shape: torch.Size([2, 3, 5])
values:
 tensor([[[30., 31., 32., 33., 34.],
         [35., 36., 37., 38., 39.],
         [40., 41., 42., 43., 44.]],

        [[15., 16., 17., 18., 19.],
         [20., 21., 22., 23., 24.],
         [25., 26., 27., 28., 29.]]])
indices:
 tensor([[[2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2],
         [2, 2, 2, 2, 2]],

        [[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]]])


### 2).one_hot()

* torch.nn.functional.one_hot(tensor, num_classes=-1) -> Tensor
* 参数：  
 tensor：任意形状的 整型 张量，元素取值范围 [0, num_classes-1]。num_classes（可选）类别总数。不填（缺省 -1）时，函数内部用 tensor.max()+1 推断  
* 返回值：  
 新张量比输入 多一维，最后一维长度 = num_classes，
 对应位置为 1，其余为 0，dtype 为 torch.int64


In [20]:
labels = torch.tensor([0, 2, 1, 3])   # shape (4,)
one_hot = F.one_hot(labels)           # num_classes 自动推断为 4
print(one_hot)
print(one_hot.shape)

tensor([[1, 0, 0, 0],
        [0, 0, 1, 0],
        [0, 1, 0, 0],
        [0, 0, 0, 1]])
torch.Size([4, 4])


SparseMoE实现代码

In [21]:
# 设置参数
class MOEConfig:
    def __init__(self, 
                 hidden_dim, expert_number, 
                 top_k, shared_expert_number = 2
                 ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_expert_number = shared_expert_number

In [22]:
# 如何为每个token选专家
class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)
        # 但是后面只选择K个专家

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self, x):
        # 假设expert_num =8， top_k = 2
        router_logits = self.gate(x)

        #计算每一个专家的概率
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float)

        # 计算topk专家的输出
        # topk是可以反向传播的
        router_weights, selected_expert_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1,
        )# router_weights, selected_expert_indices shape are (b * s, top_k)

        # 重新归一化
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True)
        router_weights = router_weights.to(x.dtype) # 防止类型改变

        expert_mask = F.one_hot(
            selected_expert_indices, 
            num_classes=self.expert_number,
        )# (b * s, top_k, expert_number)

        expert_mask = expert_mask.permute(2, 1, 0)
        # shape is (expert_number, top_k, b * s)

        return router_logits, router_weights, selected_expert_indices, expert_mask
        # router_logits (b*s, e_num), 
        # router_weights (b*s, top_k), 
        # selected_expert_indices (b*s, top_k), 
        # expert_mask (expert_num, t, b*s)

In [23]:
# 主模型
class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

        # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(config.hidden_dim, config.hidden_dim, 
                        )for _ in range(config.expert_number)
                    )

        self.router = MOERouter(config)

    def forward(self, x):
        # x shape is (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # token 维度计算 x.reshape(batch * seq_len, hidden_dim)
        hidden_states = x.view(-1, hidden_dim)
        
        # 做相关的专家计算
        router_logits, router_weights, _, expert_masks = self.router(
            hidden_states
            )
        
        # 需要最后的输出 (batch * seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
             dtype=hidden_states.dtype,
             device=hidden_states.device
        )

        # 遍历每一个专家
        # 把选中的专家的token的hidden_states 加到 final_hidden_states
        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_masks shape is (expert_num, top_k, b*s )
            # current_masks shape is (top_k, b*s )
            current_expert_mask = expert_masks[expert_idx]

            router_weight_idx, top_x = torch.where(current_expert_mask)
            # router_weight_idx 是 0 or 1 (topk=2)
            # 表示的是这个token是作为当前专家的 top1 还是 top2
            # top_x 的值是 token 在 batch*seq_len 中的位置索引
            # 例如对于 batch_size=2, seq_len=4 的输入:
            # top_x in [0, 7] 表示在展平后的 8 个 token 的位置
            # 作用： router_weight_idx --> weight, top_x --> hidden_states
            
            # hidden_states (b*s, hidden_dim) --> (1, b*s, hidden_dim) -->
            # --> (selected_token_number, hidden_dim)
            current_states = hidden_states.unsqueeze(
                0)[:, top_x, :].reshape(-1, hidden_dim)
            
            current_states = expert_layer(current_states)

            current_token_router_weight = router_weights[top_x, router_weight_idx]
            # 最终的shape就变成了(selected_token_number, )
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # shape就变成了(selected_token_number, 1)
            
            #广播 (selected_token_number, hidden_dim) *(selected_token_number, 1)
            current_hidden_stasts = current_states * current_token_router_weight
            
            final_hidden_states.index_add_(
                0, 
                top_x,
                current_hidden_stasts   
            )
        # 把 final_hidden_states 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        router_logits = router_logits.view(batch_size, seq_len, -1)    
        return final_hidden_states, router_logits # shape 是 (b * s, expert_number)

In [24]:
# 测试代码
if __name__ == "__main__":
    torch.manual_seed(42)

    # 配置
    hidden_dim = 64
    expert_num = 8
    top_k = 2
    config = MOEConfig(hidden_dim, expert_num, top_k)

    # 构造模型和输入
    moe = SparseMOE(config)
    batch, seq_len = 2, 4
    x = torch.randn(batch, seq_len, hidden_dim)

    # 前向
    out, router_logits = moe(x)

    print("Input shape :", x.shape)
    print("Output shape:", out.shape)            # 应为 (2,4,64)
    print("Router logits shape:", router_logits.shape)  # 应为 (2,4,8)
    print("Output nan? ", torch.isnan(out).any())

Input shape : torch.Size([2, 4, 64])
Output shape: torch.Size([2, 4, 64])
Router logits shape: torch.Size([2, 4, 8])
Output nan?  tensor(False)


# 3.share_expert_sparseMoE

In [25]:
class ShareExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.moe_model = SparseMOE(config)
        self.shared_experts = nn.ModuleList(
            [
                BasicExpert(
                    config.hidden_dim, config.hidden_dim
                ) for _ in range(config.shared_expert_number)
            ]
        )

    def forward(self, x):
        # x shape 是 (b, s, hidden_dim)
        # 首先过 moe 模型
        sparse_moe_out, router_logits = self.moe_model(x)
        
        # 针对的还是 x 的每一个 
        # 然后过 shared experts
        shared_experts_out = [
            expert(x) for expert in self.shared_experts
        ] # 每一个 expert 的输出 shape 是 (b, s, hidden_dim)
        
        shared_experts_out = torch.stack(
            shared_experts_out, dim=0
        ).sum(dim=0)
        
        # 把 sparse_moe_out 和 shared_experts_out 加起来
        return sparse_moe_out + shared_experts_out, router_logits


def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    share_expert_moe = ShareExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)


test_share_expert_moe()

torch.Size([2, 4, 16]) torch.Size([2, 4, 2])
