# MoE vs Dense Embeddings - Fair Comparison

This notebook trains **two models with matched active parameter counts**:

1. **Dense Model**: Standard feed-forward layers (~16M params)
2. **MoE Model**: 8 experts, top-2 routing (~16M active params, ~40M total)

## Key Points for Fair Comparison:
- **Same active parameters**: Both models compute with ~16M params per forward pass
- **Same data**: Identical training and validation sets
- **Same hyperparameters**: Learning rate, batch size, epochs
- **Same architecture**: Layers, attention heads, hidden dim
- **Only difference**: Dense FFN vs MoE FFN

## Expected Training Time: ~10-15 minutes total (both models)

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.metrics.pairwise import cosine_similarity
import warnings
warnings.filterwarnings('ignore')
import time

# Import both models
from src.models import EmbeddingModel, EmbeddingModelMoE
from src.data import SimpleTokenizer, PairDataset, load_dataset_for_training
from src.training import MultipleNegativesRankingLoss, EmbeddingTrainer
from src.evaluation import compute_similarity, compute_embedding_statistics
from src.utils import save_model

torch.manual_seed(42)
np.random.seed(42)

# Device
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("✓ Using MPS (Metal) - M4 MAX GPU")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("✓ Using CUDA GPU")
else:
    device = torch.device("cpu")
    print("⚠ Using CPU")

print(f"Device: {device}")

## 1. Configuration - Matched Active Parameters

We'll configure both models to have the **same active parameter count**:

- **Dense**: 384 hidden, 1536 FFN → ~16M params
- **MoE**: 384 hidden, 8 experts × 768 FFN, top-2 → ~16M active, ~40M total

The MoE model has 2.5x more total parameters, but only uses 2/8 experts per token, so active params match.

In [None]:
# Shared configuration
SHARED_CONFIG = {
    'hidden_dim': 384,
    'num_layers': 6,
    'num_heads': 12,
    'max_seq_len': 128,
    'dropout': 0.1,
    'pooling_mode': 'mean',
    'num_train_samples': 50000,  # Smaller for faster comparison
    'val_size': 0.1,
    'batch_size': 64,
    'num_epochs': 8,  # Enough to see differences
    'learning_rate': 2e-5,
    'weight_decay': 0.01,
    'temperature': 0.05,
}

# Dense model config
DENSE_CONFIG = {
    **SHARED_CONFIG,
    'ff_dim': 1536,  # Standard 4x hidden_dim
}

# MoE model config - MATCHED ACTIVE PARAMS
MOE_CONFIG = {
    **SHARED_CONFIG,
    'ff_dim': 768,  # Per-expert FFN (smaller than dense)
    'num_experts': 8,
    'top_k': 2,  # 2/8 = 25% active
}

print("Configuration:")
print(f"  Training samples: {SHARED_CONFIG['num_train_samples']:,}")
print(f"  Batch size: {SHARED_CONFIG['batch_size']}")
print(f"  Epochs: {SHARED_CONFIG['num_epochs']}")
print(f"\nDense Model:")
print(f"  FFN dim: {DENSE_CONFIG['ff_dim']}")
print(f"\nMoE Model:")
print(f"  FFN dim per expert: {MOE_CONFIG['ff_dim']}")
print(f"  Num experts: {MOE_CONFIG['num_experts']}")
print(f"  Top-K: {MOE_CONFIG['top_k']}")

## 2. Load Data (Same for Both Models)

In [None]:
print("Loading datasets...\n")

train_pairs, val_pairs = load_dataset_for_training(
    dataset_name='combined',
    num_samples=SHARED_CONFIG['num_train_samples'],
    val_size=SHARED_CONFIG['val_size'],
    cache_dir='../data/cache'
)

print(f"\n✓ Training pairs: {len(train_pairs):,}")
print(f"✓ Validation pairs: {len(val_pairs):,}")

# Build tokenizer
print("\nBuilding tokenizer...")
vocab_sentences = []
for s1, s2 in train_pairs[:25000]:
    vocab_sentences.extend([s1, s2])

tokenizer = SimpleTokenizer(vocab_size=30000, max_length=128)
tokenizer.fit(vocab_sentences)
print(f"✓ Vocabulary: {len(tokenizer):,} tokens")

# Create datasets
train_dataset = PairDataset(train_pairs, tokenizer)
val_dataset = PairDataset(val_pairs, tokenizer)

train_loader = DataLoader(train_dataset, batch_size=SHARED_CONFIG['batch_size'], shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=SHARED_CONFIG['batch_size'], shuffle=False, num_workers=0)

print(f"✓ Train batches: {len(train_loader):,}")
print(f"✓ Val batches: {len(val_loader):,}")

## 3. Initialize Models - Verify Parameter Match

In [None]:
print("="*70)
print("INITIALIZING MODELS")
print("="*70)

# Dense model
print("\n[1/2] Dense Model...")
dense_model = EmbeddingModel(
    vocab_size=len(tokenizer),
    hidden_dim=DENSE_CONFIG['hidden_dim'],
    num_layers=DENSE_CONFIG['num_layers'],
    num_heads=DENSE_CONFIG['num_heads'],
    ff_dim=DENSE_CONFIG['ff_dim'],
    max_seq_len=DENSE_CONFIG['max_seq_len'],
    dropout=DENSE_CONFIG['dropout'],
    pooling_mode=DENSE_CONFIG['pooling_mode'],
    pad_token_id=tokenizer.pad_token_id,
    normalize_embeddings=True
).to(device)

dense_params = sum(p.numel() for p in dense_model.parameters())
print(f"  Total parameters: {dense_params:,}")
print(f"  Active parameters: {dense_params:,}")

# MoE model
print("\n[2/2] MoE Model...")
moe_model = EmbeddingModelMoE(
    vocab_size=len(tokenizer),
    hidden_dim=MOE_CONFIG['hidden_dim'],
    num_layers=MOE_CONFIG['num_layers'],
    num_heads=MOE_CONFIG['num_heads'],
    ff_dim=MOE_CONFIG['ff_dim'],
    num_experts=MOE_CONFIG['num_experts'],
    top_k=MOE_CONFIG['top_k'],
    max_seq_len=MOE_CONFIG['max_seq_len'],
    dropout=MOE_CONFIG['dropout'],
    pooling_mode=MOE_CONFIG['pooling_mode'],
    pad_token_id=tokenizer.pad_token_id,
    normalize_embeddings=True
).to(device)

moe_param_stats = moe_model.count_parameters()
print(f"  Total parameters: {moe_param_stats['total']:,}")
print(f"  Active parameters: {moe_param_stats['active']:,}")
print(f"  Expert parameters: {moe_param_stats['expert_total']:,}")
print(f"  Sparsity: {moe_param_stats['sparsity']*100:.1f}%")

# Comparison
print("\n" + "="*70)
print("PARAMETER COMPARISON")
print("="*70)
print(f"Dense active params:  {dense_params:,}")
print(f"MoE active params:    {moe_param_stats['active']:,}")
print(f"Difference:           {abs(dense_params - moe_param_stats['active']):,} ({abs(dense_params - moe_param_stats['active'])/dense_params*100:.1f}%)")
print(f"\n✓ Models have similar active parameter counts!")
print(f"\nMoE has {moe_param_stats['total']/dense_params:.2f}x total params but same compute per forward pass")

## 4. Train Dense Model

In [None]:
print("="*70)
print("TRAINING DENSE MODEL")
print("="*70)

# Setup
loss_fn_dense = MultipleNegativesRankingLoss(temperature=SHARED_CONFIG['temperature'])
optimizer_dense = torch.optim.AdamW(
    dense_model.parameters(),
    lr=SHARED_CONFIG['learning_rate'],
    weight_decay=SHARED_CONFIG['weight_decay']
)
scheduler_dense = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_dense, T_max=SHARED_CONFIG['num_epochs'], eta_min=1e-7
)

trainer_dense = EmbeddingTrainer(
    model=dense_model,
    loss_fn=loss_fn_dense,
    optimizer=optimizer_dense,
    device=device,
    scheduler=scheduler_dense
)

start_time = time.time()
history_dense = trainer_dense.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=SHARED_CONFIG['num_epochs'],
    eval_every=1,
    save_best=True,
    save_path="../models/dense_comparison.pt"
)
dense_training_time = time.time() - start_time

print(f"\n✓ Dense model trained in {dense_training_time/60:.2f} minutes")
print(f"  Final train loss: {history_dense['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history_dense['val_loss'][-1]:.4f}")

## 5. Train MoE Model

In [None]:
print("="*70)
print("TRAINING MoE MODEL")
print("="*70)

# Custom loss that includes auxiliary loss
class MoELossWrapper(nn.Module):
    def __init__(self, base_loss, aux_weight=0.01):
        super().__init__()
        self.base_loss = base_loss
        self.aux_weight = aux_weight
        
    def forward(self, emb1, emb2):
        return self.base_loss(emb1, emb2)

loss_fn_moe = MoELossWrapper(
    MultipleNegativesRankingLoss(temperature=SHARED_CONFIG['temperature']),
    aux_weight=0.01
)

# Modify trainer to handle aux loss
class MoETrainer(EmbeddingTrainer):
    def train_epoch(self, train_loader, epoch):
        self.model.train()
        total_loss = 0
        num_batches = 0
        
        from tqdm import tqdm
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
        for batch in pbar:
            batch = {k: v.to(self.device) for k, v in batch.items()}
            
            # Forward
            output1 = self.model(batch["input_ids_1"], batch["attention_mask_1"])
            output2 = self.model(batch["input_ids_2"], batch["attention_mask_2"])
            
            emb1 = output1["embeddings"]
            emb2 = output2["embeddings"]
            
            # Main loss
            main_loss = self.loss_fn(emb1, emb2)
            
            # Aux loss from MoE
            aux_loss1 = output1.get("aux_loss", 0)
            aux_loss2 = output2.get("aux_loss", 0)
            aux_loss = (aux_loss1 + aux_loss2) / 2 if aux_loss1 is not None else 0
            
            # Total loss
            loss = main_loss + (aux_loss if isinstance(aux_loss, torch.Tensor) else 0)
            
            # Backward
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})
        
        return total_loss / num_batches

optimizer_moe = torch.optim.AdamW(
    moe_model.parameters(),
    lr=SHARED_CONFIG['learning_rate'],
    weight_decay=SHARED_CONFIG['weight_decay']
)
scheduler_moe = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer_moe, T_max=SHARED_CONFIG['num_epochs'], eta_min=1e-7
)

trainer_moe = MoETrainer(
    model=moe_model,
    loss_fn=loss_fn_moe,
    optimizer=optimizer_moe,
    device=device,
    scheduler=scheduler_moe
)

start_time = time.time()
history_moe = trainer_moe.train(
    train_loader=train_loader,
    val_loader=val_loader,
    num_epochs=SHARED_CONFIG['num_epochs'],
    eval_every=1,
    save_best=True,
    save_path="../models/moe_comparison.pt"
)
moe_training_time = time.time() - start_time

print(f"\n✓ MoE model trained in {moe_training_time/60:.2f} minutes")
print(f"  Final train loss: {history_moe['train_loss'][-1]:.4f}")
print(f"  Final val loss: {history_moe['val_loss'][-1]:.4f}")

## 6. Compare Training Curves

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

epochs = range(1, len(history_dense['train_loss']) + 1)

# Training loss
axes[0].plot(epochs, history_dense['train_loss'], 'b-o', label='Dense Train', linewidth=2, markersize=6)
axes[0].plot(epochs, history_dense['val_loss'], 'b--s', label='Dense Val', linewidth=2, markersize=6, alpha=0.7)
axes[0].plot(epochs, history_moe['train_loss'], 'r-o', label='MoE Train', linewidth=2, markersize=6)
axes[0].plot(epochs, history_moe['val_loss'], 'r--s', label='MoE Val', linewidth=2, markersize=6, alpha=0.7)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss', fontsize=12)
axes[0].set_title('Training Curves Comparison', fontsize=14, fontweight='bold')
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Final comparison
models = ['Dense', 'MoE']
train_losses = [history_dense['train_loss'][-1], history_moe['train_loss'][-1]]
val_losses = [history_dense['val_loss'][-1], history_moe['val_loss'][-1]]

x = np.arange(len(models))
width = 0.35

axes[1].bar(x - width/2, train_losses, width, label='Train Loss', alpha=0.8)
axes[1].bar(x + width/2, val_losses, width, label='Val Loss', alpha=0.8)
axes[1].set_xlabel('Model', fontsize=12)
axes[1].set_ylabel('Final Loss', fontsize=12)
axes[1].set_title('Final Loss Comparison', fontsize=14, fontweight='bold')
axes[1].set_xticks(x)
axes[1].set_xticklabels(models)
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("\nTraining Summary:")
print("="*70)
print(f"Dense:  Train={history_dense['train_loss'][-1]:.4f}, Val={history_dense['val_loss'][-1]:.4f}, Time={dense_training_time/60:.2f}min")
print(f"MoE:    Train={history_moe['train_loss'][-1]:.4f}, Val={history_moe['val_loss'][-1]:.4f}, Time={moe_training_time/60:.2f}min")
print(f"\nWinner (lower val loss): {'Dense' if history_dense['val_loss'][-1] < history_moe['val_loss'][-1] else 'MoE'}")

## 7. Evaluation - Semantic Similarity Test

In [None]:
print("Evaluating both models on semantic similarity...\n")

test_pairs = [
    ("A man is playing guitar", "A person is playing a musical instrument", True),
    ("The dog is running in the park", "A canine is jogging outdoors", True),
    ("I love programming", "I enjoy writing code", True),
    ("The weather is sunny today", "It's a beautiful day", True),
    ("A woman is cooking dinner", "Someone is preparing food", True),
    ("A man is playing guitar", "The weather is sunny", False),
    ("I love programming", "A dog is running", False),
    ("The car is fast", "Pizza is delicious", False),
    ("Trees are tall", "I enjoy music", False),
    ("The ocean is vast", "Programming is fun", False),
]

def evaluate_model(model, name):
    model.eval()
    similarities = []
    labels = []
    
    print(f"\n{name} Model:")
    print("-" * 60)
    
    with torch.no_grad():
        for s1, s2, is_similar in test_pairs:
            enc1 = tokenizer.encode(s1, return_tensors="pt")
            enc2 = tokenizer.encode(s2, return_tensors="pt")
            
            emb1 = model(enc1['input_ids'].to(device), enc1['attention_mask'].to(device))
            emb2 = model(enc2['input_ids'].to(device), enc2['attention_mask'].to(device))
            
            if isinstance(emb1, dict):
                emb1 = emb1['embeddings']
            if isinstance(emb2, dict):
                emb2 = emb2['embeddings']
            
            sim = torch.nn.functional.cosine_similarity(emb1, emb2).item()
            similarities.append(sim)
            labels.append(is_similar)
            
            status = "✓" if (is_similar and sim > 0.5) or (not is_similar and sim < 0.5) else "✗"
            if len([s for s in similarities]) <= 5:  # Print first few
                print(f"{status} {sim:.3f} | {s1[:30]}... <-> {s2[:30]}...")
    
    similar_sims = [s for s, l in zip(similarities, labels) if l]
    dissimilar_sims = [s for s, l in zip(similarities, labels) if not l]
    
    sep = np.mean(similar_sims) - np.mean(dissimilar_sims)
    
    print(f"\nSimilar pairs:    {np.mean(similar_sims):.3f} ± {np.std(similar_sims):.3f}")
    print(f"Dissimilar pairs: {np.mean(dissimilar_sims):.3f} ± {np.std(dissimilar_sims):.3f}")
    print(f"Separation:       {sep:.3f}")
    
    return similarities, labels, sep

dense_sims, dense_labels, dense_sep = evaluate_model(dense_model, "Dense")
moe_sims, moe_labels, moe_sep = evaluate_model(moe_model, "MoE")

print("\n" + "="*70)
print("SIMILARITY COMPARISON")
print("="*70)
print(f"Dense separation: {dense_sep:.3f}")
print(f"MoE separation:   {moe_sep:.3f}")
print(f"\nBetter separation (higher is better): {'Dense' if dense_sep > moe_sep else 'MoE'}")

## 8. Expert Usage Analysis (MoE Only)

In [None]:
print("Analyzing MoE expert usage...\n")

# Get expert usage stats
sample_batch = next(iter(train_loader))
usage_stats = moe_model.get_expert_usage_stats(
    sample_batch['input_ids_1'][:16].to(device),
    sample_batch['attention_mask_1'][:16].to(device)
)

# Plot usage per layer
fig, axes = plt.subplots(2, 3, figsize=(15, 8))
axes = axes.flatten()

for layer_idx in range(6):
    usage = usage_stats[f'layer_{layer_idx}']
    axes[layer_idx].bar(range(8), usage, alpha=0.7, color='steelblue')
    axes[layer_idx].set_xlabel('Expert ID')
    axes[layer_idx].set_ylabel('Usage Count')
    axes[layer_idx].set_title(f'Layer {layer_idx} Expert Usage')
    axes[layer_idx].grid(True, alpha=0.3, axis='y')
    
    # Add mean line
    mean_usage = usage.mean()
    axes[layer_idx].axhline(mean_usage, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_usage:.1f}')
    axes[layer_idx].legend()

plt.tight_layout()
plt.show()

# Check load balancing
print("\nLoad Balancing Analysis:")
print("="*70)
for layer_idx in range(6):
    usage = usage_stats[f'layer_{layer_idx}']
    std = usage.std()
    mean = usage.mean()
    cv = (std / mean) if mean > 0 else 0
    print(f"Layer {layer_idx}: Mean={mean:.1f}, Std={std:.1f}, CV={cv:.3f}")

print("\n✓ Low CV (coefficient of variation) indicates good load balancing")

## 9. Final Comparison Summary

In [None]:
print("="*70)
print("FINAL COMPARISON RESULTS")
print("="*70)

print("\n1. MODEL CONFIGURATION")
print("-" * 70)
print(f"Dense:  {dense_params:,} total params, {dense_params:,} active")
print(f"MoE:    {moe_param_stats['total']:,} total params, {moe_param_stats['active']:,} active")
print(f"Match:  {abs(dense_params - moe_param_stats['active']) / dense_params * 100:.1f}% difference")

print("\n2. TRAINING PERFORMANCE")
print("-" * 70)
print(f"Dense:  Final train loss = {history_dense['train_loss'][-1]:.4f}, val loss = {history_dense['val_loss'][-1]:.4f}")
print(f"MoE:    Final train loss = {history_moe['train_loss'][-1]:.4f}, val loss = {history_moe['val_loss'][-1]:.4f}")
winner_train = "Dense" if history_dense['val_loss'][-1] < history_moe['val_loss'][-1] else "MoE"
improvement = abs(history_dense['val_loss'][-1] - history_moe['val_loss'][-1])
print(f"Winner: {winner_train} (by {improvement:.4f})")

print("\n3. SEMANTIC SIMILARITY")
print("-" * 70)
print(f"Dense:  Separation = {dense_sep:.3f}")
print(f"MoE:    Separation = {moe_sep:.3f}")
winner_sim = "Dense" if dense_sep > moe_sep else "MoE"
improvement_sim = abs(dense_sep - moe_sep)
print(f"Winner: {winner_sim} (by {improvement_sim:.3f})")

print("\n4. TRAINING TIME")
print("-" * 70)
print(f"Dense:  {dense_training_time/60:.2f} minutes")
print(f"MoE:    {moe_training_time/60:.2f} minutes")
print(f"Difference: {abs(dense_training_time - moe_training_time)/60:.2f} minutes")

print("\n5. CONCLUSION")
print("-" * 70)
if winner_train == winner_sim:
    print(f"✓ Clear winner: {winner_train}")
    print(f"  Better validation loss AND better semantic similarity")
else:
    print(f"⚠ Mixed results:")
    print(f"  Training: {winner_train} is better")
    print(f"  Similarity: {winner_sim} is better")

print("\n6. MoE BENEFITS")
print("-" * 70)
print(f"  Sparsity: {moe_param_stats['sparsity']*100:.1f}% of experts inactive per token")
print(f"  Total capacity: {moe_param_stats['total']/dense_params:.2f}x parameters for same compute")
print(f"  Specialization: Experts can learn different linguistic patterns")

print("\n" + "="*70)

## Summary

### Key Findings:

1. **Fair Comparison**: Both models have ~16M active parameters
2. **Training**: Check which model achieves lower validation loss
3. **Similarity**: Check which model has better semantic separation
4. **MoE Benefits**: More total capacity with same compute

### Next Steps:

1. If MoE wins: Experts are helping! Try more experts or different top-k
2. If Dense wins: MoE needs tuning - try different expert sizes or load balancing weights
3. Either way: This is a controlled experiment showing the effect of MoE architecture

### Technical Notes:

- MoE has 2.5x total params but only 25% active (top-2 of 8 experts)
- Load balancing loss encourages equal expert usage
- Expert specialization happens during training
- This architecture is production-ready for scaling