In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
import os
from typing import Dict, List, Tuple
from collections import defaultdict

In [None]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

In [3]:
def get_top_k_tokens(hidden_states: torch.Tensor, lm_head: torch.nn.Linear, tokenizer, k: int = 5) -> List[Tuple[str, float]]:
    """ get topk tokens from hidden states using lm head """
    with torch.no_grad():
        # Ensure hidden_states has at least 2 dimensions (batch_size, num_tokens, hidden_dim)
        if hidden_states.dim() == 2:
            hidden_states = hidden_states.unsqueeze(0)  # Add batch dimension
        logits = lm_head(hidden_states)  # (batch_size, num_tokens, vocab_size)
    
    # Get top-k tokens
    scores, token_ids = torch.topk(logits, k=k, dim=-1)  # (batch_size, num_tokens, k)
    
    # Decode tokens and collect results for each position
    results = []
    for pos in range(scores.size(1)):  # Iterate over token positions
        pos_results = []
        for i in range(k):
            token = tokenizer.decode(token_ids[0, pos, i])  # Decode token for this position
            score = scores[0, pos, i].item()  # Get score for this position
            pos_results.append((token, score))
        results.append(pos_results)
    
    return results

In [4]:
class DeepseekLayerAnalyzer:
    """ Analyzes the behavior of a DeepSeek MoE model by capturing and analyzing outputs from different layers and experts.
    
    Args:
        model: The DeepSeek MoE model to analyze
        tokenizer: The tokenizer associated with the model
        
    Attributes:
        layer_outputs (defaultdict): Stores outputs from each model layer
        moe_gate_outputs (defaultdict): Stores gate outputs from MoE layers
        moe_combined_outputs (defaultdict): Stores combined outputs after expert computation
        expert_outputs (defaultdict): Stores individual expert outputs per layer and position
        shared_expert_outputs (defaultdict): Stores outputs from shared experts if present
        hooks (list): List of registered PyTorch hooks
        
    Methods:
        register_hooks(): Sets up hooks to capture layer and expert outputs
        analyze_tokens(input_ids, return_hidden_states): Runs inference and analyzes token behavior
        cleanup(): Removes all registered hooks
    """
    
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer
        self.layer_outputs = defaultdict(list)
        self.moe_gate_outputs = defaultdict(list)
        self.moe_combined_outputs = defaultdict(list)
        self.expert_outputs = defaultdict(lambda: defaultdict(list))
        self.shared_expert_outputs = defaultdict(list)
        self.hooks = []        
    def register_hooks(self):
        """Register hooks for layer outputs and MoE combination points"""
        
        def layer_output_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing layer outputs"""
                hidden_states = outputs[0] if isinstance(outputs, tuple) else outputs
                self.layer_outputs[layer_idx].append(hidden_states.detach())
            return hook

        def moe_gate_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing MoE gate outputs before expert computation"""
                # Capture topk_idx, topk_weight, and aux_loss from gate outputs
                if isinstance(outputs, tuple):
                    topk_idx, topk_weight, aux_loss = outputs
                    self.moe_gate_outputs[layer_idx].append({
                        'topk_idx': topk_idx.detach(),
                        'topk_weight': topk_weight.detach(),
                        'aux_loss': aux_loss.detach() if aux_loss is not None else None
                    })
            return hook

        def expert_hook(layer_idx, expert_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing expert outputs"""
                # Get the latest gate outputs for this layer
                if not self.moe_gate_outputs[layer_idx]:
                    return print(f'no gate outputs for layer {layer_idx}')
            
                gate_data = self.moe_gate_outputs[layer_idx][-1]
                
                # Handle 2D or 3D tensor shapes
                if len(gate_data['topk_idx'].shape) == 2:
                    batch_size = 1
                    seq_len, top_k = gate_data['topk_idx'].shape
                else:
                    batch_size, seq_len, top_k = gate_data['topk_idx'].shape
                
                # Get mask for tokens where this expert was selected
                expert_mask = (gate_data['topk_idx'] == expert_idx)                
                # Flatten and find positions where this expert was selected
                selected_positions = torch.nonzero(expert_mask, as_tuple=True)
                # If no tokens selected this expert, skip
                if selected_positions[0].numel() == 0:
                    return
                    
                # Get the actual inputs routed to this expert
                # Inputs[0] shape: (total_selected_tokens, hidden_dim)
                total_selected = inputs[0].shape[0] 
                # Validate we're processing the correct number of tokens
                expected_selected = expert_mask.sum().item()
                if total_selected != expected_selected:
                    print(f" expert {expert_idx} processed {total_selected} tokens but expected {expected_selected}")
                    return
                    
                # Get the full hidden states from outputs
                # outputs shape: (total_selected_tokens, hidden_dim)
                hidden_states = outputs
                if isinstance(outputs, tuple):
                    hidden_states = outputs[0]
                    
                # Record data for each selected position
                for pos_idx, pos in enumerate(selected_positions[0]):
                    token_data = {
                        'position': pos.item(),
                        'input': inputs[0][pos_idx].detach(),
                        'output': outputs[pos_idx].detach(),
                        'hidden_state': hidden_states[pos_idx].detach()  # Store full hidden state
                    }
                    
                    # Get the corresponding gate weight for this position
                    # Find which expert slot (in top_k) this expert was selected for this position
                    expert_slots = (gate_data['topk_idx'][pos.item()] == expert_idx).nonzero(as_tuple=True)[0]
                    if len(expert_slots) > 0:
                        token_data['gate_weight'] = gate_data['topk_weight'][pos.item()][expert_slots[0]].item()
                    
                    self.expert_outputs[layer_idx][expert_idx].append(token_data)
            return hook
        
        def shared_expert_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing shared expert outputs"""
                self.shared_expert_outputs[layer_idx].append({
                    'input': inputs[0].detach(),
                    'output': outputs.detach()
                })
            return hook

        def moe_combine_hook(layer_idx):
            def hook(module, inputs, outputs):
                """Hook for capturing final combined MoE outputs"""
                # For DeepseekMoE, this captures the weighted sum of expert outputs
                self.moe_combined_outputs[layer_idx].append({
                    'combined_output': outputs.detach(),
                    'input': inputs[0].detach()  # Original input before MoE
                })
            return hook

        # Register hooks for each layer
        for layer_idx, layer in enumerate(self.model.model.layers):
            # Hook for layer output
            hook = layer.register_forward_hook(layer_output_hook(layer_idx))
            self.hooks.append(hook)
            
            # If it's an MoE layer, add MoE-specific hooks
            if hasattr(layer.mlp, 'experts'):
                # Hook for gate mechanism
                gate_hook = layer.mlp.gate.register_forward_hook(moe_gate_hook(layer_idx))
                self.hooks.append(gate_hook)
                
                # Hook for each expert
                for expert_idx, expert in enumerate(layer.mlp.experts):
                    expert_hook_fn = expert.register_forward_hook(expert_hook(layer_idx, expert_idx))
                    self.hooks.append(expert_hook_fn)

                # Hook for shared expert if it exists
                if hasattr(layer.mlp, 'shared_experts'):
                    shared_hook = layer.mlp.shared_experts.register_forward_hook(shared_expert_hook(layer_idx))
                    self.hooks.append(shared_hook)
                
                # Hook for final combined output
                combine_hook = layer.mlp.register_forward_hook(moe_combine_hook(layer_idx))
                self.hooks.append(combine_hook)

    def analyze_tokens(self, input_ids: torch.Tensor, return_hidden_states: bool = False) -> Dict:
        """ run inference and analyze tokens at each layer and expert combination point """

        self.moe_combined_outputs.clear()
        self.expert_outputs.clear()
        self.shared_expert_outputs.clear()
        
        # Forward pass
        with torch.no_grad():
            outputs = self.model(input_ids)
        
        results = {
            'layer_predictions': {},
            'moe_analysis': {},
            'hidden_states': {} if return_hidden_states else None
        }
        
        # Analyze layer outputs
        for layer_idx, outputs in self.layer_outputs.items():
            if not outputs:  # Skip if no outputs captured
                continue
            hidden_states = outputs[-1]  # Get last captured output
            
            # Get token predictions for this layer
            top_tokens = get_top_k_tokens(hidden_states, self.model.lm_head, self.tokenizer)
            results['layer_predictions'][layer_idx] = top_tokens
            
            if return_hidden_states:
                results['hidden_states'][f'layer_{layer_idx}'] = hidden_states
        
        # Analyze MoE layers
        for layer_idx in self.moe_gate_outputs.keys():
            if not self.moe_gate_outputs[layer_idx]:
                continue
                
            gate_data = self.moe_gate_outputs[layer_idx][-1]  # Get last captured data
            combined_data = self.moe_combined_outputs[layer_idx][-1]
            
            # Initialize predictions dictionary by position
            expert_predictions_by_pos = defaultdict(dict)
            expert_hidden_states_by_pos = defaultdict(dict)
            
            # Process expert outputs by position
            for expert_idx, data_list in self.expert_outputs[layer_idx].items():
                for data in data_list:
                    position = data['position']
                    # Get predictions for this expert's output at this position
                    predictions = get_top_k_tokens(
                        data['output'].unsqueeze(0),
                        self.model.lm_head,
                        self.tokenizer
                    )
                    expert_predictions_by_pos[position][expert_idx] = predictions[0]  # [0] because we only have one prediction set

                    # Store hidden states
                    expert_hidden_states_by_pos[position][expert_idx] = {
                    'hidden_state': data['hidden_state'].tolist(),
                    'gate_weight': data.get('gate_weight', None)
                }
            
            # Get predictions for shared expert if it exists
            if self.shared_expert_outputs[layer_idx]:
                shared_expert_predictions = get_top_k_tokens(
                    self.shared_expert_outputs[layer_idx][-1]['output'],
                    self.model.lm_head,
                    self.tokenizer
                )
            
            # Update experts_analysis dictionary to include hidden states
            experts_analysis = {
            'selected_experts': gate_data['topk_idx'].tolist(),
            'expert_weights': gate_data['topk_weight'].tolist(),
            'aux_loss': gate_data['aux_loss'].item() if gate_data['aux_loss'] is not None else None,
            'expert_predictions_by_position': dict(expert_predictions_by_pos),
            'expert_hidden_states_by_position': dict(expert_hidden_states_by_pos),
            'shared_expert_predictions': shared_expert_predictions if self.shared_expert_outputs[layer_idx] else None
        }
            
            # Get token predictions from combined output
            combined_tokens = get_top_k_tokens(
                combined_data['combined_output'], 
                self.model.lm_head,
                self.tokenizer
            )
            
            experts_analysis['combined_output_tokens'] = combined_tokens
            results['moe_analysis'][layer_idx] = experts_analysis

        return results
    
    def cleanup(self):
        """remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks.clear()

def analyze_deepseek_moe(model, tokenizer, input_text: str, return_hidden_states: bool = False):
    """ analyze DeepSeek MoE model behavior for given input text """
    analyzer = DeepseekLayerAnalyzer(model, tokenizer)
    analyzer.register_hooks()
    
    input_ids = tokenizer(input_text, return_tensors="pt").input_ids
    
    try:
        results = analyzer.analyze_tokens(input_ids, return_hidden_states=return_hidden_states)
        return results
    finally:
        analyzer.cleanup()  


In [14]:
# Get prompt and tokenize
prompt = "translate the following text to french: The best is yet to come = Le meilleur reste à"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
n = input_ids.shape[1]
token_pos = n-1  # Last token

analysis = analyze_deepseek_moe(model, tokenizer, prompt, return_hidden_states=True)

# Create analyzer and get hidden states
analyzer = DeepseekLayerAnalyzer(model, tokenizer)
analyzer.register_hooks()


results = analyzer.analyze_tokens(input_ids, return_hidden_states=True)

In [15]:
def get_post_attn_ln_inputs(model, tokenizer, text):
    """places a hook on the post attention layernorm to retrieve its inputs"""
    # Store inputs in a dict mapping layer idx -> inputs
    post_attn_ln_inputs = {}
    hooks = []
    
    def hook_post_attn_ln(module, input, output, layer_idx):
        if layer_idx not in post_attn_ln_inputs:
            post_attn_ln_inputs[layer_idx] = []
        post_attn_ln_inputs[layer_idx].append([x.detach() for x in input])
    
    # Register hooks on post attention layernorm for each layer
    for i, layer in enumerate(model.model.layers):
        hooks.append(
            layer.post_attention_layernorm.register_forward_hook(
                lambda m, i, o, idx=i: hook_post_attn_ln(m, i, o, idx)
            )
        )
    
    try:
        # Run inference
        input_ids = tokenizer(text, return_tensors="pt").input_ids
        model(input_ids)
        
        return post_attn_ln_inputs
        
    finally:
        # Clean up hooks
        for hook in hooks:
            hook.remove()

In [16]:
post_attn_ln_inputs = get_post_attn_ln_inputs(model, tokenizer, text=prompt)

In [25]:
def get_expert_hidden_states_by_weight(analysis, post_attn_ln_inputs, token_pos, k=0):
    """Gets hidden states (with residual added) from the k-th highest weighted expert for all MoE layers"""
    hidden_states_by_layer = {}
    
    # For each layer that has MoE
    for layer_idx in analysis['moe_analysis'].keys():
        moe_analysis = analysis['moe_analysis'][layer_idx]
        
        # Get expert weights for this token
        expert_weights = {}
        selected_experts = moe_analysis['selected_experts'][token_pos]
        expert_weights_list = moe_analysis['expert_weights'][token_pos]
        
        # Map experts to their weights
        for expert_idx, weight in zip(selected_experts, expert_weights_list):
            expert_weights[expert_idx] = weight
            
        # Sort by weight and get k-th highest
        sorted_experts = sorted(expert_weights.items(), key=lambda x: x[1], reverse=True)
        if k >= len(sorted_experts):
            print(f"Warning: Layer {layer_idx} only has {len(sorted_experts)} experts, but k={k} requested")
            continue
            
        target_expert = sorted_experts[k][0]  # Get the k-th expert's id
        
        # Get hidden state for this expert
        expert_hidden_state = moe_analysis['expert_hidden_states_by_position'][token_pos][target_expert]['hidden_state']
        
        # Get residual stream for this token
        try:
            residual = post_attn_ln_inputs[layer_idx][-1][0][0][token_pos]  # Added [0] index
        except (KeyError, IndexError) as e:
            print(f"Warning: Could not get residual for layer {layer_idx}, skipping. Error: {e}")
            continue
            
        # Convert to tensor if needed
        if isinstance(expert_hidden_state, list):
            expert_hidden_state = torch.tensor(expert_hidden_state, dtype=torch.float16)
            
        # Add residual
        combined = residual + expert_hidden_state
        hidden_states_by_layer[layer_idx] = combined
        
    return hidden_states_by_layer

In [34]:
# Get final layer index and hidden state
final_layer_idx = max([int(k) for k in results["layer_predictions"].keys()])
final_hidden_state = analyzer.layer_outputs[final_layer_idx][-1][0][token_pos]

print("Shape of final hidden state :", final_hidden_state)

Shape of final hidden state : tensor([-2.1328, -2.8848,  0.5552,  ...,  1.1221, -2.6934, -1.1260],
       dtype=torch.float16)


In [32]:

n = len(tokenizer.encode(prompt))
token_pos = n-1  # Last token
k = 0  # Get highest weighted expert

hidden_states = get_expert_hidden_states_by_weight(analysis,
                                                   post_attn_ln_inputs=post_attn_ln_inputs,
                                                   token_pos=token_pos,
                                                   k=k)

print(f'hidden_states : {hidden_states}')

hidden_states : {1: tensor([ 0.0231,  0.0817, -0.1342,  ..., -0.0058, -0.0729, -0.0549],
       dtype=torch.float16), 2: tensor([ 0.1169,  0.1069, -0.2629,  ...,  0.0451, -0.1361, -0.0201],
       dtype=torch.float16), 3: tensor([ 0.1360,  0.1041, -0.3918,  ...,  0.1637, -0.0724, -0.0071],
       dtype=torch.float16), 4: tensor([ 0.2000,  0.1331, -0.7676,  ...,  0.1451, -0.1857, -0.0319],
       dtype=torch.float16), 5: tensor([-0.0033,  0.0835, -0.7563,  ...,  0.2800, -0.2732,  0.0765],
       dtype=torch.float16), 6: tensor([-0.1963,  0.1505, -0.6743,  ...,  0.3457, -0.3904,  0.1470],
       dtype=torch.float16), 7: tensor([-0.1293,  0.2632, -0.8887,  ...,  0.3867, -0.1835, -0.1223],
       dtype=torch.float16), 8: tensor([ 0.0015,  0.3374, -0.9141,  ...,  0.2864, -0.3770,  0.1615],
       dtype=torch.float16), 9: tensor([-0.0732,  0.2460, -0.8462,  ...,  0.4705, -0.7827,  0.3425],
       dtype=torch.float16), 10: tensor([ 0.0711,  0.2905, -0.9800,  ...,  0.3428, -0.2551,  0.2808],
 

In [33]:
import torch.nn.functional as F

# cosine similarity for each layer
sim = {}
for layer_idx, layer_hidden in hidden_states.items():
    # Normalize both vectors for cosine similarity
    similarity = F.cosine_similarity(final_hidden_state.unsqueeze(0),
                                   layer_hidden.unsqueeze(0))
    sim[layer_idx] = similarity.item()

# Print results sorted by layer
for layer_idx in sorted(sim.keys()):
    print(f"Layer {layer_idx} similarity: {sim[layer_idx]:.4f}")

Layer 1 similarity: 0.0064
Layer 2 similarity: 0.0087
Layer 3 similarity: 0.0131
Layer 4 similarity: 0.0206
Layer 5 similarity: 0.0040
Layer 6 similarity: -0.0043
Layer 7 similarity: 0.0240
Layer 8 similarity: 0.0075
Layer 9 similarity: 0.0093
Layer 10 similarity: -0.0001
Layer 11 similarity: 0.0196
Layer 12 similarity: 0.0048
Layer 13 similarity: 0.0221
Layer 14 similarity: 0.0418
Layer 15 similarity: 0.0453
Layer 16 similarity: 0.0729
Layer 17 similarity: 0.0902
Layer 18 similarity: 0.1382
Layer 19 similarity: 0.1611
Layer 20 similarity: 0.1864
Layer 21 similarity: 0.2150
Layer 22 similarity: 0.2467
Layer 23 similarity: 0.3972
Layer 24 similarity: 0.4072
Layer 25 similarity: 0.4253
Layer 26 similarity: 0.5107
Layer 27 similarity: 0.7437


In [52]:
import plotly.graph_objects as go

def plot_cosine_similarities(similarities, k):
    """
    Create an interactive line plot of cosine similarities across layers.
    
    Args:
    similarities: dict with layer_idx -> cosine_similarity_value
    k: which expert (by weight rank) was used
    """
    # Set plot size
    plot_size = 500  # Size in pixels for both width and height
    
    # Sort layers for x-axis
    layers = sorted(similarities.keys())
    sim_values = [similarities[layer] for layer in layers]
    
    # Create figure
    fig = go.Figure()
    
    # Add line plot
    fig.add_trace(go.Scatter(
        x=layers,
        y=sim_values,
        mode='lines+markers',
        name=f'Expert rank {k}',
        hovertemplate=
        "<b>Layer %{x}</b><br>" +
        "Cosine Similarity: %{y:.4f}<br>" +
        "<extra></extra>",  # Removes secondary box
        line=dict(width=2),
        marker=dict(size=8)
    ))
    
    # Update layout
    fig.update_layout(
        title=dict(
            text=f'Cosine Similarity with Final Hidden State<br><sup>Using {k+1}th highest weighted expert per layer</sup>',
            x=0.5,  # Center title
            y=0.95
        ),
        xaxis=dict(
            title='Layer',
            gridcolor='rgba(128,128,128,0.2)',
            tickmode='linear',
            dtick=1,  # Show every layer number
            range=[0, 27.5]  # Set x-axis range from 0 to 27
        ),
        yaxis=dict(
            title='Cosine Similarity',
            gridcolor='rgba(128,128,128,0.2)',
            range=[-0.2, 1]  # Updated range for cosine similarity
        ),
        plot_bgcolor='white',
        hovermode='x unified',  # Shows all points at a given x-coordinate
        showlegend=False,
        width=plot_size,
        height=plot_size
    )
    
    # Add grid
    fig.update_xaxes(showgrid=True, gridwidth=1)
    fig.update_yaxes(showgrid=True, gridwidth=1)
    
    return fig

# Use the function
fig = plot_cosine_similarities(sim, k)
fig.show()