In [353]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [354]:
tokens = 4
d_model = 8
top_k = 2
n_experts = 3

In [355]:
inputs = torch.rand(1, tokens, d_model)
inputs

tensor([[[0.7733, 0.1707, 0.9376, 0.9425, 0.5589, 0.4675, 0.3773, 0.5024],
         [0.0874, 0.6406, 0.7324, 0.9194, 0.3231, 0.4237, 0.4044, 0.1464],
         [0.4526, 0.0482, 0.5630, 0.1766, 0.1429, 0.3552, 0.4647, 0.2547],
         [0.6644, 0.2857, 0.5988, 0.2055, 0.3879, 0.9472, 0.5443, 0.9692]]])

In [356]:
expert_selector_matrix = nn.Linear(d_model, n_experts)
experts_weights = expert_selector_matrix(inputs)

In [357]:
experts_weights

tensor([[[-0.1679, -0.0511, -0.3551],
         [-0.2789, -0.0860, -0.2402],
         [-0.1010, -0.1184,  0.1149],
         [-0.2519, -0.2553, -0.1444]]], grad_fn=<ViewBackward0>)

In [358]:
weights, indices = experts_weights.topk(top_k, dim = -1)

In [359]:
weights, indices

(tensor([[[-0.0511, -0.1679],
          [-0.0860, -0.2402],
          [ 0.1149, -0.1010],
          [-0.1444, -0.2519]]], grad_fn=<TopkBackward0>),
 tensor([[[1, 0],
          [1, 2],
          [2, 0],
          [2, 0]]]))

In [360]:
normalized_weights = F.softmax(weights, dim = -1)
importance = normalized_weights.sum(-2)
print(normalized_weights)
print(importance)

tensor([[[0.5292, 0.4708],
         [0.5385, 0.4615],
         [0.5537, 0.4463],
         [0.5269, 0.4731]]], grad_fn=<SoftmaxBackward0>)
tensor([[2.1482, 1.8518]], grad_fn=<SumBackward1>)


In [361]:
sparse_weights = torch.zeros(experts_weights.size())  # same shape, all zeros
sparse_weights

tensor([[[0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.],
         [0., 0., 0.]]])

In [362]:
class ONE_EXPERT_NEURAL_NETWORK(nn.Module):

  def __init__(self):
    super().__init__()
    self.experts = nn.Sequential(
    nn.Linear(d_model, 2 * d_model),
    nn.ReLU(),
    nn.Linear(2 * d_model, d_model)
    )

  def forward(self, x):
    return self.experts(x)


all_experts = nn.ModuleList([ONE_EXPERT_NEURAL_NETWORK() for _ in range(n_experts)])

In [363]:
# scatter the top-k values into the zero tensor at the right indices
expert_weights = sparse_weights.scatter_(-1, indices, normalized_weights)
expert_weights, indices

(tensor([[[0.4708, 0.5292, 0.0000],
          [0.0000, 0.5385, 0.4615],
          [0.4463, 0.0000, 0.5537],
          [0.4731, 0.0000, 0.5269]]], grad_fn=<ScatterBackward0>),
 tensor([[[1, 0],
          [1, 2],
          [2, 0],
          [2, 0]]]))

In [364]:
inputs

tensor([[[0.7733, 0.1707, 0.9376, 0.9425, 0.5589, 0.4675, 0.3773, 0.5024],
         [0.0874, 0.6406, 0.7324, 0.9194, 0.3231, 0.4237, 0.4044, 0.1464],
         [0.4526, 0.0482, 0.5630, 0.1766, 0.1429, 0.3552, 0.4647, 0.2547],
         [0.6644, 0.2857, 0.5988, 0.2055, 0.3879, 0.9472, 0.5443, 0.9692]]])

In [365]:
all_tokens = torch.zeros(tokens, d_model)

for i, ids in enumerate(indices[0]):

  first_k = ids[0].item()
  second_k = ids[-1].item()

  input_ids = inputs[0][i].unsqueeze(0)

  first_k_output = all_experts[first_k](input_ids)
  second_k_output = all_experts[second_k](input_ids)

  first_k_output= first_k_output * expert_weights[0][i][first_k].unsqueeze(0).unsqueeze(0)     # [1] * [1, 8]
  second_k_output= second_k_output * expert_weights[0][i][second_k].unsqueeze(0).unsqueeze(0)  # [1] * [1, 8]

  tokens = first_k_output+second_k_output   # [1,8] + [1,8] = [1,8]

  all_tokens[i] = tokens

In [366]:
all_tokens

tensor([[ 0.0534, -0.3903,  0.0649, -0.0102, -0.0659, -0.0460,  0.1174, -0.1802],
        [ 0.0649, -0.2237,  0.1212, -0.1301, -0.0035, -0.1679,  0.1406,  0.1070],
        [-0.0540, -0.1567,  0.0401,  0.1086,  0.0949, -0.0487, -0.0576, -0.0858],
        [-0.0961, -0.1538,  0.0360,  0.0354,  0.0565, -0.0297, -0.1194,  0.0133]],
       grad_fn=<CopySlices>)