<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/deepseek_moe.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# https://www.bilibili.com/video/BV1uUPieDEK1?spm_id_from=333.788.videopod.sections&vd_source=1fecee762931e992c96e5e166be13b76
# https://github.com/hkproj/pytorch-transformer/blob/main/train.py
# https://github.com/hkproj/pytorch-transformer/blob/main/model.py
# https://github.com/XihWang/LLM_RethinkFun


In [2]:
# classical MOE

In [None]:
import torch
from torch import nn

In [None]:
class ExpertNetwork(nn.Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.linear1 = nn.Linear(hidden_size, intermediate_size)
        self.linear2 = nn.Linear(intermediate_size, hidden_size)

    def forward(self, x):
        x = self.linear1(x)
        x = nn.functional.relu(x)
        output = self.linear2(x)
        return output

class Router(nn.Module):
    def __init__(self, hidden_size, expert_num, top_k):
        super().__init__()
        self.router = nn.Linear(hidden_size, expert_num)
        self.top_k = top_k
        self.hidden_size = hidden_size

    def forward(self, x):
        x = x.view(-1, self.hidden_size)
        x = self.router(x)
        x = nn.functional.softmax(x, dim=1)
        topk_weight, topk_idx = torch.topk(x, k=self.top_k, dim=1, sorted=False)
        # 对topK权重重新归一化
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)
        return topk_weight, topk_idx

class MOELayer(nn.Module):
    def __init__(self, hidden_size, intermediate_size, expert_num, top_k):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.expert_num = expert_num
        self.top_k = top_k
        self.experts = nn.ModuleList(
            [ExpertNetwork(hidden_size, intermediate_size) for _ in range(self.expert_num)]
        )
        self.router = Router(hidden_size, expert_num, top_k)

    def forward(self, x):  # shape of x is (batch_size, seq_len, hidden_size)
        batch_size, seq_len, _ = x.size()
        token_num = batch_size * seq_len
        x_flat = x.view(token_num, self.hidden_size)

        # 通过路由器获得top-k专家选择的权重和索引，形状均为(N, top_k)
        topk_weight, topk_idx = self.router(x_flat)

        # 初始化输出张量
        output = torch.zeros_like(x_flat)
        for token_idx in range(token_num):
            for expert_idx in range(self.top_k):
                expert = self.experts[topk_idx[token_idx][expert_idx]]
                output[token_idx] += topk_weight[token_idx][expert_idx] * expert(x_flat[token_idx])

        output = output.view(batch_size, seq_len, self.hidden_size)
        return output

In [None]:
# Constants
HIDDEN_SIZE = 4096
INTERMEDIATE_SIZE = 2048
EXPERT_NUM = 8
TOP_K = 2

# Example usage
inputs = torch.randn((2, 11, 4096))
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)
outputs = moe_layer(inputs)
print(outputs.size())

In [None]:
# deepseek MOE

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AddAuxiliaryLoss(torch.autograd.Function):
    """
    Custom autograd function to incorporate auxiliary loss into the computation graph.
    """
    @staticmethod
    def forward(ctx, x, loss):
        assert loss.numel() == 1
        ctx.dtype = loss.dtype
        ctx.required_aux_loss = loss.requires_grad
        return x

    @staticmethod
    def backward(ctx, grad_output):
        grad_loss = None
        if ctx.required_aux_loss:
            grad_loss = torch.ones(1, dtype=ctx.dtype, device=grad_output.device)
        return grad_output, grad_loss

class DeepseekMLP(nn.Module):
    """
    Multi-Layer Perceptron used as an expert in the MoE layer.
    """
    def __init__(self, config, intermediate_size):
        super().__init__()
        self.fc1 = nn.Linear(config.hidden_size, intermediate_size)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(intermediate_size, config.hidden_size)

    def forward(self, x):
        return self.fc2(self.act(self.fc1(x)))

class MoeGate(nn.Module):
    """
    Gating mechanism to route tokens to top-k experts.
    """
    def __init__(self, config):
        super().__init__()
        self.num_experts = config.n_routed_experts
        self.top_k = config.num_experts_per_tok
        self.w_gating = nn.Linear(config.hidden_size, self.num_experts)

    def forward(self, hidden_states):
        # Flatten the input for per-token processing
        hidden_states = hidden_states.view(-1, hidden_states.size(-1))
        logits = self.w_gating(hidden_states)
        scores = F.softmax(logits, dim=-1)

        # Select top-k experts for each token
        topk_weight, topk_idx = torch.topk(scores, self.top_k, dim=-1)
        topk_weight = topk_weight / topk_weight.sum(dim=-1, keepdim=True)

        # Compute auxiliary loss for load balancing
        me = scores.mean(dim=0)
        ce = (scores ** 2).mean(dim=0)
        aux_loss = (self.num_experts * ce.sum() - me.sum()) / self.num_experts

        return topk_idx, topk_weight, aux_loss

class DeepseekMoE(nn.Module):
    """
    Mixture of Experts (MoE) module with optional shared experts.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_tok = config.num_experts_per_tok
        self.experts = nn.ModuleList(
            [DeepseekMLP(config, intermediate_size=config.moe_intermediate_size)
             for _ in range(config.n_routed_experts)]
        )
        self.gate = MoeGate(config)

        if config.n_shared_experts is not None:
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
            self.shared_experts = DeepseekMLP(config=config, intermediate_size=intermediate_size)

    def forward(self, hidden_states):
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        flat_topk_idx = topk_idx.view(-1)

        if self.training:
            # Repeat hidden_states for each expert
            hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
            y = torch.empty_like(hidden_states)
            for i, expert in enumerate(self.experts):
                idx = flat_topk_idx == i
                if idx.any():
                    y[idx] = expert(hidden_states[idx])
            y = (y.view(topk_weight.shape + (-1,)) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.view(*orig_shape)
            y = AddAuxiliaryLoss.apply(y, aux_loss)
        else:
            # Inference mode
            y = torch.zeros_like(hidden_states)
            for i, expert in enumerate(self.experts):
                idx = flat_topk_idx == i
                if idx.any():
                    y[idx] = expert(hidden_states[idx])
            y = (y.view(topk_weight.shape + (-1,)) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.view(*orig_shape)
            if self.config.n_shared_experts is not None:
                y = y + self.shared_experts(identity)
        return y

In [None]:
class DeepseekMoE(nn.Module):
    """
    A mixed expert module containing shared experts.
    """
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_tok = config.num_experts_per_tok
        self.experts = nn.ModuleList(
            [DeepseekMLP(config, intermediate_size=config.moe_intermediate_size)
             for i in range(config.n_routed_experts)]
        )
        self.gate = MoeGate(config)

        if config.n_shared_experts is not None:
            intermediate_size = config.moe_intermediate_size * config.n_shared_experts
            self.shared_experts = DeepseekMLP(config=config, intermediate_size=intermediate_size)

    def forward(self, hidden_states):
        identity = hidden_states
        orig_shape = hidden_states.shape
        topk_idx, topk_weight, aux_loss = self.gate(hidden_states)
        hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
        flat_topk_idx = topk_idx.view(-1)

        if self.training:
            hidden_states = hidden_states.repeat_interleave(self.num_experts_per_tok, dim=0)
            y = torch.empty_like(hidden_states)
            for i, expert in enumerate(self.experts):
                y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
            y = (y.view(topk_weight.shape + (-1,)) * topk_weight.unsqueeze(-1)).sum(dim=1)
            y = y.view(*orig_shape)
            y = AddAuxiliaryLoss.apply(y, aux_loss)
        else:
            y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
            if self.config.n_shared_experts is not None:
                y = y + self.shared_experts(identity)
        return y

In [None]:
# class Config:
#     def __init__(self):
#         self.hidden_size = 768
#         self.moe_intermediate_size = 3072
#         self.n_routed_experts = 4
#         self.n_shared_experts = 1  # Can be None if no shared experts
#         self.num_experts_per_tok = 2  # Top-k experts

# # Instantiate config and MoE module
# config = Config()
# moe_layer = DeepseekMoE(config)

# # Simulate input: batch of sequences
# batch_size = 2
# seq_length = 5
# hidden_size = config.hidden_size
# input_tensor = torch.randn(batch_size, seq_length, hidden_size)

# # Forward pass
# output = moe_layer(input_tensor)

# print("Input shape:", input_tensor.shape)
# print("Output shape:", output.shape)

In [None]:
# experiment

In [None]:
import torch
import time

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Mock Config for DeepseekMoE
class Config:
    def __init__(self, hidden_size, intermediate_size, n_routed_experts, n_shared_experts, num_experts_per_tok):
        self.hidden_size = hidden_size
        self.moe_intermediate_size = intermediate_size
        self.n_routed_experts = n_routed_experts
        self.n_shared_experts = n_shared_experts
        self.num_experts_per_tok = num_experts_per_tok

# Input dimensions
BATCH_SIZE = 16
SEQ_LENGTH = 128
HIDDEN_SIZE = 768
INTERMEDIATE_SIZE = 3072
EXPERT_NUM = 4
TOP_K = 2

# Input tensor
input_tensor = torch.randn(BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE).to(device)


In [None]:
def benchmark(model, input_tensor, n_iters=2):
    model.eval()
    model.to(device)
    torch.cuda.empty_cache()
    with torch.no_grad():
        # Warmup
        for _ in range(10):
            _ = model(input_tensor)

        # Measure
        start = time.time()
        for _ in range(n_iters):
            _ = model(input_tensor)
        # torch.cuda.synchronize()
        end = time.time()

    avg_time = (end - start) / n_iters
    return avg_time


In [None]:
# Initialize MOELayer
moe_layer = MOELayer(HIDDEN_SIZE, INTERMEDIATE_SIZE, EXPERT_NUM, TOP_K)

# Initialize DeepseekMoE
config = Config(HIDDEN_SIZE, INTERMEDIATE_SIZE, n_routed_experts=EXPERT_NUM, n_shared_experts=1, num_experts_per_tok=TOP_K)
deepseek_moe = DeepseekMoE(config)

# Benchmark
moe_time = benchmark(moe_layer, input_tensor)
deepseek_time = benchmark(deepseek_moe, input_tensor)

print(f"MOELayer avg inference time: {moe_time:.6f}s")
print(f"DeepseekMoE avg inference time: {deepseek_time:.6f}s")
