In [1]:
import torch
import numpy as np
import plotly.graph_objects as go
from collections import defaultdict
import torch.nn.functional as F
import json
import os
from typing import Dict, Tuple, List, Optional
from transformers import AutoModelForCausalLM, AutoTokenizer
import math

In [2]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [3]:
model = AutoModelForCausalLM.from_pretrained("deepseek-ai/deepseek-moe-16b-base",
                                             trust_remote_code=True,
                                             torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/deepseek-moe-16b-base", trust_remote_code=True)
model.eval()

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [4]:
class MOEExpertLens:
    def __init__(self, state_dict: Dict[str, torch.Tensor], tokenizer, device=None):
        """Initialize the MoE Expert analyzer."""
        self.device = device if device is not None else get_device()
        self.state_dict = {k: v.to(self.device) for k, v in state_dict.items()}
        self.tokenizer = tokenizer
        self.hidden_size = self.state_dict["model.embed_tokens.weight"].shape[1]
        self.vocab_size = self.state_dict["model.embed_tokens.weight"].shape[0]
        # Get model dtype from embeddings
        self.dtype = self.state_dict["model.embed_tokens.weight"].dtype

    def _process_expert(self, layer_idx: int, expert_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
        """Process hidden state through an expert's weights."""
        # Get expert weights and ensure dtype match
        gate_proj = self.state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight"]
        up_proj = self.state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight"]
        down_proj = self.state_dict[f"model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight"]

        hidden_state = hidden_state.to(self.dtype)
        
        # Apply MLPs sequentially while maintaining batch and sequence dimensions
        gate_output = F.silu(F.linear(hidden_state, gate_proj))
        up_output = F.linear(hidden_state, up_proj)
        
        # Element-wise multiplication and final projection
        x = gate_output * up_output
        return F.linear(x, down_proj)

    def _get_router_output(self, layer_idx: int, hidden_state: torch.Tensor) -> torch.Tensor:
        """Get router logits for a layer."""
        router_weights = self.state_dict[f"model.layers.{layer_idx}.mlp.gate.weight"]
        hidden_state = hidden_state.to(self.dtype)
        return F.linear(hidden_state, router_weights)

    def _process_attention(self, layer_idx: int, hidden_state: torch.Tensor, 
                         attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        """Process hidden state through self-attention while preserving token-wise information."""
        q_proj = self.state_dict[f"model.layers.{layer_idx}.self_attn.q_proj.weight"]
        k_proj = self.state_dict[f"model.layers.{layer_idx}.self_attn.k_proj.weight"]
        v_proj = self.state_dict[f"model.layers.{layer_idx}.self_attn.v_proj.weight"]
        o_proj = self.state_dict[f"model.layers.{layer_idx}.self_attn.o_proj.weight"]

        # Ensure hidden state has correct dtype
        hidden_state = hidden_state.to(self.dtype)

        # Get shapes
        batch_size, seq_len, hidden_dim = hidden_state.shape
        num_heads = 32  # DeepSeek specific
        head_dim = hidden_dim // num_heads

        # Compute QKV with shape preservation
        q = F.linear(hidden_state, q_proj).view(batch_size, seq_len, num_heads, head_dim)
        k = F.linear(hidden_state, k_proj).view(batch_size, seq_len, num_heads, head_dim)
        v = F.linear(hidden_state, v_proj).view(batch_size, seq_len, num_heads, head_dim)

        # Transpose for attention computation
        q = q.transpose(1, 2)  # [batch, num_heads, seq_len, head_dim]
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        # Cast to float32 for better numerical stability in attention computation
        q_float = q.float()
        k_float = k.float()
        v_float = v.float()

        attn_weights = torch.matmul(q_float, k_float.transpose(-2, -1)) / math.sqrt(head_dim)
        
        # Apply causal mask
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
        attn_weights.masked_fill_(causal_mask, float('-inf'))
        
        if attention_mask is not None:
            attn_weights = attn_weights + attention_mask

        attn_weights = F.softmax(attn_weights, dim=-1)
        
        # Apply attention to values
        context = torch.matmul(attn_weights, v_float)
        
        # Convert back to original dtype
        context = context.to(self.dtype)
        
        # Reshape and project to output
        context = context.transpose(1, 2).contiguous().view(batch_size, seq_len, hidden_dim)
        return F.linear(context, o_proj)

    def _apply_layer_norm(self, layer_idx: int, hidden_state: torch.Tensor, norm_type: str) -> torch.Tensor:
        """Apply layer normalization while maintaining token information."""
        weight = self.state_dict[f"model.layers.{layer_idx}.{norm_type}.weight"]
        hidden_state = hidden_state.to(self.dtype)
        return F.layer_norm(hidden_state, (self.hidden_size,), weight=weight)

    def _project_to_vocab(self, hidden_state: torch.Tensor) -> torch.Tensor:
        """Project hidden state to vocabulary space while preserving token-wise information."""
        lm_head_weights = self.state_dict["lm_head.weight"]
        hidden_state = hidden_state.to(self.dtype)
        return F.linear(hidden_state, lm_head_weights)

    def analyze_text(self, input_ids: torch.Tensor) -> Dict:
        """Analyze text through expert lens with proper token state propagation."""
        batch_size, seq_len = input_ids.shape
        # Initialize hidden states from embeddings
        hidden_states = self.state_dict["model.embed_tokens.weight"][input_ids]
        results = {}

        # Create causal attention mask (in float32 for numerical stability)
        causal_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool, device=self.device), diagonal=1)
        attention_mask = torch.zeros(batch_size, 1, seq_len, seq_len, dtype=torch.float32, device=self.device)
        attention_mask.masked_fill_(causal_mask, float('-inf'))

        for layer_idx in range(1, 28):  # Layers 1-27
            layer_results = {"tokens": {}}
            
            # Apply input layer norm while preserving token information
            normed_states = self._apply_layer_norm(layer_idx, hidden_states, "input_layernorm")
            
            # Process through attention mechanism
            attn_output = self._process_attention(layer_idx, normed_states, attention_mask)
            hidden_states = hidden_states + attn_output  # Residual connection
            
            # Post-attention layer norm
            normed_states = self._apply_layer_norm(layer_idx, hidden_states, "post_attention_layernorm")
            
            # Get router decisions for each token
            router_logits = self._get_router_output(layer_idx, normed_states)
            
            # Process each token position separately
            for pos in range(seq_len):
                token = self.tokenizer.decode([input_ids[0, pos].item()])
                token_state = normed_states[:, pos:pos+1]  # Keep batch dimension
                
                # Get top-k experts for this token
                top_k_experts = torch.topk(router_logits[:, pos], k=7, dim=-1)
                expert_weights = F.softmax(top_k_experts.values, dim=-1)
                
                expert_outputs = []
                # Process through selected experts
                for idx, expert_idx in enumerate(top_k_experts.indices[0]):
                    expert_output = self._process_expert(layer_idx, expert_idx.item(), token_state)
                    weight = expert_weights[0, idx].item()
                    
                    # Project expert output to vocab space for analysis
                    logits = self._project_to_vocab(expert_output)
                    top_tokens = torch.topk(logits.squeeze(1), k=5)
                    
                    expert_outputs.append({
                        "expert_id": expert_idx.item(),
                        "weight": weight,
                        "top_tokens": [
                            (self.tokenizer.decode([idx.item()]), prob.item())
                            for idx, prob in zip(top_tokens.indices[0], 
                                               F.softmax(top_tokens.values[0], dim=-1))
                        ]
                    })
                
                layer_results["tokens"][token] = {
                    "position": pos,
                    "expert_outputs": expert_outputs
                }
                
                # Update hidden states for this token with weighted expert outputs
                token_output = torch.zeros_like(token_state)
                for expert_out in expert_outputs:
                    expert_idx = expert_out["expert_id"]
                    weight = expert_out["weight"]
                    expert_output = self._process_expert(layer_idx, expert_idx, token_state)
                    token_output += weight * expert_output
                
                # Update the hidden states for this position
                hidden_states[:, pos:pos+1] = hidden_states[:, pos:pos+1] + token_output

            results[f"layer_{layer_idx}"] = layer_results

        return results

In [13]:
def visualize_layer_analysis(tokenizer, results: Dict, token_position: int, input_text: str):
    """
    Creates a plotly visualization of expert activations across layers for a specific token.
    Shows all 64 experts with zero weights for non-selected experts.
    """
    # Create lists to store data 
    layer_nums = []
    expert_ids = []
    weights = []
    hover_texts = []
    
    total_experts = 64  # Total number of experts in the model
    
    # Extract token we're visualizing by tokenizing input text first
    tokens = tokenizer.encode(input_text)
    token = tokenizer.decode([tokens[token_position]])  # Get tokenized token
    print(f"Visualizing token: {token}")
        
    for layer_idx in range(1, 28):  # Layers 1-27
        layer_data = results[f"layer_{layer_idx}"]
        token_data = [data for data in layer_data["tokens"].values() 
                     if data["position"] == token_position][0]
        
        # Create a mapping of expert_id to its data for this layer
        expert_map = {exp["expert_id"]: exp for exp in token_data["expert_outputs"]}
        
        # Go through all possible experts
        for expert_id in range(total_experts):
            layer_nums.append(layer_idx)
            expert_ids.append(expert_id)
            
            if expert_id in expert_map:
                # Expert was selected
                expert_data = expert_map[expert_id]
                weight = expert_data["weight"]
                top_tokens_text = "<br>".join([
                    f"{token}: {prob:.3f}" 
                    for token, prob in expert_data["top_tokens"][:5]
                ])
                hover_text = f"Layer: {layer_idx}<br>Expert: {expert_id}<br>Weight: {weight:.3f}<br>Top tokens:<br>{top_tokens_text}"
                hover_texts.append(hover_text)
            else:
                # Expert was not selected
                weight = 0
                hover_texts.append(None)  # No hover text for unselected experts
            
            weights.append(weight)
    
    # Create plotly heatmap
    fig = go.Figure(data=go.Scatter(
        x=expert_ids,
        y=layer_nums,
        mode='markers',
        marker=dict(
            size=9,
            color=weights,
            colorscale=[
                [0, 'rgba(24, 21, 23, 0.8)'],  # Very dark/transparent for zero weights
                [0.0001, 'rgb(68,1,84)'],      # Start of Viridis colorscale
                [1, 'rgb(253,231,37)']         # End of Viridis colorscale
            ],
            cmin=0,  # Set minimum of color scale to 0
            cmax=1,  # Set maximum of color scale to 1
            showscale=True,
            colorbar=dict(
                title='Weight',
                tickmode='linear',
                tick0=0,
                dtick=0.2
            ),
        ),
        text=hover_texts,
        hoverinfo='text',
        hovertemplate='%{text}<extra></extra>',  # Only show hover when text exists
    ))
    
    # Update layout with dark theme
    fig.update_layout(
        template='plotly_dark',
        title=f'Expert Activations for Token "{token}" at Position {token_position}',
        xaxis_title='Expert ID',
        yaxis_title='Layer',
        yaxis=dict(autorange='reversed'),  # Reverse y-axis to have layer 1 at top
        width=1200,
        height=800,
        showlegend=False,
        plot_bgcolor='black',
        paper_bgcolor='black'
    )
    
    # Add grid lines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)', 
                     range=[-1, total_experts])
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.2)')
    
    fig.show()

In [6]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [7]:
def get_device():
    """Get the optimal available device"""
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        # Enable TF32 for better performance on Ampere GPUs (A100, A6000, etc)
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        # Set memory allocation settings
        torch.cuda.empty_cache()
        # Enable CUDNN benchmarking for better performance
        torch.backends.cudnn.benchmark = True
    else:
        device = torch.device("cpu")
    return device

In [8]:
def apply_moe_logit_lens_with_active_expert_outputs(model, inputs, tokenizer, num_active_experts=7):
    model.eval()

    # Initial embedding
    x = model.model.embed_tokens(inputs)

    # Process each layer
    for layer_idx, layer in enumerate(model.model.layers):
        print(f"Layer {layer_idx + 1}")
        x = layer.input_layernorm(x)

        # Self-attention output
        # Remove the position_ids argument here, let the model handle it internally
        attn_output = layer.self_attn(x, x, x)
        x = attn_output + x  # Residual connection
        x = layer.post_attention_layernorm(x)

        # MoE Layer
        moe_output = layer.mlp(x)
        gate_values = moe_output["gate_values"]
        expert_outputs = moe_output["expert_outputs"]

        # Extract active experts
        for batch_idx, gates in enumerate(gate_values):
            active_experts = gates.argsort(descending=True)[:num_active_experts]
            print(f"  Batch {batch_idx + 1}:")
            for expert_idx in active_experts:
                expert_weight = gates[expert_idx].item()
                expert_output = expert_outputs[batch_idx, :, expert_idx]
                top_tokens = expert_output.topk(5, dim=-1)
                top_indices = top_tokens.indices
                top_scores = top_tokens.values
                decoded_tokens = tokenizer.decode(top_indices.tolist())
                print(f"    Expert {expert_idx}: Weight: {expert_weight:.4f}, Tokens: {decoded_tokens}, Scores: {top_scores.tolist()}")

    # Final logits
    logits = model.lm_head(x)
    top_tokens = logits.topk(5, dim=-1)
    decoded_final_tokens = tokenizer.decode(top_tokens.indices.tolist())
    print(f"Final Layer: Tokens: {decoded_final_tokens}, Scores: {top_tokens.values.tolist()}")

In [9]:
def logit_lens(model, input_tokens):
    """
    Applies a logit lens to each layer of a mixture of experts (MoE) model for specific input tokens.

    Args:
        model (torch.nn.Module): The MoE model.
        input_tokens (torch.Tensor): Input tensor with token embeddings, shape [batch_size, seq_len].

    Returns:
        dict: A dictionary containing the logit lens outputs for each token at each layer.
    """
    logit_outputs = {}

    # Pass the input tokens through the embedding layer
    embedding_output = model.embed_tokens(input_tokens)  # Shape: [batch_size, seq_len, embedding_dim]

    # Process tokens through each layer
    for layer_idx, layer in enumerate(model.layers):
        # Apply the layer and get its output
        layer_outputs = layer(embedding_output)  # Shape: [batch_size, seq_len, feature_dim]

        # Extract the mixture of experts (MoE) components for this layer
        if hasattr(layer.mlp, "experts"):
            gate_outputs = layer.mlp.gate(embedding_output)  # Shape: [batch_size, seq_len, num_experts]
            expert_outputs = []

            # Process each token separately to extract expert-specific outputs
            for token_idx in range(input_tokens.shape[1]):
                token_expert_outputs = []

                for expert_idx, expert in enumerate(layer.mlp.experts):
                    token_input = embedding_output[:, token_idx, :]  # Shape: [batch_size, embedding_dim]
                    expert_output = expert(token_input)  # Shape: [batch_size, feature_dim]
                    token_expert_outputs.append(expert_output)

                # Stack expert outputs and apply gating
                token_expert_outputs = torch.stack(token_expert_outputs, dim=1)  # Shape: [batch_size, num_experts, feature_dim]
                token_gate_weights = gate_outputs[:, token_idx, :].unsqueeze(-1)  # Shape: [batch_size, num_experts, 1]
                activated_experts_output = torch.sum(token_expert_outputs * token_gate_weights, dim=1)  # Shape: [batch_size, feature_dim]

                # Save activated expert outputs for this token
                logit_outputs[f"layer_{layer_idx}_token_{token_idx}_activated_experts"] = activated_experts_output

        # Update the embedding for the next layer
        embedding_output = layer_outputs

    return logit_outputs

In [16]:
# Initialize and use as before
analyzer = MOEExpertLens(model.state_dict(), tokenizer)
text = "the quick brown fox"
input_ids = tokenizer(text, return_tensors="pt").input_ids.to(get_device())
results = analyzer.analyze_text(input_ids)

In [17]:
visualize_layer_analysis(tokenizer, results, token_position=0, input_text=text)
visualize_layer_analysis(tokenizer, results, token_position=1, input_text=text)
visualize_layer_analysis(tokenizer, results, token_position=4, input_text=text)

Visualizing token: <｜begin▁of▁sentence｜>


Visualizing token: the


Visualizing token:  fox
