In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json
import os
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go
import numpy as np

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### split text file into tokens for model's context length

In [3]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    device = 'cpu'
    
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor(tokens).to(device)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens_tensor[i:i + chunk_size]
            # Move to CPU for decoding
            chunk_tokens = chunk_tokens.cpu()
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else:
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

### get the router logits for each token across all layers

In [4]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    device = "cpu"
    model = model.to(device)
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Move tensors to CPU for post-processing
        top_probs = top_probs.cpu()
        top_indices = top_indices.cpu()
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

### update/create the router logits json file with new tokens

In [5]:
def update_router_logits_json(results, domain, device='cpu'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        domain: String indicating the domain (e.g., 'arxiv', 'code')
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : updated json file with new tokens
    """

    json_path = f'{domain}_all_layers.json'
    
    # Initialize an empty dictionary for existing results
    existing_results = {}
    
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f:
            try:
                existing_results = json.load(f)
                # Convert string keys to integers
                existing_results = {int(k): v for k, v in existing_results.items()}
            except json.JSONDecodeError:
                print(f"Warning: {json_path} is empty or corrupted. Starting with an empty dictionary.")
    
    # Move results to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        for layer_idx, layer_tokens in results.items():
            # Convert lists to tensors and move to GPU
            tokens_tensor = torch.tensor([[t[0], t[1], t[2]] for t in layer_tokens]).cuda()
            results[layer_idx] = tokens_tensor.tolist()
    
    # Combine existing and new results for each layer
    for layer_idx, layer_tokens in results.items():
        if layer_idx in existing_results:
            existing_results[layer_idx].extend(layer_tokens)
        else:
            existing_results[layer_idx] = layer_tokens
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(existing_results, f, indent=4, ensure_ascii=False)
        
    return existing_results

### expert distribution bar graphs for a particular layer

In [6]:
def plot_expert_distribution_bar_graph(layer_idx, domain, device='cpu'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : plot of the expert distribution for a particular layer
    """
    json_path = f'{domain}_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Extract layer results
    layer_results = data[str(layer_idx)]
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Move data to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        layer_results = torch.tensor(layer_results).cuda()
        
    # Count how many tokens went to each expert
    total_assignments = len(layer_results)
    print(f'Total assignments: {total_assignments}')
    
    # Count occurrences of each expert
    if torch.cuda.is_available() and device == 'cuda':
        # Process on GPU
        for _, expert, _ in layer_results.cpu().numpy():
            expert_counts[int(expert)] += 1
    else:
        # Process on CPU
        for _, expert, _ in layer_results:
            expert_counts[expert] += 1
            
    print(f'Expert counts: {expert_counts}')
    print(f'Total experts: {len(expert_counts)}')
    print(f'Expert count for l0', expert_counts[0])
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_assignments * 100 for i in range(64)]
    
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title=f'percentage of total tokens routed to each expert for layer {layer_idx}',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig

### expert distribution heat map

In [None]:
def plot_expert_distribution_heatmap(domain, device='cpu'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : heatmap showing distribution of tokens across experts and layers
    """
    json_path = f'{domain}_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Create a 16x64 matrix to store percentages
    expert_matrix = np.zeros((16, 64))
    
    # Process each layer
    for layer in range(16):
        if str(layer) not in data:
            continue
            
        layer_results = data[str(layer)]
        total_assignments = len(layer_results)
        
        # Count expert assignments for this layer
        expert_counts = defaultdict(int)
        if torch.cuda.is_available() and device == 'cuda':
            layer_results = torch.tensor(layer_results).cuda()
            for _, expert, _ in layer_results.cpu().numpy():
                expert_counts[int(expert)] += 1
        else:
            for _, expert, _ in layer_results:
                expert_counts[expert] += 1
                
        # Calculate percentages for each expert
        for expert in range(64):
            expert_matrix[layer][expert] = expert_counts[expert] / total_assignments * 100
    
    # Create and return a single heatmap
    fig = go.Figure(data=go.Heatmap(
        z=expert_matrix,
        x=[str(i) for i in range(64)],
        y=[str(i) for i in range(16)],
        colorscale='Reds'
    ))
    
    fig.update_layout(
        title='Distribution of Tokens Across Experts and Layers',
        xaxis_title='Expert Index',
        yaxis_title='Layer',
        width=800,
        height=800,
        xaxis=dict(
            tickangle=-45,
            constrain='domain'
        ),
        yaxis=dict(
            scaleanchor='x',
            scaleratio=1
        )
    )
    
    return fig

### expert distribution for all text input and plotting

In [11]:
# Read and chunk input text file
file_path = 'data/github.txt'
domain = 'github'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Use CPU device
device = 'cpu'
model = model.to(device)  # Move model to CPU

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)

Processing chunk 1/3
Processing chunk 2/3
Processing chunk 3/3


In [None]:
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available() and device == 'cuda':
            prob = prob
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution_bar_graph(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/github/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/github/{domain}_layer{layer_to_plot}_expert_dist.png')


In [None]:
domain = 'github'

# Plot expert distribution for all processed data
fig = plot_expert_distribution_heatmap(domain=domain)
fig.show()

# Save plot as HTML and image
fig.write_html(f'plots/{domain}_expert_dist.html')
fig.write_image(f'plots/{domain}_expert_dist.png')