In [25]:
import torch
import torch.nn as nn


In [32]:
a = torch.randn([2,3])
print(a)
print(a.shape)

tensor([[ 1.6756, -0.3473,  0.1252],
        [-1.2120,  0.1400, -0.1038]])
torch.Size([2, 3])


In [33]:
b = torch.randn([2,3])
print(b)


tensor([[-1.6690,  1.1527,  1.1440],
        [-0.6459, -0.2247,  2.2357]])


In [36]:
c = torch.cat([a, b], dim = 0)
print(c)
print(c.shape)

tensor([[ 1.6756, -0.3473,  0.1252],
        [-1.2120,  0.1400, -0.1038],
        [-1.6690,  1.1527,  1.1440],
        [-0.6459, -0.2247,  2.2357]])
torch.Size([4, 3])


In [None]:
class basicExpert(nn.Module):
  def __init__(self, feature_in, feature_out):
    super().__init__()

    self.expert = nn.Linear(feature_in, feature_out)

  def forward(self, x):
    output = self.expert(x)
    return output


In [52]:
class basicMOE(nn.Module):
  def __init__(self, feature_in, feature_out, num_experts):
    super().__init__()

    self.experts = nn.ModuleList(
        [
            nn.Linear(feature_in, feature_out) for _ in range(num_experts)
        ]
    )

    self.gate = nn.Linear(feature_in, num_experts)

  def forward(self, x):
    # x.size = (batch, feature_in)
    print(f'shape of x: {x.shape}')
    expert_weight = self.gate(x) # weight.shape = (batch, num_experts)
    expert_weight = expert_weight.unsqueeze(1) # weight.shape = (batch, 1, num_experts)
    print(f'shape of weight: {expert_weight.shape}')

    expert_output = [
        expert(x).unsqueeze(1) for expert in self.experts
    ]
    # for i, eo in enumerate(expert_output):
    #   print(f'shape of {i}-th expert output: {eo.size()}')

    expert_output = torch.cat(expert_output, dim = 1) # output.shape = (batch, num_experts, feature_out)
    print(f'shape of output: {expert_output.shape}')

    output = expert_weight @ expert_output
    # print(f'shape of output: {output.shape}')
    return output

def testMOE(x, feature_in, feature_out, num_experts):
  x = torch.randn([5,2])
  moe = basicMOE(feature_in, feature_out, number_experts)
  output = moe(x)
  print(f'test shape: {output.shape}')

batch = 5
feature_in = 2
feature_out = 4
number_experts = 3
x = torch.randn([batch, feature_in])
testMOE(x, feature_in, feature_out, number_experts)



shape of x: torch.Size([5, 2])
shape of weight: torch.Size([5, 1, 3])
shape of output: torch.Size([5, 3, 4])
test shape: torch.Size([5, 1, 4])
