# Untrained model routing patterns with a sample dataset

In [1]:
import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

MODEL_ID = "AllenAI/OLMoE-1B-7B-0924"

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

# Load config only (no weights)
config = AutoConfig.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_config(config)  # On CPU

  from scipy.sparse import csr_matrix, issparse


### problems with model configuration

model = AutoModelForCausalLM.from_config(config)

#### key limitations
- no pre-trained weight
- no quantization
- no sharding
- no fp16 computing
- large model entirely in fp32 by default


### torch.nn.DataParallel is not a solution
#### problems
-  model is first fully loaded onto cuda:0
- remaining GPUs only used for gradient calculations (forward/backward)
- model memory is *not* split
=> leads to OOM error


-> solution here : use CPU
 - Inference is slow
 - But sufficient for checking routing behavior

In [2]:
model.eval()

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

In [3]:
# Dictionary to store the router outputs
router_logs = {}

def get_router_hook(layer_name):
    def hook(module, inputs, outputs):
        if isinstance(outputs, tuple):
            val = outputs[1] 
        else:
            val = outputs
        
        if isinstance(val, torch.Tensor) and val.shape[-1] == 64:
            # Detach and move to CPU to save GPU memory
            router_logs[layer_name] = val.detach().float().cpu()
    return hook

for name, module in model.named_modules():
    if hasattr(module, "_forward_hooks"):
        module._forward_hooks.clear()

print("Attaching spy hooks...")
hook_count = 0
for name, module in model.named_modules():
    # We target the specific gate layer in the OLMoE architecture
    if name.endswith("mlp.gate"): 
        module.register_forward_hook(get_router_hook(name))
        hook_count += 1

print(f"Attached {hook_count} hooks (Expect 32 for this model).")

Attaching spy hooks...
Attached 16 hooks (Expect 32 for this model).


In [21]:
import json
from collections import defaultdict, Counter
from tqdm import tqdm

# Load test dataset
from collections import defaultdict
lang_samples = defaultdict(list)

for item in test_data:
    lang = item["lang"]
    if len(lang_samples[lang]) < 30:
        lang_samples[lang].append(item)

sampled_test_data = sum(lang_samples.values(), [])

In [22]:
sampled_test_data

[{'lang': 'eng_Latn',
  'text': '3189\nSociologist Steven Seidman, author of such books as Beyond the Closet and Romantic Longings, recently interviewed me about lesbian identities and the changing landscape of sexual politics'},
 {'lang': 'eng_Latn',
  'text': 'Starting a blog is a bit like keeping a diary, which I did for so many years–throughout my teens and twenties'},
 {'lang': 'eng_Latn',
  'text': '206d\nKristof criticized professors for fostering a culture that “glorifies arcane unintelligibility while disdaining impact and audience'},
 {'lang': 'eng_Latn',
  'text': 'f51\nKristof criticized professors for fostering a culture that “glorifies arcane unintelligibility while disdaining impact and audience'},
 {'lang': 'eng_Latn',
  'text': '3e9\nA Palestinian man whose face is covered with a kefeyya emblazoned with Arabic writing points his Kalashnikov at the viewer'},
 {'lang': 'eng_Latn',
  'text': 'Same-sex parents and their children must grapple with non-recognition in the eye

In [24]:
# Start inference loop
results = []

for idx, item in enumerate(tqdm(sampled_test_data)):
    langs = item["lang"]
    text = item["text"]
    
    if isinstance(langs, str):
        langs = [langs]
    
    router_logs.clear()  # Clear old logs

    # Tokenize (on CPU)
    inputs = tokenizer(text, return_tensors="pt")  # don't send to device

    with torch.no_grad():
        _ = model(**inputs)

    # Analyze a specific layer (14)
    target_layer = "model.layers.14.mlp.gate"
    
    if target_layer in router_logs:
        logits = router_logs[target_layer]  # shape: [seq_len, num_experts]
        expert_ids = logits.argmax(dim=-1).view(-1).tolist()  # Most chosen expert per token

        for token_idx, expert_id in enumerate(expert_ids):
            for lang in langs:
                results.append({
                    "lang": lang,
                    "sentence_id": idx,
                    "token_idx": token_idx,
                    "expert_id": expert_id
                })
        if (idx + 1) % 100 == 0:
            pd.DataFrame(results).to_csv("partial_results.csv", index=False)
            results.clear()

100%|███████████████████████████████████████████████████████████████████████████████| 120/120 [2:36:08<00:00, 78.07s/it]


In [28]:
df = pd.read_csv("partial_results.csv")
lang_expert_counter = defaultdict(Counter)

for _, row in df.iterrows():
    lang = row["lang"]
    expert_id = row["expert_id"]
    lang_expert_counter[lang][expert_id] += 1

for lang, counter in lang_expert_counter.items():
    print(f"\nTop experts used in language: {lang}")
    for expert_id, count in counter.most_common(10):
        print(f"  Expert {expert_id}: {count} times")


Top experts used in language: eng_Latn
  Expert 40: 126 times
  Expert 2: 57 times
  Expert 42: 47 times
  Expert 14: 35 times
  Expert 57: 31 times
  Expert 22: 30 times
  Expert 36: 30 times
  Expert 12: 28 times
  Expert 60: 27 times
  Expert 18: 27 times

Top experts used in language: bul_Cyrl
  Expert 11: 190 times
  Expert 42: 168 times
  Expert 24: 131 times
  Expert 2: 129 times
  Expert 60: 106 times
  Expert 28: 83 times
  Expert 1: 82 times
  Expert 4: 76 times
  Expert 44: 62 times
  Expert 6: 54 times

Top experts used in language: kor_Hang
  Expert 50: 494 times
  Expert 11: 285 times
  Expert 15: 263 times
  Expert 31: 209 times
  Expert 35: 209 times
  Expert 33: 208 times
  Expert 21: 207 times
  Expert 42: 149 times
  Expert 12: 129 times
  Expert 28: 124 times

Top experts used in language: jpn_Jpan
  Expert 17: 117 times
  Expert 2: 114 times
  Expert 21: 106 times
  Expert 49: 87 times
  Expert 26: 73 times
  Expert 6: 54 times
  Expert 18: 39 times
  Expert 3: 25

# Pre-trained model routing patterns with a sample dataset

In [32]:
from transformers import BitsAndBytesConfig

# Define the model (AllenAI OLMoE-1B-7B)
MODEL_ID = "AllenAI/OLMoE-1B-7B-0924"

# Configure 4-bit quantization to save memory
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4"
)

print(f"Loading {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    quantization_config=quant_config,
    device_map="auto"
)
print(f"Success! Model loaded on {model.device}")

Loading AllenAI/OLMoE-1B-7B-0924...


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Success! Model loaded on cuda:0


In [35]:
# Dictionary to store the router outputs
router_logs = {}

# Clear old hooks (if re-running)
for name, module in model.named_modules():
    if hasattr(module, "_forward_hooks"):
        module._forward_hooks.clear()

# Attach new hooks to the router layers
print("Attaching spy hooks...")
hook_count = 0
for name, module in model.named_modules():
    # We target the specific gate layer in the OLMoE architecture
    if name.endswith("mlp.gate"): 
        module.register_forward_hook(get_router_hook(name))
        hook_count += 1

print(f"Attached {hook_count} hooks (Expect 32 for this model).")

Attaching spy hooks...
Attached 16 hooks (Expect 32 for this model).


In [36]:
from tqdm import tqdm
# language_expert count dictionary
language_expert_counts = defaultdict(Counter)

for item in tqdm(sampled_test_data, desc="Processing samples"):
    langs = item["lang"]
    text = item["text"]
    
    if isinstance(langs, str):
        langs = [langs]
        
    # Clear router logs before each inference
    router_logs.clear()

    # Run inference
    inputs = tokenizer(text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        _ = model(**inputs)

    # Analyze a specific layer (14)
    target_layer = "model.layers.14.mlp.gate"
    
    if target_layer in router_logs:
        logits = router_logs[target_layer] # [Seq, Experts]
        expert_ids = logits.argmax(dim=-1).view(-1).tolist()  # Top Experts

        # Count expert selection for each language tag
        for lang in langs:
            language_expert_counts[lang].update(expert_ids)

Processing samples: 100%|█████████████████████████████████████████████████████████████| 120/120 [02:02<00:00,  1.02s/it]


In [37]:
for lang, counter in language_expert_counts.items():
    print(f"\nTop experts used in language: {lang}")
    for expert_id, count in counter.most_common(10):
        print(f"  Expert {expert_id}: {count} times")


Top experts used in language: eng_Latn
  Expert 49: 65 times
  Expert 19: 47 times
  Expert 63: 37 times
  Expert 60: 34 times
  Expert 18: 32 times
  Expert 35: 25 times
  Expert 46: 24 times
  Expert 55: 23 times
  Expert 20: 23 times
  Expert 47: 22 times

Top experts used in language: bul_Cyrl
  Expert 13: 1350 times
  Expert 60: 32 times
  Expert 47: 21 times
  Expert 19: 21 times
  Expert 52: 19 times
  Expert 12: 19 times
  Expert 27: 13 times
  Expert 26: 12 times
  Expert 1: 12 times
  Expert 61: 10 times

Top experts used in language: kor_Hang
  Expert 13: 2580 times
  Expert 19: 314 times
  Expert 47: 203 times
  Expert 34: 91 times
  Expert 52: 73 times
  Expert 61: 37 times
  Expert 9: 35 times
  Expert 60: 32 times
  Expert 12: 22 times
  Expert 39: 18 times

Top experts used in language: jpn_Jpan
  Expert 13: 937 times
  Expert 24: 295 times
  Expert 19: 228 times
  Expert 36: 141 times
  Expert 47: 132 times
  Expert 34: 85 times
  Expert 61: 79 times
  Expert 52: 51 t