### inference for olmoe

In [4]:
from transformers import OlmoeForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F

# Load the model and tokenizer
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

# Set the model to evaluation mode
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

### to get the router logits and probabilities

In [5]:
# Prepare input text
input_text = "this is a sample input"

# Tokenize the input
inputs = tokenizer(input_text, return_tensors="pt")

# Print the tokenized input
print("Tokenized input:")
for token_id in inputs.input_ids[0]:
    token = tokenizer.decode([token_id])
    print(f"Token: '{token}', ID: {token_id.item()}")

print("\nInput shape:", inputs.input_ids.shape)


# Forward pass with output_router_logits=True
with torch.no_grad():
    outputs = model(**inputs, output_router_logits=True)

# Get the router logits from the last layer
last_layer_router_logits = outputs.router_logits[-1]

# Print router logits and probabilities for each token
for token_idx, token_id in enumerate(inputs.input_ids[0]):
    token = tokenizer.decode([token_id])
    logits = last_layer_router_logits[token_idx]
    probabilities = F.softmax(logits, dim=-1)
    
    print(f"Token: '{token}' (ID: {token_id.item()})")
    for expert_idx, (logit, prob) in enumerate(zip(logits, probabilities)):
        print(f"  expert {expert_idx}: logit = {logit.item():.4f}, post-softmax = {prob.item():.4f}")
    print()

# Print the top-k experts for each token
k = 8  # Number of top experts to show
print(f"\nTop {k} experts for each token:")
for token_idx, token_id in enumerate(inputs.input_ids[0]):
    token = tokenizer.decode([token_id])
    logits = last_layer_router_logits[token_idx]
    probabilities = F.softmax(logits, dim=-1)
    top_k_probs, top_k_indices = torch.topk(probabilities, k)
    
    print(f"Token: '{token}' (ID: {token_id.item()})")
    for i, (prob, idx) in enumerate(zip(top_k_probs, top_k_indices)):
        print(f"  {i+1}. expert {idx.item()}: expert = {prob.item():.4f}")
    print()

Tokenized input:
Token: 'this', ID: 2520
Token: ' is', ID: 310
Token: ' a', ID: 247
Token: ' sample', ID: 3410
Token: ' input', ID: 3280

Input shape: torch.Size([1, 5])
Token: 'this' (ID: 2520)
  expert 0: logit = -0.2219, post-softmax = 0.0253
  expert 1: logit = -0.3173, post-softmax = 0.0230
  expert 2: logit = -1.1706, post-softmax = 0.0098
  expert 3: logit = -2.2602, post-softmax = 0.0033
  expert 4: logit = -1.0775, post-softmax = 0.0107
  expert 5: logit = -1.1462, post-softmax = 0.0100
  expert 6: logit = -1.5410, post-softmax = 0.0068
  expert 7: logit = -2.2310, post-softmax = 0.0034
  expert 8: logit = -0.3846, post-softmax = 0.0215
  expert 9: logit = -1.9906, post-softmax = 0.0043
  expert 10: logit = -1.7290, post-softmax = 0.0056
  expert 11: logit = -1.3310, post-softmax = 0.0083
  expert 12: logit = -1.0972, post-softmax = 0.0105
  expert 13: logit = -1.5821, post-softmax = 0.0065
  expert 14: logit = -1.4272, post-softmax = 0.0076
  expert 15: logit = -1.4622, post-

### for plotting

In [6]:
import plotly.graph_objects as go
import pandas as pd

# Convert router logits to probabilities
probabilities = F.softmax(last_layer_router_logits, dim=-1).cpu().numpy()

# Create a list of tokens
tokens = [tokenizer.decode([token_id]) for token_id in inputs.input_ids[0]]

# Create a DataFrame for the heatmap
df = pd.DataFrame(probabilities, index=tokens)

# Create the heatmap using Plotly
fig = go.Figure(data=go.Heatmap(
    z=probabilities,
    x=[f'Expert {i}' for i in range(probabilities.shape[1])],
    y=tokens,
    colorscale='Viridis',
    colorbar=dict(title='Probability'),
    hovertemplate='Token: %{y}<br>Expert: %{x}<br>Probability: %{z:.4f}<extra></extra>'
))

# Update layout
fig.update_layout(
    title='Expert vs Token Heatmap',
    xaxis_title='Experts',
    yaxis_title='Tokens',
    width=1000,
    height=800,
)

# Show the plot
fig.show()

# Save the plot as an HTML file (optional)
fig.write_html("expert_token_heatmap.html")
