In [1]:
import torch
from transformers import OlmoeForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import Optional

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    model = OlmoeForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
def get_tokens_from_dataset(tokenizer, dataset_name, category, split="test"):
    """
    Load dataset from HuggingFace datasets, tokenize it and save as pkl file.
    Args:
        tokenizer : Tokenizer for the model
        dataset_name : Name of dataset on HuggingFace hub (e.g. "wikitext", "c4")
        category : Category/subset of the dataset (e.g. "abstract_algebra" for MMLU)
        split : Dataset split to load (default: "train")
    Returns :
        Tuple (output_path, inputs)
        output_path : Path to the pkl file containing tokenized dataset
        inputs : Dictionary with input_ids and attention_mask tensors for the entire dataset
    """
    
    # Load dataset with category
    dataset = load_dataset(dataset_name, category)
    
    # Check if split exists in dataset
    if split not in dataset:
        available_splits = list(dataset.keys())
        raise ValueError(f"Split '{split}' not found in dataset. Available splits: {available_splits}")
    
    # Get text field (most datasets use 'text' as the field name)
    texts = dataset[split]['text'] if 'text' in dataset[split].features else list(dataset[split][next(iter(dataset[split].features))])
    
    # Tokenize all texts at once with padding
    inputs = tokenizer(
        texts,
        padding=True,
        truncation=True,
        return_tensors="pt"
    )
    
    # Save tokenized inputs with sanitized path
    # Replace forward slashes with underscores to avoid directory issues
    safe_dataset_name = dataset_name.replace('/', '_')
    output_path = f"{safe_dataset_name}_{category}_{split}_tokenized.pkl"
    torch.save(inputs, output_path)
    
    return output_path, inputs


In [4]:
tokenized_path, inputs = get_tokens_from_dataset(tokenizer, dataset_name="cais/mmlu", category="anatomy", split="test")
print(f'inputs: {inputs}')

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


inputs: {'input_ids': tensor([[   34, 15411,  8479,  ...,     1,     1,     1],
        [   34,   346,    69,  ...,     1,     1,     1],
        [ 7371,   273,   253,  ...,     1,     1,     1],
        ...,
        [43228,   273, 28833,  ...,     1,     1,     1],
        [  688,  1821,    13,  ...,     1,     1,     1],
        [ 7371,   273,   253,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]])}


In [5]:
dataset = load_dataset("cais/mmlu", "astronomy")
print(f'dataset {dataset.keys()}')

print(dataset['test'][0])
print(f"Test set: {len(dataset['test'])} examples")
print(f"Validation set: {len(dataset['validation'])} examples")
print(f'dev set: {(dataset["dev"][2])} examples')

dataset dict_keys(['test', 'validation', 'dev'])
{'question': 'What is true for a type-Ia ("type one-a") supernova?', 'subject': 'astronomy', 'choices': ['This type occurs in binary systems.', 'This type occurs in young galaxies.', 'This type produces gamma-ray bursts.', 'This type produces high amounts of X-rays.'], 'answer': 0}
Test set: 152 examples
Validation set: 16 examples
dev set: {'question': 'Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?', 'subject': 'astronomy', 'choices': ['10000 times more', '100 times more', '1000 times more', '10 times more'], 'answer': 0} examples


In [11]:
def get_router_logits(model, input_text: str, layer_idx: Optional[int] = None, k: int = 1):
    """
    Get router logits for each token in the input text.
    
    Args:
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        layer_idx: Optional int specifying which layer to analyze. If None, analyze all layers.
        k: Number of top experts to return per token
        
    Returns:
        List of dictionaries, one per text, each containing:
        - tokens: List of tokens
        - router_probs: List of top-k router probabilities for each token in specified layer(s)
        - router_indices: List of top-k expert indices for each token in specified layer(s)
    """
    results = []
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Process router logits for each layer
    layer_probs = []
    layer_indices = []
    
    # Get router logits for the specified layer(s)
    router_logits = outputs.router_logits
    if layer_idx is not None:
        router_logits = [router_logits[layer_idx]]
        
    for layer_router_logits in router_logits:
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        layer_probs.append(top_probs.tolist())
        layer_indices.append(top_indices.tolist())
        
    results.append({
        "tokens": inputs,
        "router_probs": layer_probs,
        "router_indices": layer_indices
    })
    
    return results

In [12]:
input_text = dataset['test'][0]['question'] # retrieve question from dataset
print(f'input_text: {input_text}')
results = get_router_logits(model, input_text, layer_idx=0)
print(f'results: {results}')


input_text: What is true for a type-Ia ("type one-a") supernova?
results: [{'tokens': {'input_ids': tensor([[ 1276,   310,  2032,   323,   247,  1511,    14,    42,    66,  5550,
           881,   581,    14,    66,  2807, 13708,  8947,    32]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}, 'router_probs': [[[0.1212511956691742], [0.08246838301420212], [0.0790310949087143], [0.10655879229307175], [0.08738203346729279], [0.09577271342277527], [0.12028155475854874], [0.21959717571735382], [0.27170330286026], [0.08207882195711136], [0.10128765553236008], [0.13019609451293945], [0.17984884977340698], [0.13717882335186005], [0.08026150614023209], [0.0579301081597805], [0.07837120443582535], [0.051885128021240234]]], 'router_indices': [[[57], [14], [48], [50], [6], [32], [55], [27], [35], [6], [18], [55], [1], [27], [4], [49], [26], [6]]]}]
