In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict

In [2]:
def load_model(model_name):
    # set default device to CUDA if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = 'cpu'
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        trust_remote_code=True,
        use_flash_attention_2=True,
        device_map="auto",
    )
    # move model to GPU
    model = model.to(device)
    if device.type == "cuda":
        model = model.half()  # convert to FP16 for GPU memory efficiency
        torch.backends.cudnn.benchmark = True  # enable CUDNN autotuner
    
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model("deepseek-ai/deepseek-moe-16b-base")

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

In [3]:
def get_device():
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = get_device()

In [4]:
def get_moe_metadata(model, input_ids):
    """Get both router logits and expert indices for all MoE layers"""
    router_logits_list = []
    expert_indices_list = []
    
    def hook_fn(module, input, output):
        # output contains: (topk_idx, topk_weight, aux_loss) 
        hidden_states = input[0]
        # move computation to GPU and use half precision
        logits = torch.matmul(hidden_states.half(), module.weight.T.half()).to(model.device)
        router_logits_list.append(logits.detach().cpu()) # move to CPU to save GPU memory
        # store expert indices on CPU
        expert_indices_list.append(output[0].detach().cpu())
        
        # torch.cuda.empty_cache() 
        return output
    
    hooks = []
    for layer_idx, layer in enumerate(model.model.layers):
        if layer.mlp.__class__.__name__ == 'DeepseekMoE':
            hook = layer.mlp.gate.register_forward_hook(hook_fn)
            hooks.append(hook)

    with torch.no_grad():
        input_ids = input_ids.to(model.device)
        model(input_ids)
        input_ids = input_ids.cpu() # move back to CPU

    for hook in hooks:
        hook.remove()
    torch.cuda.empty_cache()

    # stack tensors on CPU
    moe_metadata = {
        'router_logits': torch.stack(router_logits_list) if router_logits_list else None,
        'expert_indices': torch.stack(expert_indices_list) if expert_indices_list else None
    }
    
    if moe_metadata['router_logits'] is not None:
        print(f"Router logits shape: {moe_metadata['router_logits'].shape}")
    if moe_metadata['expert_indices'] is not None:
        print(f"Expert indices shape: {moe_metadata['expert_indices'].shape}")
    
    # clear lists to free memory
    router_logits_list.clear()
    expert_indices_list.clear()
    torch.cuda.empty_cache()
    
    return moe_metadata

In [5]:
# input_txt = 'the quick brown fox'
# input_ids = tokenizer.encode(input_txt, return_tensors="pt")
# moe_metadata = get_moe_metadata(model, input_ids)
# print(moe_metadata)
# router_logits = moe_metadata['router_logits']

In [6]:
def get_last_token_router_probs(router_logits, model_layer_idx):
    """
    Get router probabilities for the last token in the sequence
    for a specific MODEL LAYER INDEX (1-27 for MoE layers)
    """
    # convert model layer index (1-27) to router_logits index (0-26)
    router_logits_idx = model_layer_idx - 1
    
    if router_logits_idx < 0 or router_logits_idx >= router_logits.size(0):
        raise ValueError(f"Invalid model_layer_idx {model_layer_idx}. Must be 1-27 for MoE layers")
    
    layer_logits = router_logits[router_logits_idx]  # [1, 5, 64]
    
    last_token_logits = layer_logits[0, -1, :]  # [64]
    routing_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
    
    return routing_probs

def topk(router_probs, k):
    """zero out all components except top k router probabilities"""
    values, indices = torch.topk(router_probs, k)
    zeroed_probs = torch.zeros_like(router_probs)
    zeroed_probs[indices] = values
    return zeroed_probs

In [7]:
def get_moe_data(model, prompts):
    """
    collects both all-token and last-token MoE data in a single forward pass per prompt.
    
    returns:
        all_token_logits: Tensor [num_prompts, num_layers, max_seq_len, num_experts] (padded)
        all_token_experts: Tensor [num_prompts, num_layers, max_seq_len, top_k] (padded with -1)
        last_token_logits: Tensor [num_prompts, num_layers, num_experts]
        last_token_experts: Tensor [num_prompts, num_layers, top_k]
    """
    num_prompts = len(prompts)
    max_seq_len = max(tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.size(1) for prompt in prompts)
    num_moe_layers = sum(1 for layer in model.model.layers if layer.mlp.__class__.__name__ == 'DeepseekMoE')
    print(f"Number of MoE layers: {num_moe_layers}")
    num_experts = model.config.n_routed_experts  # 64
    top_k = model.config.num_experts_per_tok  # 6

    # initialize tensors directly on GPU
    all_token_logits = torch.zeros((num_prompts, num_moe_layers, max_seq_len, num_experts),
                                  dtype=torch.float16, device=model.device)
    all_token_experts = -torch.ones((num_prompts, num_moe_layers, max_seq_len, top_k),
                                   dtype=torch.long, device=model.device)
    last_token_logits = torch.zeros((num_prompts, num_moe_layers, num_experts),
                                   dtype=torch.float16, device=model.device)
    last_token_experts = torch.zeros((num_prompts, num_moe_layers, top_k),
                                    dtype=torch.long, device=model.device)
    
    print(f'its working till here ?')

    for prompt_idx, prompt in enumerate(prompts):
        input_ids = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).input_ids.to(model.device)
        seq_len = input_ids.size(1)
        
        print(f'one before gate computation')
        with torch.no_grad(), torch.cuda.amp.autocast():
            moe_metadata = get_moe_metadata(model, input_ids)

        print(f'two after gate computation')
        
        # process each MoE layer
        for layer_idx in range(num_moe_layers):
            # get raw data for this layer
            layer_logits = moe_metadata['router_logits'][layer_idx].squeeze(0)  # [seq_len, 64]
            layer_experts = moe_metadata['expert_indices'][layer_idx]  # [seq_len, 6]

            # store all tokens data (with padding)
            all_token_logits[prompt_idx, layer_idx, :seq_len] = layer_logits
            all_token_experts[prompt_idx, layer_idx, :seq_len] = layer_experts
            
            # extract and store last token data
            last_token_logits[prompt_idx, layer_idx] = layer_logits[seq_len-1]
            last_token_experts[prompt_idx, layer_idx] = layer_experts[seq_len-1]

        # clear CUDA cache after each prompt
        torch.cuda.empty_cache()

    return all_token_logits, all_token_experts, last_token_logits, last_token_experts

In [8]:
# test_prompts = ['the quick brown fox', 'the capital of japan is tokyo', '12*2 is 24']
# all_token_logits, all_token_experts, last_token_logits, last_token_experts = get_moe_data(model, tokenizer, prompts= test_prompts)

In [9]:
def prepare_prompts_from_txt(txt_file_path,  domain = 'english', output_path=f'english.json'):
    """ read prompts from a txt file and save them in json format. """
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        prompts = [line.strip() for line in f.readlines() if line.strip()]
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({f"{domain}": prompts}, f, indent=4)
        
    print(f"{domain} prompts saved to {output_path}")
    return prompts

def parse_code_blocks(txt_file_path, output_path='code.json', domain='code'):
    """parse code blocks between ``` markers from a text file and save them in json format."""
    code_blocks = []
    current_block = []
    in_block = False
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip().startswith('```'):
                if in_block:
                    # Current block is complete, save it and start new block
                    if current_block:
                        code_blocks.append('\n'.join(current_block))
                    current_block = []
                # Always start a new block since ``` only indicates start
                in_block = True
                current_block = []
            elif in_block:
                # Add line to current block
                current_block.append(line.rstrip())
    
    # Save final block if exists
    if current_block:
        code_blocks.append('\n'.join(current_block))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({domain: code_blocks}, f, indent=4)
        
    print(f"code blocks saved to {output_path}")
    return code_blocks

In [10]:
def load_prompts_from_json(json_file_path):
    """load prompts from a json file and return them as a list."""
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    # get the first (and only) value from the dictionary
    # since the json structure is {"domain": [prompts]}
    prompts = list(data.values())[0]
    return prompts

def prepare_multi_domain_prompts(domain_files, output_path='all_domain_prompts.json'):
    """
    prepare a json file containing prompts from multiple domains.
    
    args:
        domain_files: Dict mapping domain names to lists of tuples (file_path, parser_func)
            where parser_func is a function that takes a file path and returns a list of prompts
            
    example:
        domain_files = {
            'code': [('code.txt', parse_code_blocks)], 
            'english': [('english.txt', prepare_prompts_from_txt)]
        }
    """
    all_prompts = {}
    
    for domain, file_list in domain_files.items():
        domain_prompts = []
        for _, prompts in file_list:
            # Use load_prompts_from_json if prompts is a dict
            if not isinstance(prompts, list):
                prompts = load_prompts_from_json(prompts)
            domain_prompts.extend(prompts)
                
        all_prompts[domain] = domain_prompts
        
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_prompts, f, indent=4)
        
    print(f"all domain prompts saved to {output_path}")
    return all_prompts

def convert_all_to_list(all_prompts):
    """
    combines prompts from all domains into a single list of domain-specific prompt lists.
    returns a list in the format [[domain1_prompts], [domain2_prompts], ...].
    """
    # Create list of domain-specific prompt lists
    combined_prompts = [
        prompts for prompts in all_prompts.values()
    ]
        
    total_prompts = sum(len(prompts) for prompts in combined_prompts)
    print(f"total prompts: {total_prompts}")
    print(f"prompts per domain:")
    for domain, prompts in zip(all_prompts.keys(), combined_prompts):
        print(f"  {domain}: {len(prompts)}")
        
    return combined_prompts

In [11]:
prepare_prompts_from_txt('interp-data/engl-lit.txt', domain='english', output_path='english.json')
prepare_prompts_from_txt('interp-data/french.txt', domain='french', output_path='french.json')
parse_code_blocks('interp-data/code.txt', 'code.json', domain='code')

code_prompts = load_prompts_from_json(json_file_path='code.json')
print(f"total code prompts : {len(code_prompts)}")
english_prompts = load_prompts_from_json(json_file_path='english.json')
print(f'total english prompts : {len(english_prompts)}')


domain_files = {
    'code': [('interp-data/code.txt', parse_code_blocks(txt_file_path='interp-data/code.txt', output_path='code.json', domain='code'))],
    'english': [('interp-data/engl-lit.txt', prepare_prompts_from_txt(txt_file_path='interp-data/engl-lit.txt', output_path='english.json', domain='english'))],
    'french': [('interp-data/french.txt', prepare_prompts_from_txt(txt_file_path='interp-data/french.txt', output_path='french.json', domain='french'))]
}
all_prompts = prepare_multi_domain_prompts(domain_files, output_path='all_prompts.json')
print(f"total domains : {len(all_prompts)}")

# convert all prompts to a single list of domain-specific prompt lists
combined_prompts = convert_all_to_list(all_prompts)

english prompts saved to english.json
french prompts saved to french.json
code blocks saved to code.json
total code prompts : 200
total english prompts : 200
code blocks saved to code.json
english prompts saved to english.json
french prompts saved to french.json
all domain prompts saved to all_prompts.json
total domains : 3
total prompts: 600
prompts per domain:
  code: 200
  english: 200
  french: 200


In [12]:
# flatten the list of lists into a single list of prompts
x =[prompt for domain_prompts in combined_prompts for prompt in domain_prompts]

In [None]:
all_token_logits, all_token_experts, last_token_logits, last_token_experts = get_moe_data(model, prompts= x,)

In [None]:
torch.save(all_token_logits, "all_token_logits.pt")
torch.save(all_token_experts, "all_token_experts.pt")
torch.save(last_token_logits, "last_token_logits.pt")
torch.save(last_token_experts, "last_token_experts.pt")

In [13]:
# Flatten the list of lists into a single list of prompts
flattened_prompts = [prompt for domain_prompts in combined_prompts for prompt in domain_prompts]

num_steps = 1  # 600 prompts total
prompts_per_step = 600

for step in range(num_steps):
    print(f"\nProcessing step {step+1}/{num_steps}")
    
    # get slice of prompts for this step
    start_idx = step * prompts_per_step
    end_idx = start_idx + prompts_per_step
    step_prompts = flattened_prompts[start_idx:end_idx]
    
    print(f"processing prompts {start_idx} to {end_idx-1}")
    all_token_logits, all_token_experts, last_token_logits, last_token_experts = get_moe_data(model, prompts=step_prompts)
    
    print(f"saving results for step {step+1}")
    torch.save(all_token_logits, f"all_token_logits-{step+1}.pt")
    torch.save(all_token_experts, f"all_token_experts-{step+1}.pt") 
    torch.save(last_token_logits, f"last_token_logits-{step+1}.pt")
    torch.save(last_token_experts, f"last_token_experts-{step+1}.pt")
    
    # clear memory
    del all_token_logits, all_token_experts, last_token_logits, last_token_experts
    torch.cuda.empty_cache()
    
print("\nFinished processing all prompts")



Processing step 1/1
processing prompts 0 to 0
Number of MoE layers: 27
its working till here ?
one before gate computation
Router logits shape: torch.Size([27, 1, 26, 64])
Expert indices shape: torch.Size([27, 26, 6])
two after gate computation
saving results for step 1

Finished processing all prompts
