# üî¨ OLMoE Hands-On Demo: See Inference with More Experts in Action

## Goal: Show EXACTLY how data flows through inference with different expert counts

You'll see:
- ‚úÖ Real token IDs and embeddings
- ‚úÖ Actual router logits and probabilities
- ‚úÖ Which experts are selected (with numbers!)
- ‚úÖ Side-by-side output comparison (8 vs 16 vs 32 vs 64 experts)
- ‚úÖ Quality differences in real outputs

---

## üì¶ Setup (Quick)

In [None]:
%%capture
!pip install -q transformers>=4.40.0 torch accelerate sentencepiece matplotlib pandas numpy

In [None]:
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
import warnings
warnings.filterwarnings('ignore')

print(f"‚úì PyTorch: {torch.__version__}")
print(f"‚úì CUDA Available: {torch.cuda.is_available()}")
print(f"‚úì GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'}")

## üöÄ Load Model

In [None]:
print("Loading OLMoE model... (this takes 2-3 minutes first time)\n")

model_name = "allenai/OLMoE-1B-7B-0924"
device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto",
    output_router_logits=True,  # CRITICAL: Get expert routing data
)

print("‚úì Model loaded!")
print(f"‚úì Device: {device}")
print(f"‚úì Config: {model.config.num_experts} experts, top-{model.config.num_experts_per_tok} default\n")

---

# üîç PART 1: Understanding the Data Flow

## Step-by-Step: What Happens During Inference

In [None]:
# Simple test prompt
prompt = "Artificial intelligence is"

print("="*80)
print("STEP-BY-STEP DATA FLOW")
print("="*80)
print(f"\nüìù Input Prompt: '{prompt}'\n")

# STEP 1: Tokenization
print("‚îÄ" * 80)
print("STEP 1: TOKENIZATION")
print("‚îÄ" * 80)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_ids = inputs.input_ids[0]

print(f"Token IDs: {input_ids.cpu().numpy()}")
print(f"Shape: {input_ids.shape}")
print(f"\nToken breakdown:")
for i, token_id in enumerate(input_ids):
    token_text = tokenizer.decode([token_id])
    print(f"  Position {i}: ID={token_id:5d} ‚Üí '{token_text}'")

print(f"\nTotal tokens in input: {len(input_ids)}")

In [None]:
# STEP 2: Forward pass to get router logits
print("\n" + "‚îÄ" * 80)
print("STEP 2: FORWARD PASS (Getting Router Decisions)")
print("‚îÄ" * 80)

with torch.no_grad():
    outputs = model(**inputs, output_router_logits=True, return_dict=True)

print(f"‚úì Forward pass complete")
print(f"\nOutput structure:")
print(f"  - logits shape: {outputs.logits.shape}")
print(f"    (batch_size, sequence_length, vocab_size)")

if outputs.router_logits:
    print(f"\n  - router_logits: {len(outputs.router_logits)} layers")
    print(f"    Each layer shape: {outputs.router_logits[0].shape}")
    print(f"    (batch_size, sequence_length, num_experts)")
else:
    print("  ‚ö†Ô∏è  Router logits not available!")

In [None]:
# STEP 3: Analyze router decisions for first layer
print("\n" + "‚îÄ" * 80)
print("STEP 3: ROUTER DECISIONS (Layer 0)")
print("‚îÄ" * 80)

if outputs.router_logits:
    # Get router logits for first layer
    router_logits_layer0 = outputs.router_logits[0][0]  # [seq_len, 64]
    
    # Convert to probabilities using softmax
    router_probs = torch.softmax(router_logits_layer0, dim=-1)
    
    print(f"Router probabilities shape: {router_probs.shape}")
    print(f"  ‚Üí For each of {router_probs.shape[0]} tokens, we have 64 expert probabilities\n")
    
    # Show detailed routing for each input token
    for token_idx in range(len(input_ids)):
        token_text = tokenizer.decode([input_ids[token_idx]])
        probs = router_probs[token_idx].cpu().numpy()
        
        # Get top-8 experts (default)
        top8_indices = np.argsort(probs)[-8:][::-1]
        top8_probs = probs[top8_indices]
        
        # Get top-16 experts
        top16_indices = np.argsort(probs)[-16:][::-1]
        top16_probs = probs[top16_indices]
        
        print(f"Token {token_idx}: '{token_text}'")
        print(f"  Top-8 experts (DEFAULT):")
        for i, (expert_id, prob) in enumerate(zip(top8_indices, top8_probs)):
            print(f"    {i+1}. Expert {expert_id:2d}: {prob:.4f} ({prob*100:.1f}%)")
        
        print(f"\n  Top-16 experts (if we use more):")
        for i, (expert_id, prob) in enumerate(zip(top16_indices[:8], top16_probs[:8])):
            marker = "‚úì" if i < 8 else "+"
            print(f"    {marker} {i+1:2d}. Expert {expert_id:2d}: {prob:.4f} ({prob*100:.1f}%)")
        for i, (expert_id, prob) in enumerate(zip(top16_indices[8:], top16_probs[8:]), 8):
            print(f"    + {i+1:2d}. Expert {expert_id:2d}: {prob:.4f} ({prob*100:.1f}%) [EXTRA with 16 experts]")
        
        print(f"\n  Total probability (top-8): {top8_probs.sum():.4f}")
        print(f"  Total probability (top-16): {top16_probs.sum():.4f}")
        print()
else:
    print("Router logits not available")

---

# üéØ PART 2: Real Inference with Different Expert Counts

## Now let's generate text with 8, 16, 32, and 64 experts and compare!

In [None]:
def generate_with_experts(model, tokenizer, prompt, num_experts, max_new_tokens=100, seed=42):
    """
    Generate text with specified number of experts.
    """
    # Set seed for reproducibility
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
    
    # Store and modify config
    original = model.config.num_experts_per_tok
    model.config.num_experts_per_tok = num_experts
    
    try:
        inputs = tokenizer(prompt, return_tensors="pt").to(device)
        
        import time
        start = time.time()
        
        with torch.no_grad():
            outputs = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                pad_token_id=tokenizer.eos_token_id,
            )
        
        elapsed = time.time() - start
        
        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
        num_tokens = len(outputs[0]) - len(inputs.input_ids[0])
        
        return {
            'text': generated_text,
            'time': elapsed,
            'tokens': num_tokens,
            'tokens_per_sec': num_tokens / elapsed
        }
    finally:
        model.config.num_experts_per_tok = original

print("‚úì Generation function ready!")

## üìä Test Case 1: Simple Factual Question

In [None]:
test_prompt_1 = "What is the capital of France?"

print("="*80)
print(f"TEST PROMPT: '{test_prompt_1}'")
print("="*80)

results_1 = {}

for num_experts in [8, 16, 32, 64]:
    print(f"\n{'‚îÄ'*80}")
    print(f"üî¨ GENERATING WITH {num_experts} EXPERTS")
    print(f"{'‚îÄ'*80}")
    
    result = generate_with_experts(
        model, tokenizer, test_prompt_1, 
        num_experts=num_experts, 
        max_new_tokens=50,
        seed=42
    )
    
    results_1[num_experts] = result
    
    print(f"\nüìù OUTPUT:")
    print(result['text'])
    print(f"\nüìä STATS:")
    print(f"  ‚è±Ô∏è  Time: {result['time']:.2f}s")
    print(f"  üìà Tokens: {result['tokens']}")
    print(f"  ‚ö° Speed: {result['tokens_per_sec']:.2f} tokens/sec")

print("\n" + "="*80)
print("COMPARISON SUMMARY")
print("="*80)

comparison_df = pd.DataFrame([
    {
        'Experts': k,
        'Time (s)': f"{v['time']:.2f}",
        'Tokens': v['tokens'],
        'Speed (tok/s)': f"{v['tokens_per_sec']:.2f}",
        'Output Preview': v['text'][:80] + '...'
    }
    for k, v in results_1.items()
])

print(comparison_df.to_string(index=False))
print("="*80)

## üìä Test Case 2: Technical Explanation

In [None]:
test_prompt_2 = "Explain how neural networks learn:"

print("="*80)
print(f"TEST PROMPT: '{test_prompt_2}'")
print("="*80)

results_2 = {}

for num_experts in [8, 16, 32, 64]:
    print(f"\n{'‚îÄ'*80}")
    print(f"üî¨ GENERATING WITH {num_experts} EXPERTS")
    print(f"{'‚îÄ'*80}")
    
    result = generate_with_experts(
        model, tokenizer, test_prompt_2, 
        num_experts=num_experts, 
        max_new_tokens=80,
        seed=42
    )
    
    results_2[num_experts] = result
    
    print(f"\nüìù OUTPUT:")
    print(result['text'])
    print(f"\nüìä STATS:")
    print(f"  ‚è±Ô∏è  Time: {result['time']:.2f}s")
    print(f"  üìà Tokens: {result['tokens']}")
    print(f"  ‚ö° Speed: {result['tokens_per_sec']:.2f} tokens/sec")

print("\n" + "="*80)
print("COMPARISON SUMMARY")
print("="*80)

comparison_df = pd.DataFrame([
    {
        'Experts': k,
        'Time (s)': f"{v['time']:.2f}",
        'Tokens': v['tokens'],
        'Speed (tok/s)': f"{v['tokens_per_sec']:.2f}",
    }
    for k, v in results_2.items()
])

print(comparison_df.to_string(index=False))
print("="*80)

## üìä Test Case 3: Code Generation

In [None]:
test_prompt_3 = "Write a Python function to calculate fibonacci numbers:"

print("="*80)
print(f"TEST PROMPT: '{test_prompt_3}'")
print("="*80)

results_3 = {}

for num_experts in [8, 16, 32, 64]:
    print(f"\n{'‚îÄ'*80}")
    print(f"üî¨ GENERATING WITH {num_experts} EXPERTS")
    print(f"{'‚îÄ'*80}")
    
    result = generate_with_experts(
        model, tokenizer, test_prompt_3, 
        num_experts=num_experts, 
        max_new_tokens=100,
        seed=42
    )
    
    results_3[num_experts] = result
    
    print(f"\nüìù OUTPUT:")
    print(result['text'])
    print(f"\nüìä STATS:")
    print(f"  ‚è±Ô∏è  Time: {result['time']:.2f}s")
    print(f"  üìà Tokens: {result['tokens']}")
    print(f"  ‚ö° Speed: {result['tokens_per_sec']:.2f} tokens/sec")

print("\n" + "="*80)
print("COMPARISON SUMMARY")
print("="*80)

comparison_df = pd.DataFrame([
    {
        'Experts': k,
        'Time (s)': f"{v['time']:.2f}",
        'Tokens': v['tokens'],
        'Speed (tok/s)': f"{v['tokens_per_sec']:.2f}",
    }
    for k, v in results_3.items()
])

print(comparison_df.to_string(index=False))
print("="*80)

---

# üìà PART 3: Visualize Performance

In [None]:
# Combine all results
all_results = {
    'Simple Question': results_1,
    'Technical Explanation': results_2,
    'Code Generation': results_3
}

# Create visualization
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

expert_counts = [8, 16, 32, 64]
colors = ['#3498db', '#e74c3c', '#2ecc71', '#f39c12']

# Plot 1: Speed comparison across test cases
ax = axes[0, 0]
x = np.arange(len(expert_counts))
width = 0.25

for i, (test_name, results) in enumerate(all_results.items()):
    speeds = [results[k]['tokens_per_sec'] for k in expert_counts]
    ax.bar(x + i*width, speeds, width, label=test_name, alpha=0.8)

ax.set_xlabel('Number of Experts', fontsize=12, fontweight='bold')
ax.set_ylabel('Tokens per Second', fontsize=12, fontweight='bold')
ax.set_title('Generation Speed Comparison', fontsize=14, fontweight='bold')
ax.set_xticks(x + width)
ax.set_xticklabels(expert_counts)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Plot 2: Generation time
ax = axes[0, 1]
for i, (test_name, results) in enumerate(all_results.items()):
    times = [results[k]['time'] for k in expert_counts]
    ax.plot(expert_counts, times, marker='o', label=test_name, linewidth=2, markersize=8)

ax.set_xlabel('Number of Experts', fontsize=12, fontweight='bold')
ax.set_ylabel('Generation Time (seconds)', fontsize=12, fontweight='bold')
ax.set_title('Time vs Expert Count', fontsize=14, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

# Plot 3: Speedup/Slowdown relative to 8 experts
ax = axes[1, 0]
x = np.arange(len(expert_counts))

for i, (test_name, results) in enumerate(all_results.items()):
    baseline_speed = results[8]['tokens_per_sec']
    relative_speeds = [results[k]['tokens_per_sec'] / baseline_speed for k in expert_counts]
    ax.bar(x + i*width, relative_speeds, width, label=test_name, alpha=0.8)

ax.axhline(y=1.0, color='black', linestyle='--', linewidth=2, label='Baseline (8 experts)')
ax.set_xlabel('Number of Experts', fontsize=12, fontweight='bold')
ax.set_ylabel('Relative Speed (vs 8 experts)', fontsize=12, fontweight='bold')
ax.set_title('Performance Scaling', fontsize=14, fontweight='bold')
ax.set_xticks(x + width)
ax.set_xticklabels(expert_counts)
ax.legend()
ax.grid(axis='y', alpha=0.3)

# Plot 4: Average across all tests
ax = axes[1, 1]
avg_speeds = []
avg_times = []

for k in expert_counts:
    speeds = [results[k]['tokens_per_sec'] for results in all_results.values()]
    times = [results[k]['time'] for results in all_results.values()]
    avg_speeds.append(np.mean(speeds))
    avg_times.append(np.mean(times))

ax.plot(expert_counts, avg_speeds, marker='s', linewidth=3, markersize=10, 
        color='purple', label='Avg Speed')
ax.set_xlabel('Number of Experts', fontsize=12, fontweight='bold')
ax.set_ylabel('Average Tokens/Second', fontsize=12, fontweight='bold')
ax.set_title('Average Performance Across All Tests', fontsize=14, fontweight='bold')
ax.grid(alpha=0.3)

# Add value labels on points
for x, y in zip(expert_counts, avg_speeds):
    ax.text(x, y + 0.5, f'{y:.1f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.savefig('olmoe_performance_comparison.png', dpi=300, bbox_inches='tight')
plt.show()

print("\n‚úì Visualizations complete!")

---

# üî¨ PART 4: Deep Dive - Expert Activation Patterns

In [None]:
# Analyze which experts are activated for different prompts
test_prompts = [
    "Write Python code:",
    "Solve this math problem:",
    "Tell me a story:",
    "Explain quantum physics:"
]

print("="*80)
print("EXPERT ACTIVATION ANALYSIS")
print("="*80)
print("\nShowing which experts are selected for different types of prompts...\n")

activation_data = []

for prompt in test_prompts:
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs, output_router_logits=True, return_dict=True)
    
    if outputs.router_logits:
        # Get first layer, first token routing
        router_logits = outputs.router_logits[0][0][0]  # First layer, batch, token
        router_probs = torch.softmax(router_logits, dim=-1).cpu().numpy()
        
        # Get top-8 and top-16
        top8 = np.argsort(router_probs)[-8:][::-1]
        top16 = np.argsort(router_probs)[-16:][::-1]
        
        print(f"Prompt: '{prompt}'")
        print(f"  Top-8 experts (default): {top8.tolist()}")
        print(f"  Top-16 experts: {top16.tolist()}")
        print(f"  Extra experts with 16: {[e for e in top16 if e not in top8]}")
        print()
        
        activation_data.append({
            'prompt': prompt,
            'top8': top8,
            'top16': top16,
            'probs': router_probs
        })

# Visualize expert activation patterns
fig, ax = plt.subplots(figsize=(14, 8))

# Create heatmap of expert activations
activation_matrix = np.zeros((len(test_prompts), 64))

for i, data in enumerate(activation_data):
    activation_matrix[i] = data['probs']

im = ax.imshow(activation_matrix, aspect='auto', cmap='YlOrRd', interpolation='nearest')
ax.set_xlabel('Expert Index', fontsize=12, fontweight='bold')
ax.set_ylabel('Prompt Type', fontsize=12, fontweight='bold')
ax.set_title('Expert Activation Heatmap (Probability Distribution)', fontsize=14, fontweight='bold')
ax.set_yticks(range(len(test_prompts)))
ax.set_yticklabels([p.replace(':', '') for p in test_prompts])

# Add colorbar
cbar = plt.colorbar(im, ax=ax)
cbar.set_label('Routing Probability', fontsize=10)

plt.tight_layout()
plt.savefig('expert_activation_patterns.png', dpi=300, bbox_inches='tight')
plt.show()

print("‚úì Expert activation analysis complete!")

---

# üìä PART 5: Summary Statistics

In [None]:
print("="*80)
print("FINAL SUMMARY: WHAT WE LEARNED")
print("="*80)

print("\n1Ô∏è‚É£ DATA FLOW:")
print("   Input Text ‚Üí Tokens ‚Üí Embeddings ‚Üí Router ‚Üí Top-k Experts ‚Üí Output")

print("\n2Ô∏è‚É£ EXPERT SELECTION:")
print("   - Default (8 experts): Router picks top-8 highest probability experts")
print("   - With 16 experts: Router picks top-16 highest probability experts")
print("   - With 32 experts: Router picks top-32 highest probability experts")
print("   - With 64 experts: Router uses ALL experts (no selection)")

print("\n3Ô∏è‚É£ PERFORMANCE IMPACT:")
for k in expert_counts:
    avg_speed = np.mean([results[k]['tokens_per_sec'] for results in all_results.values()])
    avg_time = np.mean([results[k]['time'] for results in all_results.values()])
    baseline_speed = np.mean([results[8]['tokens_per_sec'] for results in all_results.values()])
    relative = avg_speed / baseline_speed
    
    print(f"   {k:2d} experts: {avg_speed:5.1f} tok/s | {avg_time:5.2f}s | {relative*100:5.1f}% of baseline")

print("\n4Ô∏è‚É£ KEY INSIGHT:")
print("   ‚úì More experts = More specialized knowledge combined")
print("   ‚úì More experts = Higher computational cost (linear scaling)")
print("   ‚úì Trade-off: Quality vs Speed")

print("\n5Ô∏è‚É£ WHEN TO USE MORE EXPERTS:")
print("   ‚úì Complex, multi-domain questions")
print("   ‚úì When quality matters more than speed")
print("   ‚úì Offline batch processing")
print("   ‚úì Research and analysis")

print("\n6Ô∏è‚É£ WHEN TO USE DEFAULT (8 EXPERTS):")
print("   ‚úì Real-time applications")
print("   ‚úì Simple, well-defined tasks")
print("   ‚úì Resource-constrained environments")
print("   ‚úì Production deployments with high throughput needs")

print("\n" + "="*80)
print("üéâ EXPERIMENT COMPLETE!")
print("="*80)
print("\nYou now have concrete data showing:")
print("  ‚úì How expert routing works")
print("  ‚úì Which experts are selected")
print("  ‚úì Performance differences between configurations")
print("  ‚úì Real output comparisons")
print("\nFeel free to modify the test prompts and re-run to see different results!")
print("="*80)

---

# üöÄ BONUS: Your Custom Test

Try your own prompt and see the difference!

In [None]:
# Customize this!
YOUR_PROMPT = "Write a function to sort an array:"

print("="*80)
print("YOUR CUSTOM TEST")
print("="*80)
print(f"\nPrompt: {YOUR_PROMPT}\n")

for num_experts in [8, 16, 32]:
    print(f"\n{'‚îÄ'*80}")
    print(f"With {num_experts} experts:")
    print(f"{'‚îÄ'*80}")
    
    result = generate_with_experts(
        model, tokenizer, YOUR_PROMPT, 
        num_experts=num_experts, 
        max_new_tokens=80
    )
    
    print(result['text'])
    print(f"\n‚ö° {result['tokens_per_sec']:.1f} tokens/sec")

print("\n" + "="*80)