In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict
import pandas as pd
from tqdm.auto import tqdm

In [31]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    # Print CUDA details
    print(f"CUDA Device: {torch.cuda.get_device_name()}")
    print(f"CUDA Memory Allocated: {torch.cuda.memory_allocated()/1024**2:.2f}MB")
    print(f"CUDA Memory Reserved: {torch.cuda.memory_reserved()/1024**2:.2f}MB")

In [32]:
DEVICE

device(type='cpu')

In [None]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.to(DEVICE)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [24]:
model.eval()

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [5]:
def get_moe_metadata(model, input_ids):
    """Get both router logits and expert indices for all MoE layers."""
    router_logits_list = []
    expert_indices_list = []
    
    def hook_fn(module, input, output):
        # output contains: (topk_idx, topk_weight, aux_loss)
        hidden_states = input[0]
        
        logits = torch.matmul(hidden_states, module.weight.T)
        router_logits_list.append(logits.detach())
        
        # store expert indices actually used for routing
        expert_indices_list.append(output[0].detach())
        
        return output
    
    hooks = []
    for layer_idx, layer in enumerate(model.model.layers):
        if layer.mlp.__class__.__name__ == 'DeepseekMoE':
            hook = layer.mlp.gate.register_forward_hook(hook_fn)
            hooks.append(hook)

    with torch.no_grad():
        model(input_ids)
    
    for hook in hooks:
        hook.remove()

    moe_metadata = {
        'router_logits': torch.stack(router_logits_list) if router_logits_list else None,
        'expert_indices': torch.stack(expert_indices_list) if expert_indices_list else None
    }
    
    if moe_metadata['router_logits'] is not None:
        print(f"Router logits shape: {moe_metadata['router_logits'].shape}")
    if moe_metadata['expert_indices'] is not None:
        print(f"Expert indices shape: {moe_metadata['expert_indices'].shape}")
    
    return moe_metadata

def prepare_prompt(prompt, tokenizer, max_tokens=2048):
    """
    Prepare a prompt for processing, splitting if necessary to fit within model context.
    
    Args:
        prompt: The text prompt to prepare
        tokenizer: The model's tokenizer
        max_tokens: Maximum number of tokens per chunk (default: 2048)
        
    Returns:
        List of prompts that fit within token limit
    """
    # Tokenize the prompt to get token count
    tokens = tokenizer.encode(prompt)
    
    # If prompt is small enough, return as is
    if len(tokens) <= max_tokens:
        return [prompt]
    
    # Split into manageable chunks
    prepared_prompts = []
    
    # Decode tokens into chunks
    start_idx = 0
    while start_idx < len(tokens):
        end_idx = min(start_idx + max_tokens, len(tokens))
        chunk_tokens = tokens[start_idx:end_idx]
        chunk_text = tokenizer.decode(chunk_tokens)
        prepared_prompts.append(chunk_text)
        start_idx = end_idx
    
    print(f"Long prompt detected! Split into {len(prepared_prompts)} chunks.")
    return prepared_prompts

def process_text_file_for_expert_counts(file_path, model, tokenizer, output_path=None, max_tokens=4096):
    """
    Process a text file to analyze MoE routing and count tokens per expert in each layer.
    Saves a PyTorch file with expert token counts.
    
    Args:
        file_path: Path to text file with prompts (one per line)
        model: DeepSeek MoE model
        tokenizer: DeepSeek tokenizer
        output_path: Path to save PyTorch results (default: based on input filename)
        max_tokens: Maximum tokens per prompt chunk
    """
    # Load the text file
    with open(file_path, 'r', encoding='utf-8') as f:
        raw_prompts = [line.strip() for line in f.readlines() if line.strip()]
    
    print(f"Loaded {len(raw_prompts)} raw prompts from {file_path}")
    
    # Prepare prompts (handle large prompts by splitting)
    prompts = []
    for raw_prompt in raw_prompts:
        prepared_chunks = prepare_prompt(raw_prompt, tokenizer, max_tokens)
        prompts.extend(prepared_chunks)
    
    print(f"Processing {len(prompts)} prepared prompts (after splitting large ones)")
    
    # Set default output path if not provided
    if output_path is None:
        output_path = file_path.replace('.txt', '_expert_data.pt')
    
    # Initialize counter for expert usage
    # Structure: {layer_num: {expert_id: count}}
    expert_counts = {}
    
    # Calculate total tokens for progress bar
    total_tokens = 0
    for prompt in prompts:
        tokens = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        total_tokens += tokens.size(1)
    
    print(f"Total tokens to process: {total_tokens}")
    
    # Initialize progress bar
    pbar = tqdm(total=total_tokens, desc="Processing tokens")
    processed_tokens = 0
    
    # Process each prompt
    for prompt in prompts:
        # Tokenize the prompt
        tokens = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        seq_len = tokens.size(1)
        
        # Get MoE routing metadata
        moe_metadata = get_moe_metadata(model, tokens)
        
        if moe_metadata['expert_indices'] is None:
            print("No MoE layers detected or no routing information available")
            processed_tokens += seq_len
            pbar.update(seq_len)
            continue
        
        # Extract expert indices
        expert_indices = moe_metadata['expert_indices']  # shape: [num_layers, seq_len, top_k]
        num_moe_layers = expert_indices.size(0)
        
        # Initialize counter for this batch if needed
        for layer_idx in range(num_moe_layers):
            layer_num = layer_idx + 1  # 1-based layer indexing
            if layer_num not in expert_counts:
                expert_counts[layer_num] = {}
        
        # Count token routing for each layer
        for layer_idx in range(num_moe_layers):
            layer_num = layer_idx + 1  # 1-based layer indexing
            
            # Process each token in sequence
            for token_idx in range(seq_len):
                # Get experts selected for this token in this layer
                selected_experts = expert_indices[layer_idx, token_idx].cpu().numpy().tolist()
                
                # Count each expert
                for expert_id in selected_experts:
                    if expert_id not in expert_counts[layer_num]:
                        expert_counts[layer_num][expert_id] = 0
                    expert_counts[layer_num][expert_id] += 1
        
        # Update progress bar
        processed_tokens += seq_len
        pbar.update(seq_len)
    
    # Close progress bar
    pbar.close()
    
    # Convert counts to a simple tensor format for saving
    expert_token_counts = {}
    for layer_num in sorted(expert_counts.keys()):
        layer_data = expert_counts[layer_num]
        # Create a tensor with counts for each expert (assuming 64 experts)
        counts = torch.zeros(64)
        for expert_id, count in layer_data.items():
            counts[expert_id] = count
        expert_token_counts[layer_num] = counts
    
    # Save just the token counts per expert
    torch.save(expert_token_counts, output_path)
    print(f"Expert token counts saved to {output_path}")
    
    # Create a DataFrame for visualization purposes
    rows = []
    for layer_num in sorted(expert_counts.keys()):
        layer_data = expert_counts[layer_num]
        for expert_id in range(64):  # Assuming 64 experts
            count = layer_data.get(expert_id, 0)
            rows.append({
                'layer': layer_num,
                'expert_id': expert_id,
                'token_count': count
            })
    
    df = pd.DataFrame(rows)
    return df

def visualize_expert_counts(df, file_base, output_dir=None):
    """
    Visualize the token counts per expert for each layer using Plotly.
    
    Args:
        df: DataFrame with expert count data
        file_base: Base name for plot titles
        output_dir: Directory to save visualizations
    """
    import os
    import plotly.graph_objects as go
    
    # Set default output directory
    if output_dir is None:
        output_dir = "."
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    
    # Extract unique layers
    layers = sorted(df['layer'].unique())
    
    # Create a plot for each layer
    for layer in layers:
        layer_df = df[df['layer'] == layer]
        
        # Sort by expert_id for consistent visualization
        layer_df = layer_df.sort_values('expert_id')
        
        # Create plotly figure
        fig = go.Figure()
        
        # Add bar trace
        fig.add_trace(go.Bar(
            x=layer_df['expert_id'],
            y=layer_df['token_count'],
            marker_color='steelblue'
        ))
        
        # Update layout
        fig.update_layout(
            title=f'{file_base} - Layer {layer} Expert Usage',
            xaxis_title='Expert ID',
            yaxis_title='Token Count',
            yaxis_gridcolor='rgba(0,0,0,0.1)',
            width=1000,
            height=600
        )
        
        # Save the plot
        # fig.write_image(f"{output_dir}/{file_base}_layer{layer}_expert_usage.png")
    
    print(f"Expert usage visualizations saved to {output_dir}")

def analyze_text_file_routing(model, tokenizer, file_path):
    """
    Main function to analyze MoE routing for a text file.
    
    Args:
        file_path: Path to text file with prompts (one per line)
        model_name: Name of DeepSeek MoE model to use
    """
    # Process the file for expert counts and save as PyTorch file
    df = process_text_file_for_expert_counts(file_path, model, tokenizer)
    
    # Visualize the results
    file_base = os.path.basename(file_path).replace('.txt', '')
    visualize_expert_counts(df, file_base)
    
    print(f"Analysis completed for {file_path}")
    return df

In [6]:
file_path = "data-ext/test.txt"
df = analyze_text_file_routing(model, tokenizer, file_path)


Loaded 6 raw prompts from data-ext/test.txt
Processing 6 prepared prompts (after splitting large ones)
Total tokens to process: 517


Processing tokens:   0%|          | 0/517 [00:00<?, ?it/s]

Router logits shape: torch.Size([27, 1, 68, 64])
Expert indices shape: torch.Size([27, 68, 6])
Router logits shape: torch.Size([27, 1, 67, 64])
Expert indices shape: torch.Size([27, 67, 6])
Router logits shape: torch.Size([27, 1, 45, 64])
Expert indices shape: torch.Size([27, 45, 6])
Router logits shape: torch.Size([27, 1, 254, 64])
Expert indices shape: torch.Size([27, 254, 6])
Router logits shape: torch.Size([27, 1, 50, 64])
Expert indices shape: torch.Size([27, 50, 6])
Router logits shape: torch.Size([27, 1, 33, 64])
Expert indices shape: torch.Size([27, 33, 6])
Expert token counts saved to data-ext/test_expert_data.pt


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Expert usage visualizations saved to .
Analysis completed for data-ext/test.txt


In [8]:
data = torch.load("data-ext/test_expert_data.pt")


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [28]:
data[1]

tensor([ 68.,  58.,  62.,  24.,  58.,  98.,  51., 130.,  67.,  55.,  30.,  21.,
         37.,  55., 106.,  26., 171.,  27.,  60.,  20.,  66.,  40.,  51.,  46.,
        100.,  29.,  48.,  39.,  51.,  25.,  36.,  13.,  19.,  45.,  57.,  24.,
         20.,  41.,  17.,  26.,  74.,  16.,  46.,  76.,  45.,  47., 104.,  21.,
         52., 115.,  22.,  56.,  33.,  43.,   9.,  13.,  27.,  44.,  27.,  42.,
         64.,  19.,  21.,  69.])

In [29]:
data[27]

tensor([146.,  84.,  41.,  58.,  31.,  34.,  37.,  37.,  38., 151.,  23.,  48.,
         25.,  72.,  35.,  16.,  55.,  50.,  15.,   6.,  10.,  23.,  31.,  42.,
         18.,  34.,  25.,  42.,  78.,  23., 130.,  21.,  63.,  57.,  28.,  15.,
         91.,  68.,  26.,  41.,  42.,  24.,  62.,  25.,  38.,  37.,  25.,  22.,
         16.,  64.,  31.,  76.,  60.,  58., 117.,  39., 111.,  36.,  19.,  25.,
         14.,  79.,  71., 143.])