In [1]:
import torch
from transformers import OlmoeForCausalLM, AutoTokenizer
from datasets import load_dataset
from typing import Optional

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    model = OlmoeForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
dataset = load_dataset("cais/mmlu", "astronomy")
print(f'dataset {dataset.keys()}')
print(dataset['test'][0])
print(f"Test set: {len(dataset['test'])} examples")
print(f"Validation set: {len(dataset['validation'])} examples")
print(f'dev set: {(dataset["dev"][2])} examples')

dataset dict_keys(['test', 'validation', 'dev'])
{'question': 'What is true for a type-Ia ("type one-a") supernova?', 'subject': 'astronomy', 'choices': ['This type occurs in binary systems.', 'This type occurs in young galaxies.', 'This type produces gamma-ray bursts.', 'This type produces high amounts of X-rays.'], 'answer': 0}
Test set: 152 examples
Validation set: 16 examples
dev set: {'question': 'Say the pupil of your eye has a diameter of 5 mm and you have a telescope with an aperture of 50 cm. How much more light can the telescope gather than your eye?', 'subject': 'astronomy', 'choices': ['10000 times more', '100 times more', '1000 times more', '10 times more'], 'answer': 0} examples


In [4]:
def get_router_logits(model, input_text: str, layer_idx: Optional[int] = None, k: int = 1):
    """
    Get router logits for each token in the input text.
    
    Args:
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        layer_idx: Optional int specifying which layer to analyze. If None, analyze all layers.
        k: Number of top experts to return per token
        
    Returns:
        List of lists, where each inner list contains:
        [token_text, expert_index, router_probability] for each token
    """
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for the specified layer(s)
    router_logits = outputs.router_logits
    if layer_idx is not None:
        router_logits = [router_logits[layer_idx]]
        
    results = []
    for layer_router_logits in router_logits:
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Create list of [token, expert, prob] for each token
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                results.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
    
    return results # a list of lists, where each inner list contains [token, expert_number, probability]

In [22]:
# input_text = dataset['test'][0]['question'] # retrieve question from dataset
input_text = 'Square ABCD has its center at $(8,-8)$ and has an area of 4 square units. The top side of the square is horizontal. The square is then dilated with the dilation center at (0,0) and a scale factor of 2. What are the coordinates of the vertex of the image of square ABCD that is farthest from the origin? Give your answer as an ordered pair.'
print(f'input_text: {input_text}')
results = get_router_logits(model, input_text, layer_idx=0)
print(f'results: {results}')

input_text: Square ABCD has its center at $(8,-8)$ and has an area of 4 square units. The top side of the square is horizontal. The square is then dilated with the dilation center at (0,0) and a scale factor of 2. What are the coordinates of the vertex of the image of square ABCD that is farthest from the origin? Give your answer as an ordered pair.
results: [['S', 56, 0.16368664801120758], ['quare', 56, 0.10670456290245056], ['AB', 17, 0.09328509867191315], ['CD', 25, 0.09634575992822647], ['has', 3, 0.07668595761060715], ['its', 48, 0.1469525843858719], ['center', 48, 0.15707798302173615], ['at', 47, 0.07011132687330246], ['$(', 26, 0.11473149806261063], ['8', 54, 0.11328016221523285], [',-', 54, 0.07860477268695831], ['8', 27, 0.1017904132604599], [')$', 25, 0.07222343981266022], ['and', 53, 0.0700511634349823], ['has', 2, 0.1669037938117981], ['an', 48, 0.08551409095525742], ['area', 4, 0.075767382979393], ['of', 46, 0.12270712852478027], ['4', 55, 0.09616360068321228], ['square', 

In [23]:
import pandas as pd

def save_to_parquet(results, output_path='expert_counts.parquet'):
    """
    Convert router logits results to a Parquet file.
    
    Args:
        results : List of lists containing [token, expert_number, probability]
        output_path : Path to save the Parquet file
    """
    # Create DataFrame from results
    df = pd.DataFrame(results, columns=['token', 'expert_number', 'probability'])
    
    # Convert types explicitly
    df['token'] = df['token'].astype(str)
    df['expert_number'] = df['expert_number'].astype(int) 
    df['probability'] = df['probability'].astype(float)
    
    # Save to Parquet
    df.to_parquet(output_path, index=False)
    return df

save_to_parquet(results)

Unnamed: 0,token,expert_number,probability
0,S,56,0.163687
1,quare,56,0.106705
2,AB,17,0.093285
3,CD,25,0.096346
4,has,3,0.076686
...,...,...,...
78,as,22,0.069170
79,an,62,0.086820
80,ordered,4,0.084952
81,pair,59,0.123608


In [25]:
from collections import defaultdict
import plotly.graph_objects as go


# # Create a dictionary to store expert counts
# expert_counts = defaultdict(int)

# # Count how many tokens went to each expert
# total_tokens = len(set(token for token, _, _ in results))
# # unpacking the results list into token, expert, prob
# for token, expert, prob in results:
#     # print(f'token: {token}, expert: {expert}, prob: {prob}')
#     expert_counts[expert] += 1


def plot_expert_distribution(parquet_path='expert_counts.parquet'):
    """
    Read expert counts from Parquet file and create distribution plot.
    
    Args:
        parquet_path: Path to the Parquet file containing expert counts
    """
    # Read parquet file
    df = pd.read_parquet(parquet_path)
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Count how many tokens went to each expert
    total_tokens = len(set(df['token']))
    
    # Count occurrences of each expert
    for expert in df['expert_number']:
        expert_counts[expert] += 1
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_tokens * 100 for i in range(64)]
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title='percentage of total tokens routed to each expert',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig

fig = plot_expert_distribution()
fig.show()