In [67]:
from transformers import OlmoeForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
import json
import os
import numpy as np
from scipy.special import kl_div
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots

### inference for olmoe

In [68]:
# Load the model and tokenizer
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

# Set the model to eval mode
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

### to get the router logits and probabilities

In [69]:
def analyze_expert_routing(input_text, model, tokenizer):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt")

    # Print the tokenized input
    print("Tokenized input:")
    for token_id in inputs.input_ids[0]:
        token = tokenizer.decode([token_id])
        print(f"Token: '{token}', ID: {token_id.item()}")

    print(f"\nInput shape: {inputs.input_ids.shape}")

    # Forward pass with output_router_logits=True
    with torch.no_grad():
        outputs = model(**inputs, output_router_logits=True)

    # Get the router logits from the last layer
    last_layer_router_logits = outputs.router_logits[-1]

    # Initialize a dictionary to store the analysis results
    analysis_results = {
        "input_text": input_text,
        "tokens": []
    }

    # Print router logits and probabilities for each token
    print("\nRouter logits and probabilities for each token:")
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        logits = last_layer_router_logits[token_idx]
        probabilities = F.softmax(logits, dim=-1)
        
        print(f"Token: '{token}' (ID: {token_id.item()})")
        for expert_idx, (logit, prob) in enumerate(zip(logits, probabilities)):
            print(f"  expert {expert_idx}: logit = {logit.item():.4f}, post-softmax = {prob.item():.4f}")
        print()
        
        token_data = {
            "token": token,
            "id": token_id.item(),
            "router_probability": probabilities.tolist()
        }
        analysis_results["tokens"].append(token_data)

    # Print the top-k experts for each token
    k = 8
    print(f"\nTop {k} experts for each token:")
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        probabilities = F.softmax(last_layer_router_logits[token_idx], dim=-1)
        top_k_probs, top_k_indices = torch.topk(probabilities, k)
        
        print(f"Token: '{token}' (ID: {token_id.item()})")
        for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices)):
            print(f"  {i+1}. expert {idx.item()}: probability = {prob.item():.4f}")
        print()

    # Save the analysis results as a JSON file with a unique name
    base_filename = "expert_routing_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(analysis_results, f, indent=2)

    print(f"Analysis results saved to {filename}")

    return last_layer_router_logits, analysis_results

In [70]:
input_text = "What's the musical pitch of that note?"
last_layer_router_logits, analysis_results = analyze_expert_routing(input_text, model, tokenizer)

Tokenized input:
Token: 'What', ID: 1276
Token: ''s', ID: 434
Token: ' the', ID: 253
Token: ' musical', ID: 12256
Token: ' pitch', ID: 11288
Token: ' of', ID: 273
Token: ' that', ID: 326
Token: ' note', ID: 3877
Token: '?', ID: 32

Input shape: torch.Size([1, 9])

Router logits and probabilities for each token:
Token: 'What' (ID: 1276)
  expert 0: logit = -0.8721, post-softmax = 0.0151
  expert 1: logit = -1.5355, post-softmax = 0.0078
  expert 2: logit = -1.5983, post-softmax = 0.0073
  expert 3: logit = -2.1905, post-softmax = 0.0040
  expert 4: logit = -0.3444, post-softmax = 0.0256
  expert 5: logit = -0.5635, post-softmax = 0.0206
  expert 6: logit = -2.0390, post-softmax = 0.0047
  expert 7: logit = -2.2326, post-softmax = 0.0039
  expert 8: logit = -1.6298, post-softmax = 0.0071
  expert 9: logit = -2.6304, post-softmax = 0.0026
  expert 10: logit = -1.9941, post-softmax = 0.0049
  expert 11: logit = -0.3353, post-softmax = 0.0258
  expert 12: logit = 2.0358, post-softmax = 0.27

In [71]:
input_text_2 = "How will you pitch your idea to the investors?"
last_layer_router_logits_2, analysis_results_2 = analyze_expert_routing(input_text_2, model, tokenizer)


Tokenized input:
Token: 'How', ID: 2347
Token: ' will', ID: 588
Token: ' you', ID: 368
Token: ' pitch', ID: 11288
Token: ' your', ID: 634
Token: ' idea', ID: 2934
Token: ' to', ID: 281
Token: ' the', ID: 253
Token: ' investors', ID: 12946
Token: '?', ID: 32

Input shape: torch.Size([1, 10])

Router logits and probabilities for each token:
Token: 'How' (ID: 2347)
  expert 0: logit = -0.7544, post-softmax = 0.0205
  expert 1: logit = -1.9735, post-softmax = 0.0060
  expert 2: logit = -2.1020, post-softmax = 0.0053
  expert 3: logit = -1.8291, post-softmax = 0.0070
  expert 4: logit = -0.6528, post-softmax = 0.0227
  expert 5: logit = -0.7290, post-softmax = 0.0210
  expert 6: logit = -1.4088, post-softmax = 0.0106
  expert 7: logit = -2.9429, post-softmax = 0.0023
  expert 8: logit = -1.9219, post-softmax = 0.0064
  expert 9: logit = -1.5357, post-softmax = 0.0094
  expert 10: logit = -2.0978, post-softmax = 0.0053
  expert 11: logit = -0.8646, post-softmax = 0.0183
  expert 12: logit = 

### KL divergence

In [72]:


def calculate_token_kl_divergence(json_file1, json_file2, target_token):
    # Load JSON data from files
    with open(json_file1, 'r') as f1, open(json_file2, 'r') as f2:
        data1 = json.load(f1)
        data2 = json.load(f2)
    
    # Find the target token in both inputs
    token1 = next((t for t in data1['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    token2 = next((t for t in data2['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    
    if not token1 or not token2:
        raise ValueError(f"Token '{target_token}' not found in one or both inputs")
    
    # Get router probabilities for the target token
    probs1 = np.array(token1['router_probability'])
    probs2 = np.array(token2['router_probability'])
    
    # Ensure probabilities sum to 1
    probs1 = probs1 / np.sum(probs1)
    probs2 = probs2 / np.sum(probs2)
    
    # Calculate KL divergence for each expert
    kl_divergences = kl_div(probs1, probs2)
    
    # Create a dictionary of expert-wise KL divergences
    expert_kl = {f"Expert_{i}": kl for i, kl in enumerate(kl_divergences)}
    
    # Calculate the total KL divergence
    total_kl = np.sum(kl_divergences)
    
    result = {
        "token": target_token,
        "expert_kl_divergences": expert_kl,
        "total_kl_divergence": total_kl
    }
    
    # Save the result as a JSON file
    base_filename = "kl_divergence_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(result, f, indent=2)

    print(f"KL divergence analysis results saved to {filename}")
    
    return result

In [73]:
json_file1 = "expert_routing_analysis_1.json"
json_file2 = "expert_routing_analysis_2.json"
target_token = 11288  # use token id

result = calculate_token_kl_divergence(json_file1, json_file2, target_token)
print(json.dumps(result, indent=2))

KL divergence analysis results saved to kl_divergence_analysis_1.json
{
  "token": 11288,
  "expert_kl_divergences": {
    "Expert_0": 0.013822967716602026,
    "Expert_1": 0.0008376724012587114,
    "Expert_2": 0.02823257784546524,
    "Expert_3": 4.672694812931852e-06,
    "Expert_4": 0.0013771582953915314,
    "Expert_5": 0.0015263539795864131,
    "Expert_6": 0.0002516992507180229,
    "Expert_7": 0.020351488535727633,
    "Expert_8": 0.0001573097508894613,
    "Expert_9": 9.316885204474663e-05,
    "Expert_10": 0.0021342830624215495,
    "Expert_11": 5.6941361584724424e-05,
    "Expert_12": 5.563769443583597e-06,
    "Expert_13": 0.003731132175420142,
    "Expert_14": 0.000656858469318105,
    "Expert_15": 0.03211484009776097,
    "Expert_16": 0.001914476454697011,
    "Expert_17": 0.015629242566049195,
    "Expert_18": 0.014715031691438046,
    "Expert_19": 0.00485888986834053,
    "Expert_20": 0.0003451279819312054,
    "Expert_21": 0.007140513588838167,
    "Expert_22": 0.04711

### for plotting

In [74]:
def create_heatmap(probabilities, tokens, title):
    return go.Heatmap(
        z=probabilities,
        x=[f'Expert {i}' for i in range(probabilities.shape[1])],
        y=tokens,
        colorscale='Viridis',
        colorbar=dict(title='Probability'),
        hovertemplate='Token: %{y}<br>Expert: %{x}<br>Probability: %{z:.4f}<extra></extra>'
    )

# Convert router logits to probabilities for both inputs
probabilities_1 = F.softmax(last_layer_router_logits, dim=-1).cpu().numpy()
probabilities_2 = F.softmax(last_layer_router_logits_2, dim=-1).cpu().numpy()

# Create lists of tokens for both inputs
tokens_1 = [tokenizer.decode([token_id]) for token_id in tokenizer.encode(input_text, return_tensors="pt")[0]]
tokens_2 = [tokenizer.decode([token_id]) for token_id in tokenizer.encode(input_text_2, return_tensors="pt")[0]]

# Create DataFrames for the heatmaps
df_1 = pd.DataFrame(probabilities_1, index=tokens_1)
df_2 = pd.DataFrame(probabilities_2, index=tokens_2)

# Create subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=(f'Input 1 ({input_text})', f'Input 2 ({input_text_2})'))

# Add heatmaps to subplots
fig.add_trace(create_heatmap(probabilities_1, tokens_1, 'Input 1'), row=1, col=1)
fig.add_trace(create_heatmap(probabilities_2, tokens_2, 'Input 2'), row=1, col=2)

# Update layout
fig.update_layout(
    title='Expert vs Token Heatmaps for Two Inputs',
    xaxis_title='Experts',
    yaxis_title='Tokens',
    width=2000,
    height=800,
)

# Show the plot
fig.show()

# Save the plot as an HTML file (optional)
fig.write_html("expert_token_heatmaps_comparison.html")
