In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
import json
import numpy as np
import matplotlib.pyplot as plt
import os
from collections import defaultdict
import gc

In [2]:
def setup_device():
    """Set up device and optimizations with graceful CPU fallback"""
    if torch.cuda.is_available():
        device = torch.device("cuda")
        # Enable TF32 for better performance on Ampere+ GPUs
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
        print(f"Using CUDA device: {torch.cuda.get_device_name()}")
        print(f"GPU Memory available: {torch.cuda.get_device_properties(device).total_memory/1e9:.2f} GB")
    else:
        device = torch.device("cpu")
        print("CUDA not available, using CPU")
    return device

device = setup_device()

def clear_gpu_memory():
    """Clear GPU memory cache and force garbage collection"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    gc.collect()

CUDA not available, using CPU


In [3]:
def load_model(model_name, device):
    """Load model with appropriate dtype and device placement"""
    dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map="auto" if torch.cuda.is_available() else None,
        trust_remote_code=True,
    ).eval()
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    
    if not torch.cuda.is_available():
        model = model.to(device)
        
    return model, tokenizer

model, tokenizer = load_model("allenai/OLMoE-1B-7B-0924", device)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [4]:
def get_moe_metadata(model, input_ids):
    """Get both router logits and expert indices for all MoE layers"""
    router_logits_list = []
    expert_indices_list = []
    
    def hook_fn(module, input, output):
        hidden_states = input[0]
        logits = torch.matmul(hidden_states, module.weight.T)
        probs = torch.nn.functional.softmax(logits, dim=-1)
        _, indices = torch.topk(probs, model.config.num_experts_per_tok, dim=-1)
        
        router_logits_list.append(logits.detach())
        expert_indices_list.append(indices.detach())
        return output
    
    hooks = []
    layers = [layer for layer in model.model.layers 
             if layer.mlp.__class__.__name__ == 'OlmoeSparseMoeBlock']
    
    for layer in layers:
        hook = layer.mlp.gate.register_forward_hook(hook_fn)
        hooks.append(hook)

    # Move input_ids to same device as model
    input_ids = input_ids.to(model.device)
    
    with torch.no_grad():
        model(input_ids)
    
    for hook in hooks:
        hook.remove()

    router_logits = torch.stack(router_logits_list) if router_logits_list else None
    expert_indices = torch.stack(expert_indices_list) if expert_indices_list else None

    return router_logits, expert_indices

In [5]:
input = "the quick brown fox"
input_ids = tokenizer.encode(input, return_tensors="pt")
print(f"input_ids shape: {input_ids.shape}")

# Get MoE metadata
router_logits, expert_indices = get_moe_metadata(model, input_ids)

print(f"router_logits shape: {router_logits.shape}")
print(f"expert_indices shape: {expert_indices.shape}")
# print(f"router_logits shape: {router_logits.shape}")

# Each element in router_logits is tensor of shape:
# [batch_size, sequence_length, num_experts]

# Each element in expert_indices is tensor of shape:
# [batch_size, sequence_length, num_experts_per_tok]

input_ids shape: torch.Size([1, 4])
router_logits shape: torch.Size([16, 4, 64])
expert_indices shape: torch.Size([16, 4, 8])


In [6]:
def get_last_token_router_probs(router_logits, model_layer_idx):
    """
    Get router probabilities for the last token in the sequence
    for a specific MODEL LAYER INDEX (0-15 for MoE layers)
    """
    if model_layer_idx < 0 or model_layer_idx >= router_logits.size(0):
        raise ValueError(f"Invalid model_layer_idx {model_layer_idx}. Must be 0-15 for MoE layers")
    
    # Ensure inputs are on CUDA
    router_logits = router_logits.cuda()
    
    layer_logits = router_logits[model_layer_idx]  # [seq_len, num_experts]
    
    last_token_logits = layer_logits[-1]  # [num_experts]
    routing_probs = torch.nn.functional.softmax(last_token_logits, dim=-1)
    
    return routing_probs

def topk(router_probs, k):
    """zero out all components except top k router probabilities"""
    # Ensure inputs are on CUDA
    router_probs = router_probs.cuda()
    
    values, indices = torch.topk(router_probs, k)
    zeroed_probs = torch.zeros_like(router_probs, device='cuda')
    zeroed_probs[indices] = values
    return zeroed_probs

In [7]:
# x = get_last_token_router_probs(router_logits, 3)
# print(f"x: {x}")
# y = topk(x, 8)
# print(f"y: {y}")

In [15]:
def get_moe_data(model, tokenizer, prompts):
    """
    Collects both all-token and last-token MoE data in a single forward pass per prompt.
    
    Args:
        model: OLMoE model
        tokenizer: OLMoE tokenizer 
        prompts: List of prompts or list of lists of prompts
    """
    # Clear CUDA cache if available
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    
    # Flatten prompts if it's a list of lists
    if isinstance(prompts[0], list):
        prompts = [prompt for domain_prompts in prompts for prompt in domain_prompts]
    
    num_prompts = len(prompts)
    
    # First tokenize one prompt to get reasonable max length
    sample_output = tokenizer(prompts[0], return_length=True)
    max_seq_len = min(4096, max(len(tokenizer.encode(prompt)) for prompt in prompts))  # Cap at model's max length
    
    num_moe_layers = sum(1 for layer in model.model.layers if layer.mlp.__class__.__name__ == 'OlmoeSparseMoeBlock')
    print(f"Number of MoE layers: {num_moe_layers}")
    
    num_experts = model.config.num_experts  # 64 for OLMoE
    num_experts_per_tok = model.config.num_experts_per_tok  # 8 for OLMoE

    # Get device (cuda if available, else cpu)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Initialize tensors on appropriate device
    all_token_logits = torch.zeros((num_prompts, num_moe_layers, max_seq_len, num_experts),
                               dtype=torch.float16, device=device)
    all_token_experts = -torch.ones((num_prompts, num_moe_layers, max_seq_len, num_experts_per_tok),
                                dtype=torch.long, device=device)
    last_token_logits = torch.zeros((num_prompts, num_moe_layers, num_experts),
                                dtype=torch.float16, device=device)
    last_token_experts = torch.zeros((num_prompts, num_moe_layers, num_experts_per_tok),
                                 dtype=torch.long, device=device)

    # Process prompts in batches to avoid OOM
    batch_size = 8  # Adjust based on memory
    
    # Create progress bar
    from tqdm import tqdm
    pbar = tqdm(total=num_prompts, desc="Processing prompts")
    
    for i in range(0, num_prompts, batch_size):
        batch_prompts = prompts[i:i + batch_size]
        
        # Clear cache between batches if cuda available
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        for j, prompt in enumerate(batch_prompts):
            # Add padding and truncation
            inputs = tokenizer(
                prompt,
                return_tensors="pt",
                padding='max_length',
                truncation=True,
                max_length=max_seq_len
            )
            
            # Move inputs to appropriate device
            input_ids = inputs['input_ids'].to(device)
            attention_mask = inputs['attention_mask'].to(device)
            seq_len = attention_mask.sum().item()  # Get actual sequence length without padding
            
            with torch.no_grad():
                router_logits, expert_indices = get_moe_metadata(model, input_ids)
            
            for layer_idx in range(num_moe_layers):
                # Handle router logits - shape should be [1, seq_len, num_experts]
                layer_logits = router_logits[layer_idx].squeeze(0)
                
                # Handle expert indices - shape should be [1, seq_len, num_experts_per_tok]
                layer_experts = expert_indices[layer_idx].squeeze(0)
                
                # Store all tokens data
                all_token_logits[i+j, layer_idx, :seq_len] = layer_logits[:seq_len]
                all_token_experts[i+j, layer_idx, :seq_len] = layer_experts[:seq_len]
                
                # Store last token data (using last real token, not padding)
                last_token_logits[i+j, layer_idx] = layer_logits[seq_len-1]
                last_token_experts[i+j, layer_idx] = layer_experts[seq_len-1]
            
            # Clear intermediate tensors
            del input_ids, attention_mask, router_logits, expert_indices
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            # Update progress bar
            pbar.update(1)
    
    # Close progress bar
    pbar.close()

    return all_token_logits, all_token_experts, last_token_logits, last_token_experts

In [16]:
# test_prompts = ['the quick brown fox', 'the capital of japan is tokyo', 'the capital of france is paris', ]
# all_token_logits, all_token_experts, last_token_logits, last_token_experts = get_moe_data(model, tokenizer, prompts = test_prompts)

# print(f"all_token_logits shape: {all_token_logits.shape}") # [num_prompts, num_layers, seq_len, num_experts]
# print(f"all_token_experts shape: {all_token_experts.shape}")
# print(f"last_token_logits shape: {last_token_logits.shape}")
# print(f"last_token_experts shape: {last_token_experts.shape}")

In [17]:
def prepare_prompts_from_txt(txt_file_path,  domain = 'english', output_path=f'english.json'):
    """ read prompts from a txt file and save them in json format. """
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        prompts = [line.strip() for line in f.readlines() if line.strip()]
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({f"{domain}": prompts}, f, indent=4)
        
    print(f"{domain} prompts saved to {output_path}")
    return prompts

def parse_code_blocks(txt_file_path, output_path='code.json', domain='code'):
    """parse code blocks between ``` markers from a text file and save them in json format."""
    code_blocks = []
    current_block = []
    in_block = False
    
    with open(txt_file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip().startswith('```'):
                if in_block:
                    # Current block is complete, save it and start new block
                    if current_block:
                        code_blocks.append('\n'.join(current_block))
                    current_block = []
                # Always start a new block since ``` only indicates start
                in_block = True
                current_block = []
            elif in_block:
                # Add line to current block
                current_block.append(line.rstrip())
    
    # Save final block if exists
    if current_block:
        code_blocks.append('\n'.join(current_block))
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump({domain: code_blocks}, f, indent=4)
        
    print(f"code blocks saved to {output_path}")
    return code_blocks

In [18]:
def load_prompts_from_json(json_file_path):
    """load prompts from a json file and return them as a list."""
    with open(json_file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    # get the first (and only) value from the dictionary
    # since the json structure is {"domain": [prompts]}
    prompts = list(data.values())[0]
    return prompts

def prepare_multi_domain_prompts(domain_files, output_path='all_domain_prompts.json'):
    """
    prepare a json file containing prompts from multiple domains.
    
    args:
        domain_files: Dict mapping domain names to lists of tuples (file_path, parser_func)
            where parser_func is a function that takes a file path and returns a list of prompts
            
    example:
        domain_files = {
            'code': [('code.txt', parse_code_blocks)], 
            'english': [('english.txt', prepare_prompts_from_txt)]
        }
    """
    all_prompts = {}
    
    for domain, file_list in domain_files.items():
        domain_prompts = []
        for _, prompts in file_list:
            # Use load_prompts_from_json if prompts is a dict
            if not isinstance(prompts, list):
                prompts = load_prompts_from_json(prompts)
            domain_prompts.extend(prompts)
                
        all_prompts[domain] = domain_prompts
        
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(all_prompts, f, indent=4)
        
    print(f"all domain prompts saved to {output_path}")
    return all_prompts

def convert_all_to_list(all_prompts):
    """
    combines prompts from all domains into a single list of domain-specific prompt lists.
    returns a list in the format [[domain1_prompts], [domain2_prompts], ...].
    """
    # Create list of domain-specific prompt lists
    combined_prompts = [
        prompts for prompts in all_prompts.values()
    ]
        
    total_prompts = sum(len(prompts) for prompts in combined_prompts)
    print(f"total prompts: {total_prompts}")
    print(f"prompts per domain:")
    for domain, prompts in zip(all_prompts.keys(), combined_prompts):
        print(f"  {domain}: {len(prompts)}")
        
    return combined_prompts

In [19]:
prepare_prompts_from_txt('interp-data/engl-lit.txt', domain='english', output_path='olmoe-interp-pt/english.json')
prepare_prompts_from_txt('interp-data/french.txt', domain='french', output_path='olmoe-interp-pt/french.json')
parse_code_blocks('interp-data/code.txt', 'olmoe-interp-pt/code.json', domain='code')

code_prompts = load_prompts_from_json(json_file_path='olmoe-interp-pt/code.json')
print(f"total code prompts : {len(code_prompts)}")
english_prompts = load_prompts_from_json(json_file_path='olmoe-interp-pt/english.json')
print(f'total english prompts : {len(english_prompts)}')


domain_files = {
    'code': [('interp-data/code.txt', parse_code_blocks(txt_file_path='interp-data/code.txt', output_path='interp-data/code.json', domain='code'))],
    'english': [('interp-data/engl-lit.txt', prepare_prompts_from_txt(txt_file_path='interp-data/engl-lit.txt', output_path='interp-data/english.json', domain='english'))],
    'french': [('interp-data/french.txt', prepare_prompts_from_txt(txt_file_path='interp-data/french.txt', output_path='interp-data/french.json', domain='french'))]
}
all_prompts = prepare_multi_domain_prompts(domain_files, output_path='interp-data/all_prompts.json')
print(f"total domains : {len(all_prompts)}")

# convert all prompts to a single list of domain-specific prompt lists
combined_prompts = convert_all_to_list(all_prompts)

english prompts saved to olmoe-interp-pt/english.json
french prompts saved to olmoe-interp-pt/french.json
code blocks saved to olmoe-interp-pt/code.json
total code prompts : 200
total english prompts : 200
code blocks saved to interp-data/code.json
english prompts saved to interp-data/english.json
french prompts saved to interp-data/french.json
all domain prompts saved to interp-data/all_prompts.json
total domains : 3
total prompts: 600
prompts per domain:
  code: 200
  english: 200
  french: 200


In [20]:
# prepare_prompts_from_txt('interp-data/test.txt', domain='test', output_path='interp-data/test.json')

# test_prompts = load_prompts_from_json(json_file_path='interp-data/test.json')
# print(f"total test prompts : {len(test_prompts)}")

# domain_files = {
#     'test': [('interp-data/test.txt', prepare_prompts_from_txt(txt_file_path='interp-data/test.txt', output_path='interp-data/test.json', domain='test'))]
# }
# all_prompts = prepare_multi_domain_prompts(domain_files, output_path='interp-data/all_prompts.json')
# print(f"total domains : {len(all_prompts)}")

# # convert all prompts to a single list of domain-specific prompt lists
# combined_prompts = convert_all_to_list(all_prompts)

In [21]:
all_token_logits, all_token_experts, last_token_logits, last_token_experts = get_moe_data(model, tokenizer, prompts = combined_prompts)

print(f'all_token_logits shape: {all_token_logits.shape}')

Number of MoE layers: 16


Processing prompts: 100%|██████████| 600/600 [18:34<00:00,  1.86s/it]

all_token_logits shape: torch.Size([600, 16, 291, 64])





In [22]:
torch.save(all_token_logits, "olmoe-interp-pt/all_token_logits.pt")
torch.save(all_token_experts, "olmoe-interp-pt/all_token_experts.pt")
torch.save(last_token_logits, "olmoe-interp-pt/last_token_logits.pt")
torch.save(last_token_experts, "olmoe-interp-pt/last_token_experts.pt")

In [23]:
all_token_logits = torch.load("interp-pt/all_token_logits.pt", map_location=torch.device('cpu'), weights_only=True)
all_token_experts = torch.load("interp-pt/all_token_experts.pt", map_location=torch.device('cpu'), weights_only=True)
last_token_logits = torch.load("interp-pt/last_token_logits.pt", map_location=torch.device('cpu'), weights_only=True)
last_token_experts = torch.load("interp-pt/last_token_experts.pt", map_location=torch.device('cpu'), weights_only=True)

In [24]:
print(f'all_token_logits shape: {all_token_logits.shape}')
print(f'all_token_experts shape: {all_token_experts.shape}')
print(f'last_token_logits shape: {last_token_logits.shape}')
print(f'last_token_experts shape: {last_token_experts.shape}')

all_token_logits shape: torch.Size([600, 27, 284, 64])
all_token_experts shape: torch.Size([600, 27, 284, 6])
last_token_logits shape: torch.Size([600, 27, 64])
last_token_experts shape: torch.Size([600, 27, 6])


In [25]:
def bar_graph_visualize(last_token_experts, layer_number, domain, k=8):
    """
    Visualizes expert distribution for last tokens in a domain and layer using plotly.
    
    Args:
        last_token_experts: Tensor [num_prompts, num_layers, num_experts_per_tok]
        layer_number: Layer to analyze (1-indexed)
        domain: 'code', 'english', or other domain name
        k: Number of experts per token (default 8 for OLMoE)
    """
    if domain not in ['code', 'english', 'french']:
        raise ValueError("Invalid domain")
    
    # Get domain slice (adjust size based on your data)
    domain_slices = {
        'code': slice(0, 200),
        'english': slice(200, 400),
        'french': slice(400, 600)
    }
    domain_slice = domain_slices[domain]
    
    layer_idx = layer_number - 1
    if layer_idx < 0 or layer_idx >= last_token_experts.size(1):
        raise ValueError(f"Layer number must be between 1 and {last_token_experts.size(1)}")
    
    domain_experts = last_token_experts[domain_slice, layer_idx, :].numpy()
    
    expert_counts = np.zeros(64)  # OLMoE uses 64 experts
    for token_experts in domain_experts:
        unique_experts = np.unique(token_experts)
        for expert in unique_experts:
            expert_counts[expert] += 1
            
    total_tokens = domain_experts.shape[0]
    percentages = (expert_counts / total_tokens) * 100
    
    # Create plotly bar chart
    fig = go.Figure(data=[
        go.Bar(
            x=list(range(64)),
            y=percentages,
            marker_color=['darkblue' if p > 0 else 'lightgray' for p in percentages]
        )
    ])
    
    # Update layout
    fig.update_layout(
        title=f'Expert Usage Distribution - Layer {layer_number} ({domain})',
        xaxis_title='Expert ID',
        yaxis_title=f'% of {domain.capitalize()} Domain Tokens',
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=4
        ),
        yaxis=dict(range=[0, 100]),
        showlegend=False,
        width=1000,
        height=500
    )
    
    # Add gridlines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    
    fig.show()

In [28]:
bar_graph_visualize(last_token_experts, layer_number=15, domain='code')

In [29]:
bar_graph_visualize(last_token_experts, layer_number=15, domain='french')

In [30]:
bar_graph_visualize(last_token_experts, layer_number=15, domain='english')

In [35]:
def stacked_bar_graph(last_token_experts, layer_number):
    """
    Args:
        last_token_experts: Tensor of shape [600, 27, 6] (prompts, layers, experts)
        layer_number: Layer to analyze (1-27)
    """
    # Validate layer
    layer_idx = layer_number - 1
    if layer_idx < 0 or layer_idx >= last_token_experts.size(1):
        raise ValueError("Layer number must be between 1 and 27.")
    
    # Extract data for the specified layer [600 prompts, 6 experts]
    layer_data = last_token_experts[:, layer_idx, :].numpy()
    
    # Split into domains
    code = layer_data[:200]    # First 200 prompts (code)
    english = layer_data[200:400]  # Next 200 (english)
    french = layer_data[400:600]   # Last 200 (french)
    
    # Initialize domain-specific expert counts
    code_counts = np.zeros(64, dtype=int)
    eng_counts = np.zeros(64, dtype=int)
    fr_counts = np.zeros(64, dtype=int)
    
    # Count occurrences for each domain (unique experts per token)
    for domain_data, counts in zip([code, english, french], [code_counts, eng_counts, fr_counts]):
        for token_experts in domain_data:
            unique_experts = np.unique(token_experts)
            for expert in unique_experts:
                counts[expert] += 1
    
    # Convert counts to percentages (relative to N_D=600)
    total_tokens = 600
    code_pct = (code_counts / total_tokens) * 100
    eng_pct = (eng_counts / total_tokens) * 100
    fr_pct = (fr_counts / total_tokens) * 100
    
    # Create plotly stacked bar chart
    fig = go.Figure(data=[
        go.Bar(name='Code', x=list(range(64)), y=code_pct, marker_color='#2c7bb6'),
        go.Bar(name='English', x=list(range(64)), y=eng_pct, marker_color='#d7191c'),
        go.Bar(name='French', x=list(range(64)), y=fr_pct, marker_color='#fdae61')
    ])
    
    # Update layout
    fig.update_layout(
        barmode='stack',
        title=f'Domain Contributions to Experts (Layer {layer_number})',
        xaxis_title='Expert ID',
        yaxis_title='Percentage of Total Tokens',
        xaxis=dict(
            tickmode='linear',
            tick0=0,
            dtick=4
        ),
        showlegend=True,
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="right",
            x=0.99
        ),
        width=2000,
        height=750
    )
    
    # Add gridlines
    fig.update_xaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    fig.update_yaxes(showgrid=True, gridwidth=1, gridcolor='lightgray')
    
    fig.show()

In [36]:
stacked_bar_graph(last_token_experts, layer_number=1)

In [114]:
def bar_graph_all_tokens(all_token_experts, layer_number, domain):
    """
    Visualizes expert distribution for all tokens in a domain and layer.
    
    Args:
        all_token_experts: Tensor of shape [600, 27, 284, 6] (prompts, layers, tokens, experts)
        layer_number: Layer to analyze (1-27)
        domain: 'code', 'english', or 'french'
    """
    # Validate domain and extract slice
    if domain == 'code':
        domain_slice = slice(0, 200)
    elif domain == 'english':
        domain_slice = slice(200, 400)
    elif domain == 'french':
        domain_slice = slice(400, 600)
    else:
        raise ValueError("Domain must be 'code', 'english', or 'french'.")
    
    # Validate layer index
    layer_idx = layer_number - 1
    if layer_idx < 0 or layer_idx >= all_token_experts.size(1):
        raise ValueError("Layer number must be between 1 and 27.")
    
    # Extract data for domain and layer [200 prompts, 284 tokens, 6 experts]
    domain_data = all_token_experts[domain_slice, layer_idx, :, :].numpy()
    # Flatten and filter out padding (experts = -1)
    flattened = domain_data.reshape(-1, 6)  # [200*284, 6]
    valid_tokens_mask = (flattened[:, 0] != -1)  # Padding uses -1
    valid_experts = flattened[valid_tokens_mask]
    # Count unique experts per token
    expert_counts = np.zeros(64, dtype=int)
    for token_experts in valid_experts:
        unique_experts = np.unique(token_experts)
        for expert in unique_experts:
            expert_counts[expert] += 1
    
    # Compute percentages
    total_valid_tokens = valid_experts.shape[0]
    print(f'total valid tokens : {total_valid_tokens}')
    percentages = (expert_counts / total_valid_tokens) * 100
    
    # Plot
    plt.figure(figsize=(15, 6))
    bars = plt.bar(range(64), percentages, color='steelblue')
    plt.xlabel('expert ID', fontsize=12)
    plt.ylabel('% of total tokens selecting expert', fontsize=12)
    plt.title(f'expert selection for {domain.capitalize()} domain (layer {layer_number}) - all tokens', fontsize=14)
    plt.xticks(np.arange(0, 64, 4))
    plt.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Highlight bars with non-zero values
    for bar in bars:
        if bar.get_height() > 0:
            bar.set_color('#2c7bb6')  # Darker blue for emphasis
    
    plt.show()

In [None]:
bar_graph_all_tokens(all_token_experts, layer_number=1, domain='french')

In [157]:
def stacked_bar_graph_all_tokens(all_token_experts, layer_number):
    """
    Visualizes domain contributions to experts across ALL tokens (non-padded) for a given layer.
    Normalizes counts per domain to balance influence.
    """
    # Validate layer
    layer_idx = layer_number - 1
    if layer_idx < 0 or layer_idx >= all_token_experts.size(1):
        raise ValueError("Layer number must be between 1 and 27.")
    
    # Extract layer data [600 prompts, 284 tokens, 6 experts]
    layer_data = all_token_experts[:, layer_idx, :, :].numpy()
    
    # Split into domains
    code = layer_data[:200]    # Code domain (prompts 0-199)
    english = layer_data[200:400]  # English (prompts 200-399)
    french = layer_data[400:600]   # French (prompts 400-599)
    
    def process_domain(domain_data):
        """Process a domain's data to count valid tokens and expert usage."""
        flattened = domain_data.reshape(-1, 6)
        valid_mask = flattened[:, 0] != -1
        valid_experts = flattened[valid_mask]
        counts = np.zeros(64, dtype=int)
        for token in valid_experts:
            unique_experts = np.unique(token)
            for expert in unique_experts:
                counts[expert] += 1
        return counts, valid_experts.shape[0]
    
    # Process domains and get raw counts/tokens
    code_counts, code_valid = process_domain(code)
    eng_counts, eng_valid = process_domain(english)
    fr_counts, fr_valid = process_domain(french)
    
    # Normalize counts to balance domain contributions
    def normalize_counts(counts, domain_valid_tokens, scaling_factor=1000):
        """Scale counts to a common token count (e.g., 1000 tokens per domain)."""
        scale = scaling_factor / domain_valid_tokens
        return (counts * scale).astype(int)
    
    # Choose a scaling factor (e.g., min token count or fixed value)
    scaling_factor = min(code_valid, eng_valid, fr_valid)  # Use smallest domain size
    print(f'scaling factor : {scaling_factor}')
    
    # Normalize counts
    code_norm = normalize_counts(code_counts, code_valid, scaling_factor)
    eng_norm = normalize_counts(eng_counts, eng_valid, scaling_factor)
    fr_norm = normalize_counts(fr_counts, fr_valid, scaling_factor)
    
    # Total tokens after normalization (scaling_factor * 3 for 3 domains)
    total_normalized_tokens = scaling_factor * 3
    print(f'total normalized tokens : {total_normalized_tokens}')
    # Compute percentages
    code_pct = (code_norm / total_normalized_tokens) * 100
    eng_pct = (eng_norm / total_normalized_tokens) * 100
    fr_pct = (fr_norm / total_normalized_tokens) * 100
    
    # Plot settings
    experts = np.arange(64)
    colors = ['#2c7bb6', '#d7191c', '#fdae61']
    labels = ['Code', 'English', 'French']
    
    fig, ax = plt.subplots(figsize=(20, 8))
    ax.bar(experts, code_pct, color=colors[0], label=labels[0])
    ax.bar(experts, eng_pct, bottom=code_pct, color=colors[1], label=labels[1])
    ax.bar(experts, fr_pct, bottom=code_pct+eng_pct, color=colors[2], label=labels[2])
    
    ax.set_xlabel('Expert ID', fontsize=12)
    ax.set_ylabel('Percentage of Normalized Tokens (%)', fontsize=12)
    ax.set_title(f'Balanced Domain Contributions to Experts (Layer {layer_number})', fontsize=14)
    ax.set_xticks(np.arange(0, 64, 4))
    ax.legend(loc='upper right')
    plt.show()

In [None]:
stacked_bar_graph_all_tokens(all_token_experts, layer_number=27)

In [171]:
def expert_coactivation(last_token_experts, layer_number, top_k=64):
    """
    Visualizes expert co-activation for a specified layer using last token experts.
    
    Args:
        last_token_experts: Tensor of shape [600, 27, 6] (prompts, layers, top_k=6).
        layer_number: Layer to analyze (1-27).
        top_k: Number of top experts to display based on maximum co-activation scores.
    """
    # Validate layer index
    layer_idx = layer_number - 1
    if layer_idx < 0 or layer_idx >= last_token_experts.size(1):
        raise ValueError("Layer number must be between 1 and 27.")
    
    # Extract data for the specified layer [600 prompts, 6 experts]
    layer_data = last_token_experts[:, layer_idx, :].numpy()
    
    # Initialize co-occurrence and expert activation counts
    co_occurrence = np.zeros((64, 64), dtype=int)
    expert_activations = np.zeros(64, dtype=int)
    
    # Process each token's expert selections
    for token_experts in layer_data:
        unique_experts = np.unique(token_experts)  # Remove duplicates
        # Update expert activation counts
        for expert in unique_experts:
            expert_activations[expert] += 1
        # Update co-occurrence matrix
        for i in unique_experts:
            for j in unique_experts:
                if i != j:
                    co_occurrence[i, j] += 1
    
    # Compute co-activation matrix (directed)
    co_activation = np.zeros((64, 64))
    for i in range(64):
        if expert_activations[i] > 0:
            co_activation[i, :] = (co_occurrence[i, :] / expert_activations[i]) * 100
    
    # Identify top-k experts with highest max co-activation scores
    max_scores = np.max(co_activation, axis=1)
    top_experts = np.argsort(-max_scores)[:top_k]
    top_experts = np.sort(top_experts)  # Sort for ordered display
    
    # Filter matrix to include only top experts
    filtered_matrix = co_activation[np.ix_(top_experts, top_experts)]
    
    # Plot settings
    plt.figure(figsize=(12, 10))
    plt.imshow(filtered_matrix, cmap="viridis", interpolation="nearest", aspect="auto")
    plt.colorbar(label="Co-activation (%)", shrink=0.8)
    
    # Label axes with expert IDs
    tick_labels = [f"E{expert}" for expert in top_experts]
    plt.xticks(np.arange(len(top_experts)), tick_labels, rotation=90)
    plt.yticks(np.arange(len(top_experts)), tick_labels)
    plt.xlabel("Expert $E_j$")
    plt.ylabel("Expert $E_i$")
    plt.title(f"Expert Co-activation (Layer {layer_number}) - Top {top_k} Experts")
    plt.tight_layout()
    plt.show()

In [None]:
expert_coactivation(last_token_experts, layer_number=1)