In [None]:
import os

os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ["PYTORCH_TRANSFORMERS_SDP_BACKEND"] = "flash"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import json
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [None]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

In [None]:
def print_expert_weights(model, layer_idx, expert_idx):
    """
    Print the weights of a specific expert MLP at a given layer.
    
    Args:
        model: The OLMoE model
        layer_idx: Index of the layer containing the expert
        expert_idx: Index of the expert within the layer
    """
    gate_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight'
    up_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight'
    down_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight'
    
    print("\nGate Projection:")
    print(model.state_dict()[gate_proj])
    print("\nUp Projection:") 
    print(model.state_dict()[up_proj])
    print("\nDown Projection:")
    print(model.state_dict()[down_proj])


In [None]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    device = 'cpu'
    
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor(tokens).to(device)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens_tensor[i:i + chunk_size]
            # Move to CPU for decoding
            chunk_tokens = chunk_tokens.cpu()
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else:
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

In [None]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    device = "cpu"
    model = model.to(device)
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Move tensors to CPU for post-processing
        top_probs = top_probs.cpu()
        top_indices = top_indices.cpu()
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

In [None]:
def update_router_logits_json(results, domain, device='cpu'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        domain: String indicating the domain (e.g., 'arxiv', 'code')
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : updated json file with new tokens
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'
    elif domain == 'swap':
        json_path = 'swap_all_layers.json'

    
    # Initialize an empty dictionary for existing results
    existing_results = {}
    
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f:
            try:
                existing_results = json.load(f)
                # Convert string keys to integers
                existing_results = {int(k): v for k, v in existing_results.items()}
            except json.JSONDecodeError:
                print(f"Warning: {json_path} is empty or corrupted. Starting with an empty dictionary.")
    
    # Move results to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        for layer_idx, layer_tokens in results.items():
            # Convert lists to tensors and move to GPU
            tokens_tensor = torch.tensor([[t[0], t[1], t[2]] for t in layer_tokens]).cuda()
            results[layer_idx] = tokens_tensor.tolist()
    
    # Combine existing and new results for each layer
    for layer_idx, layer_tokens in results.items():
        if layer_idx in existing_results:
            existing_results[layer_idx].extend(layer_tokens)
        else:
            existing_results[layer_idx] = layer_tokens
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(existing_results, f, indent=4, ensure_ascii=False)
        
    return existing_results

In [None]:
def plot_expert_distribution(layer_idx, domain, device='cpu'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : plot of the expert distribution for a particular layer
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'
    elif domain == 'swap':
        json_path = 'swap_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Extract layer results
    layer_results = data[str(layer_idx)]
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Move data to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        layer_results = torch.tensor(layer_results).cuda()
        
    # Count how many tokens went to each expert
    total_assignments = len(layer_results)
    print(f'Total assignments: {total_assignments}')
    
    # Count occurrences of each expert
    if torch.cuda.is_available() and device == 'cuda':
        # Process on GPU
        for _, expert, _ in layer_results.cpu().numpy():
            expert_counts[int(expert)] += 1
    else:
        # Process on CPU
        for _, expert, _ in layer_results:
            expert_counts[expert] += 1
            
    print(f'Expert counts: {expert_counts}')
    print(f'Total experts: {len(expert_counts)}')
    print(f'Expert count for l0', expert_counts[0])
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_assignments * 100 for i in range(64)]
    
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title=f'percentage of total tokens routed to each expert for layer {layer_idx}',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig


In [None]:
def swap_experts(model, expert_idx, target_layer_idx, source_layer_idx=0, source_expert_idx=0):
    """
    Swap experts between two layers in the OLMoE model.
    
    Args:
        model: The OLMoE model
        expert_idx: Index of the expert in target layer to swap with
        target_layer_idx: Index of the layer containing the expert to swap with
        source_layer_idx: Index of the source layer (default 0)
        source_expert_idx: Index of the source expert (default 0)

    """
    # Access the decoder layers
    decoder_layers = model.model.layers
    print(decoder_layers[0].mlp.experts[0].gate_proj.weight.shape)
    
    # Verify indices are valid
    num_layers = len(decoder_layers)
    if target_layer_idx >= num_layers or source_layer_idx >= num_layers:
        raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
    # Get the MoE blocks from both layers
    source_moe = decoder_layers[source_layer_idx].mlp
    target_moe = decoder_layers[target_layer_idx].mlp
    
    # Verify expert indices are valid
    num_experts = len(source_moe.experts)
    if expert_idx >= num_experts or source_expert_idx >= num_experts:
        raise ValueError(f"Expert index out of range. Each layer has {num_experts} experts.")
        
    # Swap the expert weights
    source_expert = source_moe.experts[source_expert_idx]
    target_expert = target_moe.experts[expert_idx]
    
    # Swap gate projection weights
    source_expert.gate_proj.weight, target_expert.gate_proj.weight = \
        target_expert.gate_proj.weight, source_expert.gate_proj.weight
        
    # Swap up projection weights
    source_expert.up_proj.weight, target_expert.up_proj.weight = \
        target_expert.up_proj.weight, source_expert.up_proj.weight
        
    # Swap down projection weights  
    source_expert.down_proj.weight, target_expert.down_proj.weight = \
        target_expert.down_proj.weight, source_expert.down_proj.weight
    
    return {
        'swapped_experts': {
            'source': {
                'layer': source_layer_idx,
                'expert': source_expert_idx
            },
            'target': {
                'layer': target_layer_idx,
                'expert': expert_idx
            }
        }
    }

In [None]:
# Create lists of experts to swap
top_experts_list = [0, 36, 38, 41, 50, 59, 22, 10] # layer 15
bottom_experts_list = [1, 2, 3, 6, 7, 8, 9, 18] # layer 15

top_layer_idx = 0
bottom_layer_idx = 7
# Swap experts at each index
for i in range(len(top_experts_list)):
    swap_experts(model, expert_idx=top_experts_list[i], target_layer_idx=top_layer_idx, source_layer_idx=bottom_layer_idx, source_expert_idx=bottom_experts_list[i])
    print(f"Swapped experts {top_experts_list[i]} and {bottom_experts_list[i]} in layers {top_layer_idx} and {bottom_layer_idx}")

In [None]:
# Read and chunk input file
file_path = 'data/physics_arxiv_200k.txt'
domain = 'physics'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

In [None]:
domain = 'physics'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/physics/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/physics/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()