In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import json
import os
import pandas as pd
from collections import defaultdict
import plotly.graph_objects as go

In [2]:
def load_model(model_name="allenai/OLMoE-1B-7B-0924"):
    device = "cpu"
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    return model, tokenizer

model, tokenizer = load_model()

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

### split text file into tokens for model's context length

In [3]:
def prepare_text_input(file_path, chunk_size=1000, tokenizer=None):
    """    
    args :
        file_path (str): Path to the input text file
        chunk_size (int): Number of tokens per chunk
        tokenizer: HuggingFace tokenizer (if None, will split on whitespace)
        
    output : List of text chunks of approximately chunk_size tokens
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Read the full text file
    with open(file_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    if tokenizer:
        # Tokenize the full text
        tokens = tokenizer.encode(text)
        tokens_tensor = torch.tensor(tokens).to(device)
        
        # Split into chunks
        chunks = []
        for i in range(0, len(tokens), chunk_size):
            chunk_tokens = tokens_tensor[i:i + chunk_size]
            # Move to CPU for decoding
            chunk_tokens = chunk_tokens.cpu()
            # Decode tokens back to text
            chunk_text = tokenizer.decode(chunk_tokens)
            chunks.append(chunk_text)
            
    else:
        # Simple whitespace tokenization
        words = text.split()
        
        # Split into chunks
        chunks = []
        for i in range(0, len(words), chunk_size):
            chunk = ' '.join(words[i:i + chunk_size])
            chunks.append(chunk)
    
    return chunks

### get the router logits for each token across all layers

In [4]:
def get_router_logits(model, input_text: str, k: int = 1):
    """
    args :
        model: OlmoeForCausalLM model
        input_text: Text string to analyze
        k: Number of top experts to return per token
        
    output : dictionary mapping layer indices to lists of [token_text, expert_index, router_probability] for each token in that layer
    """
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)
    
    # Tokenize input text
    inputs = tokenizer(input_text, return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Forward pass with router logits enabled
    outputs = model(
        input_ids=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        output_router_logits=True,
        return_dict=True,
    )
    
    # Get router logits for all layers
    router_logits = outputs.router_logits
    
    all_layer_results = {}
    for layer_idx, layer_router_logits in enumerate(router_logits):
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(layer_router_logits.detach(), dim=-1)
        # Reshape to [seq_len, num_experts] since batch_size=1
        probs = probs.reshape(inputs['input_ids'].size(1), -1)
        # Get top k probabilities and indices for each token
        top_probs, top_indices = torch.topk(probs, k=k)
        
        # Move tensors to CPU for post-processing
        top_probs = top_probs.cpu()
        top_indices = top_indices.cpu()
        
        # Convert token IDs to text
        tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu())
        
        # Create list of [token, expert, prob] for each token
        layer_tokens = []
        for i in range(len(tokens)):
            for j in range(k):
                # Clean special characters from token text
                clean_token = tokens[i].replace('Ġ', '')
                layer_tokens.append([
                    clean_token,
                    top_indices[i][j].item(),
                    top_probs[i][j].item()
                ])
        
        all_layer_results[layer_idx] = layer_tokens
    
    return all_layer_results # Dictionary mapping layer index to list of [token, expert_number, probability]

### update/create the router logits json file with new tokens

In [5]:
def update_router_logits_json(results, domain, device='cuda'):
    """
    args :
        results: Dictionary mapping layer index to list of [token, expert_number, probability]
        domain: String indicating the domain (e.g., 'arxiv', 'code')
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : updated json file with new tokens
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'

    
    # Initialize an empty dictionary for existing results
    existing_results = {}
    
    if os.path.exists(json_path):
        # Load existing results
        with open(json_path, 'r') as f:
            try:
                existing_results = json.load(f)
                # Convert string keys to integers
                existing_results = {int(k): v for k, v in existing_results.items()}
            except json.JSONDecodeError:
                print(f"Warning: {json_path} is empty or corrupted. Starting with an empty dictionary.")
    
    # Move results to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        for layer_idx, layer_tokens in results.items():
            # Convert lists to tensors and move to GPU
            tokens_tensor = torch.tensor([[t[0], t[1], t[2]] for t in layer_tokens]).cuda()
            results[layer_idx] = tokens_tensor.tolist()
    
    # Combine existing and new results for each layer
    for layer_idx, layer_tokens in results.items():
        if layer_idx in existing_results:
            existing_results[layer_idx].extend(layer_tokens)
        else:
            existing_results[layer_idx] = layer_tokens
    
    # Save updated results with integer keys
    with open(json_path, 'w') as f:
        json.dump(existing_results, f, indent=4, ensure_ascii=False)
        
    return existing_results

### plot the expert distribution for a particular layer

In [6]:
def plot_expert_distribution(layer_idx, domain, device='cuda'):
    """    
    args :
        json_path: Path to the JSON file containing expert counts
        device: Device to use for tensor operations ('cuda' or 'cpu')
    output : plot of the expert distribution for a particular layer
    """
    if domain == 'arxiv':
        json_path = 'arxiv_all_layers.json'
    elif domain == 'github':
        json_path = 'github_all_layers.json'
    elif domain == 'math':
        json_path = 'math_all_layers.json'
    elif domain == 'physics':
        json_path = 'physics_all_layers.json'
    elif domain == 'biology':
        json_path = 'biology_all_layers.json'
    elif domain == 'legal':
        json_path = 'legal_all_layers.json'

    # Read JSON file
    with open(json_path, 'r') as file:
        data = json.load(file)
    
    # Extract layer results
    layer_results = data[str(layer_idx)]
    
    # Create a dictionary to store expert counts
    expert_counts = defaultdict(int)
    
    # Move data to GPU if available
    if torch.cuda.is_available() and device == 'cuda':
        layer_results = torch.tensor(layer_results).cuda()
        
    # Count how many tokens went to each expert
    total_assignments = len(layer_results)
    print(f'Total assignments: {total_assignments}')
    
    # Count occurrences of each expert
    if torch.cuda.is_available() and device == 'cuda':
        # Process on GPU
        for _, expert, _ in layer_results.cpu().numpy():
            expert_counts[int(expert)] += 1
    else:
        # Process on CPU
        for _, expert, _ in layer_results:
            expert_counts[expert] += 1
            
    print(f'Expert counts: {expert_counts}')
    print(f'Total experts: {len(expert_counts)}')
    print(f'Expert count for l0', expert_counts[0])
    
    # Convert to lists for plotting and calculate percentages
    experts = [f'{i}' for i in range(64)]
    percentages = [expert_counts[i]/total_assignments * 100 for i in range(64)]
    
    # Create bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=experts,
            y=percentages,
            textposition='auto',
            marker_color='red'  # You can use any color here - hex code, RGB, or color name
        )
    ])
    
    fig.update_layout(
        title=f'percentage of total tokens routed to each expert for layer {layer_idx}',
        xaxis_title='expert',
        yaxis_title='% of total tokens',
        yaxis=dict(range=[0, 100]), # Set y-axis range from 0 to 100%
        xaxis_tickangle=-45,
        bargap=0.2
    )
    
    return fig


### expert distribution for all text input and plotting

#### git

In [8]:
# Read and chunk input text file
file_path = 'data/github.txt'
domain = 'github'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    

KeyboardInterrupt: 

In [8]:
domain = 'github'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/github/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/github/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 


  with torch.cuda.amp.autocast():  # Enable automatic mixed precision


Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {57: 159, 8: 4051, 21: 29702, 25: 2619, 6: 4774, 26: 4856, 14: 1674, 2: 9293, 3: 2944, 54: 9222, 53: 4235, 4: 1929, 27: 5356, 42: 3085, 15: 8139, 44: 2915, 29: 5184, 36: 8833, 52: 9488, 37: 2288, 62: 2130, 16: 1172, 17: 5834, 55: 3713, 48: 1388, 33: 3449, 56: 1916, 61: 2122, 59: 3939, 32: 904, 13: 3909, 30: 4094, 60: 4075, 41: 9598, 49: 11515, 18: 3131, 31: 7701, 40: 2834, 47: 5370, 35: 1605, 50: 2513, 45: 1191, 58: 728, 10: 4429, 24: 3000, 20: 284, 1: 1163, 5: 478, 22: 816, 63: 1347, 7: 241, 51: 1508, 11: 404, 43: 605, 19: 376, 0: 384, 38: 945, 46: 355, 28: 57, 39: 370, 23: 24, 12: 18, 34: 49, 9: 40})
Total experts: 64
Expert count for l0 384


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)



Routing for first 5 tokens in layer 1: 



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.


User provided device_type of 'cuda', but CUDA is not available. Disabling



Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {18: 51217, 15: 28059, 42: 1699, 24: 12698, 25: 36107, 61: 8983, 32: 898, 26: 2682, 56: 1892, 7: 6285, 19: 1366, 31: 2488, 44: 2078, 27: 5738, 16: 6647, 33: 773, 62: 811, 40: 537, 6: 662, 1: 1781, 9: 1584, 17: 1225, 28: 1521, 10: 913, 30: 2997, 23: 1617, 55: 757, 47: 2188, 4: 1515, 43: 603, 58: 2280, 29: 4027, 22: 431, 50: 1120, 5: 925, 63: 292, 46: 1377, 59: 852, 8: 1706, 45: 961, 39: 853, 2: 565, 34: 2444, 36: 686, 57: 294, 51: 1208, 14: 271, 11: 1897, 60: 310, 48: 1069, 13: 989, 3: 649, 41: 307, 20: 658, 54: 615, 53: 640, 0: 1196, 38: 223, 21: 3657, 12: 616, 52: 525, 49: 169, 35: 221, 37: 116})
Total experts: 64
Expert count for l0 1196



Routing for first 5 tokens in layer 2: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {30: 6065, 14: 3931, 40: 23540, 5: 40187, 60: 64243, 45: 24498, 31: 513, 17: 1634, 28: 2447, 7: 5252, 16: 1484, 32: 2469, 35: 2549, 38: 1011, 49: 1156, 33: 905, 61: 6317, 26: 1360, 10: 1406, 63: 1906, 54: 548, 0: 935, 62: 748, 44: 588, 24: 5240, 48: 792, 46: 1583, 23: 821, 22: 1397, 39: 1623, 50: 718, 34: 1694, 3: 602, 27: 950, 12: 993, 42: 723, 19: 402, 55: 461, 9: 319, 56: 416, 11: 331, 15: 461, 6: 376, 43: 547, 21: 446, 29: 337, 52: 678, 1: 107, 18: 332, 20: 340, 13: 145, 47: 162, 8: 432, 2: 1507, 4: 471, 57: 79, 37: 91, 36: 143, 25: 415, 41: 273, 53: 60, 51: 76, 59: 198, 58: 37})
Total experts: 64
Expert count for l0 935



Routing for first 5 tokens in layer 3: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {39: 21300, 15: 77228, 43: 42162, 13: 410, 63: 6992, 44: 3246, 36: 880, 6: 3238, 37: 2821, 31: 3105, 29: 1482, 0: 244, 42: 1896, 19: 14073, 9: 5479, 30: 3449, 51: 2352, 62: 2679, 61: 1276, 41: 3006, 7: 1062, 16: 5963, 35: 1634, 34: 534, 32: 831, 3: 154, 38: 1386, 12: 1139, 20: 680, 11: 179, 46: 893, 17: 777, 52: 325, 8: 237, 56: 463, 4: 478, 45: 399, 50: 358, 49: 250, 18: 299, 54: 378, 40: 888, 57: 443, 59: 530, 26: 783, 14: 344, 5: 163, 55: 52, 10: 59, 24: 2295, 60: 183, 53: 90, 47: 177, 28: 195, 22: 6, 21: 32, 27: 122, 2: 26, 33: 112, 25: 92, 48: 83, 23: 6, 58: 20, 1: 32})
Total experts: 64
Expert count for l0 244



Routing for first 5 tokens in layer 4: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {27: 9229, 6: 33810, 25: 40574, 19: 30662, 41: 1700, 21: 21296, 2: 10374, 4: 3254, 29: 215, 61: 2778, 39: 8884, 51: 323, 17: 3447, 3: 4226, 55: 19056, 23: 773, 24: 614, 42: 496, 20: 366, 16: 1248, 43: 3615, 1: 357, 30: 700, 15: 953, 33: 4472, 14: 2006, 34: 1927, 12: 263, 38: 295, 7: 969, 47: 90, 52: 881, 63: 183, 59: 1170, 57: 457, 48: 282, 44: 141, 45: 68, 11: 142, 62: 395, 26: 218, 5: 584, 58: 209, 35: 617, 50: 292, 36: 129, 56: 79, 37: 4049, 49: 173, 31: 159, 28: 1153, 8: 147, 10: 209, 13: 129, 40: 63, 32: 389, 60: 68, 9: 345, 22: 85, 54: 13, 0: 500, 46: 154, 18: 9, 53: 6})
Total experts: 64
Expert count for l0 500



Routing for first 5 tokens in layer 5: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {31: 45585, 2: 28385, 53: 24614, 5: 2602, 8: 2658, 35: 8316, 0: 22845, 15: 4459, 42: 46699, 38: 1074, 29: 2479, 25: 492, 17: 2351, 6: 3858, 22: 1477, 51: 1265, 10: 480, 47: 114, 49: 120, 61: 279, 62: 321, 57: 1693, 40: 700, 21: 1498, 52: 1047, 20: 142, 41: 272, 37: 81, 59: 150, 60: 103, 63: 1624, 1: 194, 9: 327, 11: 374, 54: 382, 39: 635, 43: 3254, 50: 379, 23: 276, 14: 172, 24: 361, 58: 306, 28: 125, 48: 107, 56: 17, 32: 162, 26: 58, 44: 518, 34: 251, 13: 20, 7: 204, 27: 198, 18: 80, 55: 5303, 30: 171, 19: 189, 4: 106, 12: 78, 46: 119, 36: 19, 33: 42, 16: 245, 3: 15})
Total experts: 63
Expert count for l0 22845



Routing for first 5 tokens in layer 6: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {49: 7464, 52: 65270, 1: 15357, 42: 16060, 8: 5214, 59: 4716, 62: 13379, 20: 5541, 24: 20134, 57: 2242, 9: 2069, 22: 7819, 18: 17932, 36: 1799, 40: 18532, 44: 761, 5: 1256, 16: 515, 28: 769, 19: 741, 50: 975, 10: 312, 7: 487, 4: 1491, 31: 130, 23: 1673, 32: 548, 2: 611, 61: 630, 54: 360, 58: 1233, 3: 213, 39: 27, 27: 13, 33: 256, 17: 395, 63: 369, 30: 161, 48: 14, 13: 242, 26: 202, 35: 138, 37: 210, 0: 66, 38: 10, 21: 67, 34: 1475, 43: 600, 53: 1508, 41: 15, 12: 105, 29: 58, 47: 49, 6: 20, 11: 27, 14: 23, 15: 90, 60: 27, 46: 18, 25: 21, 55: 17, 45: 11, 51: 3})
Total experts: 63
Expert count for l0 66



Routing for first 5 tokens in layer 7: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {21: 7247, 17: 53905, 2: 54025, 58: 9367, 0: 15483, 12: 6814, 45: 13864, 4: 10763, 31: 1340, 63: 174, 1: 130, 35: 16984, 37: 624, 29: 486, 46: 72, 27: 10500, 10: 662, 7: 1257, 32: 341, 40: 39, 33: 747, 49: 3667, 30: 240, 34: 464, 24: 645, 11: 73, 47: 126, 54: 94, 13: 90, 20: 317, 36: 224, 60: 70, 42: 5624, 59: 160, 26: 319, 56: 284, 39: 403, 18: 161, 61: 608, 23: 606, 52: 559, 6: 39, 38: 170, 43: 313, 57: 181, 14: 199, 16: 308, 25: 8, 48: 708, 15: 340, 19: 137, 41: 67, 22: 158, 8: 41, 28: 6, 3: 35, 9: 48, 55: 24, 62: 30, 5: 22, 53: 11, 50: 67})
Total experts: 62
Expert count for l0 15483



Routing for first 5 tokens in layer 8: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {22: 1787, 60: 3387, 16: 51252, 6: 3312, 62: 1823, 51: 8042, 3: 24603, 32: 9886, 44: 7790, 48: 2885, 30: 2880, 45: 4560, 54: 53506, 37: 2598, 1: 10986, 52: 1195, 20: 303, 57: 350, 61: 1248, 23: 279, 27: 6167, 59: 308, 7: 307, 33: 11834, 35: 762, 49: 471, 39: 5, 41: 409, 2: 4586, 19: 162, 24: 256, 0: 51, 53: 460, 56: 92, 14: 350, 38: 289, 18: 360, 26: 304, 55: 126, 21: 489, 9: 207, 50: 202, 5: 145, 13: 101, 58: 15, 46: 35, 40: 202, 10: 230, 36: 55, 42: 71, 17: 36, 25: 156, 47: 72, 31: 107, 29: 67, 12: 54, 15: 11, 11: 151, 43: 25, 28: 28, 4: 19, 34: 7, 63: 9, 8: 5})
Total experts: 64
Expert count for l0 51



Routing for first 5 tokens in layer 9: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {8: 12268, 6: 16420, 4: 53183, 39: 9760, 20: 12357, 51: 18764, 46: 39874, 50: 3269, 7: 5514, 48: 385, 5: 6889, 63: 1017, 26: 555, 52: 6828, 31: 2557, 9: 1671, 24: 693, 1: 1359, 33: 3263, 28: 1022, 62: 133, 27: 1246, 34: 1160, 35: 306, 61: 1250, 49: 259, 47: 2077, 19: 57, 14: 814, 57: 4737, 40: 4288, 11: 1267, 54: 602, 22: 368, 12: 1004, 30: 303, 44: 337, 36: 75, 59: 65, 60: 1768, 13: 457, 43: 399, 38: 168, 42: 219, 32: 19, 18: 117, 21: 175, 2: 195, 16: 255, 53: 30, 10: 38, 58: 11, 23: 170, 3: 60, 55: 92, 29: 19, 41: 94, 17: 33, 37: 34, 25: 17, 45: 30, 0: 21, 15: 38, 56: 15})
Total experts: 64
Expert count for l0 21



Routing for first 5 tokens in layer 10: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {4: 945, 56: 42714, 11: 24952, 5: 3533, 43: 60783, 55: 3263, 13: 3438, 2: 4916, 19: 17305, 23: 3366, 18: 7393, 34: 6715, 49: 2396, 29: 4087, 50: 617, 28: 1089, 62: 1257, 16: 482, 58: 1133, 8: 1434, 3: 1907, 22: 436, 41: 623, 9: 183, 15: 708, 26: 1125, 46: 1757, 61: 451, 60: 1302, 45: 225, 20: 1146, 44: 1847, 48: 11700, 21: 562, 35: 364, 0: 266, 30: 379, 24: 460, 59: 345, 52: 329, 53: 90, 39: 276, 17: 149, 12: 533, 37: 101, 6: 172, 57: 1921, 27: 171, 14: 266, 7: 196, 31: 125, 32: 208, 1: 23, 38: 126, 42: 118, 40: 3, 47: 12, 36: 2, 33: 17, 51: 10, 54: 2, 10: 14, 63: 2})
Total experts: 63
Expert count for l0 266



Routing for first 5 tokens in layer 11: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {56: 823, 62: 14474, 46: 8237, 23: 9226, 33: 28098, 51: 2574, 27: 69170, 12: 4767, 41: 2919, 37: 5897, 5: 9514, 47: 5304, 13: 8982, 29: 3703, 28: 3008, 54: 5505, 7: 4770, 40: 1172, 20: 4168, 63: 3885, 2: 1440, 3: 321, 44: 2551, 25: 612, 10: 2144, 11: 466, 26: 262, 35: 161, 6: 538, 57: 1035, 59: 10210, 9: 2462, 43: 601, 16: 122, 4: 193, 38: 366, 39: 209, 32: 538, 15: 123, 45: 456, 18: 104, 61: 39, 31: 32, 58: 25, 22: 347, 24: 150, 14: 71, 52: 58, 34: 108, 55: 202, 8: 9, 36: 83, 0: 17, 50: 37, 17: 30, 1: 6, 49: 25, 42: 19, 48: 8, 19: 39, 30: 30, 60: 21, 21: 4})
Total experts: 63
Expert count for l0 17



Routing for first 5 tokens in layer 12: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {59: 48577, 31: 32781, 51: 7402, 54: 691, 38: 55152, 20: 4241, 10: 5317, 8: 11373, 55: 6519, 24: 1957, 48: 1580, 17: 5934, 57: 3152, 12: 2631, 11: 18, 30: 1829, 29: 709, 44: 1629, 26: 2181, 6: 1425, 45: 340, 60: 1237, 15: 260, 43: 1993, 47: 1363, 28: 1578, 40: 1520, 39: 1585, 56: 3102, 1: 49, 13: 553, 41: 726, 53: 1100, 35: 85, 3: 4520, 4: 591, 42: 357, 23: 1204, 32: 1042, 7: 46, 52: 139, 14: 541, 36: 134, 50: 717, 49: 332, 9: 670, 46: 144, 2: 271, 58: 241, 62: 50, 34: 133, 22: 39, 19: 68, 0: 27, 25: 7, 16: 65, 33: 72, 63: 60, 37: 79, 21: 30, 27: 9, 5: 277, 18: 2, 61: 14})
Total experts: 64
Expert count for l0 27



Routing for first 5 tokens in layer 13: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {2: 62583, 61: 12122, 62: 29873, 5: 9887, 25: 40948, 10: 2724, 46: 418, 27: 1042, 31: 639, 19: 619, 42: 1288, 29: 3635, 16: 89, 54: 1550, 4: 1699, 18: 2431, 28: 2154, 23: 1218, 53: 675, 17: 3274, 55: 1492, 6: 15216, 39: 1728, 32: 2028, 59: 848, 9: 1421, 43: 660, 15: 321, 24: 309, 60: 166, 44: 76, 57: 477, 7: 6565, 50: 329, 52: 55, 51: 669, 38: 354, 30: 27, 12: 314, 11: 883, 22: 418, 34: 77, 41: 812, 20: 435, 3: 45, 21: 6782, 26: 227, 63: 48, 0: 310, 13: 104, 14: 35, 35: 50, 8: 57, 47: 13, 49: 43, 58: 34, 56: 109, 37: 8, 45: 40, 36: 2, 33: 4, 48: 4, 1: 1, 40: 6})
Total experts: 64
Expert count for l0 310



Routing for first 5 tokens in layer 14: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {60: 1607, 34: 7116, 6: 39483, 13: 7294, 58: 37073, 12: 3158, 52: 4137, 9: 38720, 47: 4062, 1: 3449, 19: 6579, 61: 3288, 11: 4412, 36: 7171, 4: 1105, 24: 11586, 35: 693, 56: 805, 7: 4149, 50: 521, 40: 4213, 22: 3582, 54: 3763, 42: 3762, 57: 1853, 49: 2352, 32: 265, 62: 80, 5: 1295, 26: 5871, 29: 1065, 17: 654, 8: 1989, 39: 1028, 27: 1215, 33: 455, 30: 122, 10: 104, 14: 33, 23: 293, 51: 67, 46: 166, 37: 116, 16: 316, 48: 203, 15: 210, 38: 111, 20: 70, 63: 135, 45: 41, 28: 325, 31: 52, 53: 35, 0: 61, 18: 18, 25: 65, 21: 10, 43: 46, 59: 3, 44: 13, 41: 1, 2: 1, 3: 3})
Total experts: 63
Expert count for l0 61



Routing for first 5 tokens in layer 15: 
Total assignments: 222470
Expert counts: defaultdict(<class 'int'>, {8: 24809, 44: 37794, 1: 49056, 36: 10328, 29: 617, 31: 2392, 35: 3747, 17: 12156, 30: 15986, 53: 2554, 24: 8934, 6: 4229, 43: 4652, 20: 1589, 19: 3291, 33: 2874, 5: 5197, 3: 1786, 12: 709, 32: 2362, 23: 277, 9: 4616, 11: 571, 22: 1120, 41: 825, 37: 709, 16: 3985, 40: 233, 18: 1307, 58: 184, 57: 7231, 52: 1024, 50: 105, 4: 946, 42: 349, 34: 448, 47: 89, 54: 431, 13: 143, 62: 391, 63: 359, 2: 450, 21: 23, 0: 750, 51: 191, 15: 35, 61: 56, 56: 8, 55: 214, 60: 54, 46: 32, 28: 28, 27: 20, 48: 150, 7: 3, 45: 8, 26: 7, 14: 2, 25: 4, 10: 26, 38: 2, 49: 1, 59: 1})
Total experts: 63
Expert count for l0 750


#### math

In [9]:
# Read and chunk input file
file_path = 'data/math_arxiv_200k.txt'
domain = 'math'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

Processing chunk 1/195



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Processing chunk 2/195
Processing chunk 3/195
Processing chunk 4/195
Processing chunk 5/195
Processing chunk 6/195
Processing chunk 7/195
Processing chunk 8/195
Processing chunk 9/195
Processing chunk 10/195
Processing chunk 11/195
Processing chunk 12/195
Processing chunk 13/195
Processing chunk 14/195
Processing chunk 15/195
Processing chunk 16/195
Processing chunk 17/195
Processing chunk 18/195
Processing chunk 19/195
Processing chunk 20/195
Processing chunk 21/195
Processing chunk 22/195
Processing chunk 23/195
Processing chunk 24/195
Processing chunk 25/195
Processing chunk 26/195
Processing chunk 27/195
Processing chunk 28/195
Processing chunk 29/195
Processing chunk 30/195
Processing chunk 31/195
Processing chunk 32/195
Processing chunk 33/195
Processing chunk 34/195
Processing chunk 35/195
Processing chunk 36/195
Processing chunk 37/195
Processing chunk 38/195
Processing chunk 39/195
Processing chunk 40/195
Processing chunk 41/195
Processing chunk 42/195
Processing chunk 43/195


In [10]:
domain = 'math'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/math/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/math/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {37: 2893, 6: 899, 4: 1536, 59: 3763, 49: 1948, 13: 1746, 10: 5961, 44: 2329, 17: 2534, 26: 3117, 56: 1065, 8: 2681, 48: 1336, 5: 2020, 60: 1887, 54: 2087, 45: 1583, 2: 3321, 36: 10176, 41: 6777, 15: 2258, 38: 9401, 52: 1736, 32: 809, 27: 2585, 30: 1605, 31: 2874, 19: 4606, 24: 856, 35: 1133, 62: 1634, 33: 654, 47: 1865, 25: 2937, 22: 6579, 40: 913, 16: 1464, 61: 1177, 18: 548, 7: 183, 50: 4787, 46: 4666, 14: 1399, 55: 2871, 29: 1770, 0: 70730, 42: 1736, 1: 1080, 53: 505, 51: 1483, 3: 1309, 63: 316, 34: 255, 57: 119, 12: 31, 9: 58, 43: 306, 11: 50, 58: 125, 28: 24, 20: 10, 39: 54, 21: 16, 23: 1})
Total experts: 64
Expert count for l0 70730



Routing for first 5 tokens in layer 1: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {18: 4138, 0: 4909, 9: 6398, 17: 4524, 31: 1344, 58: 2549, 45: 408, 51: 697, 50: 1182, 4: 472, 5: 94, 52: 479, 13: 421, 37: 363, 16: 11701, 47: 94390, 38: 203, 30: 1823, 53: 3540, 29: 1389, 27: 7540, 23: 4527, 21: 808, 11: 991, 7: 6597, 61: 6311, 39: 3250, 41: 148, 54: 1299, 20: 1318, 1: 1146, 40: 1667, 32: 1426, 35: 231, 2: 525, 49: 560, 60: 102, 36: 2268, 33: 391, 48: 1364, 8: 538, 12: 3989, 63: 384, 59: 921, 42: 752, 34: 902, 10: 354, 24: 334, 44: 776, 28: 668, 57: 226, 46: 631, 19: 553, 3: 759, 6: 396, 55: 820, 14: 306, 56: 496, 26: 619, 43: 282, 22: 374, 15: 332, 62: 209, 25: 63})
Total experts: 64
Expert count for l0 4909



Routing for first 5 tokens in layer 2: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {30: 2879, 23: 7869, 34: 3195, 44: 2213, 12: 284, 9: 2024, 28: 3106, 52: 964, 55: 745, 27: 6749, 43: 1513, 15: 14948, 48: 1375, 46: 1570, 38: 517, 63: 20578, 35: 6723, 26: 15628, 13: 289, 37: 179, 25: 513, 53: 59, 57: 445, 1: 1641, 19: 2808, 20: 3038, 39: 3306, 6: 853, 50: 10764, 22: 4754, 21: 3153, 0: 2150, 11: 511, 29: 5705, 32: 2642, 16: 8324, 58: 79, 61: 13172, 17: 1625, 42: 1293, 60: 4722, 45: 6903, 36: 6319, 7: 2036, 2: 589, 3: 1397, 62: 544, 33: 6642, 24: 687, 41: 526, 18: 1008, 49: 253, 31: 1432, 54: 4183, 14: 322, 56: 410, 40: 77, 47: 288, 51: 160, 5: 249, 8: 207, 59: 18, 10: 22})
Total experts: 63
Expert count for l0 2150



Routing for first 5 tokens in layer 3: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {39: 244, 56: 4008, 61: 1325, 11: 1035, 58: 69, 40: 8500, 34: 575, 4: 623, 5: 1195, 41: 2107, 17: 652, 45: 4759, 9: 10810, 19: 493, 3: 308, 42: 5516, 37: 6504, 24: 52, 48: 447, 51: 12118, 59: 991, 33: 66, 63: 3299, 8: 1147, 47: 492, 38: 7516, 32: 3226, 28: 1297, 12: 4097, 50: 1504, 30: 7653, 46: 1530, 7: 3187, 2: 240, 54: 2273, 27: 1213, 21: 51, 20: 4031, 13: 645, 25: 49, 31: 1138, 60: 146, 18: 4814, 36: 3695, 0: 414, 35: 70080, 15: 61, 52: 1331, 57: 3359, 49: 3022, 62: 1305, 44: 625, 29: 1709, 26: 1080, 22: 11, 53: 58, 10: 93, 6: 36, 16: 119, 14: 21, 43: 115, 1: 78, 23: 17, 55: 3})
Total experts: 64
Expert count for l0 414



Routing for first 5 tokens in layer 4: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {27: 1106, 14: 2889, 38: 438, 45: 20, 7: 1433, 33: 1365, 15: 774, 12: 66, 32: 289, 28: 93, 17: 140688, 16: 12503, 41: 572, 34: 1325, 61: 432, 26: 4348, 20: 99, 37: 210, 10: 100, 50: 176, 1: 52, 53: 9, 11: 5, 23: 1178, 5: 1299, 36: 22, 40: 78, 57: 642, 0: 75, 52: 931, 42: 337, 51: 350, 60: 66, 56: 355, 24: 632, 58: 256, 59: 3529, 13: 374, 43: 535, 21: 8034, 44: 425, 39: 876, 6: 2758, 3: 1227, 48: 6, 25: 320, 31: 438, 4: 805, 62: 140, 22: 193, 35: 430, 46: 1845, 49: 318, 2: 125, 47: 208, 29: 729, 30: 263, 8: 213, 55: 111, 54: 5, 19: 6, 9: 48, 63: 2, 18: 1})
Total experts: 64
Expert count for l0 75



Routing for first 5 tokens in layer 5: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {31: 11968, 1: 2460, 57: 9233, 29: 731, 60: 10, 43: 1077, 35: 7768, 15: 1377, 6: 2283, 41: 828, 27: 8621, 17: 2493, 0: 79242, 62: 643, 63: 2793, 53: 4914, 11: 4007, 14: 1153, 3: 33, 16: 62, 20: 230, 45: 11, 54: 7449, 22: 6175, 18: 2114, 4: 529, 12: 97, 9: 1930, 30: 1502, 7: 1313, 19: 1621, 40: 65, 33: 160, 52: 3270, 44: 426, 58: 1906, 38: 7593, 55: 1065, 13: 3704, 36: 1411, 8: 1397, 51: 977, 48: 862, 21: 1352, 25: 2246, 61: 512, 34: 1664, 5: 2048, 24: 710, 47: 156, 10: 1446, 56: 257, 50: 410, 28: 103, 32: 151, 2: 41, 49: 64, 46: 28, 39: 391, 37: 20, 23: 68, 42: 4, 26: 1, 59: 2})
Total experts: 64
Expert count for l0 79242



Routing for first 5 tokens in layer 6: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {49: 354, 17: 1437, 4: 12046, 50: 5654, 18: 15500, 58: 746, 54: 1636, 22: 5824, 8: 2445, 40: 12219, 60: 316, 15: 5, 1: 3342, 53: 684, 24: 4115, 9: 2538, 20: 3023, 55: 561, 6: 514, 63: 4483, 62: 11614, 26: 876, 25: 204, 59: 1105, 11: 54, 14: 47, 35: 198, 37: 85, 13: 1712, 7: 614, 30: 1150, 44: 77, 57: 65547, 32: 516, 5: 19944, 33: 12248, 12: 1403, 31: 17, 56: 75, 47: 118, 29: 449, 45: 63, 16: 349, 51: 17, 3: 350, 48: 133, 19: 315, 61: 891, 36: 163, 2: 881, 46: 108, 34: 99, 42: 156, 39: 24, 43: 31, 28: 24, 41: 5, 0: 18, 21: 39, 23: 14, 10: 1, 52: 1})
Total experts: 62
Expert count for l0 18



Routing for first 5 tokens in layer 7: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {21: 943, 1: 175, 34: 5280, 24: 10722, 42: 5129, 49: 3428, 47: 452, 22: 5135, 45: 1994, 33: 5088, 9: 62, 7: 27, 52: 5336, 48: 6051, 38: 4412, 3: 2, 60: 512, 61: 7777, 46: 1652, 56: 5015, 62: 41, 14: 1684, 20: 1638, 29: 2938, 16: 3561, 59: 10192, 10: 1937, 50: 460, 58: 18029, 40: 492, 12: 4133, 23: 2384, 11: 1294, 41: 772, 5: 1045, 31: 1077, 63: 1307, 36: 14384, 39: 596, 13: 296, 26: 9694, 4: 11328, 17: 19759, 15: 4246, 19: 2204, 54: 145, 30: 538, 35: 4863, 32: 4158, 57: 2663, 18: 107, 53: 192, 37: 347, 44: 48, 55: 55, 8: 72, 6: 138, 0: 35, 27: 353, 51: 28, 43: 50, 25: 20, 2: 682})
Total experts: 63
Expert count for l0 35



Routing for first 5 tokens in layer 8: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {22: 1508, 32: 1035, 31: 684, 3: 49, 38: 972, 15: 214, 44: 10537, 48: 3569, 26: 1338, 41: 1102, 49: 2857, 18: 24330, 5: 1009, 6: 3096, 30: 602, 52: 3521, 45: 3909, 35: 635, 53: 7385, 61: 13644, 24: 9220, 12: 198, 58: 2, 17: 1475, 57: 3608, 28: 584, 2: 1934, 21: 4664, 47: 352, 36: 2249, 50: 1201, 0: 1054, 39: 1, 13: 119, 46: 111, 14: 45120, 29: 18, 37: 5498, 16: 7602, 54: 14640, 40: 1846, 10: 2154, 63: 1304, 25: 2266, 23: 444, 51: 254, 1: 1952, 9: 1433, 42: 183, 20: 1805, 55: 1070, 33: 605, 60: 29, 11: 1221, 56: 220, 62: 41, 27: 212, 4: 278, 43: 32, 59: 9, 19: 17, 34: 129, 7: 21, 8: 6})
Total experts: 64
Expert count for l0 1054



Routing for first 5 tokens in layer 9: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {8: 10679, 12: 69, 1: 3074, 33: 3363, 27: 2370, 39: 421, 40: 1327, 60: 2334, 6: 1550, 48: 1897, 38: 1190, 36: 912, 11: 1264, 31: 2132, 47: 4682, 24: 149, 54: 2312, 9: 496, 57: 5051, 50: 1507, 7: 1480, 13: 1018, 61: 696, 14: 4334, 21: 473, 17: 3, 62: 126, 2: 869, 43: 1756, 41: 44, 32: 86, 52: 2230, 23: 1194, 29: 31, 59: 701, 53: 288, 44: 667, 19: 86, 55: 547, 16: 1605, 25: 11, 26: 116, 20: 7440, 49: 288, 35: 518, 22: 1289, 5: 108883, 63: 223, 28: 8430, 30: 776, 0: 143, 42: 4482, 45: 147, 4: 96, 37: 123, 18: 747, 10: 107, 3: 57, 58: 95, 56: 65, 51: 100, 34: 26, 46: 2})
Total experts: 63
Expert count for l0 143



Routing for first 5 tokens in layer 10: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {4: 470, 21: 2027, 60: 14451, 3: 11538, 32: 3466, 57: 4738, 49: 1403, 55: 2675, 8: 533, 9: 4159, 43: 9941, 46: 2478, 18: 6873, 24: 1218, 22: 10, 62: 2211, 50: 2288, 29: 10312, 59: 4617, 7: 3503, 28: 9325, 47: 191, 26: 5597, 41: 858, 0: 1408, 17: 4695, 14: 7720, 31: 216, 12: 4457, 27: 2256, 61: 665, 58: 3312, 36: 358, 42: 2011, 37: 734, 11: 10736, 35: 773, 39: 3754, 33: 26, 34: 2997, 23: 4216, 54: 47, 5: 1778, 45: 535, 13: 11936, 2: 6145, 56: 8651, 30: 1620, 15: 3469, 52: 4380, 16: 3234, 53: 953, 19: 99, 10: 172, 1: 300, 48: 157, 38: 328, 20: 14, 6: 28, 63: 63, 51: 37, 44: 6, 25: 3, 40: 6})
Total experts: 64
Expert count for l0 1408



Routing for first 5 tokens in layer 11: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {56: 462, 26: 1609, 3: 2338, 54: 16486, 32: 3760, 2: 263, 59: 170, 41: 2210, 29: 1030, 40: 705, 11: 2687, 51: 712, 49: 657, 43: 1358, 47: 65726, 33: 1640, 4: 1190, 0: 614, 12: 5975, 36: 843, 34: 844, 31: 1278, 30: 1471, 35: 678, 60: 114, 18: 1266, 22: 1536, 45: 3206, 15: 356, 28: 224, 52: 13, 16: 518, 20: 3720, 24: 2819, 8: 861, 58: 854, 39: 1068, 19: 1127, 5: 2837, 9: 5308, 6: 47, 23: 48401, 42: 408, 61: 1567, 62: 255, 46: 4400, 17: 208, 21: 118, 63: 245, 48: 708, 44: 663, 10: 46, 7: 242, 37: 93, 57: 583, 13: 306, 38: 4, 55: 43, 27: 70, 50: 111, 53: 8, 1: 103, 25: 2, 14: 13})
Total experts: 64
Expert count for l0 614



Routing for first 5 tokens in layer 12: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {59: 2151, 13: 597, 23: 755, 53: 1123, 48: 2690, 47: 5124, 40: 1278, 10: 5960, 26: 1457, 20: 3979, 56: 1408, 3: 2359, 7: 442, 21: 33, 44: 7259, 29: 101, 30: 1544, 15: 768, 55: 13471, 6: 3284, 5: 11, 27: 62, 24: 1261, 16: 62, 39: 1301, 14: 1422, 42: 142, 49: 971, 22: 35, 43: 107612, 41: 629, 60: 2551, 33: 67, 28: 2167, 46: 152, 9: 2132, 50: 4312, 1: 312, 4: 1727, 57: 2746, 54: 1428, 36: 473, 34: 703, 35: 1252, 17: 245, 0: 268, 18: 411, 62: 1054, 61: 16, 52: 296, 45: 209, 12: 1565, 58: 321, 31: 3627, 51: 1271, 2: 233, 8: 15, 19: 112, 63: 68, 37: 16, 11: 21, 32: 70, 38: 41, 25: 5})
Total experts: 64
Expert count for l0 268



Routing for first 5 tokens in layer 13: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {2: 4663, 50: 2871, 35: 2142, 22: 10756, 14: 189, 29: 796, 18: 5186, 28: 1244, 46: 1883, 57: 11346, 10: 5195, 9: 60, 32: 15418, 5: 6066, 23: 2480, 39: 6679, 44: 111, 26: 115, 37: 149, 42: 465, 8: 1096, 31: 5717, 12: 412, 34: 643, 41: 5277, 47: 324, 51: 882, 15: 1016, 38: 5730, 0: 1864, 20: 37713, 4: 1371, 19: 6042, 36: 199, 55: 2247, 45: 2504, 43: 1642, 24: 462, 33: 35, 61: 14912, 58: 2699, 30: 4515, 3: 7946, 17: 4515, 60: 124, 49: 716, 1: 525, 27: 571, 40: 134, 48: 211, 11: 4827, 13: 232, 59: 187, 54: 993, 53: 13, 25: 682, 7: 1814, 62: 206, 56: 40, 52: 6, 16: 15, 21: 123, 6: 125, 63: 56})
Total experts: 64
Expert count for l0 1864



Routing for first 5 tokens in layer 14: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {60: 1615, 56: 3846, 35: 4482, 63: 3668, 20: 5069, 19: 6342, 26: 2811, 27: 1036, 47: 3980, 49: 379, 3: 188, 24: 13987, 1: 2743, 46: 4043, 22: 111, 57: 3657, 12: 426, 40: 4824, 5: 2726, 9: 15262, 42: 1922, 37: 275, 45: 2344, 0: 31, 53: 423, 59: 8, 18: 4086, 16: 3319, 15: 6236, 44: 1130, 54: 1992, 2: 86, 29: 604, 52: 3542, 34: 105, 17: 780, 51: 18, 7: 2711, 38: 19044, 48: 12822, 62: 604, 6: 32742, 39: 1237, 11: 1053, 10: 1327, 32: 941, 33: 3431, 23: 2677, 41: 838, 61: 1293, 13: 1094, 31: 726, 28: 895, 25: 68, 14: 378, 50: 32, 43: 376, 58: 6102, 36: 392, 30: 146, 21: 69, 55: 40, 4: 6, 8: 37})
Total experts: 64
Expert count for l0 31



Routing for first 5 tokens in layer 15: 
Total assignments: 199177
Expert counts: defaultdict(<class 'int'>, {28: 72, 62: 670, 20: 1436, 22: 515, 23: 325, 9: 27, 16: 301, 41: 596, 6: 2006, 43: 3816, 40: 184, 5: 404, 29: 3231, 31: 2244, 57: 3698, 55: 1518, 12: 54, 60: 351, 3: 1498, 36: 617, 52: 4396, 61: 1022, 1: 5824, 54: 2773, 34: 78050, 48: 180, 14: 4, 47: 161, 46: 47, 4: 1278, 63: 506, 2: 386, 7: 85, 11: 279, 56: 42, 15: 206, 17: 70683, 35: 2401, 19: 230, 53: 1327, 37: 244, 24: 2581, 44: 728, 18: 3, 42: 232, 58: 269, 26: 115, 0: 437, 21: 112, 39: 12, 38: 110, 32: 72, 10: 19, 50: 6, 13: 154, 33: 103, 51: 345, 25: 41, 49: 5, 8: 22, 30: 77, 27: 23, 45: 24})
Total experts: 63
Expert count for l0 437


#### biology

In [11]:
# Read and chunk input file
file_path = 'data/biology_arxiv_200k.txt'
domain = 'biology'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

Processing chunk 1/197



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Processing chunk 2/197
Processing chunk 3/197
Processing chunk 4/197
Processing chunk 5/197
Processing chunk 6/197
Processing chunk 7/197
Processing chunk 8/197
Processing chunk 9/197
Processing chunk 10/197
Processing chunk 11/197
Processing chunk 12/197
Processing chunk 13/197
Processing chunk 14/197
Processing chunk 15/197
Processing chunk 16/197
Processing chunk 17/197
Processing chunk 18/197
Processing chunk 19/197
Processing chunk 20/197
Processing chunk 21/197
Processing chunk 22/197
Processing chunk 23/197
Processing chunk 24/197
Processing chunk 25/197
Processing chunk 26/197
Processing chunk 27/197
Processing chunk 28/197
Processing chunk 29/197
Processing chunk 30/197
Processing chunk 31/197
Processing chunk 32/197
Processing chunk 33/197
Processing chunk 34/197
Processing chunk 35/197
Processing chunk 36/197
Processing chunk 37/197
Processing chunk 38/197
Processing chunk 39/197
Processing chunk 40/197
Processing chunk 41/197
Processing chunk 42/197
Processing chunk 43/197


In [12]:
domain = 'biology'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/biology/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/biology/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {57: 224, 24: 1534, 50: 6647, 31: 4113, 6: 858, 25: 3397, 10: 8388, 55: 3617, 41: 17307, 15: 3226, 59: 5941, 13: 3513, 44: 3604, 60: 2940, 51: 2074, 40: 2557, 0: 19799, 38: 9025, 48: 2163, 37: 3453, 19: 6729, 1: 1205, 8: 3698, 35: 1612, 61: 1123, 22: 9027, 16: 1920, 36: 11418, 47: 2020, 14: 1408, 5: 2614, 46: 4906, 33: 995, 2: 5063, 26: 3829, 17: 720, 27: 1908, 45: 1523, 52: 2085, 4: 1541, 30: 2626, 49: 2197, 34: 14149, 62: 1792, 54: 1586, 42: 1120, 29: 1460, 18: 864, 53: 621, 3: 1895, 56: 803, 32: 1339, 20: 36, 63: 115, 11: 85, 12: 5, 39: 93, 7: 105, 43: 111, 28: 21, 9: 632, 21: 17, 58: 11, 23: 2})
Total experts: 64
Expert count for l0 19799



Routing for first 5 tokens in layer 1: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {18: 690, 54: 2147, 60: 215, 23: 6896, 50: 2207, 9: 6417, 53: 4424, 37: 1385, 12: 9657, 16: 13291, 7: 8622, 21: 3107, 40: 3923, 47: 57123, 3: 1020, 42: 1361, 61: 4850, 55: 1308, 0: 8193, 36: 2369, 49: 1357, 58: 4404, 56: 398, 27: 7238, 19: 718, 34: 1049, 39: 3489, 17: 5484, 30: 3136, 20: 3050, 1: 2012, 4: 771, 31: 1824, 13: 304, 48: 2993, 35: 1349, 28: 652, 32: 2479, 59: 4495, 10: 247, 29: 863, 44: 1378, 52: 739, 46: 1447, 33: 441, 26: 868, 22: 489, 62: 431, 14: 354, 8: 553, 45: 681, 2: 951, 57: 1243, 24: 455, 51: 969, 63: 451, 5: 159, 43: 298, 6: 251, 11: 971, 41: 316, 38: 147, 15: 165, 25: 135})
Total experts: 64
Expert count for l0 8193



Routing for first 5 tokens in layer 2: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {30: 2559, 21: 3818, 15: 20186, 20: 3800, 3: 1865, 34: 4569, 33: 8832, 50: 6842, 26: 15661, 63: 21779, 54: 4656, 37: 205, 29: 5816, 32: 6672, 16: 5278, 27: 6875, 0: 1904, 35: 7729, 56: 356, 22: 4780, 6: 1229, 39: 3915, 19: 2371, 23: 8240, 36: 7585, 11: 2215, 1: 3276, 48: 1493, 46: 1800, 31: 1323, 25: 1273, 44: 2904, 17: 2167, 42: 835, 51: 5897, 57: 641, 9: 360, 7: 2421, 28: 1961, 60: 939, 45: 1746, 52: 1822, 61: 2737, 55: 296, 41: 1103, 43: 342, 13: 283, 18: 1954, 49: 324, 12: 642, 24: 618, 47: 252, 14: 443, 62: 211, 2: 163, 38: 433, 58: 137, 10: 28, 59: 141, 40: 175, 5: 423, 53: 77, 8: 32})
Total experts: 63
Expert count for l0 1904



Routing for first 5 tokens in layer 3: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {39: 216, 56: 4491, 57: 9753, 40: 7142, 61: 1315, 11: 1492, 8: 1875, 9: 3500, 35: 62536, 54: 3611, 46: 1307, 41: 2495, 51: 14128, 30: 7416, 38: 6755, 27: 1617, 49: 3740, 20: 4099, 32: 4531, 42: 7423, 18: 6451, 7: 5265, 28: 1365, 12: 4199, 37: 4435, 45: 5917, 63: 2297, 5: 1193, 50: 1885, 47: 1346, 2: 942, 60: 599, 3: 68, 19: 322, 29: 478, 44: 655, 48: 224, 4: 613, 59: 1516, 1: 1883, 31: 1339, 62: 323, 58: 67, 34: 688, 33: 121, 36: 2444, 13: 331, 10: 179, 26: 656, 52: 2376, 6: 39, 21: 76, 24: 322, 15: 22, 17: 150, 0: 951, 53: 11, 43: 42, 14: 48, 25: 68, 55: 3, 16: 38, 23: 13, 22: 7})
Total experts: 64
Expert count for l0 951



Routing for first 5 tokens in layer 4: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {27: 1417, 33: 936, 62: 2030, 38: 2183, 14: 6384, 17: 64823, 44: 1473, 52: 3843, 2: 368, 59: 7839, 16: 14288, 50: 680, 5: 3415, 46: 31771, 42: 450, 26: 8050, 56: 1535, 23: 2934, 13: 1664, 7: 4876, 32: 2198, 39: 1528, 57: 2899, 35: 2301, 58: 1211, 9: 2986, 8: 1371, 28: 301, 4: 1181, 31: 3725, 41: 1141, 49: 675, 6: 605, 3: 1474, 34: 1754, 24: 1681, 10: 786, 21: 1624, 60: 456, 15: 656, 53: 33, 30: 921, 40: 207, 22: 797, 61: 595, 43: 1549, 20: 560, 29: 1953, 19: 269, 1: 106, 55: 1086, 51: 726, 36: 86, 0: 68, 25: 450, 47: 235, 12: 56, 37: 102, 11: 12, 48: 37, 18: 3, 63: 4, 45: 38, 54: 4})
Total experts: 64
Expert count for l0 68



Routing for first 5 tokens in layer 5: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {31: 12736, 42: 39, 36: 2830, 13: 6839, 1: 3376, 57: 10366, 19: 4687, 48: 1432, 27: 13085, 18: 4295, 9: 4201, 22: 6391, 35: 9078, 38: 16414, 54: 11742, 55: 1374, 52: 4639, 30: 1809, 11: 4926, 58: 3322, 34: 3214, 0: 15782, 10: 4226, 41: 372, 28: 246, 24: 2114, 21: 1052, 47: 580, 56: 3421, 62: 1387, 17: 1981, 25: 18216, 7: 1913, 53: 784, 6: 1443, 63: 1446, 5: 3190, 50: 999, 4: 536, 44: 616, 8: 1654, 51: 1418, 14: 2593, 61: 981, 29: 4038, 20: 517, 12: 220, 39: 297, 15: 813, 49: 223, 32: 230, 2: 339, 43: 189, 60: 12, 40: 62, 23: 63, 59: 362, 33: 91, 16: 130, 46: 27, 3: 18, 37: 17, 45: 10, 26: 6})
Total experts: 64
Expert count for l0 15782



Routing for first 5 tokens in layer 6: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {49: 359, 17: 1417, 18: 28080, 55: 853, 45: 102, 4: 12849, 12: 2881, 60: 551, 50: 8761, 40: 13661, 62: 12334, 5: 30977, 22: 6497, 61: 3126, 33: 15683, 9: 969, 26: 1133, 6: 806, 7: 979, 24: 4158, 63: 6370, 59: 1034, 13: 2479, 29: 1249, 58: 1008, 16: 2481, 32: 529, 8: 2235, 35: 562, 34: 9580, 2: 1920, 46: 139, 57: 10335, 3: 568, 1: 3105, 20: 3796, 53: 715, 54: 510, 36: 268, 44: 606, 48: 1151, 30: 1534, 25: 470, 39: 18, 56: 138, 43: 174, 42: 710, 47: 439, 21: 135, 0: 139, 52: 8, 19: 53, 14: 75, 11: 58, 51: 144, 31: 39, 28: 171, 23: 27, 37: 250, 38: 2, 41: 6, 10: 3})
Total experts: 62
Expert count for l0 139



Routing for first 5 tokens in layer 7: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {21: 1061, 13: 741, 16: 6021, 1: 252, 40: 902, 22: 6329, 11: 765, 24: 11196, 58: 15461, 26: 12400, 23: 2832, 14: 2277, 59: 23361, 35: 2547, 32: 6884, 19: 2870, 61: 8458, 57: 3560, 56: 6802, 15: 7585, 41: 1460, 29: 5331, 33: 1370, 36: 12797, 12: 3382, 20: 3354, 52: 5530, 49: 1521, 34: 5828, 60: 1712, 48: 5341, 45: 1548, 63: 3763, 38: 4224, 42: 765, 46: 1972, 8: 258, 5: 1188, 4: 3493, 39: 1181, 53: 531, 31: 3220, 37: 390, 50: 746, 18: 510, 6: 337, 10: 1606, 17: 2192, 30: 617, 54: 283, 51: 44, 62: 24, 0: 283, 47: 146, 9: 190, 44: 83, 43: 637, 25: 723, 55: 408, 2: 15, 27: 88, 3: 10, 7: 3, 28: 1})
Total experts: 64
Expert count for l0 283



Routing for first 5 tokens in layer 8: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {22: 1267, 2: 2394, 14: 56302, 31: 539, 18: 34679, 40: 1714, 4: 688, 46: 154, 49: 5839, 36: 3695, 53: 6368, 38: 728, 44: 9111, 24: 7158, 47: 831, 21: 5002, 11: 6537, 35: 396, 48: 2144, 10: 1293, 9: 1694, 63: 1692, 17: 1807, 28: 597, 6: 2308, 61: 13806, 1: 1828, 23: 478, 20: 2182, 13: 295, 32: 830, 52: 3213, 41: 772, 26: 301, 5: 430, 45: 1525, 16: 872, 25: 2292, 56: 650, 55: 2120, 57: 4971, 50: 1298, 0: 827, 37: 3256, 42: 743, 33: 265, 54: 710, 51: 257, 30: 401, 60: 57, 34: 389, 12: 482, 15: 181, 59: 9, 3: 669, 27: 140, 29: 49, 7: 36, 62: 97, 43: 34, 19: 7})
Total experts: 61
Expert count for l0 827



Routing for first 5 tokens in layer 9: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {8: 13069, 24: 271, 47: 8981, 22: 4588, 30: 2214, 11: 2531, 59: 2259, 19: 35, 28: 26782, 42: 6512, 20: 8869, 5: 38185, 43: 6504, 2: 1866, 14: 7003, 38: 2178, 52: 3894, 31: 1602, 57: 3089, 49: 1169, 27: 5656, 6: 1029, 33: 5570, 48: 2759, 36: 2527, 40: 580, 18: 1116, 1: 7401, 53: 1668, 61: 804, 44: 2597, 26: 1576, 21: 1286, 12: 67, 23: 2506, 9: 3227, 50: 1635, 54: 2921, 7: 2376, 13: 1511, 39: 908, 62: 221, 35: 916, 16: 1875, 58: 468, 32: 626, 37: 226, 3: 286, 63: 481, 55: 1600, 0: 339, 4: 998, 45: 502, 56: 147, 60: 428, 25: 52, 46: 89, 29: 44, 10: 603, 34: 24, 51: 72, 41: 91})
Total experts: 62
Expert count for l0 339



Routing for first 5 tokens in layer 10: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {4: 476, 39: 5485, 15: 4191, 26: 7667, 14: 7328, 60: 13437, 29: 7343, 52: 6537, 1: 690, 0: 1513, 43: 10896, 17: 3519, 59: 6809, 46: 2770, 45: 729, 55: 2019, 32: 4180, 28: 21383, 3: 13342, 42: 2201, 53: 1189, 27: 2391, 34: 1469, 12: 4995, 18: 7205, 58: 5639, 2: 3288, 16: 6761, 13: 10386, 30: 859, 21: 1889, 35: 1503, 5: 1233, 7: 4334, 9: 3130, 50: 1933, 62: 2988, 41: 992, 36: 657, 57: 1098, 49: 643, 61: 700, 23: 3094, 38: 169, 37: 1217, 11: 3316, 8: 231, 24: 1089, 47: 269, 56: 2518, 63: 150, 33: 48, 10: 339, 6: 36, 31: 136, 54: 209, 22: 13, 19: 505, 48: 121, 51: 107, 20: 34, 44: 10, 40: 1})
Total experts: 63
Expert count for l0 1513



Routing for first 5 tokens in layer 11: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {56: 544, 3: 5494, 18: 3335, 39: 2035, 45: 5987, 26: 3997, 19: 3148, 49: 1128, 15: 1195, 48: 29952, 11: 6478, 51: 1591, 33: 5487, 12: 7948, 34: 2314, 41: 1389, 54: 23147, 8: 2000, 35: 1454, 22: 3419, 24: 6061, 42: 1013, 30: 3624, 57: 4406, 4: 2106, 23: 7660, 31: 3114, 0: 2029, 9: 8201, 16: 1232, 32: 5367, 46: 7875, 20: 5158, 6: 125, 36: 2016, 58: 2436, 5: 3021, 61: 3621, 43: 796, 40: 537, 44: 1333, 47: 10182, 50: 126, 63: 1002, 29: 851, 13: 1304, 60: 272, 17: 391, 28: 387, 2: 335, 14: 28, 7: 602, 37: 312, 59: 111, 21: 261, 55: 73, 62: 910, 27: 42, 10: 67, 1: 290, 52: 10, 53: 4, 38: 1, 25: 75})
Total experts: 64
Expert count for l0 2029



Routing for first 5 tokens in layer 12: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {59: 2333, 41: 2318, 4: 4256, 13: 1031, 47: 9204, 44: 8396, 18: 5843, 10: 6830, 50: 10111, 7: 1577, 43: 62787, 20: 4297, 24: 976, 39: 3418, 14: 3441, 46: 354, 2: 1530, 9: 1961, 60: 4040, 49: 1190, 48: 6475, 1: 736, 54: 1944, 15: 2458, 36: 1274, 62: 2808, 31: 5858, 56: 1114, 12: 675, 6: 3870, 34: 2310, 35: 15248, 52: 803, 30: 1172, 40: 405, 3: 1013, 55: 3109, 57: 3517, 28: 3103, 53: 2271, 45: 405, 32: 633, 26: 197, 51: 58, 0: 674, 42: 206, 8: 188, 23: 880, 29: 92, 61: 39, 17: 853, 16: 94, 27: 158, 33: 104, 21: 68, 38: 112, 19: 206, 58: 66, 22: 106, 37: 3, 11: 161, 63: 23, 25: 26, 5: 1})
Total experts: 64
Expert count for l0 674



Routing for first 5 tokens in layer 13: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {2: 772, 25: 37, 38: 8438, 50: 3963, 14: 85, 39: 6667, 35: 2063, 10: 5020, 22: 9073, 5: 5317, 18: 4317, 20: 59892, 8: 1462, 30: 3688, 41: 7461, 61: 2021, 27: 2779, 45: 838, 12: 465, 57: 19421, 34: 604, 58: 3074, 28: 1570, 3: 5368, 31: 3859, 15: 1513, 0: 2335, 26: 376, 55: 1646, 23: 2341, 59: 2781, 17: 4909, 19: 4399, 32: 4147, 46: 712, 49: 1950, 51: 820, 1: 321, 24: 688, 48: 581, 11: 4833, 40: 139, 47: 297, 54: 1151, 62: 788, 4: 1543, 36: 194, 63: 389, 29: 103, 37: 376, 16: 193, 13: 65, 44: 169, 9: 84, 6: 1460, 60: 986, 43: 273, 53: 73, 42: 45, 33: 90, 7: 93, 21: 27, 56: 245, 52: 20})
Total experts: 64
Expert count for l0 2335



Routing for first 5 tokens in layer 14: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {60: 1507, 22: 193, 63: 6909, 56: 5277, 29: 679, 35: 4055, 46: 3374, 17: 1366, 20: 3544, 24: 13809, 42: 2295, 19: 6982, 32: 641, 38: 7727, 47: 3786, 15: 6566, 54: 2203, 48: 10370, 61: 1054, 44: 961, 5: 1740, 23: 4081, 16: 2929, 52: 3692, 57: 3866, 9: 3464, 18: 3049, 27: 191, 6: 10661, 45: 1439, 41: 683, 33: 58012, 39: 786, 50: 102, 7: 612, 40: 1220, 1: 3479, 13: 733, 58: 339, 11: 1365, 31: 941, 53: 331, 10: 1533, 30: 6699, 12: 939, 28: 1265, 37: 140, 21: 77, 26: 516, 3: 66, 36: 642, 34: 54, 14: 359, 4: 39, 49: 139, 0: 53, 8: 945, 62: 373, 25: 472, 2: 8, 43: 61, 51: 6, 55: 2, 59: 8})
Total experts: 64
Expert count for l0 53



Routing for first 5 tokens in layer 15: 
Total assignments: 201409
Expert counts: defaultdict(<class 'int'>, {23: 1151, 52: 5366, 54: 9139, 62: 3028, 34: 93359, 55: 3528, 63: 1574, 31: 2737, 61: 4824, 1: 6302, 57: 4465, 43: 3397, 29: 7336, 35: 1947, 48: 585, 17: 11329, 2: 1913, 47: 792, 24: 5692, 6: 1801, 4: 3467, 20: 3797, 45: 231, 53: 2606, 5: 4255, 41: 642, 36: 523, 58: 117, 60: 105, 3: 1319, 38: 5786, 13: 748, 15: 903, 22: 941, 19: 413, 42: 610, 10: 72, 28: 290, 27: 30, 11: 304, 7: 261, 30: 702, 8: 138, 21: 377, 0: 193, 33: 43, 26: 176, 37: 431, 44: 239, 40: 155, 39: 742, 16: 63, 56: 130, 12: 36, 32: 26, 9: 26, 46: 81, 51: 60, 25: 38, 49: 12, 18: 7, 50: 9, 59: 36, 14: 4})
Total experts: 64
Expert count for l0 193


#### legal

In [13]:
# Read and chunk input file
file_path = 'data/legal.txt'
domain = 'legal'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

Processing chunk 1/199



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Processing chunk 2/199
Processing chunk 3/199
Processing chunk 4/199
Processing chunk 5/199
Processing chunk 6/199
Processing chunk 7/199
Processing chunk 8/199
Processing chunk 9/199
Processing chunk 10/199
Processing chunk 11/199
Processing chunk 12/199
Processing chunk 13/199
Processing chunk 14/199
Processing chunk 15/199
Processing chunk 16/199
Processing chunk 17/199
Processing chunk 18/199
Processing chunk 19/199
Processing chunk 20/199
Processing chunk 21/199
Processing chunk 22/199
Processing chunk 23/199
Processing chunk 24/199
Processing chunk 25/199
Processing chunk 26/199
Processing chunk 27/199
Processing chunk 28/199
Processing chunk 29/199
Processing chunk 30/199
Processing chunk 31/199
Processing chunk 32/199
Processing chunk 33/199
Processing chunk 34/199
Processing chunk 35/199
Processing chunk 36/199
Processing chunk 37/199
Processing chunk 38/199
Processing chunk 39/199
Processing chunk 40/199
Processing chunk 41/199
Processing chunk 42/199
Processing chunk 43/199


In [14]:
domain = 'legal'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/legal/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/legal/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {57: 563, 6: 680, 27: 1046, 16: 3715, 2: 11295, 48: 2386, 59: 7534, 44: 2662, 31: 3051, 51: 11719, 37: 4051, 3: 2110, 54: 4613, 5: 8002, 18: 1997, 41: 8752, 24: 4108, 49: 4159, 33: 1916, 62: 6327, 55: 5287, 29: 4112, 50: 3688, 13: 2998, 45: 612, 10: 7087, 36: 14405, 56: 2231, 46: 4629, 30: 2657, 26: 5114, 25: 3050, 32: 779, 15: 5151, 17: 2089, 35: 3264, 14: 2105, 39: 222, 42: 847, 47: 3406, 1: 1464, 38: 9181, 22: 8628, 61: 2387, 19: 3801, 8: 4076, 60: 1733, 63: 174, 53: 798, 40: 643, 4: 1575, 23: 12, 11: 173, 28: 3245, 43: 124, 7: 173, 9: 116, 21: 29, 12: 187, 20: 29, 52: 36, 34: 620, 58: 25, 0: 11})
Total experts: 64
Expert count for l0 11



Routing for first 5 tokens in layer 1: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {18: 3055, 61: 4292, 55: 1041, 38: 563, 40: 5921, 23: 12147, 21: 12536, 35: 5092, 59: 5145, 12: 8337, 36: 2803, 41: 3283, 29: 2616, 27: 12942, 53: 12821, 0: 8787, 1: 4818, 47: 811, 2: 1369, 51: 2121, 11: 3836, 57: 935, 20: 5871, 42: 4901, 39: 2025, 28: 1126, 54: 2367, 37: 3466, 9: 7979, 48: 3642, 13: 1159, 56: 323, 4: 981, 33: 652, 8: 489, 26: 201, 44: 1017, 60: 999, 7: 7124, 24: 764, 46: 1184, 17: 4093, 3: 1836, 6: 241, 49: 2062, 45: 1804, 50: 4537, 63: 1209, 30: 2559, 58: 7005, 32: 7231, 31: 2132, 62: 431, 15: 229, 52: 1793, 19: 1112, 10: 240, 22: 1289, 25: 160, 14: 332, 16: 471, 34: 985, 5: 149, 43: 218})
Total experts: 64
Expert count for l0 8787



Routing for first 5 tokens in layer 2: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {30: 1155, 28: 2273, 17: 2713, 18: 11968, 23: 11429, 12: 1817, 13: 1871, 0: 3542, 63: 1876, 26: 18799, 60: 2352, 8: 15831, 45: 3333, 20: 3558, 21: 11141, 3: 4681, 15: 3563, 52: 3525, 22: 7749, 9: 4953, 27: 5931, 50: 4129, 46: 2348, 56: 835, 29: 5088, 16: 1022, 38: 867, 44: 6553, 34: 5336, 7: 1718, 32: 2891, 31: 2141, 55: 216, 25: 2320, 2: 562, 49: 312, 48: 1071, 62: 1427, 43: 281, 14: 3760, 51: 1112, 59: 714, 35: 4630, 39: 4939, 33: 1908, 1: 1155, 6: 9038, 47: 1514, 54: 930, 61: 1162, 19: 2526, 5: 37, 36: 428, 57: 252, 24: 316, 42: 678, 41: 1364, 11: 789, 53: 95, 40: 49, 37: 526, 10: 2475, 58: 85})
Total experts: 63
Expert count for l0 3542



Routing for first 5 tokens in layer 3: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {39: 358, 63: 1494, 24: 1338, 3: 1281, 1: 602, 42: 17701, 41: 6729, 33: 3105, 31: 8410, 13: 711, 57: 23048, 51: 6912, 14: 105, 26: 1472, 9: 7433, 30: 4051, 32: 4326, 12: 5625, 40: 13097, 50: 9499, 29: 3177, 37: 6124, 48: 2980, 38: 6035, 36: 4587, 4: 1504, 20: 5646, 2: 851, 44: 1475, 22: 11, 19: 2344, 53: 118, 59: 5370, 61: 1681, 11: 589, 45: 2910, 47: 4301, 7: 8930, 34: 1996, 49: 2525, 46: 3365, 28: 1272, 56: 2572, 25: 5806, 18: 2559, 8: 1255, 54: 1262, 5: 393, 43: 106, 0: 82, 60: 1281, 35: 259, 62: 93, 52: 1082, 27: 809, 23: 305, 10: 39, 17: 127, 16: 286, 6: 161, 21: 19, 15: 43, 58: 25, 55: 7})
Total experts: 64
Expert count for l0 82



Routing for first 5 tokens in layer 4: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {27: 692, 33: 2165, 37: 414, 12: 294, 31: 3318, 62: 697, 43: 6504, 13: 520, 3: 106438, 49: 8436, 34: 2668, 16: 1162, 28: 469, 15: 638, 61: 1192, 4: 3444, 25: 899, 7: 5942, 32: 433, 6: 893, 14: 6176, 41: 961, 5: 7670, 40: 1381, 57: 2713, 58: 2462, 1: 844, 50: 1213, 21: 873, 39: 2069, 10: 4596, 23: 1201, 35: 931, 19: 679, 59: 1683, 60: 297, 52: 2572, 51: 2925, 24: 3181, 8: 675, 29: 1268, 42: 1246, 56: 1241, 45: 90, 26: 108, 38: 383, 55: 153, 11: 28, 48: 58, 2: 2450, 20: 219, 30: 1538, 44: 1642, 36: 111, 47: 127, 17: 9, 53: 3, 22: 83, 18: 1, 63: 31, 46: 310, 0: 16, 54: 6, 9: 218})
Total experts: 64
Expert count for l0 16



Routing for first 5 tokens in layer 5: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {31: 470, 35: 9629, 15: 3035, 23: 2713, 61: 1554, 46: 24048, 55: 7805, 39: 5644, 11: 6981, 17: 4348, 8: 1844, 52: 2001, 1: 656, 44: 1204, 50: 3795, 6: 3426, 53: 2363, 57: 16020, 22: 3741, 54: 9837, 34: 8212, 63: 1388, 5: 2428, 16: 1289, 20: 320, 43: 1452, 21: 6825, 27: 2313, 29: 1595, 9: 6750, 58: 7997, 32: 14908, 47: 8447, 51: 2044, 14: 1133, 48: 2562, 18: 572, 2: 116, 24: 2044, 10: 3746, 0: 258, 37: 10, 19: 878, 38: 1713, 28: 259, 7: 2299, 4: 935, 30: 162, 36: 621, 45: 124, 12: 322, 60: 83, 13: 139, 41: 459, 40: 228, 25: 7138, 62: 85, 49: 203, 42: 16, 33: 39, 59: 5, 56: 343, 26: 82, 3: 3})
Total experts: 64
Expert count for l0 258



Routing for first 5 tokens in layer 6: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {49: 462, 1: 4013, 22: 5715, 8: 2407, 43: 16650, 24: 12662, 40: 6644, 2: 1628, 9: 3197, 23: 16228, 62: 8523, 35: 2412, 58: 2914, 54: 1904, 18: 24894, 63: 8687, 26: 7779, 28: 1002, 20: 6402, 44: 249, 4: 12745, 30: 5383, 36: 6241, 13: 7531, 29: 8063, 59: 1147, 50: 4460, 3: 6790, 32: 67, 55: 413, 6: 1183, 5: 1030, 17: 1626, 46: 56, 57: 297, 0: 43, 53: 1963, 33: 193, 60: 175, 7: 1923, 12: 1624, 61: 1345, 47: 251, 38: 2, 25: 120, 37: 1572, 16: 399, 41: 177, 56: 88, 31: 106, 39: 260, 19: 110, 42: 8, 45: 91, 51: 12, 10: 3, 34: 1750, 48: 11, 21: 14, 52: 4, 11: 8, 14: 3})
Total experts: 62
Expert count for l0 43



Routing for first 5 tokens in layer 7: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {21: 14635, 49: 947, 31: 6100, 59: 8169, 33: 3053, 22: 1651, 13: 969, 58: 3824, 12: 4371, 18: 5155, 29: 11957, 0: 14796, 11: 12736, 61: 9947, 26: 7274, 10: 12549, 38: 1677, 19: 1744, 37: 1827, 47: 249, 24: 14829, 42: 755, 35: 3686, 39: 1543, 36: 3623, 52: 10264, 34: 2357, 23: 2883, 4: 2631, 62: 8545, 48: 4095, 45: 1252, 14: 2326, 15: 3311, 53: 2547, 16: 1742, 30: 2550, 5: 136, 56: 1704, 20: 2768, 46: 361, 7: 28, 32: 2138, 2: 1, 60: 180, 17: 20, 57: 432, 6: 527, 54: 132, 27: 159, 40: 298, 9: 490, 1: 259, 41: 480, 50: 141, 8: 27, 25: 481, 43: 220, 63: 99, 55: 2, 51: 3, 3: 1, 44: 3})
Total experts: 63
Expert count for l0 14796



Routing for first 5 tokens in layer 8: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {22: 1010, 16: 738, 44: 10971, 24: 3041, 49: 15674, 18: 35223, 52: 4095, 61: 12399, 57: 3444, 53: 13919, 45: 9785, 23: 2752, 51: 2081, 26: 1870, 48: 2668, 37: 2870, 2: 2949, 30: 667, 6: 3268, 62: 2620, 38: 1755, 21: 7399, 55: 6847, 25: 1931, 10: 381, 20: 2439, 27: 5984, 32: 1600, 3: 152, 17: 1249, 12: 1977, 1: 1811, 28: 6046, 9: 4169, 41: 2473, 63: 303, 42: 3595, 34: 3434, 54: 846, 50: 544, 11: 3436, 36: 2072, 15: 2725, 0: 592, 58: 1, 47: 2436, 60: 126, 40: 2157, 13: 1696, 35: 68, 56: 128, 31: 106, 19: 104, 7: 68, 14: 124, 33: 251, 29: 35, 46: 150, 4: 72, 5: 46, 59: 279, 43: 5, 8: 3})
Total experts: 63
Expert count for l0 592



Routing for first 5 tokens in layer 9: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {8: 1303, 20: 13437, 40: 2294, 9: 1973, 11: 4153, 52: 11054, 31: 4481, 47: 11581, 33: 9288, 39: 7862, 35: 2321, 54: 2520, 14: 10440, 61: 38703, 7: 6868, 1: 5349, 27: 8530, 30: 4053, 48: 1198, 6: 3747, 50: 3937, 63: 3823, 18: 2502, 12: 300, 2: 3509, 44: 1258, 55: 3091, 37: 882, 13: 2128, 26: 223, 42: 1078, 43: 3340, 51: 1225, 36: 742, 60: 552, 28: 7301, 49: 1816, 29: 370, 62: 3524, 38: 1589, 10: 491, 23: 1796, 53: 376, 32: 905, 34: 550, 22: 1213, 59: 991, 24: 243, 46: 84, 4: 31, 21: 921, 16: 352, 3: 386, 45: 128, 5: 375, 58: 68, 15: 1, 0: 50, 57: 273, 41: 67, 19: 9, 25: 4})
Total experts: 62
Expert count for l0 50



Routing for first 5 tokens in layer 10: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {4: 645, 11: 3593, 18: 10743, 23: 1797, 3: 11421, 58: 11143, 49: 1140, 55: 2905, 24: 5485, 60: 12081, 13: 6352, 2: 2699, 62: 4750, 50: 5810, 26: 8042, 52: 11553, 5: 3578, 9: 1521, 7: 5395, 46: 2062, 0: 5944, 29: 5623, 20: 2258, 35: 3076, 34: 2534, 44: 458, 19: 397, 59: 7618, 30: 880, 6: 777, 16: 4570, 28: 4783, 32: 3568, 1: 6759, 8: 5020, 14: 4487, 53: 2475, 41: 1802, 12: 5339, 15: 4796, 31: 775, 43: 904, 37: 1147, 17: 2242, 45: 1454, 47: 865, 27: 2091, 21: 1142, 22: 1050, 39: 2862, 42: 977, 57: 420, 56: 27, 61: 349, 51: 722, 63: 19, 10: 123, 38: 69, 36: 331, 54: 111, 48: 95, 33: 4, 25: 1})
Total experts: 63
Expert count for l0 5944



Routing for first 5 tokens in layer 11: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {56: 347, 46: 5045, 29: 1079, 9: 1189, 14: 465, 24: 11659, 12: 11904, 28: 773, 26: 7480, 49: 2165, 40: 1179, 54: 17697, 44: 4244, 39: 3212, 43: 5213, 7: 4722, 2: 1045, 11: 7165, 15: 3065, 55: 2509, 5: 4908, 45: 8529, 17: 859, 41: 2654, 62: 1011, 20: 5077, 57: 4195, 10: 384, 33: 2870, 51: 3203, 13: 46, 63: 705, 18: 5557, 32: 8258, 34: 4447, 8: 4506, 61: 1637, 35: 4453, 58: 1697, 16: 3207, 22: 3985, 47: 2134, 6: 705, 3: 4472, 4: 3075, 31: 2595, 0: 3939, 36: 4258, 21: 643, 25: 1, 30: 4044, 60: 2005, 23: 2887, 50: 5203, 19: 1035, 37: 74, 42: 302, 38: 3, 27: 74, 52: 16, 59: 117, 1: 375, 48: 1349, 53: 8})
Total experts: 64
Expert count for l0 3939



Routing for first 5 tokens in layer 12: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {59: 1875, 30: 3526, 20: 6588, 10: 11182, 29: 333, 15: 11846, 57: 3983, 48: 7971, 3: 4917, 23: 11348, 4: 3264, 44: 13455, 13: 1462, 47: 15506, 55: 8106, 19: 598, 32: 313, 28: 4303, 40: 1164, 12: 854, 26: 391, 56: 751, 58: 312, 17: 1571, 33: 940, 45: 882, 6: 4362, 50: 12163, 41: 3692, 7: 1866, 34: 3467, 52: 817, 14: 8105, 63: 15160, 42: 942, 60: 5301, 9: 1862, 39: 7075, 0: 2052, 8: 351, 27: 358, 49: 1359, 46: 3039, 24: 1503, 51: 33, 53: 4005, 16: 625, 21: 275, 36: 1180, 22: 443, 2: 1580, 38: 112, 62: 1293, 61: 184, 1: 381, 18: 1335, 54: 246, 43: 145, 37: 213, 5: 10, 31: 19, 35: 618, 11: 24, 25: 23})
Total experts: 64
Expert count for l0 2052



Routing for first 5 tokens in layer 13: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {2: 674, 4: 3357, 10: 8552, 57: 15582, 49: 3380, 18: 6876, 5: 1252, 13: 8818, 30: 12201, 39: 11580, 28: 3676, 32: 17245, 50: 5145, 55: 4030, 17: 948, 9: 508, 15: 36408, 56: 1574, 46: 5927, 54: 3548, 48: 1720, 27: 4548, 43: 1318, 29: 233, 6: 1239, 23: 4217, 22: 8471, 58: 666, 41: 8629, 31: 1002, 38: 2727, 61: 2173, 11: 603, 0: 985, 1: 233, 26: 1741, 12: 190, 62: 151, 35: 1379, 24: 1089, 33: 143, 37: 256, 34: 782, 44: 1390, 51: 1991, 59: 146, 3: 150, 45: 298, 25: 27, 8: 1671, 40: 23, 60: 105, 36: 90, 63: 59, 53: 70, 42: 198, 21: 33, 19: 658, 52: 54, 7: 48, 20: 804, 47: 12, 16: 55, 14: 1})
Total experts: 64
Expert count for l0 985



Routing for first 5 tokens in layer 14: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {60: 3376, 24: 8109, 61: 2117, 13: 5379, 19: 10639, 39: 911, 1: 4757, 63: 35632, 32: 41808, 20: 6777, 52: 5154, 12: 740, 37: 426, 57: 3751, 23: 1857, 7: 1315, 47: 3833, 27: 186, 49: 786, 5: 2652, 43: 17304, 22: 418, 40: 3405, 9: 3106, 11: 2487, 28: 791, 44: 79, 46: 2081, 48: 2928, 54: 4202, 18: 3952, 15: 1081, 31: 863, 45: 1863, 35: 1421, 16: 346, 17: 1737, 56: 739, 29: 1353, 58: 330, 26: 593, 10: 345, 4: 27, 42: 980, 34: 127, 36: 100, 53: 936, 38: 566, 14: 4199, 41: 122, 6: 39, 50: 19, 0: 7, 25: 25, 30: 872, 3: 137, 62: 554, 51: 38, 33: 3159, 55: 24, 8: 86, 21: 8, 2: 4, 59: 1})
Total experts: 64
Expert count for l0 7



Routing for first 5 tokens in layer 15: 
Total assignments: 203659
Expert counts: defaultdict(<class 'int'>, {6: 3172, 1: 3377, 53: 3494, 0: 303, 52: 7618, 31: 5142, 43: 4582, 57: 4987, 42: 49169, 29: 13414, 44: 844, 37: 2030, 24: 908, 3: 3945, 60: 300, 47: 5885, 36: 2237, 61: 5460, 27: 173, 11: 1153, 41: 1158, 5: 4465, 19: 2567, 48: 1767, 35: 2499, 51: 32267, 8: 370, 9: 101, 58: 1032, 28: 415, 63: 870, 54: 6745, 20: 3461, 4: 5882, 25: 1780, 23: 685, 62: 2166, 40: 1457, 2: 3984, 56: 1194, 22: 1220, 17: 1929, 21: 330, 33: 549, 32: 243, 46: 2723, 12: 151, 15: 812, 59: 229, 50: 84, 45: 19, 55: 659, 16: 239, 7: 110, 13: 83, 18: 28, 10: 45, 34: 247, 30: 15, 26: 124, 38: 723, 39: 34, 49: 5})
Total experts: 63
Expert count for l0 303


#### physics

In [15]:
# Read and chunk input file
file_path = 'data/physics_arxiv_200k.txt'
domain = 'physics'
chunks = prepare_text_input(file_path, chunk_size=1024, tokenizer=tokenizer)

# Check if CUDA is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)  # Move model to GPU if available

# Process all chunks
all_results = []
for i, chunk in enumerate(chunks):
    print(f'Processing chunk {i+1}/{len(chunks)}')
    # print(f'Sample text: {chunk[:10]}...')  
    
    # Get router logits for the chunk
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        results = get_router_logits(model, chunk)
    all_results.append(results)
    
    # Save intermediate results 
    update_router_logits_json(results, domain=domain)
    
    # Clear GPU cache periodically
    if (i + 1) % 10 == 0:  # Every 10 chunks
        torch.cuda.empty_cache()

Processing chunk 1/199



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Processing chunk 2/199
Processing chunk 3/199
Processing chunk 4/199
Processing chunk 5/199
Processing chunk 6/199
Processing chunk 7/199
Processing chunk 8/199
Processing chunk 9/199
Processing chunk 10/199
Processing chunk 11/199
Processing chunk 12/199
Processing chunk 13/199
Processing chunk 14/199
Processing chunk 15/199
Processing chunk 16/199
Processing chunk 17/199
Processing chunk 18/199
Processing chunk 19/199
Processing chunk 20/199
Processing chunk 21/199
Processing chunk 22/199
Processing chunk 23/199
Processing chunk 24/199
Processing chunk 25/199
Processing chunk 26/199
Processing chunk 27/199
Processing chunk 28/199
Processing chunk 29/199
Processing chunk 30/199
Processing chunk 31/199
Processing chunk 32/199
Processing chunk 33/199
Processing chunk 34/199
Processing chunk 35/199
Processing chunk 36/199
Processing chunk 37/199
Processing chunk 38/199
Processing chunk 39/199
Processing chunk 40/199
Processing chunk 41/199
Processing chunk 42/199
Processing chunk 43/199


In [16]:
domain = 'physics'
num_layers = 16  # OLMoE has 16 layers

# Combine results from all chunks
combined_results = []
for layer_idx in range(len(all_results[0])):  # For each layer
    layer_combined = []
    for chunk_result in all_results:
        # Skip if chunk_result[layer_idx] is not iterable
        if not isinstance(chunk_result[layer_idx], (list, tuple)):
            continue
            
        # Move data to GPU if available 
        if torch.cuda.is_available():
            chunk_result = [
                (token.cuda() if torch.is_tensor(token) else token,
                 expert.cuda() if torch.is_tensor(expert) else expert,
                 prob.cuda() if torch.is_tensor(prob) else prob)
                for token, expert, prob in chunk_result[layer_idx]
            ]
        layer_combined.extend(chunk_result)
    combined_results.append(layer_combined)

# Analyze and plot routing for all layers
for layer_to_plot in range(num_layers):
    print(f"\nRouting for first 5 tokens in layer {layer_to_plot}: ")
    layer_results = combined_results[layer_to_plot]
    
    # Skip if no valid results for this layer
    if not layer_results:
        print(f"No valid results for layer {layer_to_plot}")
        continue
        
    for token_info in layer_results[:5]:  # Limit to first 5 tokens
        # Skip if token_info is not a tuple/list
        if not isinstance(token_info, (tuple, list)):
            continue
            
        token, expert, prob = token_info
        # Move to CPU for printing
        if torch.cuda.is_available():
            prob = prob.cpu()
        print(f"Token: {token}, Expert: {expert}, Probability: {prob:.3f}")

    # Plot expert distribution for all processed data
    with torch.cuda.amp.autocast():  # Enable automatic mixed precision
        fig = plot_expert_distribution(layer_idx=layer_to_plot, domain=domain)
    fig.show()

    # Save plot as HTML and image
    fig.write_html(f'plots/physics/{domain}_layer{layer_to_plot}_expert_dist.html')
    fig.write_image(f'plots/physics/{domain}_layer{layer_to_plot}_expert_dist.png')
    
    # Clear GPU cache after each layer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Routing for first 5 tokens in layer 0: 



`torch.cuda.amp.autocast(args...)` is deprecated. Please use `torch.amp.autocast('cuda', args...)` instead.



Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {57: 131, 8: 3488, 18: 1585, 63: 1033, 15: 4003, 6: 774, 27: 4550, 41: 8372, 24: 1891, 0: 54317, 46: 5526, 30: 2051, 17: 1130, 22: 7610, 31: 3634, 59: 4545, 38: 10507, 42: 1327, 1: 1566, 60: 2466, 44: 3464, 55: 3247, 47: 2041, 2: 3758, 29: 1306, 36: 11489, 50: 4983, 61: 858, 13: 2271, 37: 2719, 5: 2662, 40: 1368, 14: 1058, 19: 5674, 52: 2595, 16: 1602, 35: 1166, 33: 762, 26: 3113, 25: 3134, 49: 1915, 56: 835, 48: 1633, 62: 1593, 32: 882, 10: 6709, 54: 1259, 3: 991, 51: 1754, 34: 818, 53: 1476, 45: 1149, 9: 760, 4: 1379, 20: 21, 7: 113, 11: 183, 39: 119, 28: 35, 43: 98, 12: 31, 58: 20, 23: 2, 21: 5})
Total experts: 64
Expert count for l0 54317



Routing for first 5 tokens in layer 1: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {18: 1605, 60: 1213, 11: 1706, 7: 10213, 61: 6854, 28: 915, 24: 556, 47: 80866, 30: 2196, 12: 5301, 35: 299, 46: 1168, 44: 1309, 0: 7054, 9: 8041, 16: 12200, 2: 760, 23: 5157, 21: 1791, 27: 6884, 6: 331, 48: 2057, 36: 1994, 22: 356, 53: 3990, 51: 1014, 4: 618, 20: 1645, 8: 775, 50: 1448, 42: 961, 32: 1665, 34: 904, 49: 829, 45: 552, 26: 495, 55: 1096, 19: 644, 29: 2706, 10: 274, 17: 4825, 58: 2796, 59: 1559, 52: 652, 39: 2740, 31: 1333, 54: 1638, 1: 1447, 40: 1633, 3: 681, 62: 419, 57: 430, 33: 300, 37: 622, 14: 335, 56: 336, 43: 215, 63: 427, 13: 278, 25: 15, 38: 146, 15: 102, 41: 68, 5: 117})
Total experts: 64
Expert count for l0 7054



Routing for first 5 tokens in layer 2: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {30: 4019, 14: 2804, 11: 599, 7: 3499, 5: 84, 32: 5239, 28: 2882, 15: 15081, 26: 18188, 27: 5549, 36: 26576, 62: 548, 63: 20833, 35: 7761, 34: 3762, 29: 5746, 44: 2293, 47: 243, 19: 2101, 9: 773, 50: 6658, 25: 509, 20: 3606, 33: 6960, 3: 1631, 23: 6814, 22: 4521, 48: 1321, 57: 587, 0: 1777, 1: 2153, 52: 1317, 21: 2597, 61: 5590, 45: 2245, 24: 1616, 17: 1430, 46: 1439, 41: 778, 54: 5104, 16: 5243, 18: 981, 39: 2461, 6: 727, 53: 73, 31: 1059, 12: 474, 49: 192, 40: 42, 13: 197, 38: 531, 60: 1159, 43: 701, 56: 338, 42: 513, 2: 180, 55: 324, 51: 429, 8: 323, 10: 58, 37: 182, 59: 68, 58: 68})
Total experts: 63
Expert count for l0 1777



Routing for first 5 tokens in layer 3: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {39: 222, 42: 7026, 61: 3414, 63: 5189, 35: 79156, 26: 1043, 14: 42, 15: 233, 0: 1443, 4: 374, 38: 6349, 12: 4431, 40: 8341, 48: 459, 37: 4510, 51: 14653, 3: 96, 45: 4712, 28: 1442, 20: 3182, 50: 1763, 30: 7242, 31: 818, 49: 3202, 57: 4186, 54: 2909, 56: 3212, 52: 1682, 29: 387, 44: 1199, 9: 4356, 11: 1112, 27: 1305, 32: 3868, 7: 3665, 5: 1359, 10: 224, 19: 1555, 60: 238, 18: 4277, 47: 527, 36: 1437, 46: 927, 59: 784, 34: 479, 8: 994, 62: 251, 17: 155, 41: 1407, 2: 396, 16: 201, 21: 78, 25: 61, 33: 64, 24: 151, 13: 259, 43: 23, 6: 53, 1: 279, 58: 122, 22: 4, 53: 16, 23: 9, 55: 3})
Total experts: 64
Expert count for l0 1443



Routing for first 5 tokens in layer 4: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {27: 2342, 6: 3133, 17: 85073, 25: 4098, 3: 2631, 35: 2845, 46: 20523, 34: 1744, 14: 6395, 16: 13419, 26: 5645, 21: 4621, 15: 908, 52: 2516, 59: 5387, 56: 973, 23: 2550, 5: 2384, 44: 1517, 7: 4694, 12: 178, 33: 1591, 24: 1432, 39: 1160, 42: 527, 58: 1116, 30: 694, 20: 382, 8: 952, 57: 1436, 43: 962, 55: 102, 13: 867, 38: 1142, 22: 502, 2: 265, 0: 463, 32: 2406, 61: 547, 41: 677, 28: 341, 4: 798, 49: 2227, 40: 117, 50: 641, 53: 58, 62: 811, 60: 215, 51: 694, 37: 197, 10: 270, 29: 5341, 47: 112, 31: 495, 9: 206, 1: 92, 45: 64, 36: 21, 19: 23, 48: 11, 54: 2, 18: 9, 63: 5, 11: 7})
Total experts: 64
Expert count for l0 463



Routing for first 5 tokens in layer 5: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {31: 13987, 51: 3841, 27: 8388, 53: 4840, 35: 11793, 57: 10658, 38: 34941, 50: 1243, 47: 556, 13: 3186, 1: 3625, 63: 1174, 54: 10168, 23: 79, 25: 12536, 18: 2834, 11: 4273, 19: 3020, 22: 6767, 34: 2223, 41: 645, 49: 485, 58: 2773, 30: 1478, 9: 2308, 20: 500, 10: 3923, 62: 1267, 6: 1036, 48: 1693, 52: 4113, 7: 1959, 0: 22481, 55: 1729, 61: 1182, 5: 1743, 39: 268, 14: 2401, 8: 1342, 17: 2168, 21: 849, 36: 1460, 24: 1207, 29: 346, 4: 444, 16: 72, 28: 169, 15: 628, 32: 168, 43: 458, 3: 44, 44: 620, 12: 212, 26: 8, 56: 809, 59: 128, 42: 13, 46: 19, 33: 41, 2: 21, 37: 37, 40: 135, 60: 31, 45: 11})
Total experts: 64
Expert count for l0 22481



Routing for first 5 tokens in layer 6: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {49: 360, 9: 1315, 18: 19195, 1: 6702, 22: 6350, 20: 3632, 40: 14515, 58: 1849, 4: 14063, 57: 44521, 29: 1587, 55: 669, 33: 10150, 50: 6650, 48: 564, 5: 24543, 13: 2386, 54: 666, 62: 11299, 30: 1521, 44: 366, 24: 3941, 63: 6042, 12: 2203, 31: 37, 35: 549, 7: 851, 37: 283, 39: 84, 59: 977, 8: 2441, 17: 1586, 16: 1029, 2: 1446, 25: 298, 36: 978, 47: 487, 61: 1268, 26: 940, 6: 674, 42: 84, 14: 85, 60: 878, 32: 552, 53: 669, 3: 465, 46: 292, 56: 88, 51: 43, 45: 120, 34: 514, 19: 122, 43: 116, 11: 66, 28: 129, 52: 9, 41: 25, 0: 112, 21: 146, 23: 10, 38: 6, 15: 5, 10: 3})
Total experts: 63
Expert count for l0 112



Routing for first 5 tokens in layer 7: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {21: 2356, 39: 661, 0: 4803, 12: 6868, 33: 3330, 45: 3143, 36: 13623, 59: 15686, 24: 11486, 58: 14781, 63: 9369, 26: 11841, 48: 5028, 38: 4379, 61: 8296, 22: 6199, 52: 6092, 18: 205, 35: 5532, 23: 2270, 57: 3148, 32: 6160, 56: 6353, 46: 2288, 5: 1263, 15: 5117, 29: 4129, 11: 1659, 40: 743, 42: 1720, 20: 2101, 6: 165, 4: 5793, 8: 139, 19: 2981, 49: 1960, 34: 5162, 41: 1334, 16: 4037, 31: 1303, 17: 1448, 10: 1683, 30: 614, 14: 1914, 60: 717, 37: 449, 53: 205, 51: 91, 43: 489, 9: 285, 50: 598, 13: 71, 62: 40, 2: 37, 25: 191, 1: 201, 47: 255, 27: 432, 54: 147, 55: 75, 44: 68, 7: 38, 28: 3, 3: 2})
Total experts: 64
Expert count for l0 4803



Routing for first 5 tokens in layer 8: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {22: 1284, 32: 3196, 49: 2938, 16: 7052, 44: 7889, 52: 1933, 60: 498, 62: 1382, 14: 84207, 61: 12763, 18: 18598, 5: 621, 2: 1668, 6: 2030, 45: 1713, 24: 5336, 35: 944, 53: 5008, 20: 1538, 1: 2098, 37: 4119, 10: 1961, 50: 1405, 40: 1432, 27: 1937, 51: 2712, 48: 2144, 55: 1086, 63: 969, 54: 1414, 21: 4656, 57: 1643, 9: 1242, 11: 2336, 0: 695, 25: 1452, 17: 1195, 36: 1701, 38: 645, 31: 615, 47: 256, 28: 316, 30: 325, 42: 218, 41: 707, 12: 263, 4: 250, 33: 667, 34: 162, 46: 222, 26: 550, 43: 45, 15: 311, 29: 49, 23: 158, 56: 782, 13: 115, 3: 42, 7: 10, 19: 44, 59: 2, 8: 3, 58: 3, 39: 1})
Total experts: 64
Expert count for l0 695



Routing for first 5 tokens in layer 9: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {8: 12510, 34: 2016, 28: 23722, 57: 14648, 5: 64268, 20: 7967, 47: 5971, 36: 1861, 50: 1782, 14: 7154, 24: 96, 31: 1935, 7: 1929, 16: 2268, 38: 1547, 26: 756, 33: 3916, 55: 854, 1: 4925, 2: 1751, 23: 1791, 27: 3604, 43: 4121, 40: 675, 52: 2583, 30: 1501, 13: 1742, 54: 1472, 56: 150, 3: 134, 48: 2202, 49: 614, 22: 2883, 53: 659, 6: 1171, 63: 518, 42: 4405, 18: 482, 44: 1203, 11: 1814, 61: 468, 4: 74, 39: 272, 21: 584, 59: 973, 35: 558, 0: 218, 10: 377, 12: 133, 60: 829, 29: 61, 58: 183, 19: 420, 45: 361, 62: 147, 9: 1408, 32: 539, 37: 147, 46: 1, 51: 91, 41: 73, 25: 35, 17: 4})
Total experts: 63
Expert count for l0 218



Routing for first 5 tokens in layer 10: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {4: 1615, 19: 5064, 11: 7905, 59: 8869, 60: 15794, 52: 5655, 3: 11609, 31: 206, 34: 1805, 15: 3667, 26: 7198, 14: 7581, 29: 8182, 28: 18435, 24: 977, 62: 2418, 57: 1286, 49: 768, 58: 3558, 13: 9227, 50: 2754, 43: 9814, 30: 728, 36: 1518, 0: 1635, 37: 1005, 63: 157, 7: 4076, 23: 4769, 18: 6879, 32: 3947, 39: 4029, 12: 4901, 21: 2243, 5: 1366, 17: 3329, 42: 1880, 27: 1961, 1: 591, 38: 466, 53: 986, 47: 266, 44: 446, 55: 2081, 2: 2728, 9: 3603, 16: 4317, 56: 1289, 61: 833, 46: 3019, 45: 723, 35: 789, 41: 1483, 10: 283, 51: 61, 8: 358, 33: 41, 20: 10, 54: 235, 48: 76, 40: 16, 22: 17, 6: 28, 25: 1})
Total experts: 64
Expert count for l0 1635



Routing for first 5 tokens in layer 11: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {56: 486, 44: 1972, 14: 225, 62: 1914, 23: 24299, 29: 1998, 51: 2631, 18: 2748, 48: 6410, 5: 2739, 46: 5018, 54: 19344, 36: 2119, 34: 2248, 47: 30704, 32: 5679, 40: 375, 33: 5030, 26: 3659, 61: 4772, 17: 403, 12: 7751, 3: 4119, 19: 2215, 24: 5350, 11: 5269, 45: 7086, 58: 2543, 35: 1473, 9: 7823, 0: 1812, 57: 3625, 15: 832, 30: 3106, 41: 1963, 22: 2498, 4: 1924, 16: 1193, 20: 5229, 39: 2318, 8: 1825, 7: 568, 31: 2432, 60: 216, 42: 849, 13: 100, 63: 135, 49: 1344, 28: 186, 37: 36, 38: 125, 6: 205, 43: 600, 2: 273, 52: 23, 1: 697, 21: 162, 10: 47, 59: 407, 25: 25, 50: 307, 55: 81, 27: 8, 53: 3})
Total experts: 64
Expert count for l0 1812



Routing for first 5 tokens in layer 12: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {59: 3569, 6: 4554, 29: 78, 17: 3220, 10: 7381, 30: 2347, 55: 9671, 58: 559, 44: 8539, 43: 88103, 35: 8846, 57: 2473, 47: 7140, 53: 2015, 23: 2453, 3: 835, 15: 1208, 40: 513, 4: 3606, 60: 2872, 1: 765, 20: 4837, 48: 3791, 50: 5711, 9: 1957, 2: 1021, 26: 722, 62: 1606, 39: 2572, 28: 2310, 54: 1665, 8: 477, 14: 2382, 31: 1882, 49: 1332, 56: 918, 52: 548, 46: 275, 41: 854, 34: 1038, 13: 785, 0: 248, 21: 58, 18: 518, 12: 733, 24: 1308, 33: 112, 36: 710, 37: 144, 22: 43, 7: 663, 51: 86, 11: 248, 32: 379, 45: 225, 19: 255, 16: 53, 42: 187, 27: 61, 63: 37, 61: 24, 25: 9, 5: 15, 38: 10})
Total experts: 64
Expert count for l0 248



Routing for first 5 tokens in layer 13: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {2: 724, 9: 151, 5: 5665, 54: 1137, 62: 3137, 17: 7354, 55: 3346, 4: 2587, 20: 63226, 39: 6760, 57: 15622, 32: 9043, 0: 1795, 22: 8355, 41: 4588, 46: 955, 38: 6314, 61: 1962, 45: 889, 18: 5186, 19: 17779, 27: 2143, 43: 538, 15: 952, 58: 1392, 10: 5403, 23: 2302, 28: 1219, 30: 3653, 31: 2474, 8: 1563, 3: 2325, 50: 2816, 35: 1360, 11: 983, 6: 55, 51: 725, 49: 915, 24: 454, 48: 589, 36: 153, 1: 152, 14: 698, 34: 388, 26: 174, 44: 140, 40: 100, 12: 268, 42: 184, 29: 315, 47: 150, 33: 83, 7: 256, 59: 427, 25: 51, 37: 192, 56: 261, 63: 197, 60: 275, 13: 189, 53: 11, 21: 357, 16: 10, 52: 89})
Total experts: 64
Expert count for l0 1795



Routing for first 5 tokens in layer 14: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {60: 2728, 42: 2225, 12: 4453, 47: 4628, 24: 14860, 58: 2455, 9: 6070, 38: 22368, 46: 5441, 40: 3071, 63: 5771, 56: 5201, 44: 1340, 48: 13245, 20: 4953, 14: 1463, 5: 1875, 57: 4666, 49: 235, 11: 1919, 10: 1267, 26: 834, 32: 1290, 16: 4018, 15: 7445, 19: 7673, 37: 725, 1: 2943, 52: 4738, 17: 1111, 54: 2034, 35: 4520, 23: 4683, 45: 1515, 22: 2444, 7: 1226, 39: 1712, 33: 28080, 61: 1085, 18: 3157, 6: 2956, 62: 2031, 28: 330, 31: 745, 41: 978, 36: 319, 34: 108, 53: 594, 29: 544, 13: 572, 27: 378, 2: 70, 50: 82, 21: 194, 59: 12, 43: 833, 30: 704, 3: 162, 0: 19, 25: 79, 55: 133, 8: 202, 51: 38, 4: 6})
Total experts: 64
Expert count for l0 19



Routing for first 5 tokens in layer 15: 
Total assignments: 203556
Expert counts: defaultdict(<class 'int'>, {8: 29, 24: 8984, 34: 106019, 43: 4433, 9: 1443, 1: 5925, 53: 1608, 17: 29616, 31: 2492, 5: 1213, 40: 136, 41: 315, 52: 4778, 54: 5120, 29: 4579, 63: 1068, 55: 2007, 57: 4182, 3: 758, 2: 646, 47: 284, 62: 1222, 44: 337, 35: 1681, 36: 547, 4: 1736, 18: 528, 6: 1851, 22: 696, 20: 1589, 15: 366, 13: 293, 19: 388, 61: 1508, 23: 725, 28: 112, 33: 227, 11: 285, 0: 255, 12: 34, 50: 15, 48: 295, 42: 207, 37: 291, 38: 357, 21: 155, 58: 137, 51: 947, 60: 86, 16: 101, 39: 97, 30: 35, 56: 39, 45: 162, 10: 17, 7: 143, 26: 222, 46: 47, 25: 23, 32: 17, 27: 57, 59: 72, 49: 12, 14: 7})
Total experts: 64
Expert count for l0 255
