基础版理解MOE

In [2]:
#基础版 
import torch

import torch.nn as nn
import torch.nn.functional as F

In [10]:
class BasicExpert(nn.Module):
    def __init__(self,feature_in,feature_out):
        super().__init__()
        self.fc = nn.Linear(feature_in,feature_out)
    def forward(self,x):
        return self.fc(x)

In [12]:
class BasicMOE(nn.Module):
    def __init__(self, feature_in, feature_out,num_experts):
        super().__init__()
        self.gate = nn.Linear(feature_in,num_experts)
        # output shape (batch_size, num_export)
        self.experts = nn.ModuleList(
            BasicExpert(
                feature_in,feature_out
                )for _ in range(num_experts)
        )
    def forward(self,x):
        # x shape is(batch, feature_in)
        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x) for expert in self.experts
        ] # 每一个expert输出一个（batch, feature_out)
        
        expert_outputs = [
            expert_out.unsqueeze(1)
            for expert_out in expert_out_list
        ]
        
        
        # expert out 是(b,1,feature_out)
        expert_output = torch.concat(
            expert_outputs,
            dim=1,
        )

        # expert out 是(b,1,feature_out)
        expert_output = torch.concat(
            expert_outputs,
            dim=1,
        )
        # expert_output shape (b, num_experts,feature_out)

        # expert_weights
        expert_weights = F.softmax(expert_weights, dim=1)
        # expert_weights shape(b, num_experts)

        # (b ,1, num_experts)
        expert_weights  =expert_weights.unsqueeze(1)
        # (batch,1,feature_out)  希望的输出
        output = expert_weights @ expert_output
        return output.squeeze(1)

def test_basic_moe():
    x = torch.rand(4,512)
    basic_moe = BasicMOE(512,128,4)
    output = basic_moe(x)
    print(output.shape)



test_basic_moe()

torch.Size([4, 128])


SparseMoE  和上述基本的比起来，MOE选择topK个专家，然后对这个TopK的专家输出进行加权求和，并且把输入样本变成了大模型中真实的输入Shape,(batch,seq_len,hidden_dim)

In [38]:
class MOEConfig:
    def __init__(
            self,
            hidden_dim,
            expert_number,
            top_k,
            shared_experts_number=2
            ):
        self.hidden_dim = hidden_dim
        self.expert_number = expert_number
        self.top_k = top_k
        self.shared_experts_number = shared_experts_number


class MOERouter(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.gate = nn.Linear(config.hidden_dim, config.expert_number)

        # 但是后面只会选 top_k 个专家

        self.expert_number = config.expert_number
        self.top_k = config.top_k

    def forward(self,x):
        
        # 假设 Expert number 是8，top_k 是2
        router_logits = self.gate(x) # (batch * seq_len, expert_number)

        # 计算每一个专家的概率
        router_probs = F.softmax(router_logits, dim=1,dtype=torch.float)


        # topK是可以进行反向传播的  ，下面要求出topk专家的权重和索引
        #这段代码的主要目的是为了实现一种机制，可以动态地为每个输入（或者每个时间步长的输入）
        # 选择最合适的几位专家，并获取这些专家的相关权重和索引，以便进一步处理或计算。
        router_weights, selected_expoerts_indices = torch.topk(
            router_probs,
            self.top_k,
            dim=-1
        )  # router_weights, selected_experts_indices
            #shape 都是(batch* seq_len, top_k)

        # 重新做一次归一化
        router_weights = router_weights / router_weights.sum(
            dim=-1,keepdim=True
        )

        router_weights = router_weights.to(x.dtype)

        expert_mask = F.one_hot(
            selected_expoerts_indices,
            num_classes=self.expert_number,
        ) # (batch * seq_len, top_k, expert_number)

        expert_mask = expert_mask.permute(2,1,0)
        # (expert_number, top_k, batch* seq_len)

        return router_logits, router_weights, selected_expoerts_indices,expert_mask
        # router_logits (batch* seq_len, expert_number)
        # router_weights (batch* seq_len, top_k)
        # router_logits (batch* seq_len, top_k)
        # router_logits (expert_number,top_k, batch * seq_len)





class SparseMOE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.top_k = config.top_k
        self.hidden_dim = config.hidden_dim
        self.expert_number = config.expert_number

        # 初始化专家
        self.experts = nn.ModuleList(
            BasicExpert(
                config.hidden_dim,
                config.hidden_dim,
            ) for _ in range(config.expert_number)
        )
        self.router = MOERouter(self.config)
    def forward(self,x):
        # x shape (batch, seq_len, hidden_dim)

        batch_size, seq_len, hidden_dim = x.size()

        # token 维度计算, x reshape (batch * seq_len, hidden_dim)
        hidden_states = x.view(-1, hidden_dim)
        #做相关的专家计算
        router_logits, router_weights, selected_experts_indices,expert_mask = self.router(
            hidden_states
        )

        # expert_mask shape (expert_number, top_k, batch* seq_len)
        # 最终肯定是 (batch*seq_len, hidden_dim)
        final_hidden_states = torch.zeros(
            (batch_size*seq_len, hidden_dim),
             dtype=hidden_states.dtype,
             device=hidden_states.device
            
        )

        # 遍历每一个专家
        # 把选中这个专家的token的hidden_states 加到 final_hidden_states 中
        # export = 0 可能是有100个token被选中 
        # token 的总数是 batch * seq_len

        for expert_idx in range(self.expert_number):
            expert_layer = self.experts[expert_idx]

            current_expert_mask = expert_mask[expert_idx]

            router_weights_idx, top_x = torch.where(current_expert_mask)
            # idx 是0 or 1  假设TopK 是2
            # 表示这个toekn是作为当前专家的 top1 还是 top2

            # top_x 是token 在 batch * seq_len 中的位置索引、
            # 例如对于 batch_size=2, seq_len=4 的输入
            # top_X 的直范围是0-7, 表示在展评后的8个token中的位置
            # 都是一个一维的直
            # idx 肯定是用来选 Weight
            # top_x 用来选 hidden_states

            # hidden_states # shape 是 (1, batch * seq_len, hidden_dim)
            current_state = hidden_states.unsqueeze(
                0
            )[:,top_x,:].reshape(-1,hidden_dim)

            current_state = expert_layer(current_state)
            # current_state shape (selected_token_number, hidden_dim)
            # 100个 token 选中了
            # router_weights Shape 是 (batch * seq_len, top_k)
            current_token_router_weight = router_weights[top_x,router_weights_idx]
            # 最终的 current_token_router_weight shape 
            # 就变成了 （selected_token_number)
            current_token_router_weight = current_token_router_weight.unsqueeze(-1)
            # 最终的 current_token_router_weight share
            # 就变成了 （selected_token_number,1)

            # (selected_token_number, hidden_dim)
            # (seletted_token_number,1) 这里有广播

            current_hidden_states = current_state * current_token_router_weight


            #把当前专家的输出加到 final_hidden_states 中
            final_hidden_states.index_add_(
                0,
                top_x,
                current_hidden_states.to(hidden_states.dtype)
            )

        # 把 final_hidden_states 还原到原来的shape
        final_hidden_states = final_hidden_states.reshape(batch_size,seq_len,hidden_dim)

        return final_hidden_states, router_logits # shape 是 (b*s, expert_number)

def test_token_level_moe():
    x = torch.rand(2, 4, 16)
    config = MOEConfig(16, 2, 2)
    token_level_moe = SparseMOE(config)
    out = token_level_moe(x)
    print(out[0].shape, out[1].shape)

test_token_level_moe()



torch.Size([2, 4, 16]) torch.Size([8, 2])
