In [57]:
import torch
from transformers import OlmoeForCausalLM, AutoTokenizer
from datasets import load_dataset
import json
import os
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    model = OlmoeForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### split text file into tokens for model's context length

In [56]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens[i:i + chunk_size]
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else :
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

### get the router logits for each token across all layers

In [13]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

### update/create the router logits json file with new tokens

In [58]:
def update_router_logits_json(results, json_path='router_logits_all_layers.json'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        json_path: Path to the JSON file
    output : updated json file with new tokens
    """
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f :
            existing_results = json.load(f)
            # Convert string keys to integers
            existing_results = {int(k): v for k, v in existing_results.items()}
            
        # Combine existing and new results for each layer
        for layer_idx, layer_tokens in results.items():
            if layer_idx in existing_results:
                existing_results[layer_idx].extend(layer_tokens)
            else:
                existing_results[layer_idx] = layer_tokens
                
        combined_results = existing_results
    else :
        # Create new JSON with results
        combined_results = results
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(combined_results, f)
        
    return combined_results

### save the router logits for all tokens to a parquet file

In [None]:
def save_to_parquet(results, layer_idx):
    """    
    args :
        layer_results : List of lists containing [token, expert_number, probability] for a specific layer
        layer_idx : Index of the layer being processed
        output_path : Path to save the Parquet file
    """
    output_path=f'expert_counts_layer_{layer_idx}.parquet'

    layer_results = results[layer_idx]
    # Create DataFrame from new results
    new_df = pd.DataFrame(layer_results, columns=['token', 'expert_number', 'probability'])
    
    # Add heading
    new_df.columns.name = f'Layer {layer_idx} Router Logits'
    
    # Convert types explicitly
    new_df['token'] = new_df['token'].astype(str)
    new_df['expert_number'] = new_df['expert_number'].astype(int)
    new_df['probability'] = new_df['probability'].astype(float)
    
    if os.path.exists(output_path):
        # Read existing dataframe and append new results
        existing_df = pd.read_parquet(output_path)
        combined_df = pd.concat([existing_df, new_df], ignore_index=True)
    else:
        # Create new dataframe if file doesn't exist
        combined_df = new_df
        
    # Save to Parquet
    combined_df.to_parquet(output_path, index=False)
    return combined_df

### plot the expert distribution for a particular layer

In [59]:
# # Create a dictionary to store expert counts
# expert_counts = defaultdict(int)

# # Count how many tokens went to each expert
# total_tokens = len(set(token for token, _, _ in results))
# # unpacking the results list into token, expert, prob
# for token, expert, prob in results:
#     # print(f'token: {token}, expert: {expert}, prob: {prob}')
#     expert_counts[expert] += 1


def plot_expert_distribution(layer_idx):
    """    
    args :
        parquet_path: Path to the Parquet file containing expert counts
    output : plot of the expert distribution for a particular layer
    """
    parquet_path=f'expert_counts_layer_{layer_idx}.parquet'
    # Read parquet file
    df = pd.read_parquet(parquet_path)
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Count how many tokens went to each expert
    total_tokens = len(set(df['token']))
    
    # Count occurrences of each expert
    for expert in df['expert_number']:
        expert_counts[expert] += 1
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_tokens * 100 for i in range(64)]
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title=f'percentage of total tokens routed to each expert for layer {layer_idx}',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig

### expert distribution for a single text input

In [None]:
# Read and chunk input file
file_path = 'github_oss_with_stack_texts.txt'
chunks = prepare_text_input(file_path, chunk_size=4096, tokenizer=tokenizer)

# Process first chunk
first_chunk = chunks[0]
print(f'Processing text : {first_chunk[:100]}...')  # Print first 100 chars

# Get router logits for the chunk
results = get_router_logits(model, first_chunk)

# Save results for analysis
update_router_logits_json(results)

# Analyze routing for first few tokens in layer 0
print("\nRouting for first 5 tokens in layer 0: ")
layer_results = results[0]  # Layer 0
for token_info in layer_results[:5]:
    token, expert, prob = token_info
    print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

# Save results to parquet for visualization
layer_to_plot = 0  # Analyze first layer
save_to_parquet(results, layer_idx=layer_to_plot)

# Plot expert distribution
fig = plot_expert_distribution(layer_idx=layer_to_plot)
fig.show()

### expert distribution for all text input

In [None]:
# Read and chunk input file
file_path = 'github_oss_with_stack_texts.txt'
chunks = prepare_text_input(file_path, chunk_size=4096, tokenizer=tokenizer)

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    print(f'Sample text: {chunk[:100]}...')  # Print first 100 chars
    
    # Get router logits for the chunk
    results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results
    update_router_logits_json(results)

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        layer_combined.extend(chunk_result[layer_idx])
    combined_results.append(layer_combined)

# Analyze routing for first few tokens in layer 0
print("\nRouting for first 5 tokens in layer 0: ")
layer_results = combined_results[0]  # Layer 0
for token_info in layer_results[:5]:
    token, expert, prob = token_info
    print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

# Save combined results to parquet for visualization
layer_to_plot = 0  # Analyze first layer
save_to_parquet(combined_results, layer_idx=layer_to_plot)

# Plot expert distribution for all processed data
fig = plot_expert_distribution(layer_idx=layer_to_plot)
fig.show()