## 基础MoE

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class FFNBasicExpert(nn.Module):
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.linear = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.linear(x)

In [3]:
class BasicMoE(nn.Module):
    def __init__(self, feature_in, feature_out, num_experts):
        super().__init__()
        self.experts = nn.ModuleList(
            FFNBasicExpert(feature_in, feature_out) 
            for _ in range(num_experts)
        )
        # gate 即选一个 expert 的权重
        # 将 input 送入 gate，分给 num_experts 个 expert
        # 得到 num_experts 个权重
        self.gate = nn.Linear(feature_in, num_experts)

    
    def forward(self, x):
        # x: (batch_size, feature_in), feature_in or hidden_dim
        expert_weights = self.gate(x)  # (batch_size, num_experts)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]  # 每个 expert 输出一个 (batch_size, feature_out) 

        expert_output = torch.concat(
            expert_out_list, dim=1
        )  # -> (b, num_experts, feature_out)

        ' 加权 '
        # expert_weights -> (b, num_experts)
        # expert_weights = F.softmax(expert_weights, dim=1)
        expert_weights = expert_weights.unsqueeze(1)  # -> (b, 1, num_experts)
        
        output = expert_weights @ expert_output  # -> (b, 1, feature_out)
        return output.squeeze(1)

In [4]:
x = torch.rand(4, 512)  # (batch_size, feature_in)
basic_moe = BasicMoE(512, 128, 5)  # (feature_in, feature_out, num_experts)
output = basic_moe(x)
print(output.shape)  # (4, 128)

torch.Size([4, 128])


## Sparse MoE

MoE 选择 topK 个专家，对选出的 K 个专家的输出进行加权求和，并把输入样本变成 LLM 中真实的输入 shape -> (batch_size, seq_len, hidden_dim)

In [5]:
class MoEConfig:
    def __init__(
        self,
        hidden_dim,
        expert_num,
        top_k,
        shared_experts_num=2,
    ):
        self.hidden_dim = hidden_dim
        self.expert_num = expert_num
        self.top_k = top_k
        self.shared_experts_num = shared_experts_num

class MoERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_num)
        # select top k experts
        self.top_k = config.top_k
        self.expert_num = config.expert_num
    
    def forward(self, x):
        # 计算路由 logits
        router_logits = self.gate(x)  # (batch_size * seq_len, expert_num)
        # 计算每个专家的概率
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float32)
        # 计算 top k 个专家的输出、索引
        router_weights, selected_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1
        )  # (batch_size * seq_len, top_k)

        # 重新归一化
        router_weights /= router_weights.sum(dim=-1, keepdim=True) 
        router_weights = router_weights.to(x.dtype)

        # 生成专家 mask
        expert_mask = F.one_hot(
            selected_indices,
            num_classes=self.expert_num
        )  # -> (batch_size * seq_len, top_k, expert_num)
        expert_mask = expert_mask.permute(2, 1, 0)
        # -> (expert_num, top_k, batch_size * seq_len)

        return router_logits, router_weights, selected_indices, expert_mask


class SparseMoE(nn.Module):
    # 稀疏 MOE 模型，这里每一个 token 都会过 topk 个专家
    # 得到对应token 的 hidden_embeddings
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.expert_num = config.expert_num
        self.top_k = config.top_k

        # initialize experts
        self.experts = nn.ModuleList(
            FFNBasicExpert(config.hidden_dim, config.hidden_dim)
            for _ in range(config.expert_num)
        )

        self.router = MoERouter(config)
    
    def forward(self, x):
        # x: (b, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size() 

        # 合并前两个维度，因为不是 Sample 维度了，而是 token 维度
        # 对 token 维度计算, x -reshape-> (batch * seq_len, hidden_dim)
        hidden_state = x.view(-1, hidden_dim)

        router_logits, router_weights, selected_indices, expert_mask = self.router(hidden_state)
        # selected_indices -> (batch_size * seq_len, top_k)
        # expert_mask -> (expert_num, top_k, batch_size * seq_len)

        # 初始化 final hidden states -> (batch_size * seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            device=hidden_state.device
        )

        # 遍历每个 expert
        # 把该 expert 的 token 的 hidden_states 加到 final_hidden_states 上
        # token 总数是 batch_size * seq_len
        for expert_idx in range(self.expert_num):
            expert_layer = self.experts[expert_idx]
            # 选择该 expert 的 mask (expert_num, top_k, batch_size * seq_len)
            cur_expert_mask = expert_mask[expert_idx]
            # -> (top_k, batch_size * seq_len)

            idx, top_x = torch.where(cur_expert_mask > 0)
            '''
            idx: 选择的（topK 中）第 i 个 expert；用于选择 weights
            top_x: 是 token 在 batch * seq_len 中的索引 （1 维的值）；选择 hidden_states
            '''

            # hidden_states 的 shape 是 (b * s, hidden_dim)
            # 需要取到 top_x 对应的 hidden_states
            cur_state = hidden_state.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)
            # unsqueeze(0) -> (1, batch_size * seq_len, hidden_dim)
            # cur_state -> (selected_token_num, hidden_dim)
            cur_state = expert_layer(cur_state)
            cur_token_router_weight = router_weights[top_x, idx]
            # -> (selected_token_num)
            cur_token_router_weight = cur_token_router_weight.unsqueeze(-1)
            # -> (selected_token_num, 1)

            cur_hidden_states = cur_state * cur_token_router_weight
            # -> (selected_token_num, hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                cur_hidden_states.to(final_hidden_states.dtype)
            )
        
        # 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        return final_hidden_states, router_logits

In [6]:
x = torch.rand(2, 4, 16)
config = MoEConfig(16, 2, 2)
token_level_moe = SparseMoE(config)
out = token_level_moe(x)

out[0].shape, out[1].shape

(torch.Size([2, 4, 16]), torch.Size([8, 2]))

### ShareExpert SparseMoE

相较于 SparseMoE 多了 shared experts 模型 ，即共享 token，所有 token 都经过这个 shared experts 模型，然后每个 token 会 用计算的 router 权重，选择 topK 个专家，再与共享的专家输出一起加权求和

In [7]:
class SharedExpertMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.routed_experts_moe = SparseMoE(config)
        self.shared_experts = nn.ModuleList([
            FFNBasicExpert(
                config.hidden_dim, config.hidden_dim
            ) for _ in range(config.shared_experts_num)
        ])
    
    def forward(self, x):
        # x -> (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.shape
        
        # 首先过 moe 模型
        sparse_moe_out, router_logits = self.routed_experts_moe(x)

        # 针对 x 的每一个, 过 shared experts
        shared_experts_output_list = [
            expert(x) for expert in self.shared_experts
        ]

        shared_experts_output = torch.stack(
            shared_experts_output_list,
            dim=0
        ).sum(dim=0)  # -> (batch_size, seq_len, hidden_dim)
        
        output = shared_experts_output + sparse_moe_out

        return output, router_logits

In [8]:
x = torch.rand(2, 4, 16)
config = MoEConfig(16, 2, 2)
shared_experts_moe = SharedExpertMoE(config)
out = shared_experts_moe(x)

out[0].shape, out[1].shape

(torch.Size([2, 4, 16]), torch.Size([8, 2]))

## Train

In [9]:
# 测试， loss 部分为 deepseek 生成；

def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)  # [num_experts]
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MOEConfig(hidden_dim=hidden_dim, 
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = ShareExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

NameError: name 'MOEConfig' is not defined