In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

在正式了解 Deepseek MoE 之前，我们先来了解一下 传统的 Mixture of Experts (MoE) 模型是怎么做的。

moe模型的专家就是一个个MLP模型，mlp 的输入维度和输出维度是一样的。

In [4]:
class MLP(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float = 0.):
        super().__init__()
        # 定义第一层线性变换，从输入维度到隐藏维度
        self.w1 = nn.Linear(dim, hidden_dim, bias=False)
        # 定义第二层线性变换，从隐藏维度到输入维度
        self.w2 = nn.Linear(hidden_dim, dim, bias=False)
        # 定义第三层线性变换，从输入维度到隐藏维度
        self.w3 = nn.Linear(dim, hidden_dim, bias=False)
        # 定义dropout层，用于防止过拟合
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
    
mlp_model = MLP(128, 512)

input_tensor = torch.randn(64, 128)
output_tensor = mlp_model(input_tensor)
output_tensor.shape

torch.Size([64, 128])

In [6]:
class MoE(nn.Module):
    def __init__(self, num_experts: int, dim: int, hidden_dim: int, dropout: float = 0.):
        super().__init__()
        self.num_experts = num_experts
        self.experts = nn.ModuleList([MLP(dim, hidden_dim, dropout) for _ in range(num_experts)])
        self.gate = nn.Linear(dim, num_experts)

    def forward(self, x):
        # x.shape = (batch_size, dim)
        gate_logits = self.gate(x)
        print("gate_logits.shape = ", gate_logits.shape)
        # gate_logits.shape = (batch_size, num_experts)
        gate_probs = F.softmax(gate_logits, dim=-1)
        print("gate_probs.shape = ", gate_probs.shape)
        # gate_probs.shape = (batch_size, num_experts)
        expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
        print("expert_outputs.shape = ", expert_outputs.shape)
        # expert_outputs.shape = (batch_size, num_experts, dim)
        return torch.sum(gate_probs.unsqueeze(-1) * expert_outputs, dim=1)

moe_model = MoE(4, 128, 512)
x = torch.randn(64, 128)
y = moe_model(x)
y.shape

gate_logits.shape =  torch.Size([64, 4])
gate_probs.shape =  torch.Size([64, 4])
expert_outputs.shape =  torch.Size([64, 4, 128])


torch.Size([64, 128])