In [8]:
import os

# os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
# os.environ["PYTORCH_TRANSFORMERS_SDP_BACKEND"] = "flash"

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go
import numpy as np

In [2]:
def load_model(model_name="Qwen/Qwen1.5-MoE-A2.7B"):
    # device = "cuda" if torch.cuda.is_available() else "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()
model.eval()

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

Qwen2MoeForCausalLM(
  (model): Qwen2MoeModel(
    (embed_tokens): Embedding(151936, 2048)
    (layers): ModuleList(
      (0-23): 24 x Qwen2MoeDecoderLayer(
        (self_attn): Qwen2MoeSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): Qwen2MoeRotaryEmbedding()
        )
        (mlp): Qwen2MoeSparseMoeBlock(
          (gate): Linear(in_features=2048, out_features=60, bias=False)
          (experts): ModuleList(
            (0-59): 60 x Qwen2MoeMLP(
              (gate_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (up_proj): Linear(in_features=2048, out_features=1408, bias=False)
              (down_proj): Linear(in_features=1408, out_features=2048, bias=False)
        

In [3]:
# def print_expert_weights(model, layer_idx, expert_idx):
#     """
#     Print the weights of a specific expert MLP at a given layer.
    
#     Args:
#         model: The OLMoE model
#         layer_idx: Index of the layer containing the expert
#         expert_idx: Index of the expert within the layer
#     """
#     gate_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight'
#     up_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.up_proj.weight'
#     down_proj = f'model.layers.{layer_idx}.mlp.experts.{expert_idx}.down_proj.weight'
    
#     print("\nGate Projection:")
#     print(model.state_dict()[gate_proj])
#     print("\nUp Projection:") 
#     print(model.state_dict()[up_proj])
#     print("\nDown Projection:")
#     print(model.state_dict()[down_proj])


In [3]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    device = 'cpu'
    
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor(tokens).to(device)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens_tensor[i:i + chunk_size]
            # Move to CPU for decoding
            chunk_tokens = chunk_tokens.cpu()
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else:
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

In [4]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    device = "cpu"
    model = model.to(device)
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Move tensors to CPU for post-processing
        top_probs = top_probs.cpu()
        top_indices = top_indices.cpu()
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

In [5]:
def update_router_logits_json(results, domain, device='cpu'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        domain: String indicating the domain (e.g., 'arxiv', 'code')
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : updated json file with new tokens
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'
    elif domain == 'swap':
        json_path = 'swap_all_layers.json'
    elif domain == 'mix':
        json_path = 'mix_all_layers.json'
    elif domain == 'mix-swap':
        json_path = 'mix_swap_all_layers.json'
    elif domain == 'instruct':
        json_path = 'instruct_all_layers.json'
    elif domain == 'qwen-arxiv':
        json_path = 'qwen_arxiv_all_layers.json'

    
    # Initialize an empty dictionary for existing results
    existing_results = {}
    
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f:
            try:
                existing_results = json.load(f)
                # Convert string keys to integers
                existing_results = {int(k): v for k, v in existing_results.items()}
            except json.JSONDecodeError:
                print(f"Warning: {json_path} is empty or corrupted. Starting with an empty dictionary.")
    
    # Move results to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        for layer_idx, layer_tokens in results.items():
            # Convert lists to tensors and move to GPU
            tokens_tensor = torch.tensor([[t[0], t[1], t[2]] for t in layer_tokens]).cuda()
            results[layer_idx] = tokens_tensor.tolist()
    
    # Combine existing and new results for each layer
    for layer_idx, layer_tokens in results.items():
        if layer_idx in existing_results:
            existing_results[layer_idx].extend(layer_tokens)
        else:
            existing_results[layer_idx] = layer_tokens
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(existing_results, f, indent=4, ensure_ascii=False)
        
    return existing_results

In [15]:
def plot_expert_distribution(domain, device='cpu'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : heatmap showing distribution of tokens across experts and layers
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'
    elif domain == 'swap':
        json_path = 'swap_all_layers.json'
    elif domain == 'mix':
        json_path = 'mix_all_layers.json'
    elif domain == 'mix-swap':
        json_path = 'mix_swap_all_layers.json'
    elif domain == 'instruct':
        json_path = 'instruct_all_layers.json'
    elif domain == 'qwen-arxiv':
        json_path = 'qwen_arxiv_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Create a 16x64 matrix to store percentages
    expert_matrix = np.zeros((16, 64))
    
    # Process each layer
    for layer in range(16):
        if str(layer) not in data:
            continue
            
        layer_results = data[str(layer)]
        total_assignments = len(layer_results)
        
        # Count expert assignments for this layer
        expert_counts = defaultdict(int)
        if torch.cuda.is_available() and device == 'cuda':
            layer_results = torch.tensor(layer_results).cuda()
            for _, expert, _ in layer_results.cpu().numpy():
                expert_counts[int(expert)] += 1
        else:
            for _, expert, _ in layer_results:
                expert_counts[expert] += 1
                
        # Calculate percentages for each expert
        for expert in range(64):
            expert_matrix[layer][expert] = expert_counts[expert] / total_assignments * 100
    
    # Create and return a single heatmap
    fig = go.Figure(data=go.Heatmap(
        z=expert_matrix,
        x=[str(i) for i in range(64)],
        y=[str(i) for i in range(16)],
        colorscale='Reds'
    ))
    
    fig.update_layout(
        title='Distribution of Tokens Across Experts and Layers',
        xaxis_title='Expert Index',
        yaxis_title='Layer',
        width=800,
        height=800,
        xaxis=dict(
            tickangle=-45,
            constrain='domain'
        ),
        yaxis=dict(
            scaleanchor='x',
            scaleratio=1
        )
    )
    
    return fig


In [8]:
# def swap_experts(model, expert_idx, target_layer_idx, source_layer_idx=0, source_expert_idx=0):
#     """
#     Swap experts between two layers in the OLMoE model.
    
#     Args:
#         model: The OLMoE model
#         expert_idx: Index of the expert in target layer to swap with
#         target_layer_idx: Index of the layer containing the expert to swap with
#         source_layer_idx: Index of the source layer (default 0)
#         source_expert_idx: Index of the source expert (default 0)

#     """
#     # Access the decoder layers
#     decoder_layers = model.model.layers
#     print(decoder_layers[0].mlp.experts[0].gate_proj.weight.shape)
    
#     # Verify indices are valid
#     num_layers = len(decoder_layers)
#     if target_layer_idx >= num_layers or source_layer_idx >= num_layers:
#         raise ValueError(f"Layer index out of range. Model has {num_layers} layers.")
    
#     # Get the MoE blocks from both layers
#     source_moe = decoder_layers[source_layer_idx].mlp
#     target_moe = decoder_layers[target_layer_idx].mlp
    
#     # Verify expert indices are valid
#     num_experts = len(source_moe.experts)
#     if expert_idx >= num_experts or source_expert_idx >= num_experts:
#         raise ValueError(f"Expert index out of range. Each layer has {num_experts} experts.")
        
#     # Swap the expert weights
#     source_expert = source_moe.experts[source_expert_idx]
#     target_expert = target_moe.experts[expert_idx]
    
#     # Swap gate projection weights
#     source_expert.gate_proj.weight, target_expert.gate_proj.weight = \
#         target_expert.gate_proj.weight, source_expert.gate_proj.weight
        
#     # Swap up projection weights
#     source_expert.up_proj.weight, target_expert.up_proj.weight = \
#         target_expert.up_proj.weight, source_expert.up_proj.weight
        
#     # Swap down projection weights  
#     source_expert.down_proj.weight, target_expert.down_proj.weight = \
#         target_expert.down_proj.weight, source_expert.down_proj.weight
    
#     return {
#         'swapped_experts': {
#             'source': {
#                 'layer': source_layer_idx,
#                 'expert': source_expert_idx
#             },
#             'target': {
#                 'layer': target_layer_idx,
#                 'expert': expert_idx
#             }
#         }
#     }

In [9]:
# # Layer 0 swaps
# source_experts = [0, 21, 41, 36, 15, 10, 38, 31]
# target_experts = [39, 7, 9, 57, 20, 12, 28, 23]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=0, source_layer_idx=0, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 0, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 1 swaps
# source_experts = [47, 18, 15, 25, 16, 7, 61, 27]
# target_experts = [57, 35, 10, 43, 63, 14, 41, 38]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=1, source_layer_idx=1, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 1, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 2 swaps
# source_experts = [60, 5, 45, 63, 26, 15, 40, 36]
# target_experts = [2, 47, 13, 37, 53, 59, 58, 4]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=2, source_layer_idx=2, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 2, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 3 swaps
# source_experts = [35, 15, 43, 39, 51, 9, 30, 63]
# target_experts = [10,25,58,53,21,55,23,22]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=3, source_layer_idx=3, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 3, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 4 swaps
# source_experts = [17, 25, 19, 6, 21, 46, 16, 55]
# target_experts = [12, 36, 45, 47, 11, 53, 54, 18]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=4, source_layer_idx=4, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 4, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 5 swaps
# source_experts = [0,31,42,53,2,38,35,57]
# target_experts = [46,33,40,60,26,3,37,45]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=5, source_layer_idx=5, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 5, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 6 swaps
# source_experts = [52, 57, 18, 5, 62, 24, 1, 40]
# target_experts = [45, 0, 14, 51, 41, 15, 38, 27]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=6, source_layer_idx=6, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 6, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 7 swaps
# source_experts = [2, 17, 58, 0, 59, 4, 36, 45]
# target_experts = [8, 25, 55, 51, 3, 62, 44, 28]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=7, source_layer_idx=7, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 7, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 8 swaps
# source_experts = [14, 16, 54, 18, 3, 44, 61, 32]
# target_experts = [13, 59, 34, 29, 43, 39, 58, 8]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=8, source_layer_idx=8, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 8, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 9 swaps
# source_experts = [5, 4, 46, 28, 8, 57, 51, 20]
# target_experts = [12, 3, 56, 41, 29, 25, 17, 15]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=9, source_layer_idx=9, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 9, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 10 swaps
# source_experts = [43, 56, 11, 19, 28, 48, 60, 13]
# target_experts = [31, 6, 54, 63, 51, 33, 40, 25]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=10, source_layer_idx=10, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 10, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 11 swaps
# source_experts = [27, 47, 23, 33, 54, 62, 46, 12]
# target_experts = [1, 6, 14, 25, 21, 38, 52, 53]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=11, source_layer_idx=11, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 11, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 12 swaps
# source_experts = [43, 38, 59, 31, 55, 8, 10, 44]
# target_experts = [63, 16, 11, 21, 22, 61, 5, 25]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=12, source_layer_idx=12, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 12, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 13 swaps
# source_experts = [20, 2, 25, 62, 57, 5, 61, 7]
# target_experts = [56, 37, 36, 16, 40, 63, 52, 33]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=13, source_layer_idx=13, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 13, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 14 swaps
# source_experts = [6, 9, 33, 58, 24, 38, 48, 19]
# target_experts = [3, 25, 21, 2, 51, 0, 55, 59]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=14, source_layer_idx=14, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 14, expert {source_experts[i]} with expert {target_experts[i]}")

# # Layer 15 swaps
# source_experts = [34, 1, 17, 44, 8, 57, 24, 30]
# target_experts = [25, 26, 56, 27, 10, 14, 49, 59]
# for i in range(len(source_experts)):
#     swap_experts(model, expert_idx=target_experts[i], target_layer_idx=15, source_layer_idx=15, source_expert_idx=source_experts[i])
#     print(f"Swapped experts at layer 15, expert {source_experts[i]} with expert {target_experts[i]}")


In [10]:
# Read and chunk input file
file_path = 'data/qwen_arxiv_25k.txt'
domain = 'qwen-arxiv'
chunks = prepare_text_input(file_path, chunk_size=512, tokenizer=tokenizer)

# Check if CUDA is available
device = 'cpu'
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
with torch.inference_mode():
    for i, chunk in enumerate(chunks):
        print(f'Processing chunk {i+1}/{len(chunks)}')
        # print(f'Sample text: {chunk[:10]}...')  
    
        # Get router logits for the chunk
        results = get_router_logits(model, chunk)
        all_results.append(results)
        
        # Save intermediate results 
        update_router_logits_json(results, domain=domain)

Processing chunk 1/49
Processing chunk 2/49
Processing chunk 3/49
Processing chunk 4/49
Processing chunk 5/49
Processing chunk 6/49
Processing chunk 7/49
Processing chunk 8/49
Processing chunk 9/49
Processing chunk 10/49
Processing chunk 11/49
Processing chunk 12/49
Processing chunk 13/49
Processing chunk 14/49
Processing chunk 15/49
Processing chunk 16/49
Processing chunk 17/49
Processing chunk 18/49
Processing chunk 19/49
Processing chunk 20/49
Processing chunk 21/49
Processing chunk 22/49
Processing chunk 23/49
Processing chunk 24/49
Processing chunk 25/49
Processing chunk 26/49
Processing chunk 27/49
Processing chunk 28/49
Processing chunk 29/49
Processing chunk 30/49
Processing chunk 31/49
Processing chunk 32/49
Processing chunk 33/49
Processing chunk 34/49
Processing chunk 35/49
Processing chunk 36/49
Processing chunk 37/49
Processing chunk 38/49
Processing chunk 39/49
Processing chunk 40/49
Processing chunk 41/49
Processing chunk 42/49
Processing chunk 43/49
Processing chunk 44/

In [18]:
domain = 'qwen-arxiv'

# Plot expert distribution for all processed data
# with torch.cuda.amp.autocast():  # Enable automatic mixed precision
fig = plot_expert_distribution(domain=domain)
fig.show()

# Save plot as HTML and image
fig.write_html(f'plots-qwen-arxiv/{domain}_expert_dist.html')
fig.write_image(f'plots-qwen-arxiv/{domain}_expert_dist.png')




huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
