## 1 - 基础版本的MoE


In [3]:
import torch 
import torch.nn as nn
import torch.nn.functional as F


In [4]:
class Expert(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.expert = nn.Linear(input_dim, output_dim)
        # 普通线性层
        # 输入形状为(batch_size, input_dim)

    def forward(self, x):
        return self.expert(x)

In [27]:
class BasicMoE(nn.Module):
    def __init__ (self, num_experts, input_dim, output_dim):
        super(BasicMoE, self).__init__()
        self.experts = nn.ModuleList(
            [nn.Linear(input_dim, output_dim) for _ in range(num_experts)]
        )

        self.gate = nn.Linear(input_dim, num_experts)

    def forward(self, x):
        # x 的形状是 (batch_size, input_dim)
        expert_weights = self.gate(x)
        # expert_weights 的形状是 (batch_size, num_experts)
        expert_outputs_list = [
            experts(x).unsqueeze(1) for experts in self.experts
            
        ] # expert_outputs_list 是一个列表，列表中的每个元素是一个形状为 (batch_size, 1, output_dim) 的张量
        print(expert_outputs_list[0].shape)
        expert_outputs = torch.cat(expert_outputs_list, dim=1)
        #  expert_outputs 的形状是 (batch_size, num_experts, output_dim)
        print(expert_outputs.shape)

  
        # 所以我们可以使用 softmax 函数将其转换为概率分布
        gate_outputs = F.softmax(expert_weights, dim=1)
        # gate_outputs 的形状是 (batch_size, num_experts)
        # 我们可以使用这些概率来加权每个专家的输出
        weighted_expert_outputs = gate_outputs @ expert_outputs #执行的是严格矩阵乘法
        # weighted_expert_outputs 的形状是 (batch_size, experts ,output_dim)
        outputs = torch.sum(weighted_expert_outputs, dim=1)
        # outputs 的形状是 (batch_size, output_dim) 
        # sum是对第二维度求和将experts维度消除
        # print(weighted_expert_outputs.shape)
        return outputs
def test():
    x = torch.randn(4, 512)
    model = BasicMoE(4, 512, 128)
    output = model(x)
    print(output.shape)

test()

torch.Size([4, 1, 128])
torch.Size([4, 4, 128])
torch.Size([4, 128])


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Expert(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Expert, self).__init__()
        self.expert = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.expert(x)

class BasicMoe(nn.Module):
    def __init__(self, input_dim, output_dim, num_experts):
        super(BasicMoe, self).__init__()
        self.experts = nn.ModuleList(
            [Expert(input_dim, output_dim) for _ in range(num_experts)]
        )  

        self.gate = nn.Linear(input_dim, num_experts)


    def forward(self, x):
        experts_scroes = self.gate(x)
        # x 的形状是 (batch_size, input_dim)
        experts_weights = F.softmax(experts_scroes, dim=1)
        # experts_weights 的形状是 (batch_size, input_dim)

        expert_outputs_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ] # expert_outputs_list 是一个列表，列表中的每个元素是一个形状为 (batch_size, 1, output_dim) 的张量

        expert_outputs = torch.cat(expert_outputs_list, dim=1)
        # expert_outputs 的形状是 (batch_size, num_experts, output_dim)

        weighted_expert_outputs = experts_weights @ expert_outputs
        # weighted_expert_outputs 的形状是 (batch_size, num_experts, output_dim)

        outputs = torch.sum(weighted_expert_outputs, dim=1)

        return outputs
    
def test():
    x = torch.randn(4, 512)
    model = BasicMoe(512, 128, 4)
    output = model(x)
    print(output.shape)

test()

        




torch.Size([4, 128])
