In [8]:
import torch
import torch.nn as nn

In [9]:
class basicExpert(nn.Module):
  def __init__(self, feature_in, feature_out):
    super().__init__()

    self.expert = nn.Linear(feature_in, feature_out)

  def forward(self, x):
    output = self.expert(x)
    return output


In [10]:
class basicMOE(nn.Module):
  def __init__(self, feature_in, feature_out, num_experts):
    super().__init__()

    self.experts = nn.ModuleList(
        [
            nn.Linear(feature_in, feature_out) for _ in range(num_experts)
        ]
    )

    self.gate = nn.Linear(feature_in, num_experts)

  def forward(self, x):
    # x.size = (batch, feature_in)
    print(f'shape of x: {x.shape}')
    expert_weight = self.gate(x) # weight.shape = (batch, num_experts)
    expert_weight = expert_weight.unsqueeze(1) # weight.shape = (batch, 1, num_experts)
    print(f'shape of weight: {expert_weight.shape}')

    expert_output = [
        expert(x).unsqueeze(1) for expert in self.experts
    ]

    expert_output = torch.cat(expert_output, dim = 1) # output.shape = (batch, num_experts, feature_out)
    print(f'shape of output: {expert_output.shape}')

    output = expert_weight @ expert_output
    # print(f'shape of output: {output.shape}')
    return output.squeeze()

def testMOE(x, feature_in, feature_out, num_experts):
  x = torch.randn([5,2])
  moe = basicMOE(feature_in, feature_out, number_experts)
  output = moe(x)
  print(f'test shape: {output.shape}')

batch = 5
feature_in = 2
feature_out = 4
number_experts = 3
x = torch.randn([batch, feature_in])
testMOE(x, feature_in, feature_out, number_experts)

shape of x: torch.Size([5, 2])
shape of weight: torch.Size([5, 1, 3])
shape of output: torch.Size([5, 3, 4])
test shape: torch.Size([5, 4])
