# Preliminary MoE Routing Demo (Switch Transformer)

This notebook is a small exploratory test to understand **Mixture-of-Experts (MoE) routing** in a real model.

Using **Switch-Base-8**, we inspect **token-level expert routing decisions** at a **single encoder MoE layer**.  
We capture router logits, selected experts, and routing entropy for short inputs, and align them with the tokenizer’s actual subword tokens.

This notebook is **purely observational**:
- no training
- no fine-tuning
- no architectural modification

The goal is simply to verify that expert routing can be intercepted and interpreted before scaling to larger MoE models.


In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model_name = "google/switch-base-8"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_name,
    device_map="auto",
    torch_dtype="auto"
)

model.eval()
print("Switch MoE loaded correctly.")

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/440 [00:00<?, ?it/s]

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/440 [00:00<?, ?it/s]



Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/440 [00:00<?, ?it/s]



Switch MoE loaded correctly.


Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


Loading weights:   0%|          | 0/440 [00:00<?, ?it/s]



Switch MoE loaded correctly.


Exception in thread Thread-auto_conversion:
Traceback (most recent call last):
  File "/usr/lib/python3.12/threading.py", line 1075, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.12/threading.py", line 1012, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 116, in auto_conversion
    raise e
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 95, in auto_conversion
    sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 71, in get_conversion_pr_reference
    spawn_conversion(token, private, model_id)
  File "/usr/local/lib/python3.12/dist-packages/transformers/safetensors_conversion.py", line 48, in spawn_con

List all modules with "router" in the name to find the routing components.

In [2]:
for name, module in model.named_modules():
    if "router" in name.lower():
        print(name, "->", type(module))
print("Switch MoE router modules identified.")

encoder.block.1.layer.1.mlp.router -> <class 'transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router'>
encoder.block.1.layer.1.mlp.router.classifier -> <class 'torch.nn.modules.linear.Linear'>
encoder.block.3.layer.1.mlp.router -> <class 'transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router'>
encoder.block.3.layer.1.mlp.router.classifier -> <class 'torch.nn.modules.linear.Linear'>
encoder.block.5.layer.1.mlp.router -> <class 'transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router'>
encoder.block.5.layer.1.mlp.router.classifier -> <class 'torch.nn.modules.linear.Linear'>
encoder.block.7.layer.1.mlp.router -> <class 'transformers.models.switch_transformers.modeling_switch_transformers.SwitchTransformersTop1Router'>
encoder.block.7.layer.1.mlp.router.classifier -> <class 'torch.nn.modules.linear.Linear'>
encoder.block.9.layer.1.mlp.router -> <class 'transforme

Tokenize a single example and print ids, tokens, and attention mask to see the subword split.

In [3]:
# Minimal step: tokenize a single input and inspect tokens
text = "question: Where is Paris?, context: Paris is the capital of France."

encoding = tokenizer(text, return_tensors="pt")
input_ids = encoding["input_ids"][0].tolist()
attention_mask = encoding["attention_mask"][0].tolist()

print("Input text:")
print(text)
print("\nToken IDs:")
print(input_ids)
print("\nTokens:")
print(tokenizer.convert_ids_to_tokens(input_ids))
print("\nAttention Mask:")
print(attention_mask)

Input text:
question: Where is Paris?, context: Paris is the capital of France.

Token IDs:
[822, 10, 2840, 19, 1919, 58, 6, 2625, 10, 1919, 19, 8, 1784, 13, 1410, 5, 1]

Tokens:
['▁question', ':', '▁Where', '▁is', '▁Paris', '?', ',', '▁context', ':', '▁Paris', '▁is', '▁the', '▁capital', '▁of', '▁France', '.', '</s>']

Attention Mask:
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]


Move tokenized tensors onto the same device as the model. The model might be on GPU while the tokenizer outputs are on CPU, and PyTorch requires both to be on the same device for a forward pass.

In [4]:
# Move tokenized tensors to the model device
encoding = tokenizer(text, return_tensors="pt")
input_ids = encoding["input_ids"].to(model.device)
attention_mask = encoding["attention_mask"].to(model.device)

print("input_ids device:", input_ids.device)
print("attention_mask device:", attention_mask.device)

input_ids device: cuda:0
attention_mask device: cuda:0


In [5]:
import torch

# Encoder forward only (no hooks)
with torch.no_grad(), torch.autocast(model.device.type, dtype=model.dtype):
    encoder_outputs = model.encoder(
        input_ids=input_ids,
        attention_mask=attention_mask
    )

print("Encoder output last_hidden_state:", encoder_outputs.last_hidden_state.shape)
print("This consists of the batch dimension (1), sequence length, and hidden state dimension.")


Encoder output last_hidden_state: torch.Size([1, 17, 768])
This consists of the batch dimension (1), sequence length, and hidden state dimension.



### Forward hook method


- **`router_classifier`**  
  The **linear layer inside the router** that maps each token’s hidden state to logits over the experts  
  (shape: `(hidden_dim, num_experts)`).

- **`classifier_outputs`**  
  A list used to **store the outputs captured by the hook**.  
  Hooks may fire multiple times, so outputs are accumulated safely here.

- **`classifier_hook(module, module_inputs, module_outputs)`**  
  A function that PyTorch **automatically calls when the classifier runs**.
  - `module`: the classifier layer itself (`nn.Linear`)
  - `module_inputs`: hidden-state tensors **entering the classifier**  
    (one vector per token)
  - `module_outputs`: **raw logits over experts**  
    (shape: `[num_tokens, num_experts] so (17,8) `)

- **Tuple handling inside the hook**  
  Some modules return `(output, extra_info)`.  
  This unwraps the tensor so we always store just the logits.

- **`register_forward_hook(...)`**  
  Attaches the hook to the classifier.  
  From this point on, **whenever the classifier executes**, the hook is triggered.

- **`with torch.no_grad()`**  
  Disables gradient tracking because we are **observing, not training**.  
  Saves memory and computation.

- **`torch.autocast(...)`**  
  Runs the forward pass in the model’s native precision (e.g. FP16 on GPU).  
  Prevents dtype mismatches and is standard practice for inference.

- **`model.encoder(...)`**  
  Executes a single encoder forward pass.  
  This is what **actually triggers the router and the hook**.

- **`handle.remove()`**  
  Detaches the hook immediately after the pass to avoid duplicate captures.

- **`classifier_outputs[0]`**  
  The captured tensor of **per-token expert logits**.  
  Each row corresponds to one token, each column to one expert.

- **Final prints**  
  Confirm tensor shape and inspect logits for a specific token and all experts.


In [6]:
# Hook the router classifier to capture per-expert logits
router_classifier = model.encoder.block[1].layer[1].mlp.router.classifier
classifier_outputs = []

# Define a hook function to capture the outputs of the router classifier
def classifier_hook(module, module_inputs, module_outputs):
    if isinstance(module_outputs, tuple):
        #Ensure we only capture the logits 
        module_outputs = module_outputs[0]
    classifier_outputs.append(module_outputs)


handle = router_classifier.register_forward_hook(classifier_hook)

with torch.no_grad(), torch.autocast(model.device.type, dtype=model.dtype):
    _ = model.encoder(
        input_ids=input_ids,
        attention_mask=attention_mask
    )

handle.remove()

logits = classifier_outputs[0].cpu()
print("Classifier logits shape:", tuple(logits.shape))
print("Classifier logits sample:", logits[13, :8].tolist())

Classifier logits shape: (17, 8)
Classifier logits sample: [0.05078125, -2.0, -0.17578125, 2.8125, -0.154296875, -0.173828125, 0.470703125, 0.279296875]


Convert logits into a concrete routing summary. We take the argmax to pick the top expert per token, compute entropy to measure confidence, then align those values with the tokenizer's subword tokens in a table.

In [7]:
import pandas as pd

import torch.nn.functional as F

# Compute expert assignment + entropy per token
logits = logits.float()
experts = logits.argmax(dim=-1)
entropy = -(F.softmax(logits, dim=-1) * F.log_softmax(logits, dim=-1)).sum(dim=-1)

tokens = tokenizer.convert_ids_to_tokens(input_ids[0].tolist())
df = pd.DataFrame({
    "TOKEN": tokens,
    "EXPERT": experts.cpu().numpy(),
    "LOGIT": logits.max(dim=-1).values.cpu().numpy(),
    "ENTROPY": entropy.cpu().numpy()
})
df

Unnamed: 0,TOKEN,EXPERT,LOGIT,ENTROPY
0,▁question,4,2.0,1.733095
1,:,2,2.46875,1.472535
2,▁Where,7,2.375,1.368643
3,▁is,4,2.71875,1.297853
4,▁Paris,7,2.265625,1.361608
5,?,2,2.359375,1.484937
6,",",2,2.0625,1.598489
7,▁context,5,2.65625,1.720571
8,:,2,2.796875,1.35567
9,▁Paris,7,2.921875,1.281017


Next we prepare the contents of Prompts_base and Prompts_context for input

In [12]:
from pathlib import Path
data_dir = Path("/content/drive/MyDrive/Individual-Project-25-26/stage1/data")
base_path = data_dir / "Prompt_base.jsonl"
context_path = data_dir / "Prompt_context.jsonl"

In [13]:
import json

def load_prompts(path):
    prompts = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            obj = json.loads(line)
            prompts.append(obj["prompt"])
    return prompts

base_prompts = load_prompts(base_path)
context_prompts = load_prompts(context_path)

prompt_sets = {
    "base": base_prompts,
    "context": context_prompts,
}

print("Loaded base prompts:", len(base_prompts))
print("Loaded context prompts:", len(context_prompts))

Loaded base prompts: 100
Loaded context prompts: 100


In [14]:
def tokenize_prompts(prompts):
    tokenized = []
    for prompt in prompts:
        encoding = tokenizer(prompt, return_tensors="pt")
        tokenized.append({
            "input_ids": encoding["input_ids"][0].tolist(),
            "attention_mask": encoding["attention_mask"][0].tolist(),
        })
    return tokenized



base_tokenized = tokenize_prompts(base_prompts)
context_tokenized = tokenize_prompts(context_prompts)

print("Base tokenized:", len(base_tokenized))
print("Context tokenized:", len(context_tokenized))

Base tokenized: 100
Context tokenized: 100


In [None]:
def inspect_tokenized(tokenized, prompts, label, idx=99):
    if not tokenized:
        print(f"{label}: no items")
        return
    idx = max(0, min(idx, len(tokenized) - 1))
    item = tokenized[idx]
    ids = item["input_ids"]
    mask = item["attention_mask"]
    print(f"{label} example {idx}:")
    print("prompt:", prompts[idx])
    print("num tokens:", len(ids))
    print("input_ids:", ids[:30])
    print("attention_mask:", mask[:30])
    print("tokens:", tokenizer.convert_ids_to_tokens(ids[:30]))

inspect_tokenized(base_tokenized, base_prompts, "base", idx=99)
inspect_tokenized(context_tokenized, context_prompts, "context", idx=99)

base example 99:
prompt: Where was the city originally located?
num tokens: 8
input_ids: [2840, 47, 8, 690, 5330, 1069, 58, 1]
attention_mask: [1, 1, 1, 1, 1, 1, 1, 1]
tokens: ['▁Where', '▁was', '▁the', '▁city', '▁originally', '▁located', '?', '</s>']
context example 0:
prompt: You must answer the question using ONLY the information provided in the context below.
If the answer cannot be determined from the context, respond with "Not answerable from the given context."
Do not use any external knowledge.

Context:
If Mammalia is considered as the crown group, its origin can be roughly dated as the first known appearance of animals more closely related to some extant mammals than to others. Ambondro is more closely related to monotremes than to therian mammals while Amphilestes and Amphitherium are more closely related to the therians; as fossils of all three genera are dated about 167 million years ago in the Middle Jurassic, this is a reasonable estimate for the appearance of the crown 