参考资料：[LLM MOE的进化之路，从普通简化 MOE，到 spare_moe，再到 deepseek 使用的 share_xpert_spare_moe](https://bruceyuan.com/llms-zero-to-hero/the-way-of-moe-model-evolution.html)

# 1.基础版本MOE

In [1]:
import torch
import torch.nn as nn 
import torch.nn.functional as F

In [2]:
class BasicExpert(nn.Module):
    # 一个 Expert 可以是一个最简单的， linear 层即可
    # 也可以是 MLP 层
    # 也可以是 更复杂的 MLP 层（active function 设置为 swiglu）
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.linear = nn.Linear(feature_in, feature_out)
    
    def forward(self, x):
        return self.linear(x)

In [None]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, expert_number):
        super().__init__()
        self.experts = nn.ModuleList(
            [
                BasicExpert(feature_in, feature_out) for _ in range(expert_number)
            ]
        )
        # gate 就是expert的权重
        self.gate = nn.Linear(feature_in, expert_number)
    
    def forward(self, x):
        # x 的 shape 是 （batch, feature_in)
        expert_weight = F.softmax(self.gate(x), dim=-1)  # shape 是 (batch, expert_number)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]  # 里面每一个元素的 shape 是： (batch, 1, feature_out)

        # concat 起来 (batch, expert_number, feature_out)
        expert_output = torch.cat(expert_out_list, dim=1)

        #print(expert_output.size())

        expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_nuber)

        # expert_weight * expert_out_list
        output = expert_weight @ expert_output  # (batch, 1, feature_out)
        
        return output.squeeze(1)


def test_basic_moe():
    x = torch.rand(2, 4)

    basic_moe = BasicMOE(4, 3, 2)
    out = basic_moe(x)
    print(out)

test_basic_moe()


tensor([[-0.0433, -0.2434,  0.2969],
        [-0.2261, -0.2531,  0.2997]], grad_fn=<SqueezeBackward1>)


# 2.SparseMoe
Gate从是expert的权重变成了选一个expert的概率，然后取topk个专家。  
Gate生成logits，然后softmax，然后取topk个概率，然后生成mask。  
mask形状是(expert_number, top_k, batch_size * seq_len)  

结果首先建立一个全零矩阵，形状为（batch_size * seq_len, hidden_dim）  
然后让每个专家去找对应的token,比如专家0：通过`topx, idx = torch.where(mask[0])`找到batch_size * seq_len中第idx的token送入专家0，乘上路由[topx]的权重（一共有topk个路由），  
然后累加到前面建立的全零矩阵中最后输出累加过后的矩阵结果

In [None]:
class MOEConfig:
    def __init__(self, hidden_dim, expert_num, top_k, shared_expert_num=2):
        self.hidden_dim = hidden_dim
        self.expert_num = expert_num
        self.top_k = top_k
        self.shared_expert_num = shared_expert_num

class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 门控层 会输出一个 logit
        self.gate = nn.Linear(config.hidden_dim, config.expert_num)
        self.expert_num = config.expert_num
        self.topk = config.top_k

    def forward(self, x):
        router_logits = self.gate(x)  # shape is (b * s, expert_num)

        # 计算概率
        router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float)  #[b * s, expert_num]

        #计算top_k 专家输出
        router_weights, selected_experts = torch.topk(
            router_probs, self.topk, dim=-1
        )

        #专家权重归一化：
        router_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) #[b * s, top_k]
        router_weights = router_weights.to(x.dtype)
        # print(selected_experts.shape)  # (b * s, top_k)
        # 生成专家掩码：
        expert_mask =  F.one_hot(
            selected_experts, 
            num_classes=self.expert_num)  #shape 是 (b * s, top_k, expert_num)
        # print(expert_mask.shape)
        expert_mask = expert_mask.permute(2, 1, 0) #(expert_num, top_k, b * s)
        return router_logits, router_weights, selected_experts, expert_mask

class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.topk   = config.top_k
        #初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(config.hidden_dim, config.hidden_dim)
            for _ in range(config.expert_num)
        )
        self.router = MOERouter(config)

    def forward(self, x):
        """ 
        x: [batch_size, seq_len, hidden_dim]
        """
        batch_size, seq_len, hidden_dim = x.shape

        x = x.reshape(batch_size * seq_len, hidden_dim)
        router_logits, router_weights, selected_experts, expert_mask = self.router(x)

        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype=x.dtype,
            device=x.device
        )
        for expert_idx in range(self.config.expert_num):
            expert_layer = self.experts[expert_idx]           #获得当前专家 [hidden_dim, hidden_dim]
            #获得当前专家的top_k个token的索引 idx: 第idx个topk, top_x: token在batch*seq_len中的索引
            # expert_mask.shape 是 (expert_num, top_k, batch_size * seq_len) 
            # idx 对应第一维， top_x 对应第二维
            idx, top_x = torch.where(expert_mask[expert_idx])    
            current_token = x.unsqueeze(0)[:,top_x,:].reshape(-1, hidden_dim) #(selected_token_number, hidden_dim)
            
            #专家输出 * 路由权重
            current_hidden_states = expert_layer(
                current_token
            ) * router_weights[top_x, idx].unsqueeze(-1)
            #在第0维，第top_x索引 累加进输出
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(x.dtype))          
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim) #还原形状
        return final_hidden_states, router_logits

In [40]:
def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 4, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)


test_token_level_moe()

torch.Size([2, 4, 16]) torch.Size([8, 4])


# 3.ShareExpert SparseMOE（deepseek 版本）
这个版本的MOE比SparseMOE多了共享专家，共享专家会处理所有token,最后再和SparseExpert进行加权求和

In [51]:
class ShareExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMOE(config)
        self.shared_experts  = nn.ModuleList(
            [
                BasicExpert(config.hidden_dim, config.hidden_dim) 
                for _ in range(config.shared_expert_num)
            ]
        )
    
    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()

        shared_outputs = [
            expert(x) for expert in self.shared_experts
        ]

        #(shared_experts_number, b, s, hidden_dim)
        shared_output = torch.stack(shared_outputs, dim=0)
        #(b, s, hidden_dim)
        shared_output = shared_output.sum(dim=0)

        #sparsemoe 的输出
        routed_output, router_logits = self.routed_experts_moe(x)

        #shared_output + routed_output
        return shared_output + routed_output, router_logits

In [52]:
def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    share_expert_moe = ShareExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)


test_share_expert_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])


# 4.训练方式

In [55]:
# 测试， loss 部分为 deepseek 生成；
def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)  # [num_experts]
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                    expert_num=4,
                    top_k=2,
                    shared_expert_num=2)
    model = ShareExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_num)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

Batch 0, Loss: 1.9102 (MSE: 1.8899, Aux: 2.0244)
Batch 10, Loss: 1.7203 (MSE: 1.7000, Aux: 2.0341)
Batch 20, Loss: 1.5251 (MSE: 1.5047, Aux: 2.0345)
Batch 30, Loss: 1.3906 (MSE: 1.3702, Aux: 2.0343)
Batch 40, Loss: 1.3177 (MSE: 1.2972, Aux: 2.0422)
Batch 50, Loss: 1.2340 (MSE: 1.2137, Aux: 2.0300)
Batch 60, Loss: 1.1510 (MSE: 1.1308, Aux: 2.0262)
Batch 70, Loss: 1.1400 (MSE: 1.1196, Aux: 2.0397)
Batch 80, Loss: 1.1433 (MSE: 1.1229, Aux: 2.0393)
Batch 90, Loss: 1.0955 (MSE: 1.0750, Aux: 2.0519)
