In [1]:
from moe import MoE
import torch
from torch import nn
from torch.optim import Adam

In [2]:
def train(x, y, model, loss_fn, optim):
    # model returns the prediction and the loss that encourages all experts to have equal importance and load 
    y_hat, aux_loss = model(x.float())
    # calculate prediction loss
    loss = loss_fn(y_hat, y)
    # combine losses
    total_loss = loss + aux_loss
    optim.zero_grad()
    total_loss.backward()
    optim.step()

    print("Training results - loss: {:.2f}, aux_loss: {:.3f}".format(loss.item(), aux_loss.item()))
    return model

In [3]:
def eval(x, y, model, loss_fn):
    model.eval()
    # model returns the prediction and the loss that encourages all experts to have equal importance and load
    y_hat, aux_loss = model(x.float(), train=False)
    loss = loss_fn(y_hat, y)
    total_loss = loss + aux_loss
    print("Evaluation results - loss: {:.2f}, aux_loss: {:.3f}".format(loss.item(), aux_loss.item()))

In [4]:
def dummy_data(batch_size, input_size, num_classes):
    # dummy input
    x = torch.rand(batch_size, input_size)

    # dummy target
    y = torch.randint(num_classes, (batch_size, 1)).squeeze(1)
    return x, y

In [5]:
# arguments
input_size = 1000
num_classes = 20
num_experts = 10
hidden_size = 64
batch_size = 5
k = 4

# instantiate the MoE layer
model = MoE(input_size, num_classes, num_experts, hidden_size, k=k, noisy_gating=True)

loss_fn = nn.NLLLoss()
optim = Adam(model.parameters(), lr=1e-3)

x, y = dummy_data(batch_size, input_size, num_classes)

In [11]:
# train
for _ in range(20):
    model = train(x, y, model, loss_fn, optim)

Training results - loss: 0.01, aux_loss: 0.042
Training results - loss: 0.01, aux_loss: 0.045
Training results - loss: 0.01, aux_loss: 0.044
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.044
Training results - loss: 0.01, aux_loss: 0.045
Training results - loss: 0.01, aux_loss: 0.042
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.045
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.044
Training results - loss: 0.01, aux_loss: 0.044
Training results - loss: 0.01, aux_loss: 0.042
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.041
Training results - loss: 0.01, aux_loss: 0.044
Training results - loss: 0.01, aux_loss: 0.043
Training results - loss: 0.01, aux_loss: 0.043


In [51]:
# evaluate
x, y = dummy_data(batch_size, input_size, num_classes)
eval(x, y, model, loss_fn)

Evaluation results - loss: 0.24, aux_loss: 0.038
