In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import os
from typing import Dict, List, Tuple
from collections import defaultdict

In [None]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

In [3]:
def get_top_k_tokens(hidden_states: torch.Tensor, lm_head: torch.nn.Linear, tokenizer, k: int = 5) -> List[Tuple[str, float]]:
    """ get topk tokens from hidden states using lm head """
    with torch.no_grad():
        # Ensure hidden_states has at least 2 dimensions (batch_size, num_tokens, hidden_dim)
        if hidden_states.dim() == 2:
            hidden_states = hidden_states.unsqueeze(0)  # Add batch dimension
            # Apply RMS normalization like in DeepseekRMSNorm
            variance = hidden_states.pow(2).mean(-1, keepdim=True) 
            hidden_states = hidden_states * torch.rsqrt(variance + 1e-6)

        logits = lm_head(hidden_states)  # (batch_size, num_tokens, vocab_size)
    
    # Get top-k tokens
    scores, token_ids = torch.topk(logits, k=k, dim=-1)  # (batch_size, num_tokens, k)
    
    # Decode tokens and collect results for each position
    results = []
    for pos in range(scores.size(1)):  # Iterate over token positions
        pos_results = []
        for i in range(k):
            token = tokenizer.decode(token_ids[0, pos, i])  # Decode token for this position
            score = scores[0, pos, i].item()  # Get score for this position
            pos_results.append((token, score))
        results.append(pos_results)
    
    return results

In [4]:
class DeepseekLayerAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.layer_outputs = defaultdict(list)
        self.moe_gate_outputs = defaultdict(list)
        self.moe_combined_outputs = defaultdict(list)
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.shared_expert_outputs = defaultdict(list)
        self.hooks = []
        
    def register_hooks(self):
        """Register hooks for layer outputs and MoE combination points"""
        
        def layer_output_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing layer outputs"""
                hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs
                self.layer_outputs[layer_idx].append(hidden_states.detach())
            return hook

        def moe_gate_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing MoE gate outputs before expert computation"""
                # Capture topk_idx, topk_weight, and aux_loss from gate outputs
                if isinstance(outputs, tuple):
                    topk_idx, topk_weight, aux_loss = outputs
                    self.moe_gate_outputs[layer_idx].append({
                        'topk_idx': topk_idx.detach(),
                        'topk_weight': topk_weight.detach(),
                        'aux_loss': aux_loss.detach() if aux_loss is not None else None
                    })
            return hook

        def expert_hook(layer_idx, expert_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing expert outputs"""
                # Get the latest gate outputs for this layer
                if not self.moe_gate_outputs[layer_idx]:
                    return print(f'no gate outputs for layer {layer_idx}')
            
                gate_data = self.moe_gate_outputs[layer_idx][-1]
                
                # Handle 2D or 3D tensor shapes
                if len(gate_data['topk_idx'].shape) == 2:
                    batch_size = 1
                    seq_len, top_k = gate_data['topk_idx'].shape
                else:
                    batch_size, seq_len, top_k = gate_data['topk_idx'].shape
                
                # Get mask for tokens where this expert was selected
                expert_mask = (gate_data['topk_idx'] == expert_idx)                
                # Flatten and find positions where this expert was selected
                selected_positions = torch.nonzero(expert_mask, as_tuple=True)
                # If no tokens selected this expert, skip
                if selected_positions[0].numel() == 0:
                    return
                    
                # Get the actual inputs routed to this expert
                # Inputs[0] shape: (total_selected_tokens, hidden_dim)
                total_selected = inputs[0].shape[0] 
                # Validate we're processing the correct number of tokens
                expected_selected = expert_mask.sum().item()
                if total_selected != expected_selected:
                    print(f" expert {expert_idx} processed {total_selected} tokens but expected {expected_selected}")
                    return
                    
                # Get the full hidden states from outputs
                # outputs shape: (total_selected_tokens, hidden_dim)
                hidden_states = outputs
                if isinstance(outputs, tuple):
                    hidden_states = outputs[0]
                    
                # Record data for each selected position
                for pos_idx, pos in enumerate(selected_positions[0]):
                    token_data = {
                        'position': pos.item(),
                        'input': inputs[0][pos_idx].detach(),
                        'output': outputs[pos_idx].detach(),
                        'hidden_state': hidden_states[pos_idx].detach()  # Store full hidden state
                    }
                    
                    # Get the corresponding gate weight for this position
                    # Find which expert slot (in top_k) this expert was selected for this position
                    expert_slots = (gate_data['topk_idx'][pos.item()] == expert_idx).nonzero(as_tuple=True)[0]
                    if len(expert_slots) > 0:
                        token_data['gate_weight'] = gate_data['topk_weight'][pos.item()][expert_slots[0]].item()
                    
                    self.expert_outputs[layer_idx][expert_idx].append(token_data)
            return hook
        
        def shared_expert_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing shared expert outputs"""
                self.shared_expert_outputs[layer_idx].append({
                    'input': inputs[0].detach(),
                    'output': outputs.detach()
                })
            return hook

        def moe_combine_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing final combined MoE outputs"""
                # For DeepseekMoE, this captures the weighted sum of expert outputs
                self.moe_combined_outputs[layer_idx].append({
                    'combined_output': outputs.detach(),
                    'input': inputs[0].detach()  # Original input before MoE
                })
            return hook

        # Register hooks for each layer
        for layer_idx, layer in enumerate(self.model.model.layers):
            # Hook for layer output
            hook = layer.register_forward_hook(layer_output_hook(layer_idx))
            self.hooks.append(hook)
            
            # If it's an MoE layer, add MoE-specific hooks
            if hasattr(layer.mlp, 'experts'):
                # Hook for gate mechanism
                gate_hook = layer.mlp.gate.register_forward_hook(moe_gate_hook(layer_idx))
                self.hooks.append(gate_hook)
                
                # Hook for each expert
                for expert_idx, expert in enumerate(layer.mlp.experts):
                    expert_hook_fn = expert.register_forward_hook(expert_hook(layer_idx, expert_idx))
                    self.hooks.append(expert_hook_fn)

                # Hook for shared expert if it exists
                if hasattr(layer.mlp, 'shared_experts'):
                    shared_hook = layer.mlp.shared_experts.register_forward_hook(shared_expert_hook(layer_idx))
                    self.hooks.append(shared_hook)
                
                # Hook for final combined output
                combine_hook = layer.mlp.register_forward_hook(moe_combine_hook(layer_idx))
                self.hooks.append(combine_hook)

    def analyze_tokens(self, input_ids: torch.Tensor, return_hidden_states: bool = False) -> Dict:
        """ run inference and analyze tokens at each layer and expert combination point """

        self.moe_combined_outputs.clear()
        self.expert_outputs.clear()
        self.shared_expert_outputs.clear()
        
        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        results = {
            'layer_predictions': {},
            'moe_analysis': {},
            'hidden_states': {} if return_hidden_states else None
        }
        
        # Analyze layer outputs
        for layer_idx, outputs in self.layer_outputs.items():
            if not outputs:  # Skip if no outputs captured
                continue
            hidden_states = outputs[-1]  # Get last captured output
            
            
            
            
            # Get token predictions for this layer
            top_tokens = get_top_k_tokens(hidden_states, self.model.lm_head, self.tokenizer)
            results['layer_predictions'][layer_idx] = top_tokens
            
            if return_hidden_states:
                results['hidden_states'][f'layer_{layer_idx}'] = hidden_states
        
        # Analyze MoE layers
        for layer_idx in self.moe_gate_outputs.keys():
            if not self.moe_gate_outputs[layer_idx]:
                continue
                
            gate_data = self.moe_gate_outputs[layer_idx][-1]  # Get last captured data
            combined_data = self.moe_combined_outputs[layer_idx][-1]
            
            # Initialize predictions dictionary by position
            expert_predictions_by_pos = defaultdict(dict)
            expert_hidden_states_by_pos = defaultdict(dict)
            
            # Process expert outputs by position
            for expert_idx, data_list in self.expert_outputs[layer_idx].items():
                for data in data_list:
                    position = data['position']
                    # Get predictions for this expert's output at this position
                    predictions = get_top_k_tokens(
                        data['output'].unsqueeze(0),
                        self.model.lm_head,
                        self.tokenizer
                    )
                    expert_predictions_by_pos[position][expert_idx] = predictions[0]  # [0] because we only have one prediction set

                    # Store hidden states
                    expert_hidden_states_by_pos[position][expert_idx] = {
                    'hidden_state': data['hidden_state'].tolist(),
                    'gate_weight': data.get('gate_weight', None)
                }
            
            # Get predictions for shared expert if it exists
            if self.shared_expert_outputs[layer_idx]:
                shared_expert_predictions = get_top_k_tokens(
                    self.shared_expert_outputs[layer_idx][-1]['output'],
                    self.model.lm_head,
                    self.tokenizer
                )
            
            # Update experts_analysis dictionary to include hidden states
            experts_analysis = {
            'selected_experts': gate_data['topk_idx'].tolist(),
            'expert_weights': gate_data['topk_weight'].tolist(),
            'aux_loss': gate_data['aux_loss'].item() if gate_data['aux_loss'] is not None else None,
            'expert_predictions_by_position': dict(expert_predictions_by_pos),
            'expert_hidden_states_by_position': dict(expert_hidden_states_by_pos),
            'shared_expert_predictions': shared_expert_predictions if self.shared_expert_outputs[layer_idx] else None
        }
            
            # Get token predictions from combined output
            combined_tokens = get_top_k_tokens(
                combined_data['combined_output'], 
                self.model.lm_head,
                self.tokenizer
            )
            
            experts_analysis['combined_output_tokens'] = combined_tokens
            results['moe_analysis'][layer_idx] = experts_analysis

        return results
    
    def cleanup(self):
        """remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def analyze_deepseek_moe(model, tokenizer, input_text: str, return_hidden_states: bool = False):
    """ analyze DeepSeek MoE model behavior for given input text """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    try:
        results = analyzer.analyze_tokens(input_ids, return_hidden_states=return_hidden_states)
        return results
    finally:
        analyzer.cleanup()  

In [5]:
prompt = '''One might expect language modeling performance to depend on model architecture, the size of neural models, the computing power used to train them, and the data available for this'''

analysis = analyze_deepseek_moe(
    model, 
    tokenizer,
    input_text=prompt
)

In [None]:
input_ids = tokenizer(text="venir", return_tensors="pt",).input_ids
print(f'input_ids : {input_ids}')

In [7]:
# layer_idx = 1  # example layer
# token_pos = 4
# expert_idx = 3
# hidden_state_list = analysis['moe_analysis'][layer_idx]['expert_hidden_states_by_position'][token_pos][expert_idx]['hidden_state']
# print(f'hidden_state_list : {len(hidden_state_list)}')


In [8]:
def get_post_attn_ln_inputs(model, tokenizer, text):
    """places a hook on the post attention layernorm to retrieve its inputs"""
    # Store inputs in a dict mapping layer idx -> inputs
    post_attn_ln_inputs = {}
    hooks = []
    
    def hook_post_attn_ln(module, input, output, layer_idx):
        if layer_idx not in post_attn_ln_inputs:
            post_attn_ln_inputs[layer_idx] = []
        post_attn_ln_inputs[layer_idx].append([x.detach() for x in input])
    
    # Register hooks on post attention layernorm for each layer
    for i, layer in enumerate(model.model.layers):
        hooks.append(
            layer.post_attention_layernorm.register_forward_hook(
                lambda m, i, o, idx=i: hook_post_attn_ln(m, i, o, idx)
            )
        )
    
    try:
        # Run inference
        input_ids = tokenizer(text, return_tensors="pt").input_ids
        model(input_ids)
        
        return post_attn_ln_inputs
        
    finally:
        # Clean up hooks
        for hook in hooks:
            hook.remove()

In [9]:
post_attn_ln_inputs = get_post_attn_ln_inputs(model, tokenizer, text=prompt)

In [10]:
def verify_moe_equation(model, tokenizer, layer_idx, text="the quick brown fox", tolerance=1e-5):
    """
    Verifies that MoE activations add up correctly:
    final_layer_output = residual_stream_hidden_states_before_experts + moe_output
    """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    try:
        # Run the analysis to populate data structures
        analyzer.analyze_tokens(input_ids)
        # Get moe outputs and residual
        if layer_idx not in analyzer.moe_combined_outputs or not analyzer.moe_combined_outputs[layer_idx]:
            print("No combined output data found")
            return False

        residual = post_attn_ln_inputs[layer_idx][-1][0][0]
        print(f"Residual shape: {residual.shape}")
        
        # Get the combined MoE output directly from the hook
        moe_output = analyzer.moe_combined_outputs[layer_idx][-1]['combined_output'][0]
        print(f"MoE output shape: {moe_output.shape}")

        # Get the final layer output
        final_layer_output = analyzer.layer_outputs[layer_idx][-1][0]
        print(f"Final layer output shape: {final_layer_output.shape}")
        
        # Full equation: final_output = residual + moe_output
        lhs = residual + moe_output
        print(f"lhs shape: {lhs.shape}")
        print(f'lhs {lhs}')
        rhs = final_layer_output
        print(f"rhs shape: {rhs.shape}")
        print(f'rhs {rhs}')
        # Check if close
        is_close = torch.allclose(lhs, rhs, rtol=tolerance, atol=tolerance)
        print(f'is_close {is_close}')
        if not is_close:
            max_diff = (lhs - rhs).abs().max().item()
            print(f"Maximum difference: {max_diff:.6f}")
            print("\nDetailed component analysis:")
            print(f"Residual max value: {residual.abs().max().item():.6f}")
            print(f"MoE output max value: {moe_output.abs().max().item():.6f}")
            print(f"Final output max value: {final_layer_output.abs().max().item():.6f}")
            
        return is_close
        
    finally:
        analyzer.cleanup()

In [11]:
# # Verify equation for a specific layer
# is_valid = verify_moe_equation(model, tokenizer, layer_idx=26, )
# print(f"MoE equation holds: {is_valid}")

In [None]:
# Access layer predictions
for layer_idx, preds in analysis['layer_predictions'].items():
    print(f"Layer {layer_idx} predictions:", preds)

# Access MoE analysis
for layer_idx, moe_data in analysis['moe_analysis'].items():
    print(f"\nMoE Layer {layer_idx}:")
    print("Selected experts:", moe_data['selected_experts'])
    print("Expert weights:", moe_data['expert_weights'])
    print("Combined output tokens:", moe_data['combined_output_tokens'])
    print("Expert predictions by position:", moe_data['expert_predictions_by_position'])
    print("Shared expert predictions:", moe_data['shared_expert_predictions'])

In [13]:
def get_layer_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['layer_predictions']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the layer predictions for the specified token position
    layer_predictions = results['layer_predictions'][layer_idx][token_pos]
    
    # Return top 5 predictions
    return layer_predictions[:5]

def get_shared_expert_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the MoE analysis for the specified layer
    moe_analysis = results['moe_analysis'][layer_idx]

    # Get the shared expert predictions for the specified token position
    shared_predictions = moe_analysis['shared_expert_predictions'][token_pos]
    
    # Return top 5 predictions
    return shared_predictions[:5]

def get_expert_preds(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get the MoE analysis for the specified layer
    expert_preds = []
    moe_analysis = results['moe_analysis'][layer_idx]
    expert_predictions = moe_analysis['expert_predictions_by_position']
    expert_toks = expert_predictions[token_pos]

    for expert_idx, preds in expert_toks.items():
        print(f'expert {expert_idx} : {preds}')
        expert_preds.append((int(expert_idx), preds))

    expert_preds = sorted(expert_preds, key=lambda x: x[0])
        
    return expert_preds

In [14]:
def get_highest_pred(results: dict, model, tokenizer, layer_idx: int, token_pos: int, k: int = 5) -> list:
    """gets predictions from expert with highest weight after adding residual """
    # Ensure layer exists in results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get MoE analysis for the layer
    moe_analysis = results['moe_analysis'][layer_idx]
    
    # Get expert weights for this token
    expert_weights = {}
    selected_experts = moe_analysis['selected_experts'][token_pos]
    expert_weights_list = moe_analysis['expert_weights'][token_pos]
    
    if isinstance(selected_experts, int):
        # Handle case where only one expert is selected
        expert_weights[selected_experts] = 1.0
    else:
        # Handle case where multiple experts are selected with weights
        for expert_idx, weight in zip(selected_experts, expert_weights_list):
            expert_weights[expert_idx] = weight
                
    # Find expert with highest weight
    # max_weight_expert = max(expert_weights.items(), key=lambda x: x[1])[0]
    max_weight_expert = sorted(expert_weights.items(), key=lambda x: x[1], reverse=True)[k][0]
    print(f'max_weight_expert {max_weight_expert}')
    
    # Get hidden state from this expert
    expert_hidden_state = moe_analysis['expert_hidden_states_by_position'][token_pos][max_weight_expert]['hidden_state']
    
    # Get residual stream
    residual = post_attn_ln_inputs[layer_idx][-1][0][0][token_pos]
    
    # Convert expert hidden state to tensor if it's a list
    if isinstance(expert_hidden_state, list):
        expert_hidden_state = torch.tensor(expert_hidden_state, dtype=torch.float16)
    
    # Add residual to expert hidden state
    combined = residual + expert_hidden_state
    
    # Get logits
    logits = model.lm_head(combined.unsqueeze(0))
    
    # Get top 5 predictions
    topk = torch.topk(logits[0], k=5)
    scores = topk.values.tolist()
    tokens = topk.indices.tolist()
    
    # Convert to token strings
    predictions = []
    for score, token in zip(scores, tokens):
        token_str = tokenizer.decode(token)
        predictions.append((token_str, score))
        
    return predictions

In [15]:
def get_highest_pred_combined(results: dict, model, tokenizer, res_stream: bool, layer_idx: int, token_pos: int, k: int = 6) -> list:
    """gets predictions from combining top k experts weighted by their router weights, excluding shared experts"""
    # Ensure layer exists in results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get MoE analysis for the layer
    moe_analysis = results['moe_analysis'][layer_idx]
    
    # Get expert weights for this token
    expert_weights = {}
    selected_experts = moe_analysis['selected_experts'][token_pos]
    expert_weights_list = moe_analysis['expert_weights'][token_pos]
    
    if isinstance(selected_experts, int):
        # Handle case where only one expert is selected
        expert_weights[selected_experts] = 1.0
    else:
        # Handle case where multiple experts are selected with weights
        for expert_idx, weight in zip(selected_experts, expert_weights_list):
            # Skip shared experts
            if expert_idx < 0:  # Shared experts typically have negative indices
                continue
            expert_weights[expert_idx] = weight
                
    # Get top k experts by weight
    top_k_experts = sorted(expert_weights.items(), key=lambda x: x[1], reverse=True)[:k]
    print(f'Using top {k} experts: {[expert[0] for expert in top_k_experts]}')
    
    # Initialize combined hidden state
    combined_hidden_state = None
    
    # Combine hidden states from top k experts, weighted by their router weights
    for expert_idx, weight in top_k_experts:
        expert_hidden_state = moe_analysis['expert_hidden_states_by_position'][token_pos][expert_idx]['hidden_state']
        
        # Convert expert hidden state to tensor if it's a list
        if isinstance(expert_hidden_state, list):
            expert_hidden_state = torch.tensor(expert_hidden_state, dtype=torch.float16)
            
        # Weight the expert's hidden state by its router weight
        weighted_hidden_state = expert_hidden_state * weight
        
        if combined_hidden_state is None:
            combined_hidden_state = weighted_hidden_state
        else:
            combined_hidden_state += weighted_hidden_state
    
    # to include residual stream in the combined hidden state
    if res_stream == True:
        # Get residual stream
        residual = post_attn_ln_inputs[layer_idx][-1][0][0][token_pos]
        
        # Add residual to combined expert hidden states
        combined = combined_hidden_state + residual
    else:
        combined = combined_hidden_state
    
    # Get logits
    logits = model.lm_head(combined.unsqueeze(0))
    
    # Get top 5 predictions
    topk = torch.topk(logits[0], k=5)
    scores = topk.values.tolist()
    tokens = topk.indices.tolist()
    
    # Convert to token strings
    predictions = []
    for score, token in zip(scores, tokens):
        token_str = tokenizer.decode(token)
        predictions.append((token_str, score))
        
    return predictions

In [41]:
n = len(tokenizer.encode(prompt))
print(f'n : {n}')
layer_idx = 16
token_pos = n-1
k =4 # which max expert to use
get_highest_pred_combined(analysis, model, tokenizer, res_stream = False, layer_idx=layer_idx, token_pos=token_pos, k=k)

n : 33
Using top 4 experts: [42, 40, 21, 63]


[('政', 0.0418701171875),
 ('rivia', 0.03973388671875),
 ('icacy', 0.03839111328125),
 ('ilar', 0.0374755859375),
 (' Grac', 0.036895751953125)]

In [None]:
layer_idx = 25
token_pos = 2
layer_preds = get_layer_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
shared_preds = get_shared_expert_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
expert_preds = get_expert_preds(analysis, layer_idx=layer_idx, token_pos=token_pos)

print(f'layer_preds : {layer_preds}')
print(f'shared_preds : {shared_preds}')
print(f'expert_preds : {expert_preds}')

In [27]:
def create_logit_lens_viz(analysis_results, token_pos=0, color='sunset', layers_to_plot=[]):
    """Creates a modern heatmap visualization of predictions across model layers.
    
    Args:
        analysis_results: Analysis results dictionary
        token_pos: Position of token to analyze (default: 0)
        layers_to_plot: List of layer numbers to plot. If empty, plots all layers (default: [])
    """
    # Initialize data structures
    n_layers = max(analysis_results['layer_predictions'].keys()) + 1
    data = []
    all_values = []
    
    # Process each layer
    for layer in range(n_layers):
        if layers_to_plot and layer not in layers_to_plot:
            continue
            
        row = {
            'layer': layer,
            'l_out': None,
            'expert_1': None,
            'expert_2': None,
            'expert_3': None,
            'expert_4': None,
            'expert_5': None,
            'expert_6': None,
            'shared': None,
            'expert_residual': None,
            'top_6': None,
            'top_6_res': None,
            'predictions': {
                'l_out': [],
                'expert_1': [],
                'expert_2': [],
                'expert_3': [],
                'expert_4': [],
                'expert_5': [],
                'expert_6': [],
                'shared': [],
                'expert_residual': [],
                'top_6': [],
                'top_6_res': []
            },
            'expert_ids': {},
            'top_tokens': {
                'l_out': '',
                'expert_1': '',
                'expert_2': '',
                'expert_3': '',
                'expert_4': '',
                'expert_5': '',
                'expert_6': '',
                'shared': '',
                'expert_residual': '',
                'top_6': '',
                'top_6_res': ''
            }
        }
        
        # Layer predictions
        layer_preds = get_layer_predictions_for_token(analysis_results, layer, token_pos)
        if layer_preds:
            row['l_out'] = layer_preds[0][1]
            row['predictions']['l_out'] = layer_preds
            row['top_tokens']['l_out'] = layer_preds[0][0]
            all_values.append(layer_preds[0][1])
        
        # Expert predictions
        if layer in analysis_results['moe_analysis']:
            expert_preds = get_expert_preds(analysis_results, layer, token_pos)
            moe_data = analysis_results['moe_analysis'][layer]
            
            expert_weights = {}
            for pos_idx, experts in enumerate(moe_data['selected_experts']):
                if pos_idx == token_pos:
                    for expert_idx, weight in zip(experts, moe_data['expert_weights'][pos_idx]):
                        expert_weights[expert_idx] = weight
            
            expert_logits = []
            for expert_idx in expert_weights:
                expert_data = next((p for e, p in expert_preds if e == expert_idx), None)
                if expert_data and expert_data[0]:
                    expert_logits.append((expert_idx, expert_data[0][1]))
            
            sorted_experts = sorted(expert_logits, key=lambda x: x[1], reverse=True)[:6]
            
            for i, (expert_idx, _) in enumerate(sorted_experts, 1):
                expert_data = next((p for e, p in expert_preds if e == expert_idx), None)
                if expert_data:
                    expert_key = f'expert_{i}'
                    row[expert_key] = expert_data[0][1]
                    row['predictions'][expert_key] = expert_data
                    row['expert_ids'][expert_key] = expert_idx
                    row['top_tokens'][expert_key] = expert_data[0][0]
                    all_values.append(expert_data[0][1])
            
            # Shared expert predictions
            shared_preds = get_shared_expert_predictions_for_token(analysis_results, layer, token_pos)
            if shared_preds:
                row['shared'] = shared_preds[0][1]
                row['predictions']['shared'] = shared_preds
                row['top_tokens']['shared'] = shared_preds[0][0]
                all_values.append(shared_preds[0][1])

            # Expert + Residual predictions
            expert_residual_preds = get_highest_pred(analysis_results, model, tokenizer, layer_idx=layer, token_pos=token_pos, k=0)
            if expert_residual_preds:
                row['expert_residual'] = expert_residual_preds[0][1]
                row['predictions']['expert_residual'] = expert_residual_preds
                row['top_tokens']['expert_residual'] = expert_residual_preds[0][0]
                all_values.append(expert_residual_preds[0][1])

            # Top 6 experts combined predictions without residual
            top_6_preds = get_highest_pred_combined(analysis_results, model, tokenizer, res_stream=False, layer_idx=layer, token_pos=token_pos, k=6)
            if top_6_preds:
                row['top_6'] = top_6_preds[0][1]
                row['predictions']['top_6'] = top_6_preds
                row['top_tokens']['top_6'] = top_6_preds[0][0]
                all_values.append(top_6_preds[0][1])

            # Top 6 experts combined predictions with residual
            top_6_res_preds = get_highest_pred_combined(analysis_results, model, tokenizer, res_stream=True, layer_idx=layer, token_pos=token_pos, k=6)
            if top_6_res_preds:
                row['top_6_res'] = top_6_res_preds[0][1]
                row['predictions']['top_6_res'] = top_6_res_preds
                row['top_tokens']['top_6_res'] = top_6_res_preds[0][0]
                all_values.append(top_6_res_preds[0][1])

        data.append(row)
    
    valid_values = [v for v in all_values if v is not None]
    vmin = min(valid_values) if valid_values else 0
    vmax = max(valid_values) if valid_values else 1
    
    fig = go.Figure()
    
    columns = ['l_out'] + [f'expert_{i}' for i in range(1, 7)] + ['shared', 'expert_residual', 'top_6', 'top_6_res']
    x_positions = list(range(len(columns)))
    
    # Use Plotly's built-in colorscale
    colorscale = color

    # Create y positions with increased spacing
    y_positions = list(range(len(data)))
    
    # Create the base heatmap for all columns at once
    z_matrix = []
    for row in data:
        z_row = []
        for col in columns:
            # For layer 0, set expert columns to None to make them grey
            if row['layer'] == 0 and (col.startswith('expert_') or col == 'shared' or col == 'expert_residual' or col == 'top_6' or col == 'top_6_res'):
                z_row.append(float('nan'))
            else:
                z_row.append(row[col] if row[col] is not None else float('nan'))
        z_matrix.append(z_row)
    # Add main heatmap
    fig.add_trace(go.Heatmap(
        z=z_matrix,
        x=x_positions,
        y=y_positions,
        colorscale=colorscale,
        showscale=True,
        zmin=vmin,
        zmax=vmax,
        colorbar=dict(
            title='logit value',
            titleside='right',
            y=0.5,
            thickness=20,
            len=0.5,
            tickfont=dict(family='JetBrains Mono', size=12),
            titlefont=dict(family='JetBrains Mono', size=14)
        ),
        hoverongaps=False
    ))
    
    # Add text overlays
    for i, row in enumerate(data):
        for j, col in enumerate(columns):
            # Skip text overlay for experts in layer 0
            if row['layer'] == 0 and (col.startswith('expert_') or col == 'shared' or col == 'expert_residual' or col == 'top_6' or col == 'top_6_res'):
                continue
                
            if row[col] is not None:
                # Calculate text color based on background
                normalized_value = (row[col] - vmin) / (vmax - vmin) if vmax != vmin else 0
                text_color = 'white' if normalized_value > 0.5 else 'black'
                
                # Add expert ID if applicable
                if col.startswith('expert_') and col != 'expert_residual':
                    expert_id = str(row['expert_ids'].get(col, ''))
                    fig.add_trace(go.Scatter(
                        x=[j - 0.4],  # Increased x offset to move expert ID more to the left
                        y=[i + 0.35],  # Removed y offset to center vertically
                        mode='text',
                        text=[expert_id],
                        textposition='middle center',
                        textfont=dict(family='JetBrains Mono', color=text_color, size=12),  # Increased from 12
                        hoverinfo='skip',
                        showlegend=False
                    ))
                # Add token text
                token_text = row['top_tokens'][col]
                fig.add_trace(go.Scatter(
                    x=[j],
                    y=[i],
                    mode='text',
                    text=[token_text],
                    textposition='middle center',
                    textfont=dict(family='JetBrains Mono', color=text_color, size=14),  # Increased from 12
                    hoverinfo='skip',
                    showlegend=False
                ))
    
    # Calculate dimensions with improved scaling
    n_rows = len(data)
    base_width = 1400  # Increased width
    min_height_per_row = 80  # Significantly increased height per row
    plot_height = max(min_height_per_row * n_rows, 600)  # Ensure minimum height
    
    # Update layout
    fig.update_layout(
        xaxis=dict(
            ticktext=['layer output'] + [f'exp {i}' for i in range(1, 7)] + ['shared', 'exp + residual', 'top 6', 'top 6 + residual'],
            tickvals=x_positions,
            tickangle=0,
            showgrid=False,
            zeroline=False,
            title='',
            color='black',
            range=[-0.5, len(columns) - 0.5],
            constrain='domain',
            tickfont=dict(family='JetBrains Mono', size=14),
            side='bottom'  # Move x-axis labels to bottom
        ),
        yaxis=dict(
            autorange='reversed',
            showgrid=False,
            zeroline=False,
            title=dict(
                text='Layer',
                font=dict(family='JetBrains Mono', size=16)
            ),
            color='black',
            range=[-0.5, len(data) - 0.5],
            ticktext=[str(row['layer']) for row in data],
            tickvals=y_positions,
            tickfont=dict(family='JetBrains Mono', size=14),
            constrain='domain'
        ),
        plot_bgcolor='white',
        paper_bgcolor='white',
        width=base_width,
        height=plot_height,
        margin=dict(l=80, r=120, t=50, b=100),  # Adjusted margins for bottom x-axis
        hovermode='closest'
    )
    
    # Add cell borders
    for i in range(len(data)):
        for j in range(len(columns)):
            fig.add_shape(
                type="rect",
                x0=j-0.5,
                y0=i-0.5,
                x1=j+0.5,
                y1=i+0.5,
                # line=dict(color="black", width=4),
                fillcolor="rgba(0,0,0,0)",
                layer="above"  # Changed from "below" to "above"
            )
    
    # Add hover text
    for i, row in enumerate(data):
        for j, col in enumerate(columns):
            # Skip hover for experts in layer 0
            if row['layer'] == 0 and (col.startswith('expert_') or col == 'shared' or col == 'expert_residual' or col == 'top_6' or col == 'top_6_res'):
                continue
                
            if row[col] is not None:
                preds = row['predictions'][col]
                if preds:
                    pred_text = "<br>".join([f"{token}: {score:.3f}" for token, score in preds[:5]])
                    if col.startswith('expert_') and col != 'expert_residual':
                        real_id = row['expert_ids'].get(col)
                        hover_text = f"Layer {row['layer']} expert {real_id} (logit: {row[col]:.3f}):<br>{pred_text}"
                    else:
                        hover_text = f"Layer {row['layer']} {col} (logit: {row[col]:.3f}):<br>{pred_text}"
                    
                    fig.add_trace(go.Scatter(
                        x=[j],
                        y=[i],
                        mode='markers',
                        marker=dict(opacity=0),
                        hovertext=hover_text,
                        hoverinfo='text',
                        showlegend=False
                    ))
    
    return fig

In [None]:
l = [2, 6, 10, 14, 22, 25, 27]
# l = []
x = len(tokenizer.encode(prompt))
print(x)
fig = create_logit_lens_viz(analysis, 
                            token_pos=x-1, 
                            color='blues', 
                            layers_to_plot=l)
# save the fig
i = 1
while os.path.exists(f'logit_lens_viz_{i}.html'):
    i += 1
fig.write_html(f'logit_lens_viz_{i}.html')
fig.write_image(f'logit_lens_viz_{i}.png')
fig.show()