In [1]:
import jax.numpy as jnp
import jax

In [26]:
bsz, seq_len, d = 1, 3, 1
num_experts = 4
top_k = 2
capacity = 5 # just allow all
x = jnp.arange(seq_len).reshape((bsz, seq_len, d))
x_tokens = x.reshape(bsz * seq_len, d)
x_tokens

Array([[0],
       [1],
       [2]], dtype=int32)

In [27]:
# [b, s, e]
router_logits = jnp.array([[0.1, 0.2, 0, 0], [0, 0, 0.1, 0.2], [0, 0.1, 0.2, 0]])
router_logits

Array([[0.1, 0.2, 0. , 0. ],
       [0. , 0. , 0.1, 0.2],
       [0. , 0.1, 0.2, 0. ]], dtype=float32)

In [28]:
expert_values, expert_indices = jax.lax.top_k(router_logits, k=top_k)
expert_indices_flat = expert_indices.reshape(bsz * seq_len * top_k,)
expert_values_flat = expert_values.reshape(bsz * seq_len * top_k,)
expert_indices_flat

Array([1, 0, 3, 2, 2, 1], dtype=int32)

In [29]:
# [N, K, E]
oh_e = jax.nn.one_hot(expert_indices_flat, num_experts)
oh_e

Array([[0., 1., 0., 0.],
       [1., 0., 0., 0.],
       [0., 0., 0., 1.],
       [0., 0., 1., 0.],
       [0., 0., 1., 0.],
       [0., 1., 0., 0.]], dtype=float32)

In [30]:
cumsum_per_experts = jnp.cumsum(oh_e, axis=0)
positions_flat = (oh_e * cumsum_per_experts).sum(axis=1) - 1
positions_flat

Array([0., 0., 0., 0., 1., 1.], dtype=float32)

In [33]:
within_capacity = positions_flat < capacity
within_capacity

Array([ True,  True,  True,  True,  True,  True], dtype=bool)

In [None]:
oh_c = jax.nn.one_hot(positions_flat, capacity)
# [N, E, C]
pair_dispatch = within_capacity[:, None, None] * oh_e[:, :, None] * oh_c [:, None, :]
pair_dispatch

Array([[[0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)

In [37]:
dispatch = pair_dispatch.reshape(bsz * seq_len, top_k, num_experts, capacity)
dispatch

Array([[[[0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]],


       [[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]],


       [[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]]], dtype=float32)

In [38]:
dispatch = dispatch.sum(axis=1)
dispatch

Array([[[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 0., 0., 0.]]], dtype=float32)

In [None]:
# [bsz * seq_len, E, C]
expert_inputs = jnp.einsum('n e c, n d -> e c d', dispatch, x_tokens)
expert_inputs

Array([[[0.],
        [0.],
        [0.],
        [0.],
        [0.]],

       [[0.],
        [2.],
        [0.],
        [0.],
        [0.]],

       [[1.],
        [2.],
        [0.],
        [0.],
        [0.]],

       [[1.],
        [0.],
        [0.],
        [0.],
        [0.]]], dtype=float32)

In [41]:
# Assume identity
expert_outputs = expert_inputs
pair_combine = pair_dispatch * expert_values_flat[:, None, None]
pair_combine

Array([[[0. , 0. , 0. , 0. , 0. ],
        [0.2, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ]],

       [[0.1, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ]],

       [[0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0.2, 0. , 0. , 0. , 0. ]],

       [[0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0.1, 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ]],

       [[0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0.2, 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ]],

       [[0. , 0. , 0. , 0. , 0. ],
        [0. , 0.1, 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ],
        [0. , 0. , 0. , 0. , 0. ]]], dtype=float32)

In [45]:
combine = pair_combine.reshape(bsz * seq_len, top_k, num_experts, capacity).sum(axis=1)
output_tokens = jnp.einsum("n e c, e c d -> n d", combine, expert_outputs)
output_tokens = output_tokens.reshape(bsz, seq_len, d)
output_tokens

Array([[[0. ],
        [0.3],
        [0.6]]], dtype=float32)