MoE
Summary of Mixture of Expert models:

Are pretrained much faster vs. dense models
Have faster inference compared to a model with the same number of parameters
Require high VRAM as all experts are loaded in memory
Face many challenges in fine-tuning
In the context of transformer models, a MoE consists of two main elements:

Sparse MoE layers are used instead of dense feed-forward network (FFN) layers. MoE layers have a certain number of “experts” (e.g. 8), where each expert is a neural network. In practice, the experts are FFNs, but they can also be more complex networks or even a MoE itself, leading to hierarchical MoEs!

A gate network or router, that determines which tokens are sent to which expert. For example, in the image below, the token “More” is sent to the second expert, and the token "Parameters” is sent to the first network. As we’ll explore later, we can send a token to more than one expert. How to route a token to an expert is one of the big decisions when working with MoEs - the router is composed of learned parameters and is pretrained at the same time as the rest of the network.

In [1]:
import torch
from torch import nn
import numpy as np

In [2]:
VOCAB_SIZE = 1000
SEQ_LEN = 200
EMBED_DIM = 128
N_HEADS = 8
N_EXPERT_FFNS = 6

In [3]:
from collections import OrderedDict


In [4]:
# Implementing Switch Transformer

class MySwitchTransformer(nn.Module):
  def __init__(self, vocab_size, embed_dim, seq_len, nheads, n_experts):
    super().__init__()
    self.embedding_layer = nn.Embedding(vocab_size, embed_dim)
    self.positional_encodings = nn.Embedding(seq_len, embed_dim)
    self.positional_inputs = torch.from_numpy(np.arange(seq_len))
    self.multi_head_self_attention = nn.MultiheadAttention(embed_dim=embed_dim, num_heads=nheads, batch_first=True)
    self.layer_norm_1 = nn.LayerNorm((embed_dim))
    # Router needs to select one of the k experts we have. Each token can go to a different expert FFN.
    self.router = nn.Linear(embed_dim, n_experts)
    # The list of Expert FFNs.
    # Note that wrapping this python list inside of `nn.ModuleList` is crucial
    # otherwise these FFN parameters don't show up in model.parameters() - which
    # means that backprop won't update those weights.
    self.expert_ffns = nn.ModuleList([
        nn.Sequential(
          OrderedDict([
          (f'%d-ffn-l1'.format(i), nn.Linear(embed_dim, 2048)),
          (f'%d-ffn-l2'.format(i), nn.Linear(2048, embed_dim))]))
        for i in range(n_experts)])
    self.layer_norm_2 = nn.LayerNorm((embed_dim))

  def forward(self, x):
    x = self.embedding_layer(x)
    x = x + self.positional_encodings(self.positional_inputs)
    print("x.shape: ", x.shape)
    old_x = x
    x, _ = self.multi_head_self_attention(query=x, key=x, value=x)
    x = self.layer_norm_1(x + old_x)
    print("x.shape: ", x.shape)
    h = self.router(x)
    print("h.shape: ", h.shape)
    # Take argmax of this to decide the index of which expert to use.
    ffn_layer_idx = torch.argmax(h, dim=-1)
    print("\nffn_layer_idx.shape: ", ffn_layer_idx.shape)

    # Initialize result of MOE to zeros. Then populate batchwise & tokenwise.
    result_of_moe = torch.zeros(x.shape)
    for batch_idx in range(x.shape[0]):
      for token_idx in range(x.shape[1]):
        selected_ffn_layer = self.expert_ffns[ffn_layer_idx[batch_idx][token_idx]]
        result_of_moe[batch_idx][token_idx] = selected_ffn_layer(x[batch_idx][token_idx])

    print("result_of_moe.shape: ", result_of_moe.shape)
    # Final Add & Norm
    x = self.layer_norm_2(x + result_of_moe)
    return x

In [6]:
x = torch.randint(VOCAB_SIZE, (5, SEQ_LEN))
mst = MySwitchTransformer(VOCAB_SIZE, EMBED_DIM, SEQ_LEN, N_HEADS, N_EXPERT_FFNS)
mst(x).shape

x.shape:  torch.Size([5, 200, 128])
x.shape:  torch.Size([5, 200, 128])
h.shape:  torch.Size([5, 200, 6])

ffn_layer_idx.shape:  torch.Size([5, 200])
result_of_moe.shape:  torch.Size([5, 200, 128])


torch.Size([5, 200, 128])

In [7]:
t = torch.tensor([3, 5, 9, -1, 0])
torch.argmax(t)

tensor(2)

In [8]:
t = torch.tensor([[[3, 5, 9, -1, 0], [3, 5, 9, -1, 10]]])
torch.argmax(t, dim=-1)

tensor([[2, 4]])

In [9]:
mst.parameters()

<generator object Module.parameters at 0x7bb00bcc0ac0>

In [10]:
# Looking at number of parameters

for p in mst.parameters():
  print(p.name, p.shape, p.numel())

None torch.Size([1000, 128]) 128000
None torch.Size([200, 128]) 25600
None torch.Size([384, 128]) 49152
None torch.Size([384]) 384
None torch.Size([128, 128]) 16384
None torch.Size([128]) 128
None torch.Size([128]) 128
None torch.Size([128]) 128
None torch.Size([6, 128]) 768
None torch.Size([6]) 6
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2048
None torch.Size([128, 2048]) 262144
None torch.Size([128]) 128
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2048
None torch.Size([128, 2048]) 262144
None torch.Size([128]) 128
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2048
None torch.Size([128, 2048]) 262144
None torch.Size([128]) 128
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2048
None torch.Size([128, 2048]) 262144
None torch.Size([128]) 128
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2048
None torch.Size([128, 2048]) 262144
None torch.Size([128]) 128
None torch.Size([2048, 128]) 262144
None torch.Size([2048]) 2