<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/deepseek_moe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# https://www.bilibili.com/video/BV1uUPieDEK1?spm_id_from=333.788.videopod.sections&vd_source=1fecee762931e992c96e5e166be13b76
# https://github.com/hkproj/pytorch-transformer/blob/main/train.py
# https://github.com/hkproj/pytorch-transformer/blob/main/model.py
# https://github.com/XihWang/LLM_RethinkFun


In [2]:
# classical MOE

In [3]:
import torch
from torch import nn

In [4]:
class ExpertNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        output = self.linear2(x)
        return output

class Router(nn.Module):
    def __init__(self, hidden_size, expert_num, top_k):
        super().__init__()
        self.router = nn.Linear(hidden_size, expert_num)
        self.top_k = top_k
        self.hidden_size = hidden_size

    def forward(self, x):
        x = x.view(-1, self.hidden_size)
        x = self.router(x)
        x = nn.functional.softmax(x, dim=1)
        topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=1, sorted=False)
        # 对topK权重重新归一化
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
        return topk_weight, topk_idx

class MOELayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.expert_num = expert_num
        self.top_k = top_k
        self.experts = nn.ModuleList(
            [ExpertNetwork(hidden_size, intermediate_size) for _ in range(self.expert_num)]
        )
        self.router = Router(hidden_size, expert_num, top_k)

    def forward(self, x):  # shape of x is (batch_size, seq_len, hidden_size)
        batch_size, seq_len, _ = x.size()
        token_num = batch_size * seq_len
        x_flat = x.view(token_num, self.hidden_size)

        # 通过路由器获得top-k专家选择的权重和索引，形状均为(N, top_k)
        topk_weight, topk_idx = self.router(x_flat)

        # 初始化输出张量
        output = torch.zeros_like(x_flat)
        for token_idx in range(token_num):
            for expert_idx in range(self.top_k):
                expert = self.experts[topk_idx[token_idx][expert_idx]]
                output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx])

        output = output.view(batch_size, seq_len, self.hidden_size)
        return output

In [5]:
# Constants
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2

# Example usage
inputs = torch.randn((2, 11, 4096))
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)
outputs = moe_layer(inputs)
print(outputs.size())

torch.Size([2, 11, 4096])


In [6]:
# deepseek MOE
# can run
# lack audloss from load balancing
# does not match size of moe above, therefore cannot do benchmark test

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Dummy config for testing
class Config:
    def __init__(self):
        self.moe_intermediate_size = 512
        self.n_routed_experts = 4
        self.n_shared_experts = 1
        self.num_experts_per_tok = 2

config = Config()

# Expert MLP
class DeepseekMLP(nn.Module):
    def __init__(self, config, intermediate_size):
        super().__init__()
        self.fc1 = nn.Linear(768, intermediate_size)
        self.fc2 = nn.Linear(intermediate_size, 768)

    def forward(self, x):
        return self.fc2(F.relu(self.fc1(x)))

# Gating mechanism
class MoeGate(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.linear = nn.Linear(768, config.n_routed_experts)
        self.top_k = config.num_experts_per_tok

    def forward(self, x):
        logits = self.linear(x)
        topk_weight, topk_idx = torch.topk(F.softmax(logits, dim=-1), self.top_k)
        aux_loss = torch.tensor(0.0, device=x.device)  # Dummy placeholder
        return topk_idx, topk_weight, aux_loss

# MoE Module
class DeepseekMoe(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_tok = config.num_experts_per_tok
        self.experts = nn.ModuleList([
            DeepseekMLP(config, intermediate_size=config.moe_intermediate_size)
            for _ in range(config.n_routed_experts)
        ])
        self.shared_experts = (
            DeepseekMLP(config, intermediate_size=config.moe_intermediate_size)
            if config.n_shared_experts > 0 else None
        )
        self.gate = MoeGate(config)

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_dim = hidden_states.shape
        input_flat = hidden_states.view(-1, hidden_dim)  # (B * T, H)

        topk_idx, topk_weight, aux_loss = self.gate(input_flat)  # (B * T, K)
        flat_topk_idx = topk_idx.view(-1)  # (B * T * K)

        # Repeat inputs for each top-k expert
        expanded_inputs = input_flat.repeat_interleave(self.num_experts_per_tok, dim=0)  # (B * T * K, H)
        expert_outputs = torch.zeros_like(expanded_inputs)

        # Route inputs to experts
        for i, expert in enumerate(self.experts):
            mask = (flat_topk_idx == i)
            if mask.any():
                expert_outputs[mask] = expert(expanded_inputs[mask])

        # Weight expert outputs
        weighted_output = expert_outputs.view(batch_size * seq_len, self.num_experts_per_tok, hidden_dim)
        weighted_output = (weighted_output * topk_weight.unsqueeze(-1)).sum(dim=1)  # (B * T, H)
        output = weighted_output.view(batch_size, seq_len, hidden_dim)

        # Add shared expert output if present
        if self.shared_experts is not None:
            output += self.shared_experts(hidden_states)

        return output

# Run test
if __name__ == "__main__":
    model = DeepseekMoe(config)
    dummy_input = torch.randn(2, 10, 768)  # (batch_size=2, seq_len=10, hidden_dim=768)
    output = model(dummy_input)
    print("✅ Output shape:", output.shape)  # should be [2, 10, 768]

✅ Output shape: torch.Size([2, 10, 768])


In [10]:
# experiment

In [37]:
import torch
import time

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mock Config for DeepseekMoE
class Config:
    def __init__(self, hidden_size, intermediate_size, n_routed_experts, n_shared_experts, num_experts_per_tok):
        self.hidden_size = hidden_size
        self.moe_intermediate_size = intermediate_size
        self.n_routed_experts = n_routed_experts
        self.n_shared_experts = n_shared_experts
        self.num_experts_per_tok = num_experts_per_tok

# Input dimensions
BATCH_SIZE = 16
SEQ_LENGTH = 128
HIDDEN_SIZE = 768
INTERMEDIATE_SIZE = 3072
EXPERT_NUM = 4
TOP_K = 2

# Input tensor
input_tensor = torch.randn(BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE).to(device)

In [38]:
def benchmark(model, input_tensor, n_iters=2):
    model.eval()
    model.to(device)
    torch.cuda.empty_cache()
    with torch.no_grad():
        # Warmup
        for _ in range(10):
            _ = model(input_tensor)

        # Measure
        start = time.time()
        for _ in range(n_iters):
            _ = model(input_tensor)
        # torch.cuda.synchronize()
        end = time.time()

    avg_time = (end - start) / n_iters
    return avg_time

In [39]:
# Initialize MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)

# Initialize DeepseekMoE
config = Config(HIDDEN_SIZE, INTERMEDIATE_SIZE, n_routed_experts=EXPERT_NUM, n_shared_experts=1, num_experts_per_tok=TOP_K)
deepseek_moe = DeepseekMoE(config)

# Benchmark
moe_time = benchmark(moe_layer, input_tensor)
print(f"MOELayer avg inference time: {moe_time:.6f}s")

MOELayer avg inference time: 9.419379s
