In [1]:
import torch
from transformers import OlmoeForCausalLM, AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json
import os
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### split text file into tokens for model's context length

In [4]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor(tokens).to(device)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens_tensor[i:i + chunk_size]
            # Move to CPU for decoding
            chunk_tokens = chunk_tokens.cpu()
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else:
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

### get the router logits for each token across all layers

In [5]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Move tensors to CPU for post-processing
        top_probs = top_probs.cpu()
        top_indices = top_indices.cpu()
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

### update/create the router logits json file with new tokens

In [53]:
def update_router_logits_json(results, domain, device='cuda'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        domain: String indicating the domain (e.g., 'arxiv', 'code')
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : updated json file with new tokens
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'

    
    # Initialize an empty dictionary for existing results
    existing_results = {}
    
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f:
            try:
                existing_results = json.load(f)
                # Convert string keys to integers
                existing_results = {int(k): v for k, v in existing_results.items()}
            except json.JSONDecodeError:
                print(f"Warning: {json_path} is empty or corrupted. Starting with an empty dictionary.")
    
    # Move results to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        for layer_idx, layer_tokens in results.items():
            # Convert lists to tensors and move to GPU
            tokens_tensor = torch.tensor([[t[0], t[1], t[2]] for t in layer_tokens]).cuda()
            results[layer_idx] = tokens_tensor.tolist()
    
    # Combine existing and new results for each layer
    for layer_idx, layer_tokens in results.items():
        if layer_idx in existing_results:
            existing_results[layer_idx].extend(layer_tokens)
        else:
            existing_results[layer_idx] = layer_tokens
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(existing_results, f, indent=4, ensure_ascii=False)
        
    return existing_results

### plot the expert distribution for a particular layer

In [54]:
def plot_expert_distribution(layer_idx, domain, device='cuda'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : plot of the expert distribution for a particular layer
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Extract layer results
    layer_results = data[str(layer_idx)]
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Move data to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        layer_results = torch.tensor(layer_results).cuda()
        
    # Count how many tokens went to each expert
    total_assignments = len(layer_results)
    print(f'Total assignments: {total_assignments}')
    
    # Count occurrences of each expert
    if torch.cuda.is_available() and device == 'cuda':
        # Process on GPU
        for _, expert, _ in layer_results.cpu().numpy():
            expert_counts[int(expert)] += 1
    else:
        # Process on CPU
        for _, expert, _ in layer_results:
            expert_counts[expert] += 1
            
    print(f'Expert counts: {expert_counts}')
    print(f'Total experts: {len(expert_counts)}')
    print(f'Expert count for l0', expert_counts[0])
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_assignments * 100 for i in range(64)]
    
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title=f'percentage of total tokens routed to each expert for layer {layer_idx}',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig


### expert distribution for a single text input

In [57]:
# # Read and chunk input file
# file_path = 'arxiv.txt'
# domain = 'arxiv'
# chunks = prepare_text_input(file_path, chunk_size=4096, tokenizer=tokenizer)
# print(f'Number of chunks: {len(chunks)}')
# # Process first chunk
# first_chunk = chunks[3]
# # print(f'Processing text : {first_chunk[:100]}...')  # Print first 100 chars

# # Get router logits for the chunk
# results = get_router_logits(model, first_chunk)

# # Save results for analysis
# update_router_logits_json(results, domain=domain)

# # Save results to parquet for visualization
# layer_to_plot = 0  # Analyze first layer

# # Plot expert distribution
# fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
# fig.show()

### expert distribution for all text input

In [55]:
# Read and chunk input text file
file_path = 'github.txt'
domain = 'github'
chunks = prepare_text_input(file_path, chunk_size=4096, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

Processing chunk 1/2
Processing chunk 2/2


In [56]:
layer_to_plot = 0  # Analyze first layer

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Move data to GPU if available
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert, 
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze routing for first few tokens in layer 0
print("\nRouting for first 5 tokens in layer 0: ")
layer_results = combined_results[layer_to_plot]  # Layer 0
for token_info in layer_results[:5]:  # Limit to first 5 tokens
    token, expert, prob = token_info
    # Move to CPU for printing
    if torch.cuda.is_available():
        prob = prob.cpu()
    print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

# Plot expert distribution for all processed data
with torch.cuda.amp.autocast():  # Enable automatic mixed precision
    fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
fig.show()

# Clear GPU cache
if torch.cuda.is_available():
    torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 
Token: import, Expert: 57, Probability: 0.241
Token: *, Expert: 8, Probability: 0.218
Token: as, Expert: 21, Probability: 0.204
Token: THREE, Expert: 21, Probability: 0.212
Token: from, Expert: 25, Probability: 0.113
Total assignments: 5197
Expert counts: defaultdict(<class 'int'>, {57: 2, 8: 32, 21: 1218, 25: 82, 6: 41, 26: 8, 14: 43, 2: 147, 3: 26, 54: 284, 53: 98, 4: 6, 27: 60, 42: 92, 15: 442, 44: 70, 29: 195, 36: 109, 52: 285, 37: 42, 62: 40, 16: 27, 17: 63, 55: 59, 48: 22, 33: 89, 56: 70, 61: 56, 59: 68, 32: 19, 13: 51, 30: 54, 60: 122, 41: 255, 49: 86, 18: 44, 31: 231, 40: 65, 47: 159, 35: 48, 50: 56, 45: 8, 58: 30, 10: 48, 24: 10, 20: 5, 1: 20, 5: 16, 22: 6, 63: 6, 7: 7, 51: 56, 11: 4, 43: 14, 19: 1})
Total experts: 55
Expert count for l0 0
