In [2]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import os
from typing import Dict, List, Tuple
from collections import defaultdict

In [3]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("Qwen/Qwen1.5-MoE-A2.7B")

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

In [4]:
def get_top_k_tokens(hidden_states: torch.Tensor, lm_head: torch.nn.Linear, tokenizer, k: int = 5) -> List[Tuple[str, float]]:
    """ get topk tokens from hidden states using lm head """
    with torch.no_grad():
        # Ensure hidden_states has at least 2 dimensions (batch_size, num_tokens, hidden_dim)
        if hidden_states.dim() == 2:
            hidden_states = hidden_states.unsqueeze(0)  # Add batch dimension
        logits = lm_head(hidden_states)  # (batch_size, num_tokens, vocab_size)
    
    # Get top-k tokens
    scores, token_ids = torch.topk(logits, k=k, dim=-1)  # (batch_size, num_tokens, k)
    
    # Decode tokens and collect results for each position
    results = []
    for pos in range(scores.size(1)):  # Iterate over token positions
        pos_results = []
        for i in range(k):
            token = tokenizer.decode(token_ids[0, pos, i])  # Decode token for this position
            score = scores[0, pos, i].item()  # Get score for this position
            pos_results.append((token, score))
        results.append(pos_results)
    
    return results

In [5]:
class DeepseekLayerAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.layer_outputs = defaultdict(list)
        self.moe_gate_outputs = defaultdict(list)
        self.moe_combined_outputs = defaultdict(list)
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.shared_expert_outputs = defaultdict(list)
        self.hooks = []
        
    def register_hooks(self):
        """Register hooks for layer outputs and MoE combination points"""
        
        def layer_output_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing layer outputs"""
                hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs
                self.layer_outputs[layer_idx].append(hidden_states.detach())
            return hook

        def moe_gate_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing MoE gate outputs before expert computation"""
                # Capture topk_idx, topk_weight, and aux_loss from gate outputs
                if isinstance(outputs, tuple):
                    topk_idx, topk_weight, aux_loss = outputs
                    self.moe_gate_outputs[layer_idx].append({
                        'topk_idx': topk_idx.detach(),
                        'topk_weight': topk_weight.detach(),
                        'aux_loss': aux_loss.detach() if aux_loss is not None else None
                    })
            return hook

        def expert_hook(layer_idx, expert_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing expert outputs"""
                # Get the latest gate outputs for this layer
                if not self.moe_gate_outputs[layer_idx]:
                    return print(f'no gate outputs for layer {layer_idx}')
            
                gate_data = self.moe_gate_outputs[layer_idx][-1]
                
                # Handle 2D or 3D tensor shapes
                if len(gate_data['topk_idx'].shape) == 2:
                    batch_size = 1
                    seq_len, top_k = gate_data['topk_idx'].shape
                else:
                    batch_size, seq_len, top_k = gate_data['topk_idx'].shape
                
                # Get mask for tokens where this expert was selected
                expert_mask = (gate_data['topk_idx'] == expert_idx)                
                # Flatten and find positions where this expert was selected
                selected_positions = torch.nonzero(expert_mask, as_tuple=True)
                # If no tokens selected this expert, skip
                if selected_positions[0].numel() == 0:
                    return
                    
                # Get the actual inputs routed to this expert
                # Inputs[0] shape: (total_selected_tokens, hidden_dim)
                total_selected = inputs[0].shape[0] 
                # Validate we're processing the correct number of tokens
                expected_selected = expert_mask.sum().item()
                if total_selected != expected_selected:
                    print(f" expert {expert_idx} processed {total_selected} tokens but expected {expected_selected}")
                    return
                    
                # Get the full hidden states from outputs
                # outputs shape: (total_selected_tokens, hidden_dim)
                hidden_states = outputs
                if isinstance(outputs, tuple):
                    hidden_states = outputs[0]
                    
                # Record data for each selected position
                for pos_idx, pos in enumerate(selected_positions[0]):
                    token_data = {
                        'position': pos.item(),
                        'input': inputs[0][pos_idx].detach(),
                        'output': outputs[pos_idx].detach(),
                        'hidden_state': hidden_states[pos_idx].detach()  # Store full hidden state
                    }
                    
                    # Get the corresponding gate weight for this position
                    # Find which expert slot (in top_k) this expert was selected for this position
                    expert_slots = (gate_data['topk_idx'][pos.item()] == expert_idx).nonzero(as_tuple=True)[0]
                    if len(expert_slots) > 0:
                        token_data['gate_weight'] = gate_data['topk_weight'][pos.item()][expert_slots[0]].item()
                    
                    self.expert_outputs[layer_idx][expert_idx].append(token_data)
            return hook
        
        def shared_expert_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing shared expert outputs"""
                self.shared_expert_outputs[layer_idx].append({
                    'input': inputs[0].detach(),
                    'output': outputs.detach()
                })
            return hook

        def moe_combine_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing final combined MoE outputs"""
                # For DeepseekMoE, this captures the weighted sum of expert outputs
                self.moe_combined_outputs[layer_idx].append({
                    'combined_output': outputs.detach(),
                    'input': inputs[0].detach()  # Original input before MoE
                })
            return hook

        # Register hooks for each layer
        for layer_idx, layer in enumerate(self.model.model.layers):
            # Hook for layer output
            hook = layer.register_forward_hook(layer_output_hook(layer_idx))
            self.hooks.append(hook)
            
            # If it's an MoE layer, add MoE-specific hooks
            if hasattr(layer.mlp, 'experts'):
                # Hook for gate mechanism
                gate_hook = layer.mlp.gate.register_forward_hook(moe_gate_hook(layer_idx))
                self.hooks.append(gate_hook)
                
                # Hook for each expert
                for expert_idx, expert in enumerate(layer.mlp.experts):
                    expert_hook_fn = expert.register_forward_hook(expert_hook(layer_idx, expert_idx))
                    self.hooks.append(expert_hook_fn)

                # Hook for shared expert if it exists
                if hasattr(layer.mlp, 'shared_experts'):
                    shared_hook = layer.mlp.shared_experts.register_forward_hook(shared_expert_hook(layer_idx))
                    self.hooks.append(shared_hook)
                
                # Hook for final combined output
                combine_hook = layer.mlp.register_forward_hook(moe_combine_hook(layer_idx))
                self.hooks.append(combine_hook)

    def analyze_tokens(self, input_ids: torch.Tensor, return_hidden_states: bool = False) -> Dict:
        """ run inference and analyze tokens at each layer and expert combination point """

        self.moe_combined_outputs.clear()
        self.expert_outputs.clear()
        self.shared_expert_outputs.clear()
        
        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        results = {
            'layer_predictions': {},
            'moe_analysis': {},
            'hidden_states': {} if return_hidden_states else None
        }
        
        # Analyze layer outputs
        for layer_idx, outputs in self.layer_outputs.items():
            if not outputs:  # Skip if no outputs captured
                continue
            hidden_states = outputs[-1]  # Get last captured output
            
            # Get token predictions for this layer
            top_tokens = get_top_k_tokens(hidden_states, self.model.lm_head, self.tokenizer)
            results['layer_predictions'][layer_idx] = top_tokens
            
            if return_hidden_states:
                results['hidden_states'][f'layer_{layer_idx}'] = hidden_states
        
        # Analyze MoE layers
        for layer_idx in self.moe_gate_outputs.keys():
            if not self.moe_gate_outputs[layer_idx]:
                continue
                
            gate_data = self.moe_gate_outputs[layer_idx][-1]  # Get last captured data
            combined_data = self.moe_combined_outputs[layer_idx][-1]
            
            # Initialize predictions dictionary by position
            expert_predictions_by_pos = defaultdict(dict)
            expert_hidden_states_by_pos = defaultdict(dict)
            
            # Process expert outputs by position
            for expert_idx, data_list in self.expert_outputs[layer_idx].items():
                for data in data_list:
                    position = data['position']
                    # Get predictions for this expert's output at this position
                    predictions = get_top_k_tokens(
                        data['output'].unsqueeze(0),
                        self.model.lm_head,
                        self.tokenizer
                    )
                    expert_predictions_by_pos[position][expert_idx] = predictions[0]  # [0] because we only have one prediction set

                    # Store hidden states
                    expert_hidden_states_by_pos[position][expert_idx] = {
                    'hidden_state': data['hidden_state'].tolist(),
                    'gate_weight': data.get('gate_weight', None)
                }
            
            # Get predictions for shared expert if it exists
            if self.shared_expert_outputs[layer_idx]:
                shared_expert_predictions = get_top_k_tokens(
                    self.shared_expert_outputs[layer_idx][-1]['output'],
                    self.model.lm_head,
                    self.tokenizer
                )
            
            # Update experts_analysis dictionary to include hidden states
            experts_analysis = {
            'selected_experts': gate_data['topk_idx'].tolist(),
            'expert_weights': gate_data['topk_weight'].tolist(),
            'aux_loss': gate_data['aux_loss'].item() if gate_data['aux_loss'] is not None else None,
            'expert_predictions_by_position': dict(expert_predictions_by_pos),
            'expert_hidden_states_by_position': dict(expert_hidden_states_by_pos),
            'shared_expert_predictions': shared_expert_predictions if self.shared_expert_outputs[layer_idx] else None
        }
            
            # Get token predictions from combined output
            combined_tokens = get_top_k_tokens(
                combined_data['combined_output'], 
                self.model.lm_head,
                self.tokenizer
            )
            
            experts_analysis['combined_output_tokens'] = combined_tokens
            results['moe_analysis'][layer_idx] = experts_analysis

        return results
    
    def cleanup(self):
        """remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def analyze_deepseek_moe(model, tokenizer, input_text: str, return_hidden_states: bool = False):
    """ analyze DeepSeek MoE model behavior for given input text """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    try:
        results = analyzer.analyze_tokens(input_ids, return_hidden_states=return_hidden_states)
        return results
    finally:
        analyzer.cleanup()  

In [6]:
prompt = """import pandas as pd

df = pd."""
analysis = analyze_deepseek_moe(
    model, 
    tokenizer,
    input_text=prompt
)

no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for layer 0
no gate outputs for 

AttributeError: 'tuple' object has no attribute 'detach'

In [6]:
# layer_idx = 1  # example layer
# token_pos = 4
# expert_idx = 3
# hidden_state_list = analysis['moe_analysis'][layer_idx]['expert_hidden_states_by_position'][token_pos][expert_idx]['hidden_state']
# print(f'hidden_state_list : {len(hidden_state_list)}')


In [7]:
def get_post_attn_ln_inputs(model, tokenizer, text):
    """places a hook on the post attention layernorm to retrieve its inputs"""
    # Store inputs in a dict mapping layer idx -> inputs
    post_attn_ln_inputs = {}
    hooks = []
    
    def hook_post_attn_ln(module, input, output, layer_idx):
        if layer_idx not in post_attn_ln_inputs:
            post_attn_ln_inputs[layer_idx] = []
        post_attn_ln_inputs[layer_idx].append([x.detach() for x in input])
    
    # Register hooks on post attention layernorm for each layer
    for i, layer in enumerate(model.model.layers):
        hooks.append(
            layer.post_attention_layernorm.register_forward_hook(
                lambda m, i, o, idx=i: hook_post_attn_ln(m, i, o, idx)
            )
        )
    
    try:
        # Run inference
        input_ids = tokenizer(text, return_tensors="pt").input_ids
        model(input_ids)
        
        return post_attn_ln_inputs
        
    finally:
        # Clean up hooks
        for hook in hooks:
            hook.remove()

In [8]:
post_attn_ln_inputs = get_post_attn_ln_inputs(model, tokenizer, text=prompt)

In [9]:
def verify_moe_equation(model, tokenizer, layer_idx, text="the quick brown fox", tolerance=1e-5):
    """
    Verifies that MoE activations add up correctly:
    final_layer_output = residual_stream_hidden_states_before_experts + moe_output
    """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    try:
        # Run the analysis to populate data structures
        analyzer.analyze_tokens(input_ids)
        
        # Get moe outputs and residual
        if layer_idx not in analyzer.moe_combined_outputs or not analyzer.moe_combined_outputs[layer_idx]:
            print("No combined output data found")
            return False

        residual = post_attn_ln_inputs[layer_idx][-1][0][0]
        print(f"Residual shape: {residual.shape}")
        
        # Get the combined MoE output directly from the hook
        moe_output = analyzer.moe_combined_outputs[layer_idx][-1]['combined_output'][0]
        print(f"MoE output shape: {moe_output.shape}")

        # Get the final layer output
        final_layer_output = analyzer.layer_outputs[layer_idx][-1][0]
        print(f"Final layer output shape: {final_layer_output.shape}")
        
        # Full equation: final_output = residual + moe_output
        lhs = residual + moe_output
        print(f"lhs shape: {lhs.shape}")
        print(f'lhs {lhs}')
        rhs = final_layer_output
        print(f"rhs shape: {rhs.shape}")
        print(f'rhs {rhs}')
        # Check if close
        is_close = torch.allclose(lhs, rhs, rtol=tolerance, atol=tolerance)
        print(f'is_close {is_close}')
        if not is_close:
            max_diff = (lhs - rhs).abs().max().item()
            print(f"Maximum difference: {max_diff:.6f}")
            print("\nDetailed component analysis:")
            print(f"Residual max value: {residual.abs().max().item():.6f}")
            print(f"MoE output max value: {moe_output.abs().max().item():.6f}")
            print(f"Final output max value: {final_layer_output.abs().max().item():.6f}")
            
        return is_close
        
    finally:
        analyzer.cleanup()

In [10]:
# # Verify equation for a specific layer
# is_valid = verify_moe_equation(model, tokenizer, layer_idx=26, )
# print(f"MoE equation holds: {is_valid}")

In [11]:
# Access layer predictions
for layer_idx, preds in analysis['layer_predictions'].items():
    print(f"Layer {layer_idx} predictions:", preds)

# Access MoE analysis
for layer_idx, moe_data in analysis['moe_analysis'].items():
    print(f"\nMoE Layer {layer_idx}:")
    print("Selected experts:", moe_data['selected_experts'])
    print("Expert weights:", moe_data['expert_weights'])
    print("Combined output tokens:", moe_data['combined_output_tokens'])
    print("Expert predictions by position:", moe_data['expert_predictions_by_position'])
    print("Shared expert predictions:", moe_data['shared_expert_predictions'])

Layer 0 predictions: [[('开辟', 1.3857421875), ('wk', 1.361328125), ('始终坚持', 1.2900390625), ('晗', 1.2763671875), ('固执', 1.2021484375)], [('懿', 1.1943359375), ('吸取', 1.0966796875), ('KR', 1.046875), ('HCI', 1.0390625), ('摧', 1.0341796875)], [('不惜', 1.2099609375), ('GO', 1.1455078125), ('VF', 1.1416015625), ('дата', 1.0908203125), ('^-_', 1.0751953125)], [('UTO', 0.7080078125), ('wk', 0.69091796875), ('加快推进', 0.67626953125), ('建立起', 0.67041015625), ('巍', 0.65478515625)], [(' ', 1.232421875), (' (', 1.16796875), (' n', 1.1494140625), ('-', 1.0810546875), (' q', 1.0126953125)], [('开辟', 0.93310546875), ('固执', 0.91552734375), ('WN', 0.91259765625), ('GK', 0.91064453125), ('wk', 0.904296875)], [('GK', 0.90625), ('固执', 0.90576171875), ('开辟', 0.890625), ('wk', 0.880859375), ('WN', 0.87939453125)], [(' CWE', 1.0224609375), (' NV', 0.96923828125), ('NV', 0.94970703125), ('© ', 0.9296875), ('HV', 0.91162109375)], [('ADO', 0.85205078125), ('�乐', 0.84033203125), ('吸取', 0.763671875), ('ASA', 0.75634765

In [12]:
def get_layer_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['layer_predictions']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the layer predictions for the specified token position
    layer_predictions = results['layer_predictions'][layer_idx][token_pos]
    
    # Return top 5 predictions
    return layer_predictions[:5]

def get_shared_expert_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the MoE analysis for the specified layer
    moe_analysis = results['moe_analysis'][layer_idx]

    # Get the shared expert predictions for the specified token position
    shared_predictions = moe_analysis['shared_expert_predictions'][token_pos]
    
    # Return top 5 predictions
    return shared_predictions[:5]

def get_expert_preds(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get the MoE analysis for the specified layer
    expert_preds = []
    moe_analysis = results['moe_analysis'][layer_idx]
    expert_predictions = moe_analysis['expert_predictions_by_position']
    expert_toks = expert_predictions[token_pos]

    for expert_idx, preds in expert_toks.items():
        print(f'expert {expert_idx} : {preds}')
        expert_preds.append((int(expert_idx), preds))

    expert_preds = sorted(expert_preds, key=lambda x: x[0])
        
    return expert_preds

In [13]:
def get_highest_pred(results: dict, model, tokenizer, layer_idx: int, token_pos: int, k: int = 5) -> list:
    """gets predictions from expert with highest weight after adding residual """
    # Ensure layer exists in results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get MoE analysis for the layer
    moe_analysis = results['moe_analysis'][layer_idx]
    
    # Get expert weights for this token
    expert_weights = {}
    selected_experts = moe_analysis['selected_experts'][token_pos]
    expert_weights_list = moe_analysis['expert_weights'][token_pos]
    
    if isinstance(selected_experts, int):
        # Handle case where only one expert is selected
        expert_weights[selected_experts] = 1.0
    else:
        # Handle case where multiple experts are selected with weights
        for expert_idx, weight in zip(selected_experts, expert_weights_list):
            expert_weights[expert_idx] = weight
                
    # Find expert with highest weight
    # max_weight_expert = max(expert_weights.items(), key=lambda x: x[1])[0]
    max_weight_expert = sorted(expert_weights.items(), key=lambda x: x[1], reverse=True)[k][0]
    print(f'max_weight_expert {max_weight_expert}')
    
    # Get hidden state from this expert
    expert_hidden_state = moe_analysis['expert_hidden_states_by_position'][token_pos][max_weight_expert]['hidden_state']
    
    # Get residual stream
    residual = post_attn_ln_inputs[layer_idx][-1][0][0][token_pos]
    
    # Convert expert hidden state to tensor if it's a list
    if isinstance(expert_hidden_state, list):
        expert_hidden_state = torch.tensor(expert_hidden_state, dtype=torch.float16)
    
    # Add residual to expert hidden state
    combined = residual + expert_hidden_state
    
    # Get logits
    logits = model.lm_head(combined.unsqueeze(0))
    
    # Get top 5 predictions
    topk = torch.topk(logits[0], k=5)
    scores = topk.values.tolist()
    tokens = topk.indices.tolist()
    
    # Convert to token strings
    predictions = []
    for score, token in zip(scores, tokens):
        token_str = tokenizer.decode(token)
        predictions.append((token_str, score))
        
    return predictions

In [14]:
layer_idx = 25
token_pos = 10
k = 5 # which max expert to use
get_highest_pred(analysis, model, tokenizer, layer_idx=layer_idx, token_pos=token_pos, k=k)

max_weight_expert 24


[('read', 25.25),
 ('DataFrame', 17.96875),
 (' read', 16.71875),
 ('READ', 14.4609375),
 ('reads', 14.0625)]

In [15]:
layer_idx = 25
token_pos = 10
layer_preds = get_layer_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
shared_preds = get_shared_expert_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
expert_preds = get_expert_preds(analysis, layer_idx=layer_idx, token_pos=token_pos)

print(f'layer_preds : {layer_preds}')
print(f'shared_preds : {shared_preds}')
print(f'expert_preds : {expert_preds}')

expert 14 : [('opter', 0.7607421875), ('ктри', 0.73974609375), ('opters', 0.6318359375), ('tric', 0.60693359375), ('sort', 0.59033203125)]
expert 24 : [('ities', 0.288330078125), ('Audi', 0.286865234375), ('7', 0.284423828125), ('百', 0.2705078125), ('0', 0.27001953125)]
expert 36 : [('RET', 0.43359375), (' lim', 0.38525390625), ('acad', 0.371826171875), ('icast', 0.362060546875), ('Oper', 0.35498046875)]
expert 41 : [(' involving', 1.830078125), ('ostes', 1.6259765625), ("'>&", 1.509765625), ('adoc', 1.50390625), ('wau', 1.5029296875)]
expert 43 : [('cciones', 0.274169921875), ('析', 0.270751953125), (' зи', 0.2705078125), (' compress', 0.26806640625), ('排', 0.265380859375)]
expert 49 : [('aulay', 0.291259765625), ('ardown', 0.263671875), ('fer', 0.23193359375), ('backer', 0.2159423828125), ('退', 0.2147216796875)]
layer_preds : [('read', 25.234375), ('DataFrame', 18.109375), (' read', 16.75), ('io', 14.6640625), ('READ', 14.46875)]
shared_preds : [(',', 4.7734375), (' ', 4.31640625), ('

In [16]:
def create_logit_lens_viz(analysis_results, token_pos=0):
    """Creates a heatmap visualization of predictions across model layers.
    
    Args:
        analysis_results: Dictionary containing layer predictions and MoE analysis
        token_pos: Token position to analyze (default 0)
    """
    # Initialize data structure
    n_layers = max(analysis_results['layer_predictions'].keys()) + 1
    x_positions = []  # Will store x-axis positions
    y_positions = []  # Will store layer numbers
    values = []      # Will store prediction scores
    hover_texts = [] # Will store hover text with top 5 predictions

    # Process each layer
    for layer in range(n_layers):
        # Layer predictions
        layer_preds = get_layer_predictions_for_token(analysis_results, layer, token_pos)
        x_positions.append(0)  # Layer output is at x=0
        y_positions.append(layer)
        values.append(layer_preds[0][1])  # Use score from top prediction
        hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in layer_preds])
        hover_texts.append(f"Layer {layer} Output:<br>{hover_text}")

        # Expert predictions if available
        if layer in analysis_results['moe_analysis']:
            expert_preds = get_expert_preds(analysis_results, layer, token_pos)
            moe_data = analysis_results['moe_analysis'][layer]
            
            # Create expert weight lookup for this layer's position
            expert_weights = {}
            for pos_idx, experts in enumerate(moe_data['selected_experts']):
                if pos_idx == token_pos:  # Only for our token position
                    for expert_idx, weight in zip(experts, moe_data['expert_weights'][pos_idx]):
                        expert_weights[expert_idx] = weight
            
            for expert_idx, preds in expert_preds:
                x_positions.append(expert_idx + 1)  # Offset by 1 to account for layer output
                y_positions.append(layer)
                values.append(preds[0][1])  # Use score from top prediction
                
                # Add expert weight to hover text if available
                weight_text = ""
                if expert_idx in expert_weights:
                    weight_text = f"<br>Expert Weight: {expert_weights[expert_idx]:.3f}"
                    
                hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in preds])
                hover_texts.append(f"Layer {layer} Expert {expert_idx}:{weight_text}<br>{hover_text}")

            # Shared expert predictions
            shared_preds = get_shared_expert_predictions_for_token(analysis_results, layer, token_pos)
            x_positions.append(65)  # Put shared experts after individual experts
            y_positions.append(layer)
            values.append(shared_preds[0][1])  # Use score from top prediction
            hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in shared_preds])
            hover_texts.append(f"Layer {layer} Shared Experts:<br>{hover_text}")

    # Create heatmap with larger square markers and no grid lines
    fig = go.Figure(data=go.Scatter(
        x=x_positions,
        y=y_positions,
        mode='markers',
        marker=dict(
            size=20,  # Increased marker size
            symbol='square',
            color=values,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title='logit value')
        ),
        text=hover_texts,
        hoverinfo='text'
    ))

    # Update layout with no grid lines
    fig.update_layout(
        title='Expert Contributions and Predictions for Token Position ' + str(token_pos),
        xaxis_title='layer output | routed experts | shared experts',
        yaxis_title='Layer',
        xaxis=dict(
            ticktext=['l_out'] + [f' {i}' for i in range(64)] + ['shared experts'],
            tickvals=[0] + list(range(1, 65)) + [65],
            tickangle=45,
            showgrid=False  # Removed grid lines
        ),
        yaxis=dict(
            autorange='reversed',  # Put layer 0 at the top
            showgrid=False  # Removed grid lines
        ),
        height=1000,
        width=1800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    return fig

In [17]:
create_logit_lens_viz(analysis, token_pos=10)

expert 6 : [('pless', 0.03759765625), ('vill', 0.035919189453125), ('дове', 0.035003662109375), ('blockList', 0.033660888671875), ('elev', 0.0328369140625)]
expert 7 : [('spaces', 0.0430908203125), ('uega', 0.04052734375), ('ито', 0.039093017578125), ('actly', 0.03765869140625), ('noma', 0.03717041015625)]
expert 21 : [('负', 0.06341552734375), ('thr', 0.0628662109375), ('TB', 0.06158447265625), ('хай', 0.060455322265625), ('дни', 0.06005859375)]
expert 46 : [('ovent', 0.053741455078125), ('yses', 0.052398681640625), ('нев', 0.0517578125), ('itures', 0.050262451171875), ('palette', 0.048797607421875)]
expert 49 : [('rc', 0.1719970703125), ('autoritat', 0.169677734375), ('rita', 0.1673583984375), ('eye', 0.1671142578125), ('erca', 0.1668701171875)]
expert 53 : [('gem', 0.040435791015625), ('indi', 0.039764404296875), ('inada', 0.038116455078125), ('速', 0.03741455078125), (' hers', 0.03704833984375)]
expert 2 : [('snap', 0.0219268798828125), ('heads', 0.021514892578125), ('head', 0.021087