Playing with https://github.com/lucidrains/mixture-of-experts

Imports

In [1]:
import torch
from torch import nn
from mixture_of_experts import MoE


The example from the README

In [2]:
moe = MoE(
    dim = 512,
    num_experts = 16,               # increase the experts (# parameters) of your model without increasing computation
    hidden_dim = 512 * 4,           # size of hidden dimension in each expert, defaults to 4 * dimension
    activation = nn.LeakyReLU,      # use your preferred activation, will default to GELU
    second_policy_train = 'random', # in top_2 gating, policy for whether to use a second-place expert
    second_policy_eval = 'random',  # all (always) | none (never) | threshold (if gate value > the given threshold) | random (if gate value > threshold * random_uniform(0, 1))
    second_threshold_train = 0.2,
    second_threshold_eval = 0.2,
    capacity_factor_train = 1.25,   # experts have fixed capacity per batch. we need some extra capacity in case gating is not perfectly balanced.
    capacity_factor_eval = 2.,      # capacity_factor_* should be set to a value >=1
    loss_coef = 1e-2                # multiplier on the auxiliary expert balancing auxiliary loss
)

inputs = torch.randn(4, 1024, 512)
out, aux_loss = moe(inputs) # (4, 1024, 512), (1,)
print(out)
print(aux_loss)

tensor([[[-0.0650, -0.2893, -0.1795,  ...,  0.0168, -0.0159, -0.2139],
         [-0.1484,  0.2455, -0.0968,  ...,  0.0163,  0.0279, -0.2175],
         [-0.2172,  0.2235,  0.0852,  ...,  0.0207, -0.0909, -0.0362],
         ...,
         [-0.0885, -0.1649,  0.1277,  ...,  0.4288, -0.1716, -0.0341],
         [-0.3383,  0.1532, -0.0771,  ...,  0.0015,  0.0863, -0.1107],
         [-0.0654, -0.3442, -0.2506,  ...,  0.3595, -0.0489, -0.2234]],

        [[ 0.4771, -0.2995,  0.0559,  ..., -0.0628,  0.2009,  0.2021],
         [ 0.0434, -0.0705, -0.2034,  ...,  0.0257, -0.1011,  0.2663],
         [ 0.0995,  0.2098, -0.0025,  ..., -0.0489,  0.0253,  0.1372],
         ...,
         [-0.3076, -0.1395,  0.0884,  ..., -0.2480, -0.1746, -0.2253],
         [-0.1010, -0.0126, -0.2407,  ..., -0.1442,  0.1594,  0.0124],
         [-0.0528,  0.2720,  0.4362,  ..., -0.0626,  0.4771, -0.0281]],

        [[ 0.2273,  0.3430, -0.2094,  ..., -0.0935,  0.1137,  0.0027],
         [-0.3726,  0.2877, -0.0634,  ..., -0