<a href="https://colab.research.google.com/github/stavco9/moe-llm-presentation/blob/main/presentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Seminar in Large Language Models and Information Theory (3968)

**Master’s Program · Computer Science**  
_Reichman University_

---

**Presenters** · Noam Delbari & Stav Cohen  
**Supervisor** · Dr. Alon Kipnis

# Mixture-of-Experts Language Models – High-Level Overview

Modern large language models (LLMs) can exceed **100 B parameters**, yet at inference time they only use a small slice of that capacity. **Mixture-of-Experts (MoE)** layers make this efficiency explicit: instead of applying one huge feed-forward network to every token, we keep a *team* of specialised experts and let a lightweight **router** choose the two or three best ones on-the-fly. The result is a model that *behaves* like a colossal dense transformer but *runs* like a much smaller network.

This presentation will guide you from the basics to cutting-edge MoE tricks:

1. **Transformers in 5 minutes** – the encoder–decoder “conveyor belt” and why self-attention scales as \(O(N^2)\).  
2. **Why Mixture-of-Experts?** – intuition, historical roots, and how sparsity slashes compute while boosting capacity.  
3. **Switch Transformer** – the first practical sparse MoE layer and its load-balancing loss.  
4. **State of the art** – Mixtral, DeepSeek, and PEER: top-2 gating, sub-experts, and retrieval-aware routing.  
5. **Hands-on demo** – training a toy MoE in PyTorch and visualising how tokens find their experts.  
6. **Take-aways** – when to reach for MoE, training pitfalls, and open research directions.

By the end you’ll understand *how* MoE squeezes more intelligence out of the same GPU budget and *why* it is becoming a key ingredient of next-generation LLMs.


## 1 Background:

### 1.1 Transformers

<figure>
  <img src="switch_plots/transformer.png" width="300"/>
  <figcaption><b>Figure 1</b> – Transformer architecture.</figcaption>
</figure>


#### Scaled Dot-Product Attention

The **Scaled Dot-Product Attention** module computes attention scores between queries $Q$ and keys $K$, scales them by $\tfrac{1}{\sqrt{d_k}}$ to stabilize gradients, applies an optional mask (e.g. to ignore padding or future tokens), and then uses softmax-normalized scores to weight the values $V$. This core operation lets each position in the sequence gather information from all other positions.


In [None]:
# %%
class ScaledDotProductAttention(nn.Module):
    """
    Compute scaled dot-product attention:
      Attention(Q, K, V) = softmax(Q K^T / sqrt(d_k)) V
    """
    def __init__(self, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, Q, K, V, mask=None):
        # Q, K, V: (batch, n_heads, seq_len, d_k)
        d_k = Q.size(-1)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        context = torch.matmul(attn, V)
        return context, attn


#### Multi-Head Attention

The **Multi-Head Attention** module runs several scaled-dot products in parallel (“heads”), allowing the model to jointly attend to information from different representation subspaces. It projects the input into multiple query/key/value spaces, applies attention in each head, concatenates the results, then applies a final linear projection plus residual & layer-norm.


In [None]:
# %%
class MultiHeadAttention(nn.Module):
    """
    Multi-head wrapper around ScaledDotProductAttention.
    """
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        self.d_k = d_model // n_heads
        self.n_heads = n_heads

        # Linear projections
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)

        self.attn = ScaledDotProductAttention(dropout)
        self.fc_out = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        B, T, D = x.size()
        # project & reshape to (batch, n_heads, seq_len, d_k)
        Q = self.W_Q(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_K(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        context, attn_weights = self.attn(Q, K, V, mask=mask)

        # concat & project back
        context = context.transpose(1, 2).contiguous().view(B, T, D)
        out = self.dropout(self.fc_out(context))
        out = self.layer_norm(out + x)
        return out, attn_weights


#### Position-wise Feed-Forward Network

The **Position-wise Feed-Forward** block applies a two-layer MLP to each position independently:
$$
\text{FFN}(x) = \mathrm{Dropout}\bigl(W_2\,\mathrm{ReLU}(W_1 x + b_1) + b_2\bigr)
$$
A residual connection and layer normalization follow to stabilize training.


In [None]:
# %%
class PositionwiseFeedForward(nn.Module):
    """
    Two-layer feed-forward network applied per position.
    """
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff, d_model),
            nn.Dropout(dropout),
        )
        self.layer_norm = nn.LayerNorm(d_model)

    def forward(self, x):
        out = self.net(x)
        return self.layer_norm(out + x)


#### Transformer Encoder Layer

A single encoder layer consists of:
1. **Self-Attention** (Multi-Head Attention over the input sequence)  
2. **Feed-Forward** (Position-wise MLP)  
Each sublayer uses its own residual connection and layer normalization.


In [None]:
# %%
class TransformerEncoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn       = PositionwiseFeedForward(d_model, d_ff, dropout)

    def forward(self, x, src_mask=None):
        x, _ = self.self_attn(x, mask=src_mask)
        x     = self.ffn(x)
        return x


#### Transformer Decoder Layer

Each decoder layer has three sublayers:
1. **Masked Self-Attention** (prevents future-token attention)  
2. **Cross-Attention** (attends to encoder output)  
3. **Feed-Forward**  
Residual connections and layer norms are applied around each sublayer.


In [None]:
# %%
class TransformerDecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn  = MultiHeadAttention(d_model, n_heads, dropout)
        self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ffn        = PositionwiseFeedForward(d_model, d_ff, dropout)

    def forward(self, x, enc_out, tgt_mask=None, memory_mask=None):
        x, _ = self.self_attn(x, mask=tgt_mask)
        x, _ = self.cross_attn(x=x, mask=memory_mask, x_kv=enc_out)
        x    = self.ffn(x)
        return x


#### Full Transformer Model

The top-level `Transformer` class orchestrates:
1. **Embedding + Positional Encoding** of source and target tokens  
2. **Encoder Stack** of $N$ encoder layers  
3. **Decoder Stack** of $M$ decoder layers  
4. **Final Linear** projection to vocabulary logits for next-token prediction


In [None]:
# %%
class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, tgt_vocab_size,
                 d_model=512, n_heads=8, d_ff=2048, 
                 num_enc_layers=6, num_dec_layers=6, dropout=0.1):
        super().__init__()
        # embeddings + positional encoding
        self.src_tok_emb = nn.Embedding(src_vocab_size, d_model)
        self.tgt_tok_emb = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout)

        # encoder & decoder stacks
        self.enc_layers = nn.ModuleList([
            TransformerEncoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_enc_layers)
        ])
        self.dec_layers = nn.ModuleList([
            TransformerDecoderLayer(d_model, n_heads, d_ff, dropout)
            for _ in range(num_dec_layers)
        ])

        # output projection
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def encode(self, src, src_mask=None):
        x = self.positional_encoding(self.src_tok_emb(src))
        for layer in self.enc_layers:
            x = layer(x, src_mask=src_mask)
        return x

    def decode(self, tgt, memory, tgt_mask=None, memory_mask=None):
        x = self.positional_encoding(self.tgt_tok_emb(tgt))
        for layer in self.dec_layers:
            x = layer(x, enc_out=memory, tgt_mask=tgt_mask, memory_mask=memory_mask)
        return x

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, memory_mask=None):
        memory = self.encode(src, src_mask)
        out    = self.decode(tgt, memory, tgt_mask, memory_mask)
        return self.fc_out(out)


---

### 1.2 Mixture-of-Experts (MoE)

Mixure-of-Experts (MoE) is a machine learning technique where multiple expert networks (learners) are used to divide a problem space into dedicated regions. Rather than a big one network that makes all the required tasks, it's splitted to many dedicated experts which each one of them makes a dedicated task.

#### A Brief History of MoEs
According to the following paper: https://huggingface.co/blog/moe <br>
* The roots of MoEs come from the 1991 paper Adaptive Mixture of Local Experts.
* The main idea was to have a supervised procedure for a system composed of separate networks, each handling a different subset of the training cases.
* Each separate network, or expert, specializes in a different region of the input space.

Between 2010-2015, two different research areas contributed to later MoE advancement (The years where Deep Neural Networks started being used):
* Experts as components: In the traditional MoE setup, the whole system includes a gating network and multiple experts
* MoEs as the whole model have been explored in ***SVMs***, ***Gaussian Processes***, and other methods
* The work by Eigen, Ranzato, and Ilya from Google and NYU explored MoEs as components of deeper networks: This allows having MoEs as layers in a multilayer network, making it possible for the model to be both large and efficient simultaneously.
* Conditional Computation: Traditional networks process all input data through every layer. In this period, Yoshua Bengio, a well known deep learning researcher from McGill university in Montréal researched approaches to dynamically activate or deactivate components based on the input token.
* These works led to exploring a mixture of experts in the context of NLP.
* Concretely, Shazeer et al. (2017, with “et al.” including Geoffrey Hinton and Jeff Dean, Google’s Chuck Norris) scaled this idea to a 137B LSTM (the de-facto NLP architecture back then, created by Schmidhuber) by introducing sparsity, allowing to keep very fast inference even at high scale.
* This work focused on translation but faced many challenges, such as high communication costs and training instabilities.

#### Some terms
1. ***Expert***
A small and specialized model which got trained for a particular area. It can be a neural network, decision tree, or other algorithm. In our case experts are small neural networks.

3. ***A Mixture of Experts (MoE) model***
A model that combines the predictions of multiple experts to solve complex problems.
- Each expert is trained on a specific domain or task, and a "gating network" or "router" selects the most appropriate experts for a given input.

3. ***"gating network" / "router"***
A component (a tiny linear layer) in the large model that determines which experts should be activated for a particular input. It's also trained along with the experts

#### What do we achieve from that ?
The main benefit of the MoE architecture is that it enables large-scale models, even those comprising many billions of parameters, to reduce computation costs during pre-training and achieve faster performance during evaluation time.

#### How does it work ?
It reaches it's major benefit by selectively activating only the specific experts needed for a given task, rather than activating the entire neural network for every task.

#### An illustration of a standard MoE network

<figure>
<img src="https://github.com/stavco9/moe-llm-presentation/blob/main/moe.png?raw=1" alt="Peer_Layer" width="400" height="400">
<img src="https://github.com/stavco9/moe-llm-presentation/blob/main/01_moe_layer.png?raw=1" alt="Peer_Layer" width="500" height="600">
</figure>
In the left image, the red experts are those who are active

---

####  Mixture-of-Experts implementation in popular LLMs:

##### Mixtral 8×7B (Mistral)
**Overview** – Mixtral augments a 7-billion-parameter base transformer with sparse MoE feed-forward blocks to reach **46.7 B total parameters** while keeping the *per-token* cost of roughly a 13 B dense model. It achieves this by letting each token consult only two of eight experts in every layer.

* **MoE layout** – Every transformer block swaps its dense FFN for **8 identical SwiGLU experts**  
  (hidden size = 14 336).  
* **Router** – A learned linear gate $W_g \in \mathbb{R}^{d_{\text{model}}\times 8}$.

  1. For each token $x$: logits $l = x W_g$
  2. Keep the **top-2** logits, mask the rest to $-\infty$
  3. Weights $w = \text{softmax}(l_{\text{top-2}})$

* **Capacity constraint**

  $$
    \text{capacity} \;=\; \Bigl\lceil \alpha \,\frac{T \cdot K}{N}\Bigr\rceil,
    \qquad \alpha \approx 1.25
  $$

  where $T$=tokens in the batch, $K=2$ (top-k), $N=8$ experts.  
  Overflow tokens are “zero-routed”.
* **Auxiliary Switch loss** equalises both (i) probability mass and (ii) actual token counts per expert.
* **Compute cost** – Only $2/8 = 25\%$ of FFN compute runs per token, so inference costs ≈ 12.9 B-parameter dense model while training exploits the full **46.7 B** capacity.


##### DeepSeek-MoE 16 B

**Overview** – DeepSeek 16 B pushes specialisation further by *slicing* each FFN into many narrow **sub-experts**: 64 are sparsely routed and 2 are always on. This fine-grained design lets the gate compose highly specific mixtures without raising compute beyond that of a standard 16 B dense transformer.

* **Layer composition** – Each MoE layer holds **64 routed sub-experts** + **2 always-on shared sub-experts**.
* **Sub-expert** – Same two-layer SwiGLU as a full FFN expert but at **¼ hidden width**, so each is **4× cheaper**.
* **Router (routed experts only)**

  1. Score the 64 routed sub-experts with a linear gate.  
  2. Keep the **top-6**, apply softmax → weights.  
  3. Aggregate their weighted outputs.

* **Final token output** = weighted sum of **6 routed sub-experts** **+** deterministic outputs of **2 shared sub-experts** ⇒ every token ultimately sees **8 sub-experts**.
* **Capacity & balance** – Same ceiling formula as Mixtral (shared experts are exempt); auxiliary loss encourages even traffic across the 64 routed sub-experts.
* **Why sub-experts?** Splitting the FFN into many narrow experts gives the router a richer palette of highly specialised functions while staying within the original FLOPs budget.  
  Ablations show that removing just a few high-traffic sub-experts hurts perplexity far more than in classic MoE setups, signalling **stronger expert specialisation** and lower redundancy.

---

## 2 Switch transformer

### 2.1 Introduction:

Large language models (LLMs) have achieved striking gains by growing from millions
to billions of parameters—yet *dense* scaling makes **every** parameter participate
in **every** forward-pass.  
Compute (FLOPs), memory traffic, and wall-time therefore grow linearly with model
size, and the trillion-parameter frontier strains even the largest clusters.

**Mixture-of-Experts (MoE)** layers offer a different path: *conditional
computation*.  
A lightweight *gating network* selects **one** or **few** specialized “experts”
(MLPs) per token, so only a *subset* of parameters is active each step.
Switch Transformers (Fedus *et&nbsp;al.*, 2022) refine this idea to make it practical
at unprecedented scale.

#### Why previous MoE attempts struggled  

| Bottleneck in earlier MoE work | What it means | Switch Transformer’s remedy |
| :-- | :-- | :-- |
| **Unstable top-k routing** (k > 1) | When every token is split across *k* experts (k = 2,4…), the soft mixture may starve some experts of gradient signal → divergence in very deep models. | **k = 1 “switch” routing**: each token is sent to exactly one expert chosen by `argmax` over gate logits. This keeps gradients intact and halves the routing tensor size. |
| **Cross-device communication** | Prior systems sliced *one* expert across many GPUs/TPUs → every step required an All-to-All of hidden states. | **Expert-parallel layout**: each device *owns* one whole expert. Tokens are grouped by destination expert, transferred *once*, processed locally, then regrouped—minimising traffic. |
| **Token imbalance (hot-spot experts)** | Popular tokens (e.g., punctuation) can overload a few experts, leaving others idle and blowing up memory. | **Auxiliary load-balancing loss**: penalises correlation between (i) fraction of tokens routed to expert *i* and (ii) gate probability mass on expert *i*.  <br>$$ L_{\text{aux}} = \alpha\,N \sum_{i=1}^{N} f_i\,P_i \quad\text{(Eq.\;4)} $$ |

*Definitions*  
* **Expert** – an independent feed-forward sub-network (here, a position-wise MLP).  
* **Gating network** – a tiny linear layer that produces *N* logits per token.  
* **Routing** – assigning tokens to experts based on gate probabilities.  
* **FLOPs / token** – floating-point operations needed for one token’s forward-pass.

---

#### Research Questions

This paper explicitly addresses several critical research questions:

1. **Efficiency vs. Model Capacity:**  
   *Does increasing model size through sparse routing (more experts) consistently improve language model performance (perplexity) without proportional computational cost?*

2. **Training Stability:**  
   *Can models at trillion-parameter scale be stably trained using lower-precision arithmetic (like bfloat16)?*

3. **Downstream Generalization:**  
   *Do improvements obtained during language modeling pre-training generalize to downstream tasks (e.g., QA, summarization)?*

4. **Model Compression and Deployment:**  
   *Is it feasible to distill sparse models into smaller, dense models while preserving performance improvements?*

---


### 2.2 Model architecture overview  

A Switch Transformer layer is identical to a standard Transformer block **except** that the
dense Feed-Forward Network (FFN) is replaced by a **Switch-FFN** (sparse Mixture-of-Experts).  
Figure 2 from the paper illustrates the encoder block with two example tokens (${x_1,x_2}$).  
Only the shaded *Switch-FFN* (light-blue) differs from a dense model.  :contentReference[oaicite:1]{index=1}

<figure>
  <img src="switch_plots/figure2_switch_arch.png" width="600"/>
  <figcaption><b>Figure 2</b> – Switch-FFN inside the Transformer block (Fedus et al., 2022).</figcaption>
</figure>


#### 2.2.1 Switch Transformer Module

The `SwitchTransformer` builds on the classic Transformer encoder by:

1. **Embedding + Positional Encoding**  
   - Converts token IDs into \(d_{\text{model}}\)-dimensional vectors and adds sinusoidal (or learned) timing signals.

2. **Layer Stack**  
   - Clones a prototype `SwitchTransformerLayer` \(L\) times, enabling sparse, per-token expert routing at each depth.  
   - Gathers routing statistics (`counts`, `load`, `dropped`, `max_p`) for every layer to support the auxiliary load-balancing loss.

3. **Final Normalization**  
   - Applies a top-level `LayerNorm` to stabilize outputs across the full depth.

This architecture maintains the Transformer’s depth and self-attention benefits while drastically increasing capacity via \(N\) experts, all at constant computational cost per token.  


In [None]:
# %%
class SwitchTransformer(Module):
    """
    Stacks multiple SwitchTransformerLayer’s into a full encoder.
    """
    def __init__(self, *,
                 src_vocab_size: int, d_model: int, n_heads: int,
                 d_ff: int, n_layers: int,
                 capacity_factor: float, drop_tokens: bool,
                 is_scale_prob: bool, n_experts: int,
                 max_seq_len: int, dropout_prob: float):
        super().__init__()
        # Token embedding + positional encoding
        self.token_emb = nn.Embedding(src_vocab_size, d_model)
        self.pos_enc   = PositionalEncoding(d_model, dropout_prob, max_seq_len)

        # Stack of SwitchTransformerLayers
        layer = SwitchTransformerLayer(
            d_model=d_model, n_heads=n_heads, d_ff=d_ff,
            capacity_factor=capacity_factor, drop_tokens=drop_tokens,
            is_scale_prob=is_scale_prob, n_experts=n_experts,
            dropout_prob=dropout_prob
        )
        self.layers    = clone_module_list(layer, n_layers)
        self.norm      = nn.LayerNorm(d_model)

    def forward(self, src: Tensor, src_mask: Tensor = None):
        # 1) Embed & add positional encoding
        x = self.token_emb(src) * math.sqrt(self.token_emb.embedding_dim)
        x = self.pos_enc(x)

        # 2) Apply each SwitchTransformerLayer
        all_stats = []
        for layer in self.layers:
            x, counts, load, dropped, max_p = layer(x, mask=src_mask)
            all_stats.append({
                'counts': counts,
                'load': load,
                'dropped': dropped,
                'max_p': max_p
            })

        # 3) Final layer normalization
        x = self.norm(x)
        return x, all_stats


#### 2.2.2 Switch TransformerLayer Module

This layer mirrors a standard Transformer block but replaces the fixed-position FFN with a sparse “SwitchFeedForward” mixture-of-experts:

1. **Self-Attention Sub-Layer**  
   - Applies pre-norm, multi-head self-attention (`MultiHeadAttention`) to allow tokens to mix contextual information.  
   - Adds a residual connection and dropout.

2. **Mixture-of-Experts Feed-Forward Sub-Layer**  
   - Applies pre-norm, then routes each token through exactly one of $N$ expert FFNs via `SwitchFeedForward`.  
   - Collects routing statistics (`counts`, `load`, `dropped`, `max_p`) for auxiliary losses.  
   - Adds a residual connection and dropout.

Each sublayer uses its own `LayerNorm` and the same `dropout_prob` for stability and regularization.  


In [None]:
# %%
class SwitchTransformerLayer(Module):
    """
    A single Transformer block using SwitchFeedForward instead of a fixed FFN.
    """
    def __init__(self, *, d_model: int, n_heads: int, d_ff: int,
                 capacity_factor: float, drop_tokens: bool,
                 is_scale_prob: bool, n_experts: int, dropout_prob: float):
        super().__init__()
        # 1) Multi-head self-attention sublayer
        self.attn       = MultiHeadAttention(d_model, n_heads, dropout_prob)
        # 2) Mixture-of-experts feed-forward sublayer
        base_ffn        = PositionwiseFeedForward(d_model, d_ff, dropout_prob)
        self.feed_forward = SwitchFeedForward(
            capacity_factor=capacity_factor,
            drop_tokens=drop_tokens,
            is_scale_prob=is_scale_prob,
            n_experts=n_experts,
            expert=base_ffn,
            d_model=d_model
        )
        # Layer norms & dropout
        self.norm1      = nn.LayerNorm(d_model)
        self.norm2      = nn.LayerNorm(d_model)
        self.dropout    = nn.Dropout(dropout_prob)

    def forward(self, x: Tensor, mask: Tensor = None):
        # 1) Pre-norm + self-attention
        z1, _      = self.attn(self.norm1(x), mask=mask)
        x          = x + self.dropout(z1)

        # 2) Pre-norm + switch-FFN
        z2, counts, load, dropped, max_p = self.feed_forward(self.norm2(x))
        x          = x + self.dropout(z2)

        return x, counts, load, dropped, max_p


#### 2.2.3 Switch FeedForward Module
Initializes the mixture-of-experts block by:
- Replicating the base feed-forward network into `n_experts` separate experts.
- Defining a lightweight router (`self.switch`) that maps each token embedding of size `d_model` to unnormalized logits for each expert.
- Preparing to convert those logits into a probability distribution with `Softmax`.


In [None]:
# %%
class SwitchFeedForward(Module):
    def __init__(self, *, capacity_factor: float, drop_tokens: bool,
                 is_scale_prob: bool, n_experts: int,
                 expert: FeedForward, d_model: int):
        super().__init__()
        self.capacity_factor = capacity_factor
        self.drop_tokens     = drop_tokens
        self.is_scale_prob   = is_scale_prob
        self.n_experts       = n_experts
        # 1) Create N independent copies of the base FFN
        self.experts = clone_module_list(expert, n_experts)
        # 2) Router: a linear map to produce logits for each expert
        self.switch  = nn.Linear(d_model, n_experts)
        self.softmax = nn.Softmax(dim=-1)


##### Forward funciton

This function implements the forward pass of the `SwitchFeedForward` mixture-of-experts layer. It accepts an input tensor `x` of shape `[seq_len, batch, d_model]`, where:

- `seq_len` is the sequence length  
- `batch` is the batch size  
- `d_model` is the model’s hidden dimension  

Inside, each token is separated out for per-token routing across experts; the method ultimately returns both the transformed output and auxiliary routing statistics.


##### Flatten Inputs

Reshapes input from shape `[seq_len, batch, d_model]` into `[T, d_model]` with `T = seq_len × batch`, so that each token is routed independently.


In [None]:
# %%
    def forward(self, x: Tensor):
        seq_len, batch, d_model = x.shape
        # Collapse sequence and batch dims for per-token routing
        flat_x = x.view(-1, d_model)

##### Compute Raw Routing Logits

Each token embedding `x` is linearly projected to produce unnormalized scores:
$$
h_i(x) = \bigl[W_{\text{switch}}\,x + b\bigr]_i,\quad i=1,\dots,N.
$$


In [None]:
# %%
        logits = self.switch(flat_x)

##### Normalize into Routing Probabilities

Applies softmax to the logits to yield a distribution over experts:
$$
p_i(x) = \frac{\exp\bigl(h_i(x)\bigr)}{\sum_{j=1}^N \exp\bigl(h_j(x)\bigr)},\quad \sum_i p_i(x)=1.
$$


 %%
        route_prob = self.softmax(logits)


##### Pick Top-1 Expert

Routes each token to the expert with maximum probability:
$$
i^*(x) = \arg\max_i\,p_i(x),\qquad p_{i^*}(x)=\max_i p_i(x).
$$


In [None]:
# %%
        prob_max, routes = torch.max(route_prob, dim=-1)

##### Compute Expert Capacity

Determines the per-expert capacity:
$$
C = \Bigl\lfloor \alpha\times\frac{T}{N}\Bigr\rfloor,
$$
where $T$ is total tokens, $N$ experts, and $\alpha$ is `capacity_factor`.


In [None]:
# %%
        capacity = int(self.capacity_factor * flat_x.size(0) / self.n_experts)

##### Enforce Capacity & Optionally Drop

1. **Counts:** number of tokens routed to each expert.  
2. If `drop_tokens=True`, any expert receiving more than $C$ tokens will randomly drop the excess; dropped tokens bypass later.


In [None]:
# %%
        counts  = torch.tensor([len((routes==i).nonzero()) for i in range(self.n_experts)])
        dropped = []
        if self.drop_tokens:
            for i in range(self.n_experts):
                idxs = (routes==i).nonzero(as_tuple=True)[0]
                if idxs.numel() > capacity:
                    perm = idxs[torch.randperm(idxs.numel())]
                    dropped.append(perm[capacity:])
                    idxs = perm[:capacity]

##### Dispatch Tokens to Experts
Applies each expert $E_i$ only to its assigned subset $\mathcal{I}_i$:
$$
y_i = E_i\bigl(x_{\mathcal{I}_i}\bigr).
$$


In [None]:
# %%
        final_out = flat_x.new_zeros(flat_x.shape)
        for i in range(self.n_experts):
            idxs = (routes==i).nonzero(as_tuple=True)[0]
            if idxs.numel() > 0:
                out_i = self.experts[i](flat_x[idxs])
                final_out[idxs] = out_i


##### Handle Dropped Tokens

Tokens dropped due to capacity limits are passed through unchanged:
$$
y_{\mathrm{bypass}}(x) = x.
$$


In [None]:
# %%
        if dropped:
            all_dropped = torch.cat(dropped)
            final_out[all_dropped] = flat_x[all_dropped]


#####  Scale by Gate Value

Modulates each token’s expert output by its gate value:
$$
\tilde y(x) = p_{i^*}(x)\,y_{i^*}(x).
$$


In [None]:
# %%
        if self.is_scale_prob:
            final_out = final_out * prob_max.unsqueeze(-1)
        else:
            final_out = final_out * (prob_max/prob_max.detach()).unsqueeze(-1)


#####  Restore Shape & Return

Reshapes the result back to `[seq_len, batch, d_model]` and returns auxiliary statistics for load-balancing loss.


In [None]:
# %%
        output = final_out.view(seq_len, batch, d_model)
        return output, counts, route_prob.sum(0), sum(len(d) for d in dropped), prob_max

####  2.2.4 **Capacity and Hard Limit**

<figure>
  <img src="switch_plots/figure3_capacity_dynamics.png" width="650"/>
  <figcaption><b>Figure 3 — Token-routing dynamics under two capacity factors.</b></figcaption>
</figure>

**What the diagram shows**

*Left panel (capacity 1.0)*  
* 12 tokens must be routed across three experts (rows).  
* Each expert’s capacity is exactly the ideal load \(B/N = 4\) tokens.  
* Because many tokens happen to share the same favourite expert, that expert
  overflows — the dotted red boxes show the **dropped** tokens that will bypass
  this layer.

*Right panel (capacity 1.5)*  
* The capacity per expert is now \(4 \times 1.5 = 6\) tokens, giving 50 % slack.  
* All 12 tokens fit; no red overflow boxes, but a few **empty white slots** indicate
  wasted compute/communication.

**Take-away** A small slack margin (the paper standardises on 1.25) almost
eliminates overflows yet keeps extra FLOPs and bandwidth modest.  Capacity is
therefore the runtime *circuit-breaker* that guarantees fixed memory and latency
even when the router’s token-to-expert distribution is skewed.

**Why does capacity matter?**

Routing decisions can lead to unbalanced token assignments, with some experts overloaded while others remain underutilized. Without mitigation, overloaded experts would cause out-of-memory errors, degraded performance, and unpredictable latency.

Switch Transformer addresses this via a **hard capacity limit** on each expert:

$$
\text{capacity per expert} = \left\lceil \frac{B}{N} \times \text{capacity\_factor} \right\rceil
$$

- $ B $: Total number of tokens in the micro-batch
- $ N $: Number of experts
- $\text{capacity\_factor}$: Typically set to $1.25$

---



####  2.2.5 **Load-Balancing Losses**

Early attempts at Mixture-of-Experts models encountered a critical problem known as **expert collapse**, where a small number of experts dominated token assignments, starving others of training signal. Switch Transformers introduced two auxiliary losses to prevent collapse and encourage balanced expert usage:

##### 1. **Auxiliary Load-Balancing Loss**

$$
L_{\text{aux}} = \alpha N \sum_{i=1}^{N} f_i P_i,\quad
f_i = \frac{\text{tokens routed to expert } i}{\text{batch size}},\quad
P_i = \text{average gate probability for expert } i
$$

Encourages even token distribution by penalizing high correlation between an expert’s frequency of selection and router confidence.

You can think of $\mathbf f^\top \mathbf P$ as the **expected confidence** the router has in the experts it actually chooses.  
- If the router always picks highly‐confident experts (i.e.\ $\mathbf P$ is very “peaked”), then $\mathbf f^\top \mathbf P$ will be high.  
- By **penalizing** $\sum_{i=1}^N f_i\,P_i = \mathbf f^\top \mathbf P$, we force the router to **spread its confidence** more evenly across all experts.  
- As a result, the actual assignments $\mathbf f$ become more balanced, improving expert utilization and reducing hardware bottlenecks.


##### 2. **Z-Loss (Logit Regularization)**


$$
L_z = \beta \sum_{\text{tokens}}\sum_{i}(h_i(token_i) - \text{stop\_grad}(token_i))^2
$$

**Intuition:**  
- Encourages router logits $h_i$ to remain small in magnitude, preventing overly confident selections (extremely large or small logits can destabilize softmax distributions).
- Maintains stable gradient flow and well-conditioned router softmax.


1. **Definition of** $\mathrm{stop\_grad}(x)$:  
   $$
     \mathrm{stop\_grad}(x)\;=\;x
     \quad\text{(forward value)}\,,\qquad
     \frac{\partial\,\mathrm{stop\_grad}(x)}{\partial x}=0
     \quad\text{(blocks gradients)}
   $$

2. **Corrected Z-Loss** for a target logit $h_i(x)$:  
   $$
     \mathcal{L}_Z
     \;=\;\beta\;\bigl(h_i(x)\;-\;\mathrm{stop\_grad}\bigl[\log\!\sum_{j=1}^N e^{h_j(x)}\bigr]\bigr)^{2}.
   $$

3. **Relation to log-softmax**:  
   $$
     h_i(x)-\log\!\sum_{j=1}^N e^{h_j(x)}
     =\log\!\bigl(e^{h_i(x)}\bigr)\;-\;\log\!\sum_{j=1}^N e^{h_j(x)}
     =\log p_i(x),
   $$  
   so this difference directly measures the **log-probability** of expert $i$ under the softmax.

---


### 2.3 Experiments & Results
#### Baseline Models Used for Comparison  

The authors selected **four families** of comparison models, each serving a distinct purpose.

####   Dense T5 Series – 
**T5** stands for **“Text-to-Text Transfer Transformer.”**  
Released by Google in late 2019, it introduced a simple yet powerful idea: *cast every NLP task—translation, summarisation, QA, sentiment, …—as feeding one piece of text in and predicting another piece of text out.*  
This unification plus large-scale span-corruption pre-training on the **C4** web corpus produced a strong encoder–decoder baseline.

**Why the authors picked T5:**

1. **Like-for-like objective and codebase** – eliminates spurious gains from task formulation or optimiser tweaks.  
2. **Widely reported benchmarks** – GLUE / SuperGLUE / SQuAD scores for T5 are standard yard-sticks, so improvements are easy to contextualise.  
3. **Scales up smoothly** – letting the paper test whether sparse scaling beats a dense model simply made *bigger* (e.g., T5-Large or T5-XXL).

Thus, throughout the experiments T5 provides a **clean, well-understood dense baseline** against which the efficiency and quality of the Switch (sparse) approach can be judged.


| Model | Params | FLOPs / token | Why chosen |
|-------|--------|--------------|-----------|
| **T5-Base** | 223 M | 1 × (reference) | Same size class as 2-expert Switch; establishes a dense baseline that already fits on a single TPU/GPU. |
| **T5-Large** | 739 M | 3.5 × Switch-Base | Represents a “scale-up dense” strategy within the same architecture and codebase. |
| **T5-XXL** | 11 B | 6.3 T per seq | State-of-the-art dense model at publication time; tests whether sparse can outpace *very* large dense models under the same cluster budget. |

*Rationale* – All T5 variants share the *exact* training objective, tokenizer, and optimizer code. That isolates the effect of **conditional vs dense compute**.

#####   MoE Transformer (Top-2 Routing) – 

| Variant | Experts | Routing | Why chosen |
|---------|---------|---------|-----------|
| **MoE-Transformer (Shazeer et al.)** | 128 | top-2 | Prevailing MoE design before Switch; higher FLOPs because two experts fire per token. |

*Rationale* – Validates whether **single-expert** routing is genuinely more efficient/stable than the established top-k approach.

##### Why these baselines are fair  

* **Same tokenizer and data** → eliminates corpus effects.  
* **Same optimizer hyper-params** (where feasible) → isolates architectural difference.  
* **FLOP-matched pairs** (Switch-Base vs T5-Base, Switch-Large vs T5-Large) → asks:  
  > *“Given the **same compute budget**, which architecture learns faster / better?”*  
* **Higher-FLOP dense models** (T5-Large, -XXL) → test the critique  
  > *“Just spend more FLOPs on dense; why bother with sparsity?”*  
* **Legacy MoE top-2** → ensures the improvement isn’t merely “MoE vs dense” but due to the **Switch simplification**.

Using this spectrum of baselines, the paper demonstrates that Switch Transformers outperform:

1. **Compute-matched dense** models (fair efficiency test),  
2. **Heavier dense** models (efficiency-vs-quality frontier), and  
3. **Previous sparse** architectures (methodological advance).

This comprehensive baseline suite strengthens the claim that **conditional compute via single-expert routing is a superior scaling path**.

####  Scaling Properties 
The paper’s **Scaling Properties** section asks:  
> *“If we keep FLOPs / token roughly constant, how far can we improve quality by adding more experts (i.e., more parameters)?”*  

To answer, the authors run three tightly-controlled experiments.

#####  Step-Basis Scaling

<figure>
  <img src="switch_plots/figure_4_step_basis_scaling.png" width="600"/>
  <figcaption><b>Figure: Scaling Switch Transformer (perplexity vs. training steps and wall-clock).</b></figcaption>
</figure>



| | |
|---|---|
| **Setup** | Pre-train **Switch-Base** models with 2 → 256 experts (223 M → 14.7 B parameters) **for a fixed 100 k steps**. FLOPs/token stay constant because each token still activates one expert. |
| **Purpose** | Is parameter-only scaling (via experts) a free win when compute is fixed? |
| **Main results** | Perplexity **drops monotonically** as experts double. The 64-expert model matches T5-Base quality **7.5 × sooner** in steps. |
| **Interpretation** | Extra capacity (parameters) is effectively used even though compute is unchanged. Sparse routing is therefore a *new scaling axis* orthogonal to FLOPs. |

#####  Time-Basis Scaling

<figure>
  <img src="switch_plots/figure_5_time_basis_scaling.png" width="300"/>
  <figcaption><b>Figure: Scaling Switch Transformer (perplexity vs. training time and wall-clock).</b></figcaption>
</figure>


| | |
|---|---|
| **Setup** | Measure **wall-clock minutes** to reach target perplexities on identical TPU pods. Same model family as above. |
| **Purpose** | Extra experts add routing overhead (softmax, All-to-All). Do they erase the step advantage? |
| **Main results** | Sparse models *still* win: 64-expert Switch reaches T5-Base quality in **≈ 140 min vs 350 min** (≈ 2.5× faster). |
| **Conclusion** | Routing + communication overhead is small relative to the gains from parameter scaling. Sparse models give **real-time savings**, not just step savings. |

---


#####  Sparse vs. “Just Make the Dense Model Bigger”

<figure>
  <img src="switch_plots/figure_6_parameters_basis_scaling.png" width="600"/>
  <figcaption><b>Figure: Sample Efficiency Switch Transformer VS T5 variants.</b></figcaption>
</figure>


| | |
|---|---|
| **Setup** | Compare **Switch-Base (64 e)** against **T5-Large** which spends **3.5× more FLOPs/token** than Switch-Base. |
| **Purpose** | Critics could argue “dense scaling already works—just spend more FLOPs.” |
| **Main results** | Despite T5-Large’s heavier compute, Switch-Base is **2.5× faster** to the same perplexity and *still* ends lower. |
| **Conclusion** | Conditional computation **dominates** naive dense scaling in the speed/quality trade-off. You can’t buy the same improvement just by burning more FLOPs per token. |

---


####  Down-Stream Experiments

The authors performed five focused studies to verify that the pre-training gains of Switch Transformers **transfer** to real tasks and to understand how best to fine-tune, regularise, and deploy very large sparse models.


#####  Fine-Tuning Benchmark Suite

<figure>
  <img src="switch_plots/table_5_fine_tuning_results.png" width="400"/>
  <figcaption><b>Table: Fine-tuning results. T5 baselines VS Switch models across
  a diverse set of natural language test.</b></figcaption>
</figure>


- **GLUE —** A bundle of nine sentence‐level and sentence-pair evaluations (sentiment, paraphrase, natural-language inference, etc.) that together gauge broad language understanding in English.

- **SuperGLUE —** A harder successor to GLUE featuring multi-sentence reasoning tasks such as BoolQ, ReCoRD, and WSC; designed to test deeper compositional reasoning and commonsense.

- **SQuAD v1.1 —** Reading-comprehension question answering on Wikipedia passages; checks the model’s ability to locate and extract exact answer spans from context.

- **XSum —** Single-sentence abstractive news summarisation; evaluates whether the system can condense an article into one concise, fluent sentence while preserving key facts.

- **Winogrande —** Commonsense pronoun-resolution puzzles; measures the model’s grasp of implicit world knowledge needed to resolve ambiguous references.

- **TriviaQA (closed-book) —** Open-domain factoid QA answered without external documents; probes how much factual knowledge is stored internally in the model’s parameters.

- **ANLI —** Adversarial natural-language inference collected via model-in-the-loop annotation; assesses robustness to deliberately tricky NLI examples.

- **ARC (Easy & Challenge) —** Multiple-choice grade-school science-exam questions; tests logical reasoning over short factual statements rather than surface pattern matching.



| | |
|---|---|
| **Setup** | Fine-tuned FLOP-matched pairs: **Switch-Base (7 B) vs T5-Base (0.2 B)** and **Switch-Large (26 B) vs T5-Large (0.7 B)** on GLUE, SuperGLUE, SQuAD, XSum, Winogrande, TriviaQA, ANLI, ARC. Dropout: 0.1 non-expert, 0.4 expert; 100 k training steps. |
| **Purpose** | Confirm that sparse pre-training advantages appear on diverse NLU, QA, and summarisation tasks. |
| **Main results** | • **+4.4 pp SuperGLUE** (Base) and +2 pp (Large).<br>• Closed-book TriviaQA +6 pp.<br>• Gains on Winogrande, XSum; mixed on ARC. |
| **Conclusion** | Pre-training gains **transfer broadly**; sparse capacity especially helps knowledge-heavy tasks. |

####  Distillation for Deployment

##### How knowledge distillation works in general  

1. **Run the teacher** – pass each input through the (large) teacher model  
   * obtain either the **logits** $z^{(T)}$ or the softened probabilities  
     $\sigma(z^{(T)}/T)$ at temperature $T>1$.

2. **Run the student** – pass the *same* input through the small model  
   to produce logits $z^{(S)}$.

3. **Blend two losses**

$$
\mathcal{L}
  = \lambda \; \underbrace{\mathrm{KL}\!\bigl[\sigma(z^{(T)}/T)\;\|\;\sigma(z^{(S)}/T)\bigr]}_{\text{soft-target loss}}
  + (1-\lambda)\; \underbrace{\mathrm{CE}\!\bigl[y,\;\sigma(z^{(S)})\bigr]}_{\text{hard-target loss}}
$$

* $y$ – ground-truth labels  
* $\sigma$ – softmax  
* $\lambda$ – weight that trades off “mimic the teacher” vs. “fit the labels”  
* Back-propagate **only through the student**; the teacher is frozen.

4. **Optimise the student** until validation perplexity or task metric plateaus.  
   The student thus learns a compressed approximation of the teacher’s behaviour
   while still respecting the original task labels.


##### Switch-Transformer distillation  

* **Teacher models** – sparse Switch-Base variants  
  * 3.8 B , 7.4 B , 14.7 B parameters (64 or 128 experts)  
  * already pre-trained on C4; one run fine-tuned on SuperGLUE.

* **Student model** – dense T5-Base, 223 M parameters.

* **Weight initialisation** – copy all **non-expert weights** (embeddings, attention,
  residual projections) from teacher to student; randomly init the rest.

* **Soft / hard mix** –  
  $\lambda = 0.25$ for the soft-target KL term, $1-\lambda = 0.75$ for the standard cross-entropy.

* **Temperature** – $T = 2.0$ to soften the teacher’s probability distribution.

* **Training data & length** –  
  * C4 span-corruption for language modelling distillation (150 k steps).  
  * SuperGLUE labelled set for task-specific distillation (same step budget).

* **Optimiser & schedule** – identical Adafactor settings used in pre-training; no extra tricks.

* **Outcome** – student keeps **≈30 %** of the teacher’s quality gain while shrinking
  model size by **95–99 %**, demonstrating a deployable path for Switch-Transformer
  knowledge.


<figure>
  <img src="switch_plots/table_6_distillition.png" width="400"/>
  <figcaption><b>Table: Fine-tuning results. T5 baselines VS Switch models across
  a diverse set of natural language test.</b></figcaption>
</figure>


| | |
|---|---|
| **What was done** | Distilled sparse **Switch-Base 3.8 B / 7.4 B / 14.7 B** teachers into a 223 M dense T5-Base student. Tricks: initialise student with teacher’s non-expert weights + 0.25 × soft-loss + 0.75 × hard-loss. |
| **Why** | Massive trillion-parameter models are hard to deploy; distillation offers a lighter alternative. |
| **Main results** | • **≈ 30 % of teacher gain retained** at 95–99 % compression.<br>• Fine-tuned SuperGLUE distillation keeps 30 % gain on a 97 % compressed model. |
| **Interpretation** | Distillation provides a **practical path** from huge sparse teachers to deployable dense students while preserving a meaningful slice of quality improvement. |


---



### 2.4 Switch Extensions: MixLoRA & MoE-Mamba

#### MixLoRA: LoRA-Based Sparse MoE (Li et al., 2024)

**At a glance**  
A parameter-efficient sparse MoE that injects LoRA adapters as experts into a frozen backbone for multi-task fine-tuning on modest GPUs.

**What is LoRA?**  
LoRA (Low-Rank Adaptation) freezes a large weight matrix $W\in\mathbb{R}^{d\times d}$ and learns two much smaller matrices $A\in\mathbb{R}^{d\times r}$ and $B\in\mathbb{R}^{r\times d}$ with $r\ll d$, such that
$$
W' = W + B\,A.
$$

**Switch-inspired element**  
- **Auxiliary load-balancing loss** to encourage uniform expert utilization.

**Key improvements**  
- **Top-2 gating** for richer expert selection.  
- **Independent attention-layer adapters** for per-layer specialization.  
- **40 % GPU memory reduction** and **30 % lower latency**.  
- **9 % accuracy gain** on multi-task benchmarks.

#### MoE-Mamba: SSM-Infused Sparse MoE (Pióro et al., 2024)

**At a glance**  
A hybrid SSM/MoE model that interleaves Mamba State-Space Model blocks with true Switch-style sparse experts.

**What is an SSM?**  
A State-Space Model evolves a hidden state $h_t$ via  
$$
h_t = A\,h_{t-1} + B\,u_t,\quad
y_t = C\,h_t + D\,u_t,
$$  
where $u_t$ is the input and $y_t$ the output.

**Innovation in Mamba**  
- **Input-dependent parameterization** of $A,B,C,D$ for content-based gating.  
- **Parallel-scan (FFT-style) solver** achieving $O(N\log N)$ inference on long sequences.  
- **Diagonal + low-rank decomposition** for compact, efficient long-range modeling.

**Switch-inspired elements**  
- **Single-expert routing** ($k=1$) with capacity-factor buffering.  
- **Auxiliary load-balancing loss** at every MoE layer.  
- **Low-precision routing** (float32 jitter) for stable training.

**Key improvements**  
- **Sequential interleaving** of SSM and MoE blocks.  
- **Extensive ablations** on expert count, placement, and parameter ratios.  
- **2.35× fewer training steps** to match Mamba perplexity and outperforms Transformer-MoE.

---


### 2.5 Contributions

The paper's main contributions are:

- Introducing **single-expert (top-1) routing**, drastically simplifying routing and reducing computational overhead.
- Presenting clear evidence of performance improvements across diverse NLP benchmarks, validating the model’s practical utility.
- Proposing effective model compression through knowledge distillation, enabling deployment of large model capabilities into significantly smaller models.

Together, these findings validate the viability of conditional computation at unprecedented scale.
