## 2 - SqarseMoE

In [1]:
import torch
import torch.nn as nn  
import torch.nn.functional as F

In [2]:
class Expert(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        
        self.expert = nn.Linear(input_dim, output_dim)
        # 普通线性层
        # 输入形状为(batch_size, input_dim)

    def forward(self, x):
        return self.expert(x)

In [3]:
class BasicMoE(nn.Module):
    def __init__ (self, num_experts, input_dim, output_dim):
        super(BasicMoE, self).__init__()
        self.experts = nn.ModuleList(
            [nn.Linear(input_dim, output_dim) for _ in range(num_experts)]
        )

        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        # x 的形状是 (batch_size, input_dim)
        expert_weights = self.gate(x)
        # expert_weights 的形状是 (batch_size, num_experts)
        expert_outputs_list = [
            experts(x).unsqueeze(1) for experts in self.experts
            
        ] # expert_outputs_list 是一个列表，列表中的每个元素是一个形状为 (batch_size, 1, output_dim) 的张量
        print(expert_outputs_list[0].shape)
        expert_outputs = torch.cat(expert_outputs_list, dim=1)
        #  expert_outputs 的形状是 (batch_size, num_experts, output_dim)
        print(expert_outputs.shape)

  
        # 所以我们可以使用 softmax 函数将其转换为概率分布
        gate_outputs = F.softmax(expert_weights, dim=1)
        # gate_outputs 的形状是 (batch_size, num_experts)
        # 我们可以使用这些概率来加权每个专家的输出
        weighted_expert_outputs = gate_outputs @ expert_outputs #执行的是严格矩阵乘法
        # weighted_expert_outputs 的形状是 (batch_size, experts ,output_dim)
        outputs = torch.sum(weighted_expert_outputs, dim=1)
        # outputs 的形状是 (batch_size, output_dim) 
        # sum是对第二维度求和将experts维度消除
        # print(weighted_expert_outputs.shape)
        return outputs
def test():
    x = torch.randn(4, 512)
    model = BasicMoE(4, 512, 128)
    output = model(x)
    print(output.shape)

test()

torch.Size([4, 1, 128])
torch.Size([4, 4, 128])
torch.Size([4, 128])


In [11]:
class MoEconfig:
    def __init__(self, top_k, expert_num, hidden_dim, share_expert_num=2):
        self.top_k = top_k # 选择出的K个专家
        self.expert_num = expert_num # 专家数量
        self.share_expert_num = share_expert_num # 共享专家数量
        self.hidden_dim = hidden_dim # 专家的隐藏层维度

class MoErouter(nn.Module):
    def __init__(self, hidden_num, expert_num, top_k):
        super(MoErouter, self).__init__()
        self.gate = nn.Linear(hidden_num, expert_num)
        self.expert_num = expert_num
        self.top_k = top_k

    def forward(self, hidden_state):
        # hidden_state.shape=（batch_size*seq_len, hidden_dim）
        # 计算路由
        router_logits = self.gate(hidden_state) # gate是一个线性层，输入是hidden_state，输出是router_logits
        # router_logits.shape=（batch_size*seq_len, expert_num）

        # 计算专家进行softmax的概率
        router_probs = F.softmax(router_logits, dim=-1)

        # 计算topk专家的输出
        router_weights, selected_experts = torch.topk(router_probs, self.top_k, dim=-1)
        # router_weights.shape=（batch_size*seq_len, top_k）
        # selected_experts.shape=（batch_size*seq_len, top_k）

        # 进行专家权重的归一化
        router_weights = router_weights / torch.sum(router_weights, dim=-1, keepdim=True)
        router_weights = router_weights.to(hidden_state.dtype)

        # 生成专家的mask
        experts_mask = F.one_hot(selected_experts, num_classes=self.expert_num)
        # experts_mask.shape=（batch_size*seq_len, top_k, expert_num）

        # 将专家的mask转置
        experts_mask = experts_mask.permute(2, 1, 0) 
        # permute函数用于维度的转换,参数是新的各个维度的索引
        # 我希望它变成experts_mask.shape=（expert_num, top_k, batch_size*seq_len）

        return router_logits, router_weights, selected_experts, experts_mask
                # 返回的router_logits是路由的logits，
                # router_weights是路由的概率，
                # selected_experts是选择的专家，
                # experts_mask是专家的mask

        
        
class SparseMoE(nn.Module):
    def __init__(self, config):
        super(SparseMoE, self).__init__()
        self.expert_num = config.expert_num
        self.share_expert_num = config.share_expert_num
        self.hidden_dim = config.hidden_dim
        self.experts = nn.ModuleList(
            [
                Expert(self.hidden_dim, self.hidden_dim) for _ in range(self.expert_num)  # 专家数量
            ]
        )
        self.router = MoErouter(self.hidden_dim, self.expert_num, config.top_k)

    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        hidden_state = x.view(-1, hidden_dim)  # 展开成(batch_size*seq_len, hidden_dim)
        router_logits, router_weights, selected_experts_idx, expert_mask = self.router(hidden_state)
        
        final_hidden_states = torch.zeros((batch_size * seq_len, hidden_dim), dtype=torch.float32, device=x.device)

        for experts_idx in range(self.expert_num):
            expert_layer = self.experts[experts_idx]
            idx, top_x = torch.where(expert_mask[experts_idx])  # 获取选中的专家及其位置
            current_state = hidden_state.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)
            current_hidden_state = expert_layer(current_state) * router_weights[top_x, idx].unsqueeze(-1)
            final_hidden_states = final_hidden_states.index_add_(0, top_x, current_hidden_state.to(hidden_state.dtype))

        final_hidden_states = final_hidden_states.view(batch_size, seq_len, hidden_dim)
        return final_hidden_states, router_logits


def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MoEconfig(2, 2, 16)
    token_level_moe = SparseMoE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)

test_token_level_moe()


torch.Size([2, 4, 16]) torch.Size([8, 2])


## DeepSeekMoE


In [12]:

class DeepSeekMoE(nn.Module):
    def __init__(self, config):
        super(DeepSeekMoE, self).__init__()
        self.DeepSeekMoEconfig = SparseMoE(config)
        self.shared_experts = nn.ModuleList(
            [
                Expert(config.hidden_dim, config.hidden_dim) for _ in range(config.share_expert_num)  # 专家数量
            ]
        )

    def forward(self, x):
        
        # 先过SparseMoE
        sparse_moe_out, router_logits = self.DeepSeekMoEconfig(x)
        # sparse_moe_out.shape=（batch_size, seq_len, hidden_dim）
        # router_logits.shape=（batch_size*seq_len, expert_num）

        # 过共享专家
        shared_expert_out = [
            expert(x) for expert in self.shared_experts
        ] # shared_expert_out是一个列表，列表中的每个元素是一个tensor,
        # 其中每个expert的shape是（batch_size, seq_len, hidden_dim）

        # 拼接
        shared_expert_out = torch.stack(shared_expert_out, dim=0).sum(dim=0)
        # shared_expert_out.shape=（ batch_size, seq_len, hidden_dim）

        return sparse_moe_out + shared_expert_out, router_logits
    
def test_DeepSeekMoE(): 
    x = torch.rand(2, 4, 16)
    config = MoEconfig(2, 2, 16)
    deep_seek_moe = DeepSeekMoE(config)
    out = deep_seek_moe(x)
    print(out[0].shape, out[1].shape)

test_DeepSeekMoE()



torch.Size([2, 4, 16]) torch.Size([8, 2])


In [13]:

def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失
    
    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量
    
    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]
    
    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1) 
    
    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float() 
    
    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts
    
    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)
    
    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts
    
    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数
    
    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight
    
    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100
    
    # Initialize model and optimizer
    config = MoEconfig(hidden_dim=hidden_dim, 
                      expert_num=4,
                      top_k=2,
                      share_expert_num=2)
    model = DeepSeekMoE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    
    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)
        
        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)
        
        aux_loss = switch_load_balancing_loss(router_logits, config.expert_num)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss
        
        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

Batch 0, Loss: 1.8591 (MSE: 1.8390, Aux: 2.0148)
Batch 10, Loss: 1.6580 (MSE: 1.6379, Aux: 2.0058)
Batch 20, Loss: 1.4946 (MSE: 1.4745, Aux: 2.0068)
Batch 30, Loss: 1.3924 (MSE: 1.3724, Aux: 2.0051)
Batch 40, Loss: 1.2909 (MSE: 1.2708, Aux: 2.0165)
Batch 50, Loss: 1.2380 (MSE: 1.2179, Aux: 2.0148)
Batch 60, Loss: 1.2032 (MSE: 1.1831, Aux: 2.0148)
Batch 70, Loss: 1.1571 (MSE: 1.1369, Aux: 2.0175)
Batch 80, Loss: 1.1068 (MSE: 1.0866, Aux: 2.0238)
Batch 90, Loss: 1.0898 (MSE: 1.0696, Aux: 2.0190)
