# Mixture-of-Experts Implementation
**References**: 
- Outrageously Large Neural Networks: The Sparsely-Gated Mixture-of-Experts Layer (https://arxiv.org/abs/1701.06538)

**Purpose**: The capacity of a neural network to absorb information is limited by its number of parameters. Conditional computation. where parts of the network are active on a per-example basis, has been proposed in theory as a way of dramatically increasing model capacity without a proportional increase in computation. In practice, however, there are signficiant algorithmic and performance challenges.

**Approach**: In this work, Shazeer et al. address those challenges and finally realize the promise of conditional comptuation, achieving greater than 1000x improvements in model capacity with only minor losses in computational efficiency on modern GPU clusters. 

They introduce a Sparsely-Gated Mixture-of-Experts layer (MoE), consisting of up to thousands of feed-forward sub-networks. A trainable gating network determines a sparse combination of these experts to use for each example.

They apply the MoE to the tasks of language modeling and machine translation, where model capacity is critical for absorbing the vast quantities of knowledge available in the training corpora. 

They present model architectures in which a MoE with up to 137 billion parameters is applied convolutionally between stacked LSTM layers. 

**Result**: On large language modeling and machine translation benchmarks, these models achieve significantly better results than state-of-the-art at lower computational cost.

**Notes**:

*Theory*

Conditional computation
- Exploiting scale in both training data and model size has been central to the success of deep learning. however, as both model size and number of training examples increase, this leads to roughly quadratic blow-up in training costs. Unfortunately, the advances in computing power and distributed computation fall short of meeting such demand.

<div style="text-align:center;">
  <img src="2025-08-21_MoE.png" alt="diagram" style="width:50%;">
</div>

- The main problem and motivation is that the ideas are promising in theory, but no work to date has yet demonstrated massive improvements in model capacity, training time, or model quality

Sparsely-Gated Mixture-of-Experts Layer
- Note there are MANY ways to implement the theory of mixture-of-experts. The biggest question engineering and practicality wise was what the best way to implement the theory of mixture-of-experts was. Seeing the different perspectives of introducing the idea of using multiple MoEs with their own gating networks as parts of a deep model was key.

**Questions**

Q1: Aren't there computational challenges in splitting the batch sequences by expert then?
- A1: Yep, refer to DeepSpeed-MoE, Megatron-LM

## Self-Attention Decoder Class
Just starting off with the Self-Attention Decoder Class previously implemented because I think that can be easily updated to a MoE architecture. Specifically, I think that the MoE architecture just has the FFN class as MoE

In [4]:
!pip install torch
!pip install numpy




[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting numpy
  Downloading numpy-2.3.2-cp311-cp311-win_amd64.whl.metadata (60 kB)
     ---------------------------------------- 0.0/60.9 kB ? eta -:--:--
     ------ --------------------------------- 10.2/60.9 kB ? eta -:--:--
     ------------ ------------------------- 20.5/60.9 kB 330.3 kB/s eta 0:00:01
     ------------------------------- ------ 51.2/60.9 kB 440.4 kB/s eta 0:00:01
     -------------------------------------- 60.9/60.9 kB 405.6 kB/s eta 0:00:00
Downloading numpy-2.3.2-cp311-cp311-win_amd64.whl (13.1 MB)
   ---------------------------------------- 0.0/13.1 MB ? eta -:--:--
    --------------------------------------- 0.2/13.1 MB 4.6 MB/s eta 0:00:03
   ----- ---------------------------------- 1.9/13.1 MB 20.4 MB/s eta 0:00:01
   ------------- -------------------------- 4.3/13.1 MB 34.1 MB/s eta 0:00:01
   -------------------- ------------------- 6.6/13.1 MB 42.2 MB/s eta 0:00:01
   ---------------------------- ----------- 9.2/13.1 MB 45.4 MB/s eta 0:00:01
   -------


[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
# Write attention

import torch
import torch.nn as nn
import math

"""
Dimension key:

B = batch_size
N = sequence length
M = memory length (length of sequence being attended to - K/V)
D = model dimension
H = number of heads
K = head dimension
"""
class AttentionBlock(nn.Module):

  def __init__(self, D, H, dropout_p):
    super().__init__()

    self.ln1 = nn.LayerNorm(D) # normalize over the last dimension which is expected to be of size D
    self.attn = SelfAttention(D, H, dropout_p)
    self.ln2 = nn.LayerNorm(D)
    self.ffn = FeedForward(D, dropout_p)

  def forward(self, X):
    X = X + self.attn(self.ln1(X))
    X = X + self.ffn(self.ln2(X))
    return X


class FeedForward(nn.Module):

  def __init__(self, D, dropout_p, expansion=4):
    super().__init__()
    hidden = expansion * D
    self.fc1 = nn.Linear(D, hidden, bias=True) # bias=True to give higher degrees of freedom
    self.act = nn.GELU() # Used in BERT and GPT. Smoother than ReLU
    self.fc2 = nn.Linear(hidden, D, bias=True)
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, X_BND):
    Z1_BNH = self.fc1(X_BND) # (B,N,D) @ (D,H) -> (B,N,H)
    H_BNH = self.act(Z1_BNH) # (B,N,H)
    Z2_BND = self.fc2(H_BNH) # (B,N,H) @ (H,D) -> (B,N,D)
    Y_BND = self.dropout(Z2_BND) # (B,N,D)

    # cache / weights
    # X_BND, Z1_BNH, H_BNH, Z2_BND, dropout_mask


    return Y_BND

  # TODO: implement backprop


class SelfAttention(nn.Module):

  def __init__(self, D, H, dropout_p):
    super().__init__()
    assert D % H == 0, "D must be divisible by H"

    self.Wq = nn.Linear(D, D)
    self.Wk = nn.Linear(D, D)
    self.Wv = nn.Linear(D, D)
    self.Wo = nn.Linear(D, D)
    self.H = H
    self.dropout = nn.Dropout(dropout_p)

  def forward(self, X):
    B, N, D = X.shape
    H = self.H
    K = D // H
    M = N

    Q_BND = self.Wq(X)
    K_BMD = self.Wk(X)
    V_BMD = self.Wv(X)

    # split to heads
    Q_BHNK = Q_BND.reshape(B, N, H, K).permute(0, 2, 1, 3)
    K_BHMK = K_BMD.reshape(B, M, H, K).permute(0, 2, 1, 3)
    V_BHMK = V_BMD.reshape(B, M, H, K).permute(0, 2, 1, 3)

    # causal mask
    mask = torch.tril(torch.ones(N, M, device=X.device, dtype=torch.float32))

    # compute attention scores
    logits_BHNM = Q_BHNK @ K_BHMK.transpose(-2, -1) #torch.einsum("BND,BMD->BNM", Q_BND, K_BMD)
    logits_BHNM = logits_BHNM / math.sqrt(K)
    logits_BHNM = logits_BHNM.masked_fill(mask == 0, float("-inf"))
    weights_BHNM = torch.softmax(logits_BHNM, dim=-1)
    O_BHNK = weights_BHNM @ V_BHMK # torch.einsum("BNM,BMD->BND", weights, V_BMD)
    O_BND = O_BHNK.permute(0, 2, 1, 3).contiguous().reshape(B, N, D)

    Y_BND = self.Wo(O_BND)
    Y_BND = self.dropout(Y_BND) # for regularization of the final representation
    return Y_BND


class Decoder(nn.Module):

  def __init__(self, D, H, dropout_p, vocab_size, num_layers, max_seq_len):
    super().__init__()

    self.tok_embd = nn.Embedding(vocab_size, D)
    self.pos_embd = nn.Embedding(max_seq_len, D) # GPT paper uses learned positional embeddings. Easier to implement as well
    self.dropout = nn.Dropout(dropout_p)

    self.layers = nn.ModuleList([AttentionBlock(D, H, dropout_p) for _ in range(num_layers)])
    self.ln = nn.LayerNorm(D)

    self.lm_head = nn.Linear(D, vocab_size, bias=False)
    self.lm_head.weight = self.tok_embd.weight # weight tying

  def forward(self, input_ids):
    """
    input_ids: (B, N) # input tokens
    returns logits: (B, N, vocab_size)
    """
    B, N = input_ids.shape
    pos = torch.arange(N, device=input_ids.device).unsqueeze(0) # (1, N)

    X = self.tok_embd(input_ids) + self.pos_embd(pos) # (B, N, D)
    X = self.dropout(X)

    for l in self.layers:
      X = l(X)

    X = self.ln(X)
    logits = self.lm_head(X) # (B, N, vocab_size)

    # apply temperature (usually goes in generate() though)
    # logits /= temperature # lower temperature makes higher values go up farther (spikier). higher temperature makes it more diffuse

    # apply top-k (usually goes in generate() though)
    k = 50
    values, _ = torch.topk(logits, k) # (B, N, k)
    min_topk = values[:, :, -1] # (B, N)
    min_topk = min_topk.unsqueeze(-1) # (B, N, 1)
    mask = logits < min_topk # (B, N, vocab_size)
    logits[mask] = -float("inf") # remove everything except top-k tokens

    return logits


# TESTS
def _test_sa():
  B, N, D = 2, 4, 16
  H = 8
  dropout_p = 0.1
  X = torch.randn(B, N, D)
  sa = SelfAttention(D, H, dropout_p)
  Y = sa(X)
  assert Y.shape == (B, N, D)

def _test_ff():
  B, N, D = 2, 4, 16
  H = 8
  dropout_p = 0.1
  X = torch.randn(B, N, D)
  ff = FeedForward(D, dropout_p)
  Y = ff(X)
  assert Y.shape == (B, N, D)

def _test_block():
  B, N, D = 2, 4, 16
  H = 8
  dropout_p = 0.1
  X = torch.randn(B, N, D)
  block = AttentionBlock(D, H, dropout_p)
  Y = block(X)
  assert Y.shape == (B, N, D)

def _test_decoder():
  B, N, D = 2, 4, 16
  H = 8
  dropout_p = 0.1
  vocab_size = 500
  num_layers = 4
  max_seq_len = 128
  input_ids = torch.randint(high=vocab_size, size=(B, N))
  block = Decoder(D, H, dropout_p, vocab_size, num_layers, max_seq_len)
  logits = block(input_ids)
  assert logits.shape == (B, N, vocab_size)


_test_sa()
_test_ff()
_test_block()
_test_decoder()