Playing with https://github.com/lucidrains/mixture-of-experts

Imports

In [2]:
import torch
from torch import nn
from mixture_of_experts import MoE


The example from the README

In [4]:
moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
print(out)
print(aux_loss)

tensor([[[-0.2746, -0.1281, -0.0212,  ...,  0.0909,  0.0826, -0.0075],
         [-0.0411, -0.0162, -0.2991,  ...,  0.0493,  0.2491,  0.2779],
         [ 0.2574,  0.1190, -0.1743,  ...,  0.0723,  0.3445,  0.1705],
         ...,
         [ 0.0364,  0.1023,  0.2359,  ..., -0.0697,  0.3408,  0.2871],
         [-0.0080, -0.1538,  0.3131,  ..., -0.0061,  0.2665,  0.1159],
         [-0.1432,  0.1608, -0.2175,  ..., -0.1243, -0.1145, -0.0474]],

        [[-0.1594,  0.0981, -0.0154,  ..., -0.1482, -0.0412,  0.2412],
         [-0.4920,  0.5089,  0.1198,  ...,  0.1961, -0.0150, -0.5014],
         [-0.0025,  0.0338,  0.0950,  ..., -0.1753,  0.0987, -0.0198],
         ...,
         [ 0.5218,  0.0578,  0.1279,  ...,  0.2237,  0.1803, -0.1579],
         [-0.0624, -0.4715,  0.0489,  ...,  0.1466,  0.2376, -0.1085],
         [-0.1073,  0.1775,  0.0738,  ..., -0.0948, -0.1230, -0.1071]],

        [[-0.0401,  0.0699, -0.1740,  ...,  0.0241,  0.0544,  0.2312],
         [-0.5468, -0.0066,  0.1501,  ...,  0