In [1]:
import torch
from transformers import OlmoeForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import Optional

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    model = OlmoeForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
dataset = load_dataset("cais/mmlu", "astronomy")
print(f'dataset {dataset.keys()}')
print(dataset['test'][0])
print(f"Test set: {len(dataset['test'])} examples")
print(f"Validation set: {len(dataset['validation'])} examples")
print(f'dev set: {(dataset["dev"][2])} examples')

dataset dict_keys(['test', 'validation', 'dev'])
{'question': 'What is true for a type-Ia ("type one-a") supernova?', 'subject': 'astronomy', 'choices': ['This type occurs in binary systems.', 'This type occurs in young galaxies.', 'This type produces gamma-ray bursts.', 'This type produces high amounts of X-rays.'], 'answer': 0}
Test set: 152 examples
Validation set: 16 examples
dev set: {'question': 'Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?', 'subject': 'astronomy', 'choices': ['10000 times more', '100 times more', '1000 times more', '10 times more'], 'answer': 0} examples


In [5]:
def get_router_logits(model, input_text: str, layer_idx: Optional[int] = None, k: int = 1):
    """
    Get router logits for each token in the input text.
    
    Args:
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        layer_idx: Optional int specifying which layer to analyze. If None, analyze all layers.
        k: Number of top experts to return per token
        
    Returns:
        List of lists, where each inner list contains:
        [token_text, expert_index, router_probability] for each token
    """
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for the specified layer(s)
    router_logits = outputs.router_logits
    if layer_idx is not None:
        router_logits = [router_logits[layer_idx]]
        
    results = []
    for layer_router_logits in router_logits:
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Create list of [token, expert, prob] for each token
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                results.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
    
    return results # a list of lists, where each inner list contains [token, expert_number, probability]

In [6]:
# input_text = dataset['test'][0]['question'] # retrieve question from dataset
input_text = 'Let \\[f(x) = \\left\\{\n\\begin{array}{cl} ax+3, &\\text{ if }x>2, \\\\\nx-5 &\\text{ if } -2 \\le x \\le 2, \\\\\n2x-b &\\text{ if } x <-2.\n\\end{array}\n\\right.\\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).'
print(f'input_text: {input_text}')
results = get_router_logits(model, input_text, layer_idx=0)
print(f'results: {results}')

input_text: Let \[f(x) = \left\{
\begin{array}{cl} ax+3, &\text{ if }x>2, \\
x-5 &\text{ if } -2 \le x \le 2, \\
2x-b &\text{ if } x <-2.
\end{array}
\right.\]Find $a+b$ if the piecewise function is continuous (which means that its graph can be drawn without lifting your pencil from the paper).
results: [['Let', 57, 0.1621944010257721], ['\\[', 57, 0.09745539724826813], ['f', 54, 0.09076446294784546], ['(', 14, 0.09150239080190659], ['x', 54, 0.13038761913776398], [')', 6, 0.07337653636932373], ['=', 49, 0.08847833424806595], ['\\', 33, 0.11007364839315414], ['left', 29, 0.07177269458770752], ['\\{', 47, 0.09915725141763687], ['Ċ', 54, 0.08345343172550201], ['\\', 6, 0.0734686404466629], ['begin', 29, 0.04893956705927849], ['{', 47, 0.11041700839996338], ['array', 8, 0.0676625519990921], ['}{', 6, 0.05729150399565697], ['cl', 0, 0.15870289504528046], ['}', 55, 0.06711704283952713], ['ax', 0, 0.1668691635131836], ['+', 44, 0.08225768059492111], ['3', 2, 0.0885232612490654], [',', 36, 0.

In [8]:
from collections import defaultdict
import plotly.graph_objects as go

# Create a dictionary to store expert counts
expert_counts = defaultdict(int)

# Count how many tokens went to each expert
total_tokens = len(set(token for token, _, _ in results))
# unpacking the results list into token, expert, prob
for token, expert, prob in results:
    # print(f'token: {token}, expert: {expert}, prob: {prob}')
    expert_counts[expert] += 1

# Convert to lists for plotting and calculate percentages
experts = [f'{i}' for i in range(64)]
percentages = [expert_counts[i]/total_tokens * 100 for i in range(64)]

# Create bar chart
fig = go.Figure(data=[
    go.Bar(
        x=experts,
        y=percentages,
        # text=[f'{p:.1f}%' for p in percentages],
        textposition='auto',
    )
])

fig.update_layout(
    title='percentage of total tokens routed to each expert',
    xaxis_title='expert',
    yaxis_title='% of total tokens',
    xaxis_tickangle=-45,
    bargap=0.2
)

fig.show()
