In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Routers import SoftRouter
from Experts import BasicExpert

In [2]:
class BasicMOE(nn.Module):
    def __init__(self, in_features, out_features, hidden_dim, num_experts):
        super(BasicMOE, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.num_experts = num_experts
        self.experts = nn.ModuleList([BasicExpert(in_features, out_features) for _ in range(num_experts)])
        self.gate = nn.Linear(in_features, num_experts)
        self.router = SoftRouter(hidden_dim, num_experts, self.gate)

    def forward(self, x):
        batch_size = x.size(0)
        output = torch.zeros(batch_size, self.out_features)
        # batch-wise calculation
        for i in range(batch_size):
            expert_weights = self.router(x[i])
            for j in range(self.num_experts):
                output = output + expert_weights[j] * self.experts[j](x[i])
        return output

In [4]:
in_features = 10
out_features = 5
num_experts = 3
batch_size = 32
hidden_dim = 16
model = BasicMOE(in_features, out_features, hidden_dim, num_experts)
x = torch.randn(batch_size, in_features)
y = model(x)
print(y.shape)

torch.Size([32, 5])
