In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import os
from typing import Dict, List, Tuple
from collections import defaultdict

In [2]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
def get_top_k_tokens(hidden_states: torch.Tensor, lm_head: torch.nn.Linear, tokenizer, k: int = 5) -> List[Tuple[str, float]]:
    """ get topk tokens from hidden states using lm head """
    with torch.no_grad():
        # Ensure hidden_states has at least 2 dimensions (batch_size, num_tokens, hidden_dim)
        if hidden_states.dim() == 2:
            hidden_states = hidden_states.unsqueeze(0)  # Add batch dimension
        logits = lm_head(hidden_states)  # (batch_size, num_tokens, vocab_size)
    
    # Get top-k tokens
    scores, token_ids = torch.topk(logits, k=k, dim=-1)  # (batch_size, num_tokens, k)
    
    # Decode tokens and collect results for each position
    results = []
    for pos in range(scores.size(1)):  # Iterate over token positions
        pos_results = []
        for i in range(k):
            token = tokenizer.decode(token_ids[0, pos, i])  # Decode token for this position
            score = scores[0, pos, i].item()  # Get score for this position
            pos_results.append((token, score))
        results.append(pos_results)
    
    return results

In [4]:
class DeepseekLayerAnalyzer:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.layer_outputs = defaultdict(list)
        self.moe_gate_outputs = defaultdict(list)
        self.moe_combined_outputs = defaultdict(list)
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.shared_expert_outputs = defaultdict(list)
        self.hooks = []
        
    def register_hooks(self):
        """Register hooks for layer outputs and MoE combination points"""
        
        def layer_output_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing layer outputs"""
                hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs
                self.layer_outputs[layer_idx].append(hidden_states.detach())
            return hook

        def moe_gate_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing MoE gate outputs before expert computation"""
                # Capture topk_idx, topk_weight, and aux_loss from gate outputs
                if isinstance(outputs, tuple):
                    topk_idx, topk_weight, aux_loss = outputs
                    self.moe_gate_outputs[layer_idx].append({
                        'topk_idx': topk_idx.detach(),
                        'topk_weight': topk_weight.detach(),
                        'aux_loss': aux_loss.detach() if aux_loss is not None else None
                    })
            return hook

        def expert_hook(layer_idx, expert_idx):
            def hook(module, inputs, outputs):
                # Get the latest gate outputs for this layer
                if not self.moe_gate_outputs[layer_idx]:
                    return print(f'no gate outputs for layer {layer_idx}')
                    
                gate_data = self.moe_gate_outputs[layer_idx][-1]
                
                # Handle 2D or 3D tensor shapes
                if len(gate_data['topk_idx'].shape) == 2:
                    batch_size = 1
                    seq_len, top_k = gate_data['topk_idx'].shape
                else:
                    batch_size, seq_len, top_k = gate_data['topk_idx'].shape
                    
                
                # Get mask for tokens where this expert was selected
                expert_mask = (gate_data['topk_idx'] == expert_idx)                
                # Flatten and find positions where this expert was selected
                selected_positions = torch.nonzero(expert_mask, as_tuple=True)
                # If no tokens selected this expert, skip
                if selected_positions[0].numel() == 0:
                    return
                    
                # Get the actual inputs routed to this expert
                # Inputs[0] shape: (total_selected_tokens, hidden_dim)
                # We need to verify these correspond to the gate's selection
                total_selected = inputs[0].shape[0] 
                # Validate we're processing the correct number of tokens
                expected_selected = expert_mask.sum().item()
                if total_selected != expected_selected:
                    print(f" expert {expert_idx} processed {total_selected} tokens but expected {expected_selected}")
                    return
                    
                # Record only if we have matching selections
                # Store position information along with the outputs
                for pos_idx, pos in enumerate(selected_positions[0]):
                    self.expert_outputs[layer_idx][expert_idx].append({
                        'position': pos.item(),
                        'input': inputs[0][pos_idx].detach(),
                        'output': outputs[pos_idx].detach()
                    })
            return hook

        def shared_expert_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing shared expert outputs"""
                self.shared_expert_outputs[layer_idx].append({
                    'input': inputs[0].detach(),
                    'output': outputs.detach()
                })
            return hook

        def moe_combine_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing final combined MoE outputs"""
                # For DeepseekMoE, this captures the weighted sum of expert outputs
                self.moe_combined_outputs[layer_idx].append({
                    'combined_output': outputs.detach(),
                    'input': inputs[0].detach()  # Original input before MoE
                })
            return hook

        # Register hooks for each layer
        for layer_idx, layer in enumerate(self.model.model.layers):
            # Hook for layer output
            hook = layer.register_forward_hook(layer_output_hook(layer_idx))
            self.hooks.append(hook)
            
            # If it's an MoE layer, add MoE-specific hooks
            if hasattr(layer.mlp, 'experts'):
                # Hook for gate mechanism
                gate_hook = layer.mlp.gate.register_forward_hook(moe_gate_hook(layer_idx))
                self.hooks.append(gate_hook)
                
                # Hook for each expert
                for expert_idx, expert in enumerate(layer.mlp.experts):
                    expert_hook_fn = expert.register_forward_hook(expert_hook(layer_idx, expert_idx))
                    self.hooks.append(expert_hook_fn)

                # Hook for shared expert if it exists
                if hasattr(layer.mlp, 'shared_experts'):
                    shared_hook = layer.mlp.shared_experts.register_forward_hook(shared_expert_hook(layer_idx))
                    self.hooks.append(shared_hook)
                
                # Hook for final combined output
                combine_hook = layer.mlp.register_forward_hook(moe_combine_hook(layer_idx))
                self.hooks.append(combine_hook)

    def analyze_tokens(self, input_ids: torch.Tensor, return_hidden_states: bool = False) -> Dict:
        """ run inference and analyze tokens at each layer and expert combination point """
        # Clear previous results
        self.layer_outputs.clear()
        self.moe_gate_outputs.clear()
        self.moe_combined_outputs.clear()
        self.expert_outputs.clear()
        self.shared_expert_outputs.clear()
        
        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        results = {
            'layer_predictions': {},
            'moe_analysis': {},
            'hidden_states': {} if return_hidden_states else None
        }
        
        # Analyze layer outputs
        for layer_idx, outputs in self.layer_outputs.items():
            if not outputs:  # Skip if no outputs captured
                continue
            hidden_states = outputs[-1]  # Get last captured output
            
            # Get token predictions for this layer
            top_tokens = get_top_k_tokens(hidden_states, self.model.lm_head, self.tokenizer)
            results['layer_predictions'][layer_idx] = top_tokens
            
            if return_hidden_states:
                results['hidden_states'][f'layer_{layer_idx}'] = hidden_states
        
        # Analyze MoE layers
        for layer_idx in self.moe_gate_outputs.keys():
            if not self.moe_gate_outputs[layer_idx]:
                continue
                
            gate_data = self.moe_gate_outputs[layer_idx][-1]  # Get last captured data
            combined_data = self.moe_combined_outputs[layer_idx][-1]
            
            # Initialize predictions dictionary by position
            expert_predictions_by_pos = defaultdict(dict)
            
            # Process expert outputs by position
            for expert_idx, data_list in self.expert_outputs[layer_idx].items():
                for data in data_list:
                    position = data['position']
                    # Get predictions for this expert's output at this position
                    predictions = get_top_k_tokens(
                        data['output'].unsqueeze(0),
                        self.model.lm_head,
                        self.tokenizer
                    )
                    expert_predictions_by_pos[position][expert_idx] = predictions[0]  # [0] because we only have one prediction set
            
            # Get predictions for shared expert if it exists
            if self.shared_expert_outputs[layer_idx]:
                shared_expert_predictions = get_top_k_tokens(
                    self.shared_expert_outputs[layer_idx][-1]['output'],
                    self.model.lm_head,
                    self.tokenizer
                )
            
            # Analyze expert selection and combination
            experts_analysis = {
                'selected_experts': gate_data['topk_idx'].tolist(),
                'expert_weights': gate_data['topk_weight'].tolist(),
                'aux_loss': gate_data['aux_loss'].item() if gate_data['aux_loss'] is not None else None,
                'expert_predictions_by_position': dict(expert_predictions_by_pos),
                'shared_expert_predictions': shared_expert_predictions
            }
            
            # Get token predictions from combined output
            combined_tokens = get_top_k_tokens(
                combined_data['combined_output'], 
                self.model.lm_head,
                self.tokenizer
            )
            
            experts_analysis['combined_output_tokens'] = combined_tokens
            results['moe_analysis'][layer_idx] = experts_analysis

        return results
    
    def cleanup(self):
        """remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def analyze_deepseek_moe(model, tokenizer, input_text: str, return_hidden_states: bool = False):
    """ analyze DeepSeek MoE model behavior for given input text """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    try:
        results = analyzer.analyze_tokens(input_ids, return_hidden_states=return_hidden_states)
        return results
    finally:
        analyzer.cleanup()  

In [5]:
analysis = analyze_deepseek_moe(
    model, 
    tokenizer,
    "the capital of the united states is"
)   

In [6]:
# Access layer predictions
# for layer_idx, preds in analysis['layer_predictions'].items():
#     print(f"Layer {layer_idx} predictions:", preds)

# # Access MoE analysis
# for layer_idx, moe_data in analysis['moe_analysis'].items():
#     print(f"\nMoE Layer {layer_idx}:")
#     print("Selected experts:", moe_data['selected_experts'])
#     print("Expert weights:", moe_data['expert_weights'])
#     print("Combined output tokens:", moe_data['combined_output_tokens'])
#     print("Expert predictions by position:", moe_data['expert_predictions_by_position'])
#     print("Shared expert predictions:", moe_data['shared_expert_predictions'])

In [7]:
def get_layer_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['layer_predictions']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the layer predictions for the specified token position
    layer_predictions = results['layer_predictions'][layer_idx][token_pos]
    
    # Return top 5 predictions
    return layer_predictions[:5]

def get_shared_expert_predictions_for_token(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the results.")

    # Get the MoE analysis for the specified layer
    moe_analysis = results['moe_analysis'][layer_idx]

    # Get the shared expert predictions for the specified token position
    shared_predictions = moe_analysis['shared_expert_predictions'][token_pos]
    
    # Return top 5 predictions
    return shared_predictions[:5]

def get_expert_preds(results: dict, layer_idx: int, token_pos: int) -> list:
    # Ensure the layer exists in the results
    if layer_idx not in results['moe_analysis']:
        raise ValueError(f"Layer {layer_idx} not found in the MoE analysis results.")
        
    # Get the MoE analysis for the specified layer
    expert_preds = []
    moe_analysis = results['moe_analysis'][layer_idx]
    expert_predictions = moe_analysis['expert_predictions_by_position']
    expert_toks = expert_predictions[token_pos]

    for expert_idx, preds in expert_toks.items():
        print(f'expert {expert_idx} : {preds}')
        expert_preds.append((int(expert_idx), preds))

    expert_preds = sorted(expert_preds, key=lambda x: x[0])
        
    return expert_preds

In [8]:
layer_idx = 26
token_pos = 0
layer_preds = get_layer_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
shared_preds = get_shared_expert_predictions_for_token(analysis, layer_idx=layer_idx, token_pos=token_pos)
expert_preds = get_expert_preds(analysis, layer_idx=layer_idx, token_pos=token_pos)

print(f'layer_preds : {layer_preds}')
print(f'shared_preds : {shared_preds}')
print(f'expert_preds : {expert_preds}')

expert 31 : [(' with', 0.2666015625), (' as', 0.25244140625), (' and', 0.2509765625), (' primary', 0.232666015625), ('高于', 0.228515625)]
expert 33 : [('�', 2.34765625), ('Kind', 1.9462890625), ('ano', 1.9091796875), (' Arabia', 1.8212890625), ('kind', 1.791015625)]
expert 39 : [('1', 0.103515625), ('6', 0.09796142578125), ('2', 0.09716796875), ('4', 0.09698486328125), ('7', 0.09674072265625)]
expert 40 : [(' regarding', 144.25), (' Down', 131.0), ('anys', 128.625), (' relating', 125.25), ('allot', 122.1875)]
expert 45 : [(' stable', 0.58642578125), (' comm', 0.5537109375), ('�', 0.533203125), ('ipus', 0.52294921875), ('ter', 0.51806640625)]
expert 57 : [('and', 0.08404541015625), ('b', 0.081787109375), ('con', 0.0731201171875), ('ol', 0.07183837890625), ('ul', 0.071044921875)]
layer_preds : [('�', 21.484375), (' Hague', 20.59375), ('WER', 19.328125), (' SEA', 18.75), ('ISAM', 18.53125)]
shared_preds : [(' ', 2.54296875), (' involving', 2.533203125), ('0', 2.35546875), (' u', 2.34179687

In [9]:
def create_logit_lens_viz(analysis_results, token_pos=0):
    """Creates a heatmap visualization of predictions across model layers.
    
    Args:
        analysis_results: Dictionary containing layer predictions and MoE analysis
        token_pos: Token position to analyze (default 0)
    """
    # Initialize data structure
    n_layers = max(analysis_results['layer_predictions'].keys()) + 1
    x_positions = []  # Will store x-axis positions
    y_positions = []  # Will store layer numbers
    values = []      # Will store prediction scores
    hover_texts = [] # Will store hover text with top 5 predictions

    # Process each layer
    for layer in range(n_layers):
        # Layer predictions
        layer_preds = get_layer_predictions_for_token(analysis_results, layer, token_pos)
        x_positions.append(0)  # Layer output is at x=0
        y_positions.append(layer)
        values.append(layer_preds[0][1])  # Use score from top prediction
        hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in layer_preds])
        hover_texts.append(f"Layer {layer} Output:<br>{hover_text}")

        # Expert predictions if available
        if layer in analysis_results['moe_analysis']:
            expert_preds = get_expert_preds(analysis_results, layer, token_pos)
            for expert_idx, preds in expert_preds:
                x_positions.append(expert_idx + 1)  # Offset by 1 to account for layer output
                y_positions.append(layer)
                values.append(preds[0][1])  # Use score from top prediction
                hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in preds])
                hover_texts.append(f"Layer {layer} Expert {expert_idx}:<br>{hover_text}")

            # Shared expert predictions
            shared_preds = get_shared_expert_predictions_for_token(analysis_results, layer, token_pos)
            x_positions.append(65)  # Put shared experts after individual experts
            y_positions.append(layer)
            values.append(shared_preds[0][1])  # Use score from top prediction
            hover_text = "<br>".join([f"{pred[0]}: {pred[1]:.3f}" for pred in shared_preds])
            hover_texts.append(f"Layer {layer} Shared Experts:<br>{hover_text}")

    # Create heatmap with larger square markers and no grid lines
    fig = go.Figure(data=go.Scatter(
        x=x_positions,
        y=y_positions,
        mode='markers',
        marker=dict(
            size=20,  # Increased marker size
            symbol='square',
            color=values,
            colorscale='Viridis',
            showscale=True,
            colorbar=dict(title='logit value')
        ),
        text=hover_texts,
        hoverinfo='text'
    ))

    # Update layout with no grid lines
    fig.update_layout(
        title='Expert Contributions and Predictions for Token Position ' + str(token_pos),
        xaxis_title='layer output | routed experts | shared experts',
        yaxis_title='Layer',
        xaxis=dict(
            ticktext=['l_out'] + [f' {i}' for i in range(64)] + ['shared experts'],
            tickvals=[0] + list(range(1, 65)) + [65],
            tickangle=45,
            showgrid=False  # Removed grid lines
        ),
        yaxis=dict(
            autorange='reversed',  # Put layer 0 at the top
            showgrid=False  # Removed grid lines
        ),
        height=1000,
        width=1800,
        plot_bgcolor='black',
        paper_bgcolor='black',
        font=dict(color='white')
    )

    return fig

In [14]:
create_logit_lens_viz(analysis, token_pos=7)

expert 4 : [('姆斯', 0.485595703125), ('RIB', 0.43896484375), ('ece', 0.427978515625), ('нан', 0.420166015625), ('ovi', 0.39453125)]
expert 14 : [('eso', 0.1871337890625), ('ebly', 0.1788330078125), ('чик', 0.17724609375), ('光灯', 0.177001953125), ('�', 0.1756591796875)]
expert 30 : [('жите', 0.046630859375), ('赵丽', 0.03875732421875), ('стики', 0.037109375), ('FTWARE', 0.0362548828125), ('зина', 0.035614013671875)]
expert 36 : [(' yourselves', 0.0718994140625), ('�起', 0.06988525390625), ('те', 0.06427001953125), ('ENTR', 0.06396484375), ('登', 0.06390380859375)]
expert 44 : [('наги', 0.01904296875), ('ondin', 0.01690673828125), ('рад', 0.01654052734375), ('хия', 0.01605224609375), (' pude', 0.015960693359375)]
expert 45 : [('�', 0.019744873046875), ('ontal', 0.0191497802734375), ('vinguda', 0.0184173583984375), ('nech', 0.01812744140625), ('FieldLocation', 0.0169830322265625)]
expert 7 : [('reso', 0.1539306640625), ('resos', 0.1416015625), ('ublic', 0.126220703125), ('ques', 0.1259765625),