In [11]:
from transformers import OlmoeForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
import json
import os
import numpy as np
from scipy.special import kl_div
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots

### inference for olmoe

In [12]:
# Load the model and tokenizer
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

# Set the model to eval mode
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

### to get the router logits and probabilities

In [28]:
def analyze_expert_routing(input_text, model, tokenizer, layer_num):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt")

    # Print the tokenized input
    print("Tokenized input:")
    for token_id in inputs.input_ids[0]:
        token = tokenizer.decode([token_id])
        print(f"Token: '{token}', ID: {token_id.item()}")

    print(f"\nInput shape: {inputs.input_ids.shape}")

    # Forward pass with output_router_logits=True
    with torch.no_grad():
        outputs = model(**inputs, output_router_logits=True)

    # Get the router logits from the specified layer
    router_logits = outputs.router_logits[layer_num]

    # Print which layer the logits are from
    print(f"\nRouter logits are from layer {layer_num} of the model")

    # Initialize a dictionary to store the analysis results
    analysis_results = {
        "input_text": input_text,
        "tokens": []
    }

    # Print router logits and probabilities for each token
    print("\nRouter logits and probabilities for each token:")
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        logits = router_logits[token_idx]
        probabilities = F.softmax(logits, dim=-1)
        
        print(f"Token: '{token}' (ID: {token_id.item()})")
        for expert_idx, (logit, prob) in enumerate(zip(logits, probabilities)):
            print(f"  expert {expert_idx}: logit = {logit.item():.4f}, post-softmax = {prob.item():.4f}")
        print()
        
        token_data = {
            "token": token,
            "id": token_id.item(),
            "router_probability": probabilities.tolist()
        }
        analysis_results["tokens"].append(token_data)

    # Print the top-k experts for each token
    k = 8
    print(f"\nTop {k} experts for each token:")
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        probabilities = F.softmax(router_logits[token_idx], dim=-1)
        top_k_probs, top_k_indices = torch.topk(probabilities, k)
        
        print(f"Token: '{token}' (ID: {token_id.item()})")
        for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices)):
            print(f"  {i+1}. expert {idx.item()}: probability = {prob.item():.4f}")
        print()

    # Save the analysis results as a JSON file with a unique name
    base_filename = "expert_routing_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(analysis_results, f, indent=2)

    print(f"Analysis results saved to {filename}")

    return router_logits, analysis_results

layer_num = 0 # layer to analyze (0-15)

In [30]:
input_text = "How will you pitch your idea to the investors?"
last_layer_router_logits, analysis_results = analyze_expert_routing(input_text, model, tokenizer,layer_num)

Tokenized input:
Token: 'How', ID: 2347
Token: ' will', ID: 588
Token: ' you', ID: 368
Token: ' pitch', ID: 11288
Token: ' your', ID: 634
Token: ' idea', ID: 2934
Token: ' to', ID: 281
Token: ' the', ID: 253
Token: ' investors', ID: 12946
Token: '?', ID: 32

Input shape: torch.Size([1, 10])

Router logits are from layer 0 of the model

Router logits and probabilities for each token:
Token: 'How' (ID: 2347)
  expert 0: logit = -3.0995, post-softmax = 0.0008
  expert 1: logit = -0.1792, post-softmax = 0.0145
  expert 2: logit = -0.1681, post-softmax = 0.0147
  expert 3: logit = -0.3178, post-softmax = 0.0126
  expert 4: logit = 0.1484, post-softmax = 0.0202
  expert 5: logit = -0.4750, post-softmax = 0.0108
  expert 6: logit = 0.0466, post-softmax = 0.0182
  expert 7: logit = 0.7465, post-softmax = 0.0367
  expert 8: logit = -1.4290, post-softmax = 0.0042
  expert 9: logit = -0.9157, post-softmax = 0.0070
  expert 10: logit = -0.5113, post-softmax = 0.0104
  expert 11: logit = -0.1142, p

In [31]:
input_text_2 = "What's the musical pitch of that note?"
last_layer_router_logits_2, analysis_results_2 = analyze_expert_routing(input_text_2, model, tokenizer, layer_num)


Tokenized input:
Token: 'What', ID: 1276
Token: ''s', ID: 434
Token: ' the', ID: 253
Token: ' musical', ID: 12256
Token: ' pitch', ID: 11288
Token: ' of', ID: 273
Token: ' that', ID: 326
Token: ' note', ID: 3877
Token: '?', ID: 32

Input shape: torch.Size([1, 9])

Router logits are from layer 0 of the model

Router logits and probabilities for each token:
Token: 'What' (ID: 1276)
  expert 0: logit = -3.1594, post-softmax = 0.0007
  expert 1: logit = -0.1499, post-softmax = 0.0145
  expert 2: logit = -0.1624, post-softmax = 0.0143
  expert 3: logit = -0.5962, post-softmax = 0.0093
  expert 4: logit = 0.0160, post-softmax = 0.0171
  expert 5: logit = -0.7897, post-softmax = 0.0076
  expert 6: logit = -0.2797, post-softmax = 0.0127
  expert 7: logit = 0.4158, post-softmax = 0.0255
  expert 8: logit = -0.4473, post-softmax = 0.0108
  expert 9: logit = -1.2607, post-softmax = 0.0048
  expert 10: logit = -0.7840, post-softmax = 0.0077
  expert 11: logit = -0.9305, post-softmax = 0.0066
  exp

### KL divergence

In [32]:


def calculate_token_kl_divergence(json_file1, json_file2, target_token):
    # Load JSON data from files
    with open(json_file1, 'r') as f1, open(json_file2, 'r') as f2:
        data1 = json.load(f1)
        data2 = json.load(f2)
    
    # Find the target token in both inputs
    token1 = next((t for t in data1['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    token2 = next((t for t in data2['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    
    if not token1 or not token2:
        raise ValueError(f"Token '{target_token}' not found in one or both inputs")
    
    # Get router probabilities for the target token
    probs1 = np.array(token1['router_probability'])
    probs2 = np.array(token2['router_probability'])
    
    # Ensure probabilities sum to 1
    probs1 = probs1 / np.sum(probs1)
    probs2 = probs2 / np.sum(probs2)
    
    # Calculate KL divergence for each expert
    kl_divergences = kl_div(probs1, probs2)
    
    # Create a dictionary of expert-wise KL divergences
    expert_kl = {f"Expert_{i}": kl for i, kl in enumerate(kl_divergences)}
    
    # Calculate the total KL divergence
    total_kl = np.sum(kl_divergences)
    
    result = {
        "token": target_token,
        "expert_kl_divergences": expert_kl,
        "total_kl_divergence": total_kl
    }
    
    # Save the result as a JSON file
    base_filename = "kl_divergence_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(result, f, indent=2)

    print(f"KL divergence analysis results saved to {filename}")
    
    return result

In [19]:
json_file1 = "expert_routing_analysis_1.json"
json_file2 = "expert_routing_analysis_2.json"
target_token =  11288 # use token id

result = calculate_token_kl_divergence(json_file1, json_file2, target_token)
print(json.dumps(result, indent=2))

KL divergence analysis results saved to kl_divergence_analysis_1.json
{
  "token": 11288,
  "expert_kl_divergences": {
    "Expert_0": 0.009453152321398161,
    "Expert_1": 0.0010285437104599988,
    "Expert_2": 0.019584900520401495,
    "Expert_3": 4.717867089935879e-06,
    "Expert_4": 0.0011487151426666509,
    "Expert_5": 0.001311021631701912,
    "Expert_6": 0.00022197336573581824,
    "Expert_7": 0.039499575999606906,
    "Expert_8": 0.0001660485313056282,
    "Expert_9": 8.936434448928235e-05,
    "Expert_10": 0.0015351983129790721,
    "Expert_11": 5.294601192123516e-05,
    "Expert_12": 5.617166773176294e-06,
    "Expert_13": 0.0025187000490325545,
    "Expert_14": 0.0005910418768493853,
    "Expert_15": 0.017322743208425296,
    "Expert_16": 0.0025415154277968455,
    "Expert_17": 0.010948062989407141,
    "Expert_18": 0.024757676308353425,
    "Expert_19": 0.007494485665606062,
    "Expert_20": 0.00031472358405389665,
    "Expert_21": 0.011316649822444161,
    "Expert_22": 0

### for plotting

In [34]:
def create_heatmap(probabilities, tokens, title, layer_num):
    return go.Heatmap(
        z=probabilities,
        x=[f'Expert {i}' for i in range(probabilities.shape[1])],
        y=tokens,
        colorscale='Viridis',
        colorbar=dict(title='Probability'),
        hovertemplate='Token: %{y}<br>Expert: %{x}<br>Probability: %{z:.4f}<extra></extra>'
    )

# Convert router logits to probabilities for both inputs
probabilities_1 = F.softmax(last_layer_router_logits, dim=-1).cpu().numpy()
probabilities_2 = F.softmax(last_layer_router_logits_2, dim=-1).cpu().numpy()

# Create lists of tokens for both inputs
tokens_1 = [tokenizer.decode([token_id]) for token_id in tokenizer.encode(input_text, return_tensors="pt")[0]]
tokens_2 = [tokenizer.decode([token_id]) for token_id in tokenizer.encode(input_text_2, return_tensors="pt")[0]]

# Create DataFrames for the heatmaps
df_1 = pd.DataFrame(probabilities_1, index=tokens_1)
df_2 = pd.DataFrame(probabilities_2, index=tokens_2)

# Create subplots
fig = make_subplots(rows=1, cols=2, subplot_titles=(f'Input 1 ({input_text}) - Layer {layer_num}', f'Input 2 ({input_text_2}) - Layer {layer_num}'))

# Add heatmaps to subplots
fig.add_trace(create_heatmap(probabilities_1, tokens_1, f'Input 1 - Layer {layer_num}', layer_num), row=1, col=1)
fig.add_trace(create_heatmap(probabilities_2, tokens_2, f'Input 2 - Layer {layer_num}', layer_num), row=1, col=2)

# Update layout
fig.update_layout(
    title=f'Expert vs Token Heatmaps for Two Inputs - Layer {layer_num}',
    xaxis_title='Experts',
    yaxis_title='Tokens',
    width=2000,
    height=800,
)

# Show the plot
fig.show()

# Save the plot as an HTML file (optional)
fig.write_html(f"expert_token_heatmaps_comparison_layer_{layer_num}.html")