In [33]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict
import pandas as pd
from tqdm.auto import tqdm


In [34]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if DEVICE.type == "cuda":
    # Print CUDA details
    print(f"CUDA Device: {torch.cuda.get_device_name()}")
    print(f"CUDA Memory Allocated: {torch.cuda.memory_allocated()/1024**2:.2f}MB")
    print(f"CUDA Memory Reserved: {torch.cuda.memory_reserved()/1024**2:.2f}MB")

In [35]:
DEVICE

device(type='cpu')

In [36]:
def load_model(model_name):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        # use_flash_attention_2=True,
    )
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model.to(DEVICE)
    return model, tokenizer

tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

In [5]:
model.eval()

DeepseekForCausalLM(
  (model): DeepseekModel(
    (embed_tokens): Embedding(102400, 2048)
    (layers): ModuleList(
      (0): DeepseekDecoderLayer(
        (self_attn): DeepseekSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): DeepseekRotaryEmbedding()
        )
        (mlp): DeepseekMLP(
          (gate_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (up_proj): Linear(in_features=2048, out_features=10944, bias=False)
          (down_proj): Linear(in_features=10944, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): DeepseekRMSNorm()
        (post_attention_layernorm): DeepseekRMSNorm()
      )
      (1-27): 27 x DeepseekDecod

In [5]:
def get_moe_metadata(model, input_ids):
    """Get both router logits and expert indices for all MoE layers."""
    router_logits_list = []
    expert_indices_list = []
    
    def hook_fn(module, input, output):
        # output contains: (topk_idx, topk_weight, aux_loss)
        hidden_states = input[0]
        
        logits = torch.matmul(hidden_states, module.weight.T)
        router_logits_list.append(logits.detach())
        
        # store expert indices actually used for routing
        expert_indices_list.append(output[0].detach())
        
        return output
    
    hooks = []
    for layer_idx, layer in enumerate(model.model.layers):
        if layer.mlp.__class__.__name__ == 'DeepseekMoE':
            hook = layer.mlp.gate.register_forward_hook(hook_fn)
            hooks.append(hook)

    with torch.no_grad():
        model(input_ids)
    
    for hook in hooks:
        hook.remove()

    moe_metadata = {
        'router_logits': torch.stack(router_logits_list) if router_logits_list else None,
        'expert_indices': torch.stack(expert_indices_list) if expert_indices_list else None
    }
    
    if moe_metadata['router_logits'] is not None:
        print(f"Router logits shape: {moe_metadata['router_logits'].shape}")
    if moe_metadata['expert_indices'] is not None:
        print(f"Expert indices shape: {moe_metadata['expert_indices'].shape}")
    
    return moe_metadata

def prepare_prompt(prompt, tokenizer, max_tokens=2048):
    """
    Prepare a prompt for processing, splitting if necessary to fit within model context.
    
    Args:
        prompt: The text prompt to prepare
        tokenizer: The model's tokenizer
        max_tokens: Maximum number of tokens per chunk (default: 2048)
        
    Returns:
        List of prompts that fit within token limit
    """
    # Check if the input is a list of lines/prompts
    if isinstance(prompt, list):
        all_prepared_prompts = []
        for single_prompt in prompt:
            # Process each line/prompt individually
            prepared_chunks = prepare_prompt(single_prompt, tokenizer, max_tokens)
            all_prepared_prompts.extend(prepared_chunks)
        return all_prepared_prompts
    
    # Process a single prompt
    tokens = tokenizer.encode(prompt)
    
    # If prompt is small enough, return as is
    if len(tokens) <= max_tokens:
        return [prompt]
    
    # Split into manageable chunks
    prepared_prompts = []
    
    # Decode tokens into chunks
    start_idx = 0
    while start_idx < len(tokens):
        end_idx = min(start_idx + max_tokens, len(tokens))
        chunk_tokens = tokens[start_idx:end_idx]
        chunk_text = tokenizer.decode(chunk_tokens)
        prepared_prompts.append(chunk_text)
        start_idx = end_idx
    
    print(f"Long prompt detected! Split into {len(prepared_prompts)} chunks.")
    return prepared_prompts

def process_text_file_for_expert_counts(file_path, model, tokenizer, output_path=None, max_tokens=4096):
    """
    Process a text file to analyze MoE routing and count tokens per expert in each layer.
    Saves a PyTorch file with expert token counts.
    
    Args:
        file_path: Path to text file with prompts (one per line)
        model: DeepSeek MoE model
        tokenizer: DeepSeek tokenizer
        output_path: Path to save PyTorch results (default: based on input filename)
        max_tokens: Maximum tokens per prompt chunk
    """
    # Load the text file
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Check if this is a GitHub code file
    if 'github.txt' in file_path:
        import re
        # Find all code blocks using the file pattern
        file_pattern = re.compile(r'.*\b\w+\.(js|py|c|cpp|java|ts|rb|go|rs|cs|swift|kt|php)$', re.MULTILINE)
        
        # Find all matches (file headers)
        matches = list(file_pattern.finditer(content))
        
        # Extract code blocks between file headers
        raw_prompts = []
        for i in range(len(matches)):
            start_pos = matches[i].start()
            # If this is the last match, go to the end of the file
            if i == len(matches) - 1:
                end_pos = len(content)
            else:
                end_pos = matches[i+1].start()
            
            # Extract the code block including the file header
            code_block = content[start_pos:end_pos].strip()
            raw_prompts.append(code_block)
    else:
        # Regular text file processing (one prompt per line)
        raw_prompts = [line.strip() for line in content.split('\n') if line.strip()]
    
    print(f"Loaded {len(raw_prompts)} raw prompts from {file_path}")
    
    # Prepare prompts (handle large prompts by splitting)
    prompts = []
    for raw_prompt in raw_prompts:
        prepared_chunks = prepare_prompt(raw_prompt, tokenizer, max_tokens)
        prompts.extend(prepared_chunks)
    
    print(f"Processing {len(prompts)} prepared prompts (after splitting large ones)")
    
    # Set default output path if not provided
    if output_path is None:
        output_path = file_path.replace('.txt', '_expert_data.pt')
    
    # Initialize counter for expert usage
    # Structure: {layer_num: {expert_id: count}}
    expert_counts = {}
    
    # Calculate total tokens for progress bar
    total_tokens = 0
    for prompt in prompts:
        tokens = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        total_tokens += tokens.size(1)
    
    print(f"Total tokens to process: {total_tokens}")
    
    # Initialize progress bar
    pbar = tqdm(total=total_tokens, desc="Processing tokens")
    processed_tokens = 0
    
    # Process each prompt
    for prompt in prompts:
        # Tokenize the prompt
        tokens = tokenizer.encode(prompt, return_tensors="pt").to(DEVICE)
        seq_len = tokens.size(1)
        
        # Get MoE routing metadata
        moe_metadata = get_moe_metadata(model, tokens)
        
        if moe_metadata['expert_indices'] is None:
            print("No MoE layers detected or no routing information available")
            processed_tokens += seq_len
            pbar.update(seq_len)
            continue
        
        # Extract expert indices
        expert_indices = moe_metadata['expert_indices']  # shape: [num_layers, seq_len, top_k]
        num_moe_layers = expert_indices.size(0)
        
        # Initialize counter for this batch if needed
        for layer_idx in range(num_moe_layers):
            layer_num = layer_idx + 1  # 1-based layer indexing
            if layer_num not in expert_counts:
                expert_counts[layer_num] = {}
        
        # Count token routing for each layer
        for layer_idx in range(num_moe_layers):
            layer_num = layer_idx + 1  # 1-based layer indexing
            
            # Process each token in sequence
            for token_idx in range(seq_len):
                # Get experts selected for this token in this layer
                selected_experts = expert_indices[layer_idx, token_idx].cpu().numpy().tolist()
                
                # Count each expert
                for expert_id in selected_experts:
                    if expert_id not in expert_counts[layer_num]:
                        expert_counts[layer_num][expert_id] = 0
                    expert_counts[layer_num][expert_id] += 1
        
        # Update progress bar
        processed_tokens += seq_len
        pbar.update(seq_len)
    
    # Close progress bar
    pbar.close()
    
    # Convert counts to a simple tensor format for saving
    expert_token_counts = {}
    for layer_num in sorted(expert_counts.keys()):
        layer_data = expert_counts[layer_num]
        # Create a tensor with counts for each expert (assuming 64 experts)
        counts = torch.zeros(64)
        for expert_id, count in layer_data.items():
            counts[expert_id] = count
        expert_token_counts[layer_num] = counts
    
    # Save just the token counts per expert
    torch.save(expert_token_counts, output_path)
    print(f"Expert token counts saved to {output_path}")
    
    # Create a DataFrame for visualization purposes
    rows = []
    for layer_num in sorted(expert_counts.keys()):
        layer_data = expert_counts[layer_num]
        for expert_id in range(64):  # Assuming 64 experts
            count = layer_data.get(expert_id, 0)
            rows.append({
                'layer': layer_num,
                'expert_id': expert_id,
                'token_count': count
            })
    
    df = pd.DataFrame(rows)
    return df

def analyze_text_file_routing(model, tokenizer, file_path):
    """
    Main function to analyze MoE routing for a text file.
    
    Args:
        file_path: Path to text file with prompts (one per line)
        model_name: Name of DeepSeek MoE model to use
    """
    # Process the file for expert counts and save as PyTorch file
    df = process_text_file_for_expert_counts(file_path, model, tokenizer)
    
    print(f"Analysis completed for {file_path}")
    return df

In [None]:
file_path = "data-ext/test.txt"
df = analyze_text_file_routing(model, tokenizer, file_path)


In [6]:
data = torch.load("data-ext/gsm8k_expert_data.pt")

  data = torch.load("data-ext/gsm8k_expert_data.pt")


In [7]:
data[1]

tensor([28550.,  8051., 20614.,  6439.,  8464., 15529.,  8803., 17602.,  7979.,
         9239.,  9358.,  6246.,  6327.,  9828., 10541.,  6800., 16189.,  6640.,
        11420.,  5676.,  8199.,  9876., 11074.,  4400.,  4386.,  6602., 24165.,
         6705.,  6997.,  8640.,  8214.,  4198.,  4410., 11428.,  5912.,  8226.,
         7188.,  6626.,  6126.,  5993., 10989.,  5436.,  6459., 11866., 19479.,
        11095., 19071.,  9194.,  6852.,  8344.,  5369.,  5105.,  6272.,  9522.,
         4824.,  2773.,  6093.,  6270.,  6276.,  6767., 13202.,  9643.,  5591.,
         7132.])

In [46]:
def bar_graph_all_tokens_paper(expert_data, layer_number, tokenizer, domain=None):
    """
    Visualizes expert distribution for all tokens in a file for a specific layer.
    
    Args:
        expert_data: Dictionary with layer numbers as keys and tensor of expert counts as values
                     or path to PyTorch file with this data
        layer_number: Layer to analyze (1-27)
        domain: Optional domain name for title (e.g., 'GSM8K', 'Math', etc.)
    
    Returns:
        fig: Plotly figure object
    """
    # Load data if a file path is provided
    if isinstance(expert_data, str):
        txt_file_path = expert_data.replace("_expert_data_chat.pt", ".txt").replace("_expert_data_base.pt", ".txt")
        
        num_tokens = None
        if os.path.exists(txt_file_path):
            # Try different encodings
            encodings = ['utf-8', 'latin-1', 'utf-16']
            for encoding in encodings:
                try:
                    with open(txt_file_path, 'r', encoding=encoding) as f:
                        text_content = f.read()
                        tokens = tokenizer(text_content, return_tensors="pt")
                        num_tokens = tokens.input_ids.numel()
                        print(f"Total tokens in {txt_file_path}: {num_tokens}")
                        break
                except UnicodeDecodeError:
                    continue
            else:
                print(f"Could not read {txt_file_path} with any of the attempted encodings")
        
        expert_data = torch.load(expert_data)
    else:
        # If expert_data is already loaded (not a string path)
        num_tokens = None
    
    # Validate layer number is in the data
    if layer_number not in expert_data:
        raise ValueError(f"Layer {layer_number} not found in expert data")
    
    # Get counts for the specified layer
    expert_counts = expert_data[layer_number].numpy()
        
    # Compute percentages
    total_tokens = num_tokens
    if total_tokens == 0:
        print("No tokens found for this layer")
        return
    
    print(f"total_tokens: {total_tokens}")
    print(f"num_tokens: {num_tokens}")
    percentages = (expert_counts / total_tokens) * 100
    
    # Set discrete opacity based on 9.375% threshold
    # 9.375% is 6 times the expected uniform distribution (1/64 = 1.5625%)
    threshold = 9.375
    opacities = np.where(percentages >= threshold, 1.0, 0.3)
    
    # Create plotly figure
    fig = go.Figure()
    
    # Add bar trace with discrete color and opacity
    fig.add_trace(go.Bar(
        x=list(range(64)),
        y=percentages,
        marker=dict(
            color='#636EFA',  # Blue color for all bars
            opacity=opacities
        ),
        hovertemplate='Expert ID: %{x}<br>Tokens: %{text}<br>Percentage: %{y:.2f}%<extra></extra>',
        text=[f"{int(count)}" for count in expert_counts],
        textposition='none'  # Ensure no text is displayed on the bars
    ))
    
    # Add horizontal line at threshold
    fig.add_shape(
        type="line",
        x0=-0.5,
        x1=63.5,
        y0=threshold,
        y1=threshold,
        line=dict(
            color="red",
            width=2,
            dash="dash",
        )
    )
    
    # Add annotation for the threshold line
    fig.add_annotation(
        x=63,
        y=threshold,
        text=f"{threshold}% threshold",
        showarrow=False,
        yshift=10,
        font=dict(color="red")
    )
    
    routed_tokens = expert_counts.sum()
    
    # Domain label for title
    domain_label = f" - {domain}" if domain else ""
    token_info = f" (No. of Tokens: {num_tokens}, Routed Tokens: {int(routed_tokens)})" if num_tokens else f" (Routed Tokens: {int(routed_tokens)})"
    
    # Calculate y-axis max based on the max percentage or default to 100 if num_tokens is None
    y_max = 100
    
    # Update layout with white background
    fig.update_layout(
        title=f'Expert Usage Distribution - Layer {layer_number}{domain_label}{token_info}',
        xaxis_title='Expert ID',
        yaxis_title='Usage Percentage (%)',
        yaxis_range=[0, y_max],  # Dynamic y-axis based on data
        xaxis=dict(tickmode='linear', tick0=0, dtick=4),
        showlegend=False,
        width=1000,
        height=600,
        plot_bgcolor='white',
        paper_bgcolor='white'
    )
    
    # Add gridlines with lighter color
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='rgba(128, 128, 128, 0.1)')
    fig.update_xaxes(showgrid=False)
    
    return fig

In [53]:
fig = bar_graph_all_tokens_paper("data-ext/arxiv_title_abstract_expert_data_base.pt", layer_number=15, tokenizer=tokenizer)
fig.show()

Total tokens in data-ext/arxiv_title_abstract.txt: 103197
total_tokens: 103197
num_tokens: 103197



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [52]:
fig = bar_graph_all_tokens_paper("data-ext/arxiv_title_abstract_expert_data_chat.pt", layer_number=15, tokenizer=tokenizer)
fig.show()

Total tokens in data-ext/arxiv_title_abstract.txt: 103197
total_tokens: 103197
num_tokens: 103197



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

