In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class MLP(nn.Module):
    """
    Multi-layer Perceptron
    """
    def __init__(self, fc_dims, input_dim, dropout):
        super(MLP, self).__init__()
        fc_layers = []
        for fc_dim in fc_dims:
            fc_layers.append(nn.Linear(input_dim, fc_dim))
            fc_layers.append(nn.ReLU())
            fc_layers.append(nn.Dropout(p=dropout))
            input_dim = fc_dim
        self.fc = nn.Sequential(*fc_layers)

    def forward(self, x):
        return self.fc(x)

In [5]:
mlp = MLP([256, 256], 128, 0.5)
x = torch.randn(64, 128)
output = mlp(x)
print(output.size()) # torch.Size([64, 256])

torch.Size([64, 256])


In [3]:
class MoE(nn.Module):
    """
    Mixture of Export
    """
    def __init__(self, moe_arch, inp_dim, dropout):
        super(MoE, self).__init__()
        export_num, export_arch = moe_arch
        self.export_num = export_num
        self.gate_net = nn.Linear(inp_dim, export_num)
        self.export_net = nn.ModuleList([MLP(export_arch, inp_dim, dropout) for _ in range(export_num)])

    def forward(self, x):
        gate = self.gate_net(x).view(-1, self.export_num)  # (bs, export_num)
        gate = nn.functional.softmax(gate, dim=-1).unsqueeze(dim=1) # (bs, 1, export_num)
        experts = [net(x) for net in self.export_net]
        experts = torch.stack(experts, dim=1)  # (bs, expert_num, emb)
        out = torch.matmul(gate, experts).squeeze(dim=1)
        return out

In [7]:
moe = MoE((4, [256, 256]), 128, 0.5)
x = torch.randn(64, 128)
output = moe(x)
print(output.size()) # torch.Size([64, 256])

torch.Size([64, 256])
