In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

## Sparse MoE

MoE 选择 topK 个专家，对选出的 K 个专家的输出进行加权求和，并把输入样本变成 LLM 中真实的输入 shape -> (batch_size, seq_len, hidden_dim)

In [3]:
class MoEConfig:
    def __init__(
        self,
        hidden_dim,
        expert_num,
        top_k,
        shared_experts_num=2,
    ):
        self.hidden_dim = hidden_dim
        self.expert_num = expert_num
        self.top_k = top_k
        self.shared_experts_num = shared_experts_num

class MoERouter(nn.Module):
    def __init__(self, config):
        self.gate = nn.Linear(config.hidden_dim, config.expert_num)
        # select top k experts
        self.top_k = config.top_k
        self.expert_num = config.expert_num
    
    def forward(self, x):
        router_logits = self.gate(x)  # (batch_size * seq_len, expert_num)
        # 计算每个专家的概率
        router_probs = F.softmax(router_logits, dim=1, dtype=torch.float32)
        # 计算 top k 个专家的输出、索引
        router_weights, selected_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1
        )  # (batch_size * seq_len, top_k)

        # 重新归一化
        router_weights /= router_weights.sum(dim=-1, keepdim=True) 
        router_weights = router_weights.to(x.dtype)

        # 生成 mask
        '''

        '''
        expert_mask = F.one_hot(
            selected_indices,
            num_classes=self.expert_num
        )  # -> (batch_size * seq_len, top_k, expert_num)

        expert_mask = expert_mask.permute(2, 1, 0)
        # -> (expert_num, top_k, batch_size * seq_len)

        return router_logits, router_weights, selected_indices, expert_mask


class SparseMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_dim = config.hidden_dim
        self.expert_num = config.expert_num
        self.top_k = config.top_k

        # initialize experts
        self.experts = nn.ModuleList(
            BasicExpert(config.hidden_dim, config.hidden_dim)
            for _ in range(config.expert_num)
        )

        self.router = MoERouter(config)
    
    def forward(self, x):
        # x: (b, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = x.size() 

        # 对 token 维度计算, x -reshape-> (batch * seq_len, hidden_dim)
        hidden_state = x.view(-1, hidden_dim)

        router_logits, router_weights, selected_indices, expert_mask = self.router(hidden_state)
        # expert_mask -> (expert_num, top_k, batch_size * seq_len)

        # final hidden states -> (batch_size * seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size * seq_len, hidden_dim),
            device=hidden_state.device
        )

        # 遍历每个 expert
        # 把该 expert 的 token 的 hidden_states 加到 final_hidden_states 上
        # token 总数是 batch_size * seq_len
        for expert_idx in range(self.expert_num):
            expert_layer = self.experts[expert_idx]
            # 选择该 expert 的 mask (expert_num, top_k, batch_size * seq_len)
            cur_expert_mask = expert_mask[expert_idx]
            # -> (top_k, batch_size * seq_len)

            idx, top_x = torch.where(cur_expert_mask > 0)
            '''
            idx: 选择的（topK 中）第 i 个 expert；用于选择 weights
            top_x: 是 token 在 batch * seq_len 中的索引 （1 维的值）；选择 hidden_states
            '''

            cur_state = hidden_state.unsqueeze(0)[:, top_x, :].reshape(-1, hidden_dim)
            # unsqueeze(0) -> (1, batch_size * seq_len, hidden_dim)
            cur_state = expert_layer(cur_state)
            cur_token_router_weight = router_weights[top_x, idx]
            # -> (selected_token_num)
            cur_token_router_weight = cur_token_router_weight.unsqueeze(-1)
            # -> (selected_token_num, 1)

            cur_hidden_states = cur_state * cur_token_router_weight
            # -> (selected_token_num, hidden_dim)

            final_hidden_states.index_add_(
                0,
                top_x,
                cur_hidden_states.to(final_hidden_states.dtype)
            )
        
        # 还原到原来的 shape
        final_hidden_states = final_hidden_states.reshape(batch_size, seq_len, hidden_dim)
        return final_hidden_states, router_logits