# 完全从零写一个 MOE 大模型（LLM）
Build a miniMoE-LLM from scratch
从零开始构建 MoE 模型，从基础讲起，共讲解三个版本
1. 基础版本，理解MOE
2. SparseMoE，了解大模型怎么训练 MOE LLM
3. ShareExpert SparseMoe，了解 deepseek 训练 MOE 模型算法


## 版本1： 基础版MOE


In [1]:
from dataclasses import dataclass
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class BasicExpert(nn.Module):
    # 一个 Expert 可以是一个最简单的， linear 层即可
    # 也可以是 MLP 层
    # 也可以是 更复杂的 MLP 层（active function 设置为 swiglu）
    def __init__(self, feature_in, feature_out):
        super().__init__()
        self.linear = nn.Linear(feature_in, feature_out)

    def forward(self, x):
        return self.linear(x)


In [3]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out, expert_number):
        super().__init__()
        self.experts = nn.ModuleList(
            [
                BasicExpert(feature_in, feature_out) for _ in range(expert_number)
            ]
        )
        # gate 就是选择一个 expert
        self.gate = nn.Linear(feature_in, expert_number)

    def forward(self, x):
        # x 的 shape 是 (batch, feature_in)
        expert_weight = self.gate(x) # shape 是 (batch, expert_number)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
                           ] # 里面每一个元素的 shape 是： (batch, 1, feature_out)
        # concat 起来 (batch, expert_number, feature_out)
        expert_output = torch.concat(expert_out_list, dim = 1)

        # print(expert_output.size())

        expert_weight = expert_weight.unsqueeze(1) # (batch, 1, expert_number)
        expert_weight = F.softmax(expert_weight, dim = -1)

        # expert_weight * expert_out_list
        output = expert_weight @ expert_output # (batch, 1, feature_out)

        return output.squeeze(1)

In [4]:
def test_basic_moe():
    x = torch.rand(2, 4)

    basic_moe = BasicMOE(4, 3, 2)
    out = basic_moe(x)
    print(out)
    print(out.shape)

test_basic_moe()

tensor([[-0.7145,  0.0609, -0.8199],
        [-0.6657, -0.3102, -0.5012]], grad_fn=<SqueezeBackward1>)
torch.Size([2, 3])


## 版本2： SparseMoe

In [5]:
class MOERouter(nn.Module):
    # 选择专家模型, 返回选择的专家模型索引和专家模型的gate
    def __init__(self, hidden_dim, expert_number, top_k):
        super().__init__()
        """
        hidden_dim :
        expert_number : 专家的个数
        top_K : 每次通过选择激活几个专家
        """
        self.gate = nn.Linear(hidden_dim, expert_number)
        self.expert_number = expert_number
        self.top_k = top_k

    def forward(self, hidden_states):
        # 计算router_logits
        router_logits = self.gate(hidden_states) # (b * s, expert_number)

        # 计算专家经过的softmax之后的概率
        routing_probs = F.softmax(router_logits, dim=-1, dtype=torch.float) # (b * s, expert_number)

        # 计算前top_k专家的输出
        router_weights, selected_experts_idex = torch.topk(
            routing_probs, self.top_k, dim=-1
        )
        # (b * s, top_k)  torch.topk()在-1维选择topk个元素, return routing_probs选择的元素和对应的索引
        """
        Example::

        >>> x = torch.arange(1., 6.)
        >>> x
        tensor([ 1.,  2.,  3.,  4.,  5.])
        >>> torch.topk(x, 3)
        torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
        """

        # 专家权重归一化
        routing_weights = router_weights / router_weights.sum(dim=-1, keepdim=True) # (b * s, top_k)
        router_weights = routing_weights.to(hidden_states.dtype) # (b * s, top_k)

        # 生成专家掩码
        expert_mask = F.one_hot(
            selected_experts_idex,
            num_classes=self.expert_number
        ) # (b * s, top_k, expert_number)
        # expert_mask 为专家对应的索引one_hot形式
        expert_mask = expert_mask.permute(2, 1, 0) # (expert_number, top_k, b * s)
        """
        Examples:
        >>> F.one_hot(torch.arange(0, 5) % 3)
        tensor([[1, 0, 0],
                [0, 1, 0],
                [0, 0, 1],
                [1, 0, 0],
                [0, 1, 0]])
        >>> F.one_hot(torch.arange(0, 5) % 3, num_classes=5)
        tensor([[1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0],
                [0, 0, 1, 0, 0],
                [1, 0, 0, 0, 0],
                [0, 1, 0, 0, 0]])
        >>> F.one_hot(torch.arange(0, 6).view(3,2) % 3) => (3, 2, 3)
        tensor([[[1, 0, 0],
                 [0, 1, 0]],
                [[0, 0, 1],
                 [1, 0, 0]],
                [[0, 1, 0],
                 [0, 0, 1]]])
        """
        return router_logits, router_weights, selected_experts_idex, expert_mask


class MOEconfig:
    def __init__(
            self,
            hidden_dim, #
            expert_number,
            top_k,
            shared_experts_number = 2
            ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number

class SparseMOE(nn.Module):
    # 稀疏 MOE 模型，这里每一个 token 都会过 topk 个专家，得到对应token 的 hidden_embeddings
    def __init__(self, config):
        super().__init__()

        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number
        self.top_k = config.top_k

        self.experts = nn.ModuleList(
            [
                BasicExpert(self.hidden_dim, self.hidden_dim) for _ in range(self.expert_number)
            ]
        )

        self.router = MOERouter(self.hidden_dim, self.expert_number, self.top_k)

    def forward(self, x):
        # x shape is (b, s, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size()

        # 合并前两个维度, 因为不是 Sample 维度了, 而是 token 维度
        hidden_states = x.view(-1, hidden_dim)  # (b * s, hidden_dim)

        router_logits, router_weights, selected_experts_indices, expert_mask = self.router(hidden_states)
        # router_weights (b * s, top_k)
        # 其中 selected_experts_indices shape 是 (b * s, top_k)
        # 其中 expert_mask shape 是 (expert_number, top_k, b * s)

        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            dtype = hidden_states.dtype,
            device = hidden_states.device
        )

        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            # expert_mask[expert_idx] shape 是 (top_k, b * s)
            idx, top_x = torch.where(expert_mask[expert_idx])
            # torch.where返回的idx为行索引 ,top_k为列索引
            # idx 的值是表示这个 token 是作为当前专家第idx专家

            # hidden_states 的 shape 是 (b * s, hidden_dim)
            # 需要取到 top_x 对应的 hidden_states
            current_states = hidden_states.unsqueeze(
                0
            )[:, top_x, :].reshape(-1, hidden_dim) # (selected_token_number, hidden_dim）

            # router_weight 的 shape 是 (b * s, top_k)
            current_hidden_states = expert_layer(
                current_states,
            ) * router_weights[top_x, idx].unsqueeze(-1) # (selected_token_number, 1)

            # 把当前专家的输出加到 final_hidden_states 中
            # 方式1 的写法性能更好，并且方式1容易出现
            final_hidden_states.index_add_(0, top_x, current_hidden_states)
            # 方式2
            final_hidden_states[top_x] += current_hidden_states.to(hidden_states.dtype)
            # 方式1 的写法性能更差, 并且方式1容易出现错误, += 操作在处理重复索引时需要多次读写内存, 可能会导致竞争条件

        # 把 final_hidden_states 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)

        return final_hidden_states, router_logits # shape 是 (b * s, expert_number)



In [8]:
def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEconfig(16, 2, 2)  # hidden_dim, expert_number, top_k,
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)



test_token_level_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])


## 版本3：ShareExpert SparseMoE （deepseek 版本）

In [100]:
class ShareExpertMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.moe_model = SparseMOE(config)
        self.share_experts = nn.ModuleList(
            [
                BasicExpert(
                    config.hidden_dim, config.hidden_dim
                ) for _ in range(config.shared_experts_number)
            ]
        )

    def forward(self, x):
        # shape of X  is (b, s, hidden_dim)

        # 1. 首先过router模型
        sparse_moe_out, router_logits = self.moe_model(x)
        # 1. 再过share模型
        share_experts_out = [expert(x) for expert in self.share_experts] # (b, s, hidden_dim)

        share_experts_out = torch.stack(
           share_experts_out, dim=0
        ).sum(dim=0)

        # 把 sparse_moe_out 和 shared_experts_out 加起来
        return sparse_moe_out + share_experts_out, router_logits





def test_share_expert_moe():
    x = torch.rand(2, 4, 16)
    config = MOEconfig(16, 2, 2)
    share_expert_moe = ShareExpertMOE(config)
    out = share_expert_moe(x)
    print(out[0].shape, out[1].shape)

test_share_expert_moe()

torch.Size([2, 4, 16]) torch.Size([8, 2])


In [48]:
# 测试， loss 部分为 deepseek 生成；

def switch_load_balancing_loss(router_logits: torch.Tensor, num_experts: int) -> torch.Tensor:
    """
    计算 Switch Transformers 的负载均衡损失

    Args:
        router_logits: shape [batch_size * sequence_length, num_experts]
        num_experts: 专家数量

    Returns:
        total_loss: 总损失 = auxiliary_loss + z_loss
    """
    # 计算路由概率
    router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts]

    # 获取每个token的最优专家
    _, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s]

    # 创建one-hot矩阵表示选中的专家
    mask = torch.nn.functional.one_hot(selected_experts, num_experts).float()  # [b*s, num_experts]

    # 计算每个专家的期望负载 (理想情况下应该是 1/num_experts)
    expected_load = torch.ones_like(router_probs) / num_experts

    # 计算实际负载 (每个专家处理的token数量除以总token数量)
    # 在batch维度上计算平均值
    actual_load = mask.mean(dim=0)  # [num_experts]

    # 计算auxiliary loss
    # 这会惩罚负载分布与期望负载的差异
    aux_loss = torch.sum(actual_load * router_probs.mean(dim=0)) * num_experts

    # 计算z_loss (可选)
    # 这会惩罚过大的路由logits
    z_loss = torch.mean(torch.square(router_logits))
    z_loss_weight = 0.001  # 可调整的超参数

    # 总损失
    total_loss = aux_loss + z_loss * z_loss_weight

    return total_loss

def test_moe_training():
    # Create a simple dataset
    batch_size = 32
    seq_len = 16
    hidden_dim = 32
    num_batches = 100

    # Initialize model and optimizer
    config = MOEconfig(hidden_dim=hidden_dim,
                      expert_number=4,
                      top_k=2,
                      shared_experts_number=2)
    model = ShareExpertMOE(config)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Training loop
    model.train()
    for batch in range(num_batches):
        # Generate random input data
        x = torch.randn(batch_size, seq_len, hidden_dim)
        target = torch.randn(batch_size, seq_len, hidden_dim)

        # Forward pass
        output, router_logits = model(x)

        # Compute losses
        # MSE loss for prediction
        mse_loss = F.mse_loss(output, target)

        aux_loss = switch_load_balancing_loss(router_logits, config.expert_number)
        # Combined loss
        total_loss = mse_loss + 0.01 * aux_loss

        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        if batch % 10 == 0:
            print(f"Batch {batch}, Loss: {total_loss.item():.4f} "
                  f"(MSE: {mse_loss.item():.4f}, Aux: {aux_loss.item():.4f})")

# Run the training test
test_moe_training()

Batch 0, Loss: 2.3664 (MSE: 2.3461, Aux: 2.0238)
Batch 10, Loss: 2.0350 (MSE: 2.0147, Aux: 2.0313)
Batch 20, Loss: 1.8910 (MSE: 1.8707, Aux: 2.0292)
Batch 30, Loss: 1.6639 (MSE: 1.6436, Aux: 2.0329)
Batch 40, Loss: 1.4933 (MSE: 1.4728, Aux: 2.0466)
Batch 50, Loss: 1.4352 (MSE: 1.4148, Aux: 2.0453)
Batch 60, Loss: 1.3297 (MSE: 1.3091, Aux: 2.0618)
Batch 70, Loss: 1.2803 (MSE: 1.2596, Aux: 2.0741)
Batch 80, Loss: 1.2275 (MSE: 1.2067, Aux: 2.0810)
Batch 90, Loss: 1.2094 (MSE: 1.1883, Aux: 2.1043)


In [59]:
router_logits = tensor = torch.empty(2 * 3 , 3).uniform_(0.0, 10.0) # (b * s, expert_number)

In [60]:
router_logits # (b * s, expert_number)

tensor([[0.9214, 4.9050, 8.1670],
        [5.8767, 9.7485, 8.7586],
        [8.9079, 6.8320, 8.6373],
        [4.1807, 3.6645, 0.3205],
        [0.4258, 5.4426, 3.3150],
        [7.6354, 7.5148, 5.6424]])

In [61]:
 # 计算路由概率
router_probs = torch.softmax(router_logits, dim=-1)  # [b*s, num_experts] => (6,3)

In [62]:
router_probs # [b*s, num_experts] => (6,3)

tensor([[6.8650e-04, 3.6872e-02, 9.6244e-01],
        [1.4954e-02, 7.1816e-01, 2.6688e-01],
        [5.2956e-01, 6.6430e-02, 4.0401e-01],
        [6.1809e-01, 3.6889e-01, 1.3020e-02],
        [5.8860e-03, 8.8829e-01, 1.0582e-01],
        [4.9441e-01, 4.3822e-01, 6.7377e-02]])

In [72]:
# 获取每个token的最优专家
selected_experts_weight, selected_experts = torch.topk(router_probs, k=2, dim=-1)  # [b*s, top_k] => (6,2)


In [96]:
selected_experts_weight

tensor([[0.9624, 0.0369],
        [0.7182, 0.2669],
        [0.5296, 0.4040],
        [0.6181, 0.3689],
        [0.8883, 0.1058],
        [0.4944, 0.4382]])

In [95]:
selected_experts

tensor([[2, 1],
        [1, 2],
        [0, 2],
        [0, 1],
        [1, 2],
        [0, 1]])

In [74]:
# 创建one-hot矩阵表示选中的专家
mask = torch.nn.functional.one_hot(selected_experts, 3).float()  # [b*s, top_k, num_experts] => (6,2,3)

In [98]:
mask[:,:,0]

tensor([[0., 0.],
        [0., 0.],
        [1., 0.],
        [1., 0.],
        [0., 0.],
        [1., 0.]])

In [75]:
mask.shape

torch.Size([6, 2, 3])

In [77]:
expert_mask = mask.permute(2, 1, 0)  # (expert_number, top_k, b * s) => (3, 2, 6)

In [86]:
expert_mask[0].shape # (2, 6)

torch.Size([2, 6])

In [89]:
expert_mask[0]

tensor([[0., 0., 1., 1., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [87]:
idx, top_x = torch.where(expert_mask[0])

In [88]:
idx

tensor([0, 0, 0])

In [90]:
top_x

tensor([2, 3, 5])

In [94]:
selected_experts_weight

tensor([[0.9624, 0.0369],
        [0.7182, 0.2669],
        [0.5296, 0.4040],
        [0.6181, 0.3689],
        [0.8883, 0.1058],
        [0.4944, 0.4382]])

In [99]:
selected_experts_weight[idx, top_x].unsqueeze(-1)

IndexError: index 2 is out of bounds for dimension 1 with size 2