In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:

#basic_expert
class BasicExpert(nn.Module):
    def __init__(self, hidden_dim, dropout):
        super().__init__()
        
        mid_dim = hidden_dim * 8 // 3
        self.up = nn.Linear(hidden_dim, mid_dim)
        self.down = nn.Linear(mid_dim, hidden_dim)
        self.gated = nn.Linear(hidden_dim, mid_dim)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        
        out = self.dropout(
            self.down(
               F.silu(self.up(x)) * self.gated(x)
            )
        )
        return out

In [None]:
#basic_moe
class BasicMoE(nn.Module):
    def __init__(self, num_experts, hidden_dim, dropout):
        super().__init__()
        self.experts = nn.ModuleList(
            [BasicExpert(hidden_dim, dropout)  for _ in range(num_experts)]
        )
        self.gate = nn.Linear(hidden_dim, num_experts)
        
    def forward(self, x):
        expert_weights = self.gate(x)
        expert_out_list = [
            expert(x).unsqueeze(1) for expert in self.experts
        ]
        #batch, num_experts, hidden_dim
        expert_out = torch.cat(expert_out_list, dim=1)
        #batch, 1, num_experts
        expert_weights.unsqueeze(dim=1)
        out = expert_weights @ expert_out
        
        return out.squeeze()

In [21]:
def test_basic_moe():
    x = torch.rand(2, 4)

    basic_moe = BasicMoE(2, 4, 0.5)
    out = basic_moe(x)
    print(out)
test_basic_moe()

tensor([[[-0.0152,  0.0000, -0.0787,  0.0000],
         [ 0.1217,  0.0000, -0.2513,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0128],
         [ 0.0000,  0.0000,  0.0000, -0.1028]]], grad_fn=<SqueezeBackward0>)
