In [1]:
import torch
from torch.nn import Module, ModuleList, Linear, ReLU

## Simple MoE

- gating: (hidden_size, num_experts), need softmac to get the prob
- expert：like mlp (hidden->interm->hidden)
- x(bs, seqlen, hidden_size), expert_weights (bs, seqlen, num_experts), each expert output (bs, seqlen, hidden_size), stack all output (bs, seqlen, num_experts, hidden_size)
- no fancy mechanisms like top-k 

In [2]:
class SimpleGating(Module):
    def __init__(self, hidden_size, num_experts):
        super().__init__()
        self.ln = Linear(hidden_size, num_experts)
    
    def forward(self, x):
        return torch.softmax(self.ln(x), dim=-1)


class Expert(Module):
    def __init__(self, hidden_size, intermediate_size):
        super().__init__()
        self.up_proj = Linear(hidden_size, intermediate_size, bias=False)
        self.down_proj = Linear(intermediate_size, hidden_size, bias=False)
        self.act_fn = ReLU()
    
    def forward(self, x):
        up = self.act_fn(self.up_proj(x))
        return self.down_proj(up)
    
    
class SimpleMoELayer(Module):
    def __init__(self, hidden_size, intermediate_size, num_experts):
        super().__init__()
        self.gating = SimpleGating(hidden_size, num_experts)
        self.experts = ModuleList([
            Expert(hidden_size, intermediate_size)
            for _ in range(num_experts)
        ])
    
    def forward(self, x):
        expert_weights = self.gating(x)
        # print(f"{expert_weights.shape=}, {expert_weights=}")
        x_processed_by_experts = [e(x).unsqueeze(2) for e in self.experts]
        x_processed = torch.cat(x_processed_by_experts, dim=2)
        # print(f"{x_processed.shape=}, {x_processed=}")
        weighted_x = x_processed * expert_weights.unsqueeze(-1)
        return weighted_x.sum(dim=2)
        

        

In [3]:
m = SimpleMoELayer(32, 256, 6)
m(torch.rand(8, 128, 32)).shape

torch.Size([8, 128, 32])