In [1]:
from transformers import OlmoeForCausalLM, AutoTokenizer
import torch
import torch.nn.functional as F
import json
import os
import numpy as np
from scipy.special import kl_div
import plotly.graph_objects as go
import pandas as pd
from plotly.subplots import make_subplots

### inference for olmoe

In [2]:
# Load the model and tokenizer
model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924")
tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924")

# Set the model to eval mode
model.eval()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

OlmoeForCausalLM(
  (model): OlmoeModel(
    (embed_tokens): Embedding(50304, 2048, padding_idx=1)
    (layers): ModuleList(
      (0-15): 16 x OlmoeDecoderLayer(
        (self_attn): OlmoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (q_norm): OlmoeRMSNorm((2048,), eps=1e-05)
          (k_norm): OlmoeRMSNorm((2048,), eps=1e-05)
        )
        (mlp): OlmoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=64, bias=False)
          (experts): ModuleList(
            (0-63): 64 x OlmoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1024, bias=False)
              (down_proj): Linear(in

### to get the router logits and probabilities

In [44]:
def analyze_expert_routing(input_text, model, tokenizer, layer_num):
    # Tokenize the input
    inputs = tokenizer(input_text, return_tensors="pt")

    # Print the tokenized input
    print("Tokenized input:")
    for token_id in inputs.input_ids[0]:
        token = tokenizer.decode([token_id])
        print(f"Token: '{token}', ID: {token_id.item()}")

    print(f"\nInput shape: {inputs.input_ids.shape}")

    # Forward pass with output_router_logits=True and output_hidden_states=True
    with torch.no_grad():
        outputs = model(**inputs, output_router_logits=True, output_hidden_states=True)

    # Get the router logits and hidden states from the specified layer
    router_logits = outputs.router_logits[layer_num]
    hidden_states = outputs.hidden_states[layer_num] if outputs.hidden_states is not None else None

    # Print the size of the router logits tensor
    print(f"\nSize of router logits tensor: {router_logits.size()}")
    if hidden_states is not None:
        print(f"Size of hidden states tensor: {hidden_states.size()}")

    # Print which layer the logits are from
    print(f"Router logits are from layer {layer_num} of the model")

    # Initialize a dictionary to store the analysis results
    analysis_results = {
        "input_text": input_text,
        "tokens": []
    }

    # Print router logits and probabilities for each token
    print("\nRouter logits and probabilities for each token:")
    for token_idx, token_id in enumerate(inputs.input_ids[0]):
        token = tokenizer.decode([token_id])
        logits = router_logits[token_idx]
        probabilities = F.softmax(logits, dim=-1)
        
        print(f"Token: '{token}' (ID: {token_id.item()})")
        print(f"  Pre-softmax logits size: {logits.size()}")
        print(f"  Post-softmax probabilities size: {probabilities.size()}")
        
        token_data = {
            "token": token,
            "id": token_id.item(),
            "router_probability": probabilities.tolist()
        }
        analysis_results["tokens"].append(token_data)

    # Save the analysis results as a JSON file with a unique name
    base_filename = "expert_routing_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(analysis_results, f, indent=2)

    print(f"Analysis results saved to {filename}")

    return router_logits, analysis_results

layer_num = 0 # layer to analyze (0-15)

In [45]:
input_text = "The tennis match was thrilling to watch."
router_logits, analysis_results = analyze_expert_routing(input_text, model, tokenizer,layer_num)

Tokenized input:
Token: 'The', ID: 510
Token: ' tennis', ID: 23354
Token: ' match', ID: 3761
Token: ' was', ID: 369
Token: ' thrilling', ID: 47330
Token: ' to', ID: 281
Token: ' watch', ID: 3698
Token: '.', ID: 15

Input shape: torch.Size([1, 8])

Size of router logits tensor: torch.Size([8, 64])
Size of hidden states tensor: torch.Size([1, 8, 2048])
Router logits are from layer 0 of the model

Router logits and probabilities for each token:
Token: 'The' (ID: 510)
  Pre-softmax logits size: torch.Size([64])
  Post-softmax probabilities size: torch.Size([64])
Token: ' tennis' (ID: 23354)
  Pre-softmax logits size: torch.Size([64])
  Post-softmax probabilities size: torch.Size([64])
Token: ' match' (ID: 3761)
  Pre-softmax logits size: torch.Size([64])
  Post-softmax probabilities size: torch.Size([64])
Token: ' was' (ID: 369)
  Pre-softmax logits size: torch.Size([64])
  Post-softmax probabilities size: torch.Size([64])
Token: ' thrilling' (ID: 47330)
  Pre-softmax logits size: torch.Si

In [4]:
input_text_2 = "The online app found her a good match."
router_logits_2, analysis_results_2 = analyze_expert_routing(input_text_2, model, tokenizer, layer_num)


NameError: name 'analyze_expert_routing' is not defined

### KL divergence

In [6]:

def calculate_token_kl_divergence(json_file1, json_file2, target_token):
    # Load JSON data from files
    with open(json_file1, 'r') as f1, open(json_file2, 'r') as f2:
        data1 = json.load(f1)
        data2 = json.load(f2)
    
    # Find the target token in both inputs
    token1 = next((t for t in data1['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    token2 = next((t for t in data2['tokens'] if t['token'] == target_token or t['id'] == target_token), None)
    
    if not token1 or not token2:
        raise ValueError(f"Token '{target_token}' not found in one or both inputs")
    
    # Get router probabilities for the target token
    probs1 = np.array(token1['router_probability'])
    probs2 = np.array(token2['router_probability'])
    
    # Ensure probabilities sum to 1
    probs1 = probs1 / np.sum(probs1)
    probs2 = probs2 / np.sum(probs2)
    
    # Calculate KL divergence for each expert
    kl_divergences = kl_div(probs1, probs2)
    
    # Create a dictionary of expert-wise KL divergences
    expert_kl = {f"Expert_{i}": kl for i, kl in enumerate(kl_divergences)}
    
    # Calculate the total KL divergence
    total_kl = np.sum(kl_divergences)
    
    result = {
        "token": target_token,
        "expert_kl_divergences": expert_kl,
        "total_kl_divergence": total_kl
    }
    
    # Save the result as a JSON file
    base_filename = "kl_divergence_analysis"
    counter = 1
    filename = f"{base_filename}_{counter}.json"
    while os.path.exists(filename):
        counter += 1
        filename = f"{base_filename}_{target_token}_{counter}.json"

    with open(filename, "w") as f:
        json.dump(result, f, indent=2)

    print(f"KL divergence analysis results saved to {filename}")
    
    return result

In [7]:
json_file1 = "expert_routing_analysis_1.json"
json_file2 = "expert_routing_analysis_2.json"
target_token =  3761 # use token id

result = calculate_token_kl_divergence(json_file1, json_file2, target_token)
print(json.dumps(result, indent=2))

KL divergence analysis results saved to kl_divergence_analysis_1.json
{
  "token": 3761,
  "expert_kl_divergences": {
    "Expert_0": 0.0007812551380871591,
    "Expert_1": 0.00010657651477648418,
    "Expert_2": 0.0018868115937772031,
    "Expert_3": 0.003657222177318448,
    "Expert_4": 0.02627870109798979,
    "Expert_5": 0.00020753548841244591,
    "Expert_6": 0.006246837614801284,
    "Expert_7": 0.00032112401034644494,
    "Expert_8": 0.0062337563288345,
    "Expert_9": 0.009274871867047087,
    "Expert_10": 0.05506577895205873,
    "Expert_11": 0.0037240712460497646,
    "Expert_12": 0.0006189850341964769,
    "Expert_13": 0.0002855744818948073,
    "Expert_14": 0.004536782226628571,
    "Expert_15": 0.00014146661485302942,
    "Expert_16": 0.0026401448071286925,
    "Expert_17": 0.006159151167502408,
    "Expert_18": 1.6855209224618181e-06,
    "Expert_19": 0.001297812824272446,
    "Expert_20": 0.003679239357394368,
    "Expert_21": 0.003841600094424787,
    "Expert_22": 0.017

### for plotting

In [30]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import json
import os
import shutil

def generate_heatmap(input_text, model, tokenizer, output_file, num_layers=16):

    fig = make_subplots(
        rows=num_layers, cols=1, 
        subplot_titles=[f"Layer {i}" for i in range(num_layers)],
        vertical_spacing=0.02
    )

    for layer in range(num_layers):
        # Call analyze_expert_routing without specifying output file
        router_logits, analysis_results = analyze_expert_routing(input_text, model, tokenizer, layer)
        
        # Find the JSON file created by analyze_expert_routing
        default_json_file = f"json/expert_routing_analysis{layer}.json"
        
        # Extract data from analysis_results
        tokens = [token['token'] for token in analysis_results['tokens']]
        probabilities = [token['router_probability'] for token in analysis_results['tokens']]
        
        heatmap = go.Heatmap(
            z=probabilities,
            x=[f"E{i}" for i in range(64)],
            y=tokens,
            colorscale='orrd', # old - Viridis
            zmax=1,
            zmin=0,
            colorbar=dict(
                # Edit the tickvals and ticktext here to change the non-linear scale
                tickvals=[0, 0.1, 0.3, 0.6, 1],  # Non-linear scale
                ticktext=['0', '0.1', '0.3', '0.6', '1']
            )
        )
        
        fig.add_trace(heatmap, row=layer + 1, col=1)
        
        fig.update_xaxes(
            title_text="Experts" if layer == num_layers - 1 else None,
            row=layer + 1, col=1,
            tickangle=45,
            tickmode='array',
            tickvals=[f"E{i}" for i in range(0, 64, 8)],
            ticktext=[f"E{i}" for i in range(0, 64, 8)]
        )
        fig.update_yaxes(
            title_text="Tokens" if layer == 0 else None,
            row=layer + 1, col=1,
            tickmode='array',
            tickvals=tokens,
            ticktext=tokens,
            side='left'
        )

    fig.update_layout(
        title_text=f"Expert Routing Heatmaps for '{input_text}'",
        height=300 * num_layers,
        width=1000,
        font=dict(size=10),
        coloraxis_colorbar=dict(
            title='Probability',
            thickness=10,
            len=0.5,
            # Edit the tickvals and ticktext here to change the non-linear scale
            tickvals=[0, 0.1, 0.3, 0.6, 1],  # Non-linear scale
            ticktext=['0', '0.1', '0.3', '0.6', '1'],
            orientation='h'
        )
    )

    for i in fig['layout']['annotations']:
        i['font'] = dict(size=12)

    # fig.show()

    fig.write_html(output_file)
    print(f"Heatmaps saved to {output_file}")

In [31]:
import json

with open('sentences.json', 'r') as f:
    sentences = json.load(f)

for dictionary_name, sentence_list in sentences.items():
    for i, sentence in enumerate(sentence_list):
        input_text = sentence
        output_file = f"plots/{dictionary_name}_{i}.html"
        generate_heatmap(input_text, model, tokenizer, output_file)

Tokenized input:
Token: 'The', ID: 510
Token: ' stars', ID: 6114
Token: ' tw', ID: 2500
Token: 'inkled', ID: 34269
Token: ' brightly', ID: 43925
Token: ' in', ID: 275
Token: ' the', ID: 253
Token: ' night', ID: 2360
Token: ' sky', ID: 8467
Token: '.', ID: 15

Input shape: torch.Size([1, 10])

Router logits are from layer 0 of the model

Router logits and probabilities for each token:
Token: 'The' (ID: 510)
  expert 0: logit = 0.1502, post-softmax = 0.0232
  expert 1: logit = -0.7955, post-softmax = 0.0090
  expert 2: logit = -0.2742, post-softmax = 0.0152
  expert 3: logit = -0.5139, post-softmax = 0.0119
  expert 4: logit = -0.3479, post-softmax = 0.0141
  expert 5: logit = 2.1354, post-softmax = 0.1689
  expert 6: logit = 0.6977, post-softmax = 0.0401
  expert 7: logit = -1.5684, post-softmax = 0.0042
  expert 8: logit = -1.7015, post-softmax = 0.0036
  expert 9: logit = -0.0581, post-softmax = 0.0188
  expert 10: logit = -0.7583, post-softmax = 0.0094
  expert 11: logit = -3.2744, p

### plot of KL divergence

In [75]:
import plotly.graph_objects as go
import json
import plotly.express as px
import numpy as np

# Load the KL divergence results
with open('kl_divergence_analysis_1.json', 'r') as f:
    kl_data = json.load(f)

# Extract expert numbers and KL divergence values
experts = [int(expert.split('_')[1]) for expert in kl_data['expert_kl_divergences'].keys()]
kl_values = list(kl_data['expert_kl_divergences'].values())

# Create a color scale based on KL divergence values
colors = px.colors.sequential.Viridis

# Normalize KL values for color mapping
norm_kl_values = (np.array(kl_values) - min(kl_values)) / (max(kl_values) - min(kl_values))

# Create the bar chart
fig = go.Figure(data=[go.Bar(
    x=experts,
    y=kl_values,
    marker=dict(
        color=norm_kl_values,
        colorscale=colors,
        colorbar=dict(title="KL Divergence")
    )
)])

# Update layout
fig.update_layout(
    title={
        'text': 'KL Divergence by Expert',
        'y':0.95,
        'x':0.5,
        'xanchor': 'center',
        'yanchor': 'top',
        'font': dict(size=24)
    },
    xaxis_title='Expert Number',
    yaxis_title='KL Divergence',
    bargap=0.2,
    bargroupgap=0.1,
    plot_bgcolor='white',  # Set plot background to white
    paper_bgcolor='white',  # Set paper background to white
    font=dict(family="Arial", size=14),
    xaxis=dict(showgrid=True, gridcolor='lightgrey'),  # Add light grid to x-axis
    yaxis=dict(showgrid=True, gridcolor='lightgrey'),  # Add light grid to y-axis
)

# Add a horizontal line for the mean KL divergence
mean_kl = sum(kl_values) / len(kl_values)
fig.add_shape(
    type="line",
    x0=-1,
    y0=mean_kl,
    x1=len(experts),
    y1=mean_kl,
    line=dict(color="red", width=2, dash="dash"),
)

# Add annotation for mean KL divergence
fig.add_annotation(
    x=len(experts) * 1.02,
    y=mean_kl,
    text=f"Mean: {mean_kl:.4f}",
    showarrow=False,
    font=dict(size=12, color="red")
)

# Add annotation for total KL divergence
fig.add_annotation(
    x=len(experts) * 0.5,
    y=max(kl_values) * 1.1,
    text=f"Total KL Divergence: {kl_data['total_kl_divergence']:.4f}",
    showarrow=False,
    font=dict(size=16, color="darkblue")
)

# Highlight top 5 experts with highest KL divergence
top_5_indices = sorted(range(len(kl_values)), key=lambda i: kl_values[i], reverse=True)[:5]
for idx in top_5_indices:
    fig.add_annotation(
        x=experts[idx],
        y=kl_values[idx],
        text=f"Expert {experts[idx]}",
        showarrow=True,
        arrowhead=2,
        arrowsize=1,
        arrowwidth=2,
        arrowcolor="#636363",
        font=dict(size=10, color="black"),
        bgcolor="white",
        bordercolor="black",
        borderwidth=1,
    )

# Update size and add modebar
fig.update_layout(
    width=1000,
    height=600,
    modebar_add=["v1hovermode", "toggleSpikelines"]
)

# Show the plot
fig.show()

print that that token was routed to which expert for that input (just 1 E)

In [32]:
# Assuming `model` and `tokenizer` are already defined and loaded as per the context

# Define the input text
input_text = "I deposited my paycheck at the bank."

# Tokenize the input text
tokens = tokenizer.tokenize(input_text)

# Print the tokens
print("Tokens:", tokens)


Tokens: ['I', 'Ġdeposited', 'Ġmy', 'Ġpay', 'check', 'Ġat', 'Ġthe', 'Ġbank', '.']
