In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

In [2]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [11]:
def get_moe_metadata(model, input_ids):
    """Get both router logits and expert indices for all MoE layers"""
    router_logits_list = []
    expert_indices_list = []
    hidden_states_list = []
    
    def hook_fn(module, input, output):
        # output contains: (topk_idx, topk_weight, aux_loss)
        hidden_states = input[0]
        
        logits = torch.matmul(hidden_states, module.weight.T)
        router_logits_list.append(logits.detach())
        
        # store expert indices actually used for routing
        expert_indices_list.append(output[0].detach())

        # store the hidden states
        hidden_states_list.append(hidden_states.detach())
        
        return output
    
    hooks = []
    for layer_idx, layer in enumerate(model.model.layers):
        if layer.mlp.__class__.__name__ == 'DeepseekMoE':
            hook = layer.mlp.gate.register_forward_hook(hook_fn)
            hooks.append(hook)

    with torch.no_grad():
        model(input_ids)
    
    for hook in hooks:
        hook.remove()

    moe_metadata = {
        'router_logits': torch.stack(router_logits_list) if router_logits_list else None,
        'expert_indices': torch.stack(expert_indices_list) if expert_indices_list else None,
        'hidden_states': torch.stack(hidden_states_list) if hidden_states_list else None
    }
    
    if moe_metadata['router_logits'] is not None:
        print(f"Router logits shape: {moe_metadata['router_logits'].shape}")
    if moe_metadata['expert_indices'] is not None:
        print(f"Expert indices shape: {moe_metadata['expert_indices'].shape}")
    if moe_metadata['hidden_states'] is not None:
        print(f"Hidden states shape: {moe_metadata['hidden_states'].shape}")
    
    return moe_metadata

In [49]:
def get_expert_outputs(model, moe_metadata):
    """Compute expert outputs for top-k selected experts in all MoE layers"""
    expert_outputs = []
    num_layers = 27
    
    # Get metadata dimensions
    # num_layers = moe_metadata['expert_indices'].shape[0]
    print(f'expert_indices shape: {moe_metadata["expert_indices"].shape}')
    num_tokens = moe_metadata['expert_indices'].shape[1]
    top_k = moe_metadata['expert_indices'].shape[2]
    hidden_dim = moe_metadata['hidden_states'].shape[-1]

    # Pre-allocate tensor: [layers, tokens, top_k, hidden_dim]
    all_expert_outputs = torch.zeros(
        (num_layers, num_tokens, top_k, hidden_dim),
        device=model.device
    )

    for layer_idx in range(num_layers):
        # Get MoE components for current layer
        expert_module = model.model.layers[layer_idx+1].mlp.experts
        layer_hidden_states = moe_metadata['hidden_states'][layer_idx]  # [1, num_tokens, hdim]
        layer_expert_indices = moe_metadata['expert_indices'][layer_idx]  # [num_tokens, top_k]

        for token_idx in range(num_tokens):
            # Get hidden state for this token (remove batch dim)
            hidden_state = layer_hidden_states[0, token_idx]  # [hdim]

            # Get expert indices for this token
            expert_indices = layer_expert_indices[token_idx]

            # Process through each selected expert
            for expert_pos, expert_idx in enumerate(expert_indices):
                expert = expert_module[expert_idx.item()]
                
                # Add batch dimension for processing
                with torch.no_grad():
                    expert_out = expert(hidden_state.unsqueeze(0))  # [1, hdim]
                
                all_expert_outputs[layer_idx, token_idx, expert_pos] = expert_out.squeeze(0)

    print(f"Expert outputs shape: {all_expert_outputs.shape}")
    return all_expert_outputs

In [50]:
def project_expert_outputs(model, expert_outputs):
    """
    Project expert outputs through LM head while maintaining structure
    Returns tensor of shape [num_layers, num_tokens, num_experts, vocab_size]
    """
    # Get model dtype from LM head
    model_dtype = model.lm_head.weight.dtype
    
    # Get original shape details
    num_layers, num_tokens, num_experts, hidden_dim = expert_outputs.shape
    vocab_size = model.lm_head.out_features
    print(f'vocab_size: {vocab_size}')
    # Pre-allocate output tensor using model dtype
    expert_logits = torch.zeros(
        (num_layers, num_tokens, num_experts, vocab_size),
        device=model.device,
        dtype=model_dtype  # Match model's dtype
    )

    # Process each layer, token and expert individually
    for layer_idx in range(num_layers):
        for token_idx in range(num_tokens):
            for expert_idx in range(num_experts):
                # Get expert output and cast to model dtype
                expert_output = expert_outputs[layer_idx, token_idx, expert_idx]
                expert_output = expert_output.to(model_dtype)  # <-- CRITICAL CAST
                
                # Project through LM head
                with torch.no_grad():
                    logits = model.lm_head(expert_output.unsqueeze(0))
                
                # Store result
                expert_logits[layer_idx, token_idx, expert_idx] = logits.squeeze(0)

    print(f"Expert logits shape: {expert_logits.shape}")
    return expert_logits

In [51]:
def get_expert_topk_tokens(expert_logits, tokenizer, k=5):
    """
    Get top-k tokens for each expert at each layer and token position
    Returns nested dictionary:
    {
        layer_idx: {
            token_idx: {
                expert_idx: {
                    'tokens': [decoded tokens],
                    'scores': [corresponding scores],
                    'ids': [token ids]
                }, ...
            }, ...
        }, ...
    }
    """
    num_layers, num_tokens, num_experts, _ = expert_logits.shape
    results = {}

    for layer_idx in range(num_layers):
        layer_results = {}
        for token_idx in range(num_tokens):
            token_results = {}
            for expert_idx in range(num_experts):
                # Get logits for this expert configuration
                expert_logit = expert_logits[layer_idx, token_idx, expert_idx]
                
                # Get top-k predictions
                topk_scores, topk_indices = torch.topk(expert_logit, k)
                
                # Convert to CPU/numpy for decoding
                topk_indices_cpu = topk_indices.cpu().numpy()
                topk_scores_cpu = topk_scores.cpu().numpy()
                
                # Decode tokens
                decoded_tokens = tokenizer.batch_decode(topk_indices_cpu)
                
                token_results[expert_idx] = {
                    'tokens': decoded_tokens,
                    'scores': topk_scores_cpu.tolist(),
                    'ids': topk_indices_cpu.tolist()
                }
            
            layer_results[token_idx] = token_results
        results[layer_idx] = layer_results

    return results

In [52]:
input_txt = "the quick brown fox"
input_ids = tokenizer.encode(input_txt, return_tensors="pt")
moe_metadata = get_moe_metadata(model, input_ids)
expert_outputs = get_expert_outputs(model, moe_metadata)
expert_logits = project_expert_outputs(model, expert_outputs)
expert_topk_tokens = get_expert_topk_tokens(expert_logits, tokenizer)

Router logits shape: torch.Size([27, 1, 5, 64])
Expert indices shape: torch.Size([27, 5, 6])
Hidden states shape: torch.Size([27, 1, 5, 2048])
expert_indices shape: torch.Size([27, 5, 6])
Expert outputs shape: torch.Size([27, 5, 6, 2048])
vocab_size: 102400
Expert logits shape: torch.Size([27, 5, 6, 102400])


In [48]:
print(expert_topk_tokens[26][1])

{0: {'tokens': [' nether', ' outermost', 'ymap', 'ermis', 'ограф'], 'scores': [12.6640625, 12.5, 11.671875, 11.1171875, 10.8984375], 'ids': [90704, 99790, 91064, 97648, 36224]}, 1: {'tokens': ['irat', 'baid', ' Pallars', 'ntic', ' remains'], 'scores': [11.3671875, 11.1953125, 10.609375, 10.5390625, 10.1640625], 'ids': [83977, 62627, 43386, 6466, 7544]}, 2: {'tokens': [' court', ' cause', ' victim', ' amount', 'мани'], 'scores': [10.5390625, 10.421875, 10.375, 10.328125, 10.1640625], 'ids': [6518, 4309, 17180, 3744, 27802]}, 3: {'tokens': ['��', 'odox', 'Sec', 'Co', 'Bind'], 'scores': [7.78515625, 6.84375, 6.78125, 6.640625, 6.33984375], 'ids': [689, 35024, 8508, 8854, 22641]}, 4: {'tokens': [' Braves', ' Warriors', ' Bruins', ' Reds', ' Bears'], 'scores': [18.640625, 18.546875, 18.21875, 17.78125, 17.484375], 'ids': [97762, 51354, 98696, 77886, 50243]}, 5: {'tokens': ['息', 'ipre', 'ivil', ' S', '\tS'], 'scores': [9.5234375, 9.4453125, 9.3125, 9.1875, 8.609375], 'ids': [3714, 71905, 552

In [84]:
def get_shared_expert_outputs(model, input_ids):
    """Get outputs from shared experts in all MoE layers"""
    shared_outputs_list = []
    hidden_states_list = []
    moe_layers = []

    def hook_fn(module, input, output):
        # Capture hidden states entering the MoE layer
        hidden_states = input[0]
        hidden_states_list.append(hidden_states.squeeze(0).detach())
        return output

    hooks = []
    # Identify MoE layers and register hooks
    for layer in model.model.layers:
        if layer.mlp.__class__.__name__ == 'DeepseekMoE':
            moe_layers.append(layer.mlp)
            hook = layer.mlp.gate.register_forward_hook(hook_fn)
            hooks.append(hook)

    # Forward pass to collect hidden states
    with torch.no_grad():
        model(input_ids)
    
    # Remove hooks after forward pass
    for hook in hooks:
        hook.remove()

    # Compute shared expert outputs for each MoE layer
    for layer_idx, moe_layer in enumerate(moe_layers):
        hidden_states = hidden_states_list[layer_idx]
        
        # Get output from shared experts (which is a single DeepseekMLP)
        with torch.no_grad():
            expert_out = moe_layer.shared_experts(hidden_states)
        layer_shared_outputs = [expert_out]
        
        # Stack outputs: [num_shared_experts=1, seq_len, hidden_dim] 
        shared_outputs_list.append(torch.stack(layer_shared_outputs, dim=0))

    # Stack all layer outputs to get shape [num_layers, num_shared_experts=1, seq_len, hidden_dim]
    shared_outputs_tensor = torch.stack(shared_outputs_list, dim=0)

    return {
        'shared_expert_outputs': shared_outputs_tensor,
        'hidden_states': torch.stack(hidden_states_list) if hidden_states_list else None
    }

In [85]:
# Get shared expert outputs separately
shared_data = get_shared_expert_outputs(model, input_ids)

print(shared_data['shared_expert_outputs'].shape)
print(f"Shared expert outputs : {len(shared_data['shared_expert_outputs'])}")
print(f"First layer shared outputs shape: {shared_data['shared_expert_outputs'][0].shape}")

torch.Size([27, 1, 5, 2048])
Shared expert outputs : 27
First layer shared outputs shape: torch.Size([1, 5, 2048])


In [86]:
expert_logits = project_expert_outputs(model, expert_outputs=shared_data['shared_expert_outputs'])
expert_topk_tokens = get_expert_topk_tokens(expert_logits, tokenizer)

vocab_size: 102400
Expert logits shape: torch.Size([27, 1, 5, 102400])


In [100]:
expert_topk_tokens[19][0][4]

{'tokens': ['ა', ' <!--[', 'ა�', 'ELY', '\tandroid'],
 'scores': [1.1953125, 1.1025390625, 1.1005859375, 1.0888671875, 1.0205078125],
 'ids': [46554, 69586, 56166, 70939, 97199]}