# Mixture-of-Recursions (MoR) Implementation Demo

This notebook demonstrates the implementation of the Mixture-of-Recursions model as described in the paper:
"Mixture-of-Recursions: Learning Dynamic Recursive Depths for Adaptive Token-Level Computation"

## Key Features:
- Recursive transformer layers with parameter sharing
- Adaptive token-level computation via routing
- Selective attention and KV caching
- Efficiency optimizations

In [None]:
# Import necessary libraries
import sys
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from transformers import AutoTokenizer

# Add src to path
sys.path.append('../src')

from models.mor_model import MixtureOfRecursions, MoRConfig
from utils.config import Config

# Set up plotting
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

## 1. Model Configuration and Initialization

In [None]:
# Create model configuration
config = MoRConfig(
    vocab_size=50257,  # GPT-2 vocabulary size
    hidden_size=512,   # Smaller for demo
    num_attention_heads=8,
    num_hidden_layers=6,
    max_recursion_depth=4,
    min_recursion_depth=1,
    use_kv_sharing=True,  # Enable KV sharing variant
    router_hidden_size=128
)

print("Model Configuration:")
for key, value in config.__dict__.items():
    print(f"  {key}: {value}")

In [None]:
# Initialize model and tokenizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = MixtureOfRecursions(config).to(device)
tokenizer = AutoTokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

# Model statistics
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Statistics:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: ~{total_params * 4 / 1024**2:.1f} MB")

## 2. Forward Pass Demonstration

In [None]:
# Test input
test_text = "The quick brown fox jumps over the lazy dog. This sentence demonstrates adaptive computation."
inputs = tokenizer(test_text, return_tensors="pt", padding=True, truncation=True, max_length=32)
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)

print(f"Input text: {test_text}")
print(f"Input shape: {input_ids.shape}")
print(f"Tokens: {tokenizer.convert_ids_to_tokens(input_ids[0])}")

In [None]:
# Forward pass
model.eval()
with torch.no_grad():
    outputs = model(input_ids=input_ids, attention_mask=attention_mask)

logits = outputs["logits"]
recursion_depths = outputs["recursion_depths"]
router_loss = outputs["router_loss"]

print(f"Output logits shape: {logits.shape}")
print(f"Recursion depths shape: {recursion_depths.shape}")
print(f"Router loss: {router_loss.item():.4f}")

# Show recursion depths per token
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
depths = recursion_depths[0].cpu().numpy()

print("\nToken-level recursion depths:")
for token, depth in zip(tokens, depths):
    if token != tokenizer.pad_token:
        print(f"  {token:15} -> depth {depth}")

## 3. Recursion Depth Analysis

In [None]:
# Analyze recursion depth patterns across different inputs
test_sentences = [
    "Simple sentence.",
    "This is a more complex sentence with multiple clauses and sophisticated vocabulary.",
    "The quick brown fox jumps.",
    "In the realm of artificial intelligence, the development of sophisticated language models represents a significant milestone.",
    "Hello world!"
]

depth_analysis = []

for sentence in test_sentences:
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=64)
    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        depths = outputs["recursion_depths"][0].cpu().numpy()
        
        # Only consider non-padding tokens
        valid_depths = depths[attention_mask[0].cpu().numpy() == 1]
        
        depth_analysis.append({
            'sentence': sentence,
            'length': len(sentence.split()),
            'avg_depth': np.mean(valid_depths),
            'max_depth': np.max(valid_depths),
            'min_depth': np.min(valid_depths),
            'depth_std': np.std(valid_depths),
            'depths': valid_depths
        })

# Display analysis
print("Recursion Depth Analysis:")
print("-" * 80)
for analysis in depth_analysis:
    print(f"Sentence: {analysis['sentence'][:50]}{'...' if len(analysis['sentence']) > 50 else ''}")
    print(f"  Length: {analysis['length']} words")
    print(f"  Avg depth: {analysis['avg_depth']:.2f} ± {analysis['depth_std']:.2f}")
    print(f"  Depth range: {analysis['min_depth']} - {analysis['max_depth']}")
    print()

In [None]:
# Visualize recursion depth patterns
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# 1. Average depth vs sentence length
lengths = [a['length'] for a in depth_analysis]
avg_depths = [a['avg_depth'] for a in depth_analysis]

axes[0, 0].scatter(lengths, avg_depths, s=100, alpha=0.7)
axes[0, 0].set_xlabel('Sentence Length (words)')
axes[0, 0].set_ylabel('Average Recursion Depth')
axes[0, 0].set_title('Recursion Depth vs Sentence Complexity')
axes[0, 0].grid(True, alpha=0.3)

# 2. Depth distribution histogram
all_depths = np.concatenate([a['depths'] for a in depth_analysis])
axes[0, 1].hist(all_depths, bins=range(config.min_recursion_depth, config.max_recursion_depth + 2), 
                alpha=0.7, edgecolor='black')
axes[0, 1].set_xlabel('Recursion Depth')
axes[0, 1].set_ylabel('Frequency')
axes[0, 1].set_title('Overall Depth Distribution')
axes[0, 1].set_xticks(range(config.min_recursion_depth, config.max_recursion_depth + 1))

# 3. Depth variance analysis
depth_stds = [a['depth_std'] for a in depth_analysis]
axes[1, 0].bar(range(len(depth_analysis)), depth_stds, alpha=0.7)
axes[1, 0].set_xlabel('Sentence Index')
axes[1, 0].set_ylabel('Depth Standard Deviation')
axes[1, 0].set_title('Depth Variability per Sentence')
axes[1, 0].set_xticks(range(len(depth_analysis)))

# 4. Token-level depth visualization for longest sentence
longest_idx = np.argmax(lengths)
longest_analysis = depth_analysis[longest_idx]
axes[1, 1].plot(longest_analysis['depths'], 'o-', linewidth=2, markersize=6)
axes[1, 1].set_xlabel('Token Position')
axes[1, 1].set_ylabel('Recursion Depth')
axes[1, 1].set_title(f'Token-level Depths: "{longest_analysis["sentence"][:30]}..."')
axes[1, 1].grid(True, alpha=0.3)
axes[1, 1].set_ylim(config.min_recursion_depth - 0.5, config.max_recursion_depth + 0.5)

plt.tight_layout()
plt.show()

## 4. Efficiency Analysis

In [None]:
# Compare computational efficiency with different recursion depths
import time

def measure_throughput(model, input_ids, attention_mask, num_runs=10):
    """Measure model throughput."""
    # Warmup
    for _ in range(3):
        with torch.no_grad():
            _ = model(input_ids, attention_mask)
    
    # Measure
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    start_time = time.time()
    
    for _ in range(num_runs):
        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
    
    torch.cuda.synchronize() if torch.cuda.is_available() else None
    end_time = time.time()
    
    avg_time = (end_time - start_time) / num_runs
    tokens_per_second = (input_ids.shape[0] * input_ids.shape[1]) / avg_time
    
    return avg_time, tokens_per_second, outputs

# Test with different sequence lengths
sequence_lengths = [64, 128, 256, 512]
throughput_results = []

for seq_len in sequence_lengths:
    if seq_len > config.max_position_embeddings:
        continue
        
    # Create test input
    test_input = torch.randint(0, config.vocab_size, (4, seq_len)).to(device)
    test_mask = torch.ones_like(test_input).to(device)
    
    avg_time, throughput, outputs = measure_throughput(model, test_input, test_mask)
    avg_depth = outputs["recursion_depths"].float().mean().item()
    
    throughput_results.append({
        'seq_len': seq_len,
        'avg_time': avg_time,
        'throughput': throughput,
        'avg_depth': avg_depth
    })
    
    print(f"Seq len {seq_len:3d}: {throughput:6.1f} tokens/sec, {avg_time*1000:5.1f}ms, avg depth: {avg_depth:.2f}")

# Visualize throughput results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

seq_lens = [r['seq_len'] for r in throughput_results]
throughputs = [r['throughput'] for r in throughput_results]
avg_depths = [r['avg_depth'] for r in throughput_results]

ax1.plot(seq_lens, throughputs, 'o-', linewidth=2, markersize=8)
ax1.set_xlabel('Sequence Length')
ax1.set_ylabel('Throughput (tokens/sec)')
ax1.set_title('Model Throughput vs Sequence Length')
ax1.grid(True, alpha=0.3)

ax2.plot(seq_lens, avg_depths, 's-', linewidth=2, markersize=8, color='orange')
ax2.set_xlabel('Sequence Length')
ax2.set_ylabel('Average Recursion Depth')
ax2.set_title('Recursion Depth vs Sequence Length')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 5. Training Demonstration (Mini-batch)

In [None]:
# Demonstrate training on a small batch
model.train()

# Create training data
training_texts = [
    "The cat sat on the mat.",
    "Machine learning is fascinating.",
    "Recursive models can be efficient.",
    "Attention mechanisms are powerful."
]

# Tokenize training data
train_inputs = tokenizer(training_texts, return_tensors="pt", padding=True, truncation=True, max_length=32)
train_input_ids = train_inputs["input_ids"].to(device)
train_attention_mask = train_inputs["attention_mask"].to(device)

# Create labels (shifted input_ids)
labels = train_input_ids.clone()
labels[train_attention_mask == 0] = -100

# Setup optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)

# Training loop
losses = []
router_losses = []
lm_losses = []

print("Training demonstration (10 steps):")
for step in range(10):
    optimizer.zero_grad()
    
    # Forward pass
    outputs = model(train_input_ids, train_attention_mask)
    logits = outputs["logits"]
    router_loss = outputs["router_loss"]
    
    # Compute language modeling loss
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = labels[..., 1:].contiguous()
    
    lm_loss = nn.CrossEntropyLoss()(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1)
    )
    
    # Total loss
    total_loss = lm_loss + router_loss
    
    # Backward pass
    total_loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()
    
    # Store losses
    losses.append(total_loss.item())
    router_losses.append(router_loss.item())
    lm_losses.append(lm_loss.item())
    
    if step % 2 == 0:
        avg_depth = outputs["recursion_depths"].float().mean().item()
        print(f"Step {step:2d}: Total={total_loss.item():.4f}, LM={lm_loss.item():.4f}, Router={router_loss.item():.4f}, Depth={avg_depth:.2f}")

# Plot training losses
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(losses, label='Total Loss', linewidth=2)
ax1.plot(lm_losses, label='LM Loss', linewidth=2)
ax1.plot(router_losses, label='Router Loss', linewidth=2)
ax1.set_xlabel('Training Step')
ax1.set_ylabel('Loss')
ax1.set_title('Training Losses')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Show final recursion depths
final_depths = outputs["recursion_depths"][0].cpu().numpy()
ax2.plot(final_depths, 'o-', linewidth=2, markersize=6)
ax2.set_xlabel('Token Position')
ax2.set_ylabel('Recursion Depth')
ax2.set_title('Final Recursion Depths (First Sample)')
ax2.grid(True, alpha=0.3)
ax2.set_ylim(config.min_recursion_depth - 0.5, config.max_recursion_depth + 0.5)

plt.tight_layout()
plt.show()

## 6. Model Architecture Visualization

In [None]:
# Visualize model architecture and parameter distribution
def analyze_model_parameters(model):
    """Analyze parameter distribution across model components."""
    param_stats = {}
    
    for name, module in model.named_modules():
        if len(list(module.children())) == 0:  # Leaf modules only
            num_params = sum(p.numel() for p in module.parameters())
            if num_params > 0:
                component = name.split('.')[0] if '.' in name else name
                if component not in param_stats:
                    param_stats[component] = 0
                param_stats[component] += num_params
    
    return param_stats

param_stats = analyze_model_parameters(model)

# Create visualization
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

# Parameter distribution pie chart
components = list(param_stats.keys())
param_counts = list(param_stats.values())

ax1.pie(param_counts, labels=components, autopct='%1.1f%%', startangle=90)
ax1.set_title('Parameter Distribution by Component')

# Parameter counts bar chart
ax2.bar(components, [p/1000 for p in param_counts])
ax2.set_ylabel('Parameters (thousands)')
ax2.set_title('Parameter Counts by Component')
ax2.tick_params(axis='x', rotation=45)

plt.tight_layout()
plt.show()

# Print detailed statistics
print("\nDetailed Parameter Analysis:")
print("-" * 40)
total = sum(param_counts)
for component, count in sorted(param_stats.items(), key=lambda x: x[1], reverse=True):
    percentage = 100 * count / total
    print(f"{component:20s}: {count:8,} ({percentage:5.1f}%)")
print("-" * 40)
print(f"{'Total':20s}: {total:8,} (100.0%)")

## 7. Next Steps and Conclusions

This notebook demonstrated the core implementation of the Mixture-of-Recursions model. Key observations:

1. **Adaptive Computation**: The model successfully assigns different recursion depths to different tokens
2. **Parameter Efficiency**: Shared recursive layers reduce parameter count compared to standard transformers
3. **Training Stability**: The router loss helps balance computation across tokens

### For Production Use:
- Scale up model size and training data
- Implement more sophisticated routing strategies
- Add comprehensive evaluation on standard benchmarks
- Optimize for inference speed with custom CUDA kernels

### Potential Improvements:
- Dynamic vocabulary routing
- Hierarchical recursion patterns
- Multi-modal extensions
- Distillation from larger models