# Adaptive Transformer Inference Demo

This notebook demonstrates how to use the Adaptive Transformer with Sentinel Gates for text generation. We'll load a pre-trained model, generate text, and visualize the attention patterns and gate values during generation.

In [None]:
import sys
import os
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from transformers import AutoTokenizer

# Add the parent directory to the path to import modules
sys.path.append('..')

## 1. Load Models

We'll load both the baseline model (GPT-2) and our adaptive model for comparison.

In [None]:
from models.loaders.loader import load_baseline_model, load_adaptive_model
from utils.checkpoint import load_checkpoint

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load tokenizer and baseline model
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
baseline_model = load_baseline_model(model_name, device)

# Load adaptive model
adaptive_model = load_adaptive_model(model_name, baseline_model, device)

# Path to checkpoint (update this with your checkpoint path)
checkpoint_path = "../checkpoints/adaptive_model.pth"

# Try to load checkpoint if it exists
if os.path.exists(checkpoint_path):
    optimizer = torch.optim.AdamW(adaptive_model.parameters())
    head_lr_multipliers = {}
    adaptive_model, _, _, _, _ = load_checkpoint(
        adaptive_model, optimizer, head_lr_multipliers, checkpoint_path, device)
    print(f"Loaded checkpoint from {checkpoint_path}")
else:
    print(f"No checkpoint found at {checkpoint_path}. Using freshly initialized adaptive model.")

## 2. Set Up Generation Wrapper

We'll use the GenerationWrapper class to handle text generation with both models.

In [None]:
from utils.generation_wrapper import GenerationWrapper

# Create generation wrappers
adaptive_wrapper = GenerationWrapper(model=adaptive_model, tokenizer=tokenizer, device=device)
baseline_wrapper = GenerationWrapper(model_name=model_name, device=device)

## 3. Generate Text

Let's generate text with both models and compare the outputs.

In [None]:
# Sample prompts
prompts = [
    "Once upon a time in a land far away",
    "The scientist made a discovery that would change",
    "The future of artificial intelligence depends on"
]

# Generation parameters
generation_params = {
    "max_length": 100,
    "temperature": 0.8,
    "top_k": 50,
    "top_p": 0.95,
    "do_sample": True,
    "num_return_sequences": 1
}

# Generate text with both models
for prompt in prompts:
    print(f"Prompt: {prompt}\n")
    
    print("Adaptive Model:")
    adaptive_outputs = adaptive_wrapper.generate_text(prompt, **generation_params)
    for output in adaptive_outputs:
        print(f"{output}\n")
    
    print("Baseline Model:")
    baseline_outputs = baseline_wrapper.generate_text(prompt, **generation_params)
    for output in baseline_outputs:
        print(f"{output}\n")
    
    print("-" * 80)

## 4. Visualize Attention Patterns

Now, let's generate text while visualizing the attention patterns.

In [None]:
# Generate with attention visualization
prompt = "The key to understanding artificial intelligence is"

# Generate text with attention visualization
adaptive_outputs = adaptive_wrapper.generate_text(
    prompt, 
    **generation_params,
    visualize_attention=True,
    max_length=50  # Shorter for clearer visualization
)

print(f"Generated text:\n{adaptive_outputs[0]}")

# The visualizations are saved as PNG files in the current directory
# Let's display the first few attention maps
import glob
attention_files = sorted(glob.glob("attention_step_*.png"))

# Display the first 3 attention maps or all if fewer than 3
for i, file in enumerate(attention_files[:3]):
    print(f"Attention map {i+1}:")
    from IPython.display import Image, display
    display(Image(file))

## 5. Visualize Gate Values

Let's track how gate values change during generation.

In [None]:
# Generate with gate value tracking
prompt = "The future of language models will be determined by"

# Generate text with gate value tracking
adaptive_outputs = adaptive_wrapper.generate_text(
    prompt, 
    **generation_params,
    track_gate_values=True,
    max_length=50  # Shorter for clearer visualization
)

print(f"Generated text:\n{adaptive_outputs[0]}")

# The gate dynamics visualization is saved as gate_dynamics.png
from IPython.display import Image, display
display(Image("gate_dynamics.png"))

## 6. Analyze Head Utilization

Let's analyze which heads are most active in our model.

In [None]:
def get_head_utilization(model):
    """Calculate and visualize head utilization."""
    # Extract gate values
    gate_values = {}
    for layer_idx, block in enumerate(model.blocks):
        attn_module = block["attn"]
        gate_values[layer_idx] = attn_module.gate.detach().cpu().numpy()
    
    # Convert to matrix for visualization
    num_layers = len(gate_values)
    num_heads = len(gate_values[0])
    gate_matrix = np.zeros((num_layers, num_heads))
    
    for layer_idx, gates in gate_values.items():
        gate_matrix[layer_idx] = gates
    
    # Calculate overall utilization
    active_heads = (gate_matrix > 0.1).sum()
    total_heads = num_layers * num_heads
    utilization = active_heads / total_heads * 100
    
    # Visualization
    plt.figure(figsize=(12, 8))
    sns.heatmap(gate_matrix, cmap="viridis", annot=True, fmt=".2f")
    plt.title(f"Attention Gate Values (Utilization: {utilization:.2f}%)")
    plt.xlabel("Head Index")
    plt.ylabel("Layer Index")
    plt.tight_layout()
    plt.show()
    
    return {
        "active_heads": int(active_heads),
        "total_heads": total_heads,
        "utilization_percent": float(utilization),
        "gate_matrix": gate_matrix
    }

# Analyze head utilization
utilization_stats = get_head_utilization(adaptive_model)

print(f"Active heads: {utilization_stats['active_heads']} / {utilization_stats['total_heads']} ({utilization_stats['utilization_percent']:.2f}%)")

## 7. Compare Generation Performance

Let's benchmark the generation speed of both models.

In [None]:
import time

def benchmark_generation(wrapper, prompt, max_length, num_runs=5):
    """Benchmark generation speed."""
    times = []
    
    for _ in range(num_runs):
        start_time = time.time()
        _ = wrapper.generate_text(
            prompt, 
            max_length=max_length,
            temperature=0.7,
            do_sample=True,
            num_return_sequences=1
        )
        end_time = time.time()
        times.append(end_time - start_time)
    
    avg_time = sum(times) / len(times)
    tokens_per_second = max_length / avg_time
    
    return {
        "avg_time_seconds": avg_time,
        "tokens_per_second": tokens_per_second,
        "times": times
    }

# Set up benchmark parameters
prompt = "The future of artificial intelligence is"
max_length = 200
num_runs = 3

# Benchmark both models
print("Benchmarking adaptive model...")
adaptive_benchmark = benchmark_generation(adaptive_wrapper, prompt, max_length, num_runs)

print("Benchmarking baseline model...")
baseline_benchmark = benchmark_generation(baseline_wrapper, prompt, max_length, num_runs)

# Print results
print("Generation benchmark results:")
print(f"Adaptive model: {adaptive_benchmark['tokens_per_second']:.2f} tokens/sec")
print(f"Baseline model: {baseline_benchmark['tokens_per_second']:.2f} tokens/sec")

# Calculate speedup/slowdown
speedup = (adaptive_benchmark['tokens_per_second'] / baseline_benchmark['tokens_per_second'] - 1) * 100
print(f"{('Speedup' if speedup > 0 else 'Slowdown')}: {abs(speedup):.2f}%")

# Visualize comparison
plt.figure(figsize=(10, 6))
plt.bar(["Adaptive Model", "Baseline Model"], 
        [adaptive_benchmark['tokens_per_second'], baseline_benchmark['tokens_per_second']])
plt.ylabel("Tokens per Second")
plt.title("Generation Speed Comparison")
plt.grid(axis='y')
plt.show()

## 8. Parameter Efficiency

Finally, let's examine the parameter efficiency of our adaptive model.

In [None]:
def count_parameters(model):
    """Count parameters in a model, distinguishing between active and inactive."""
    if not hasattr(model, "blocks"):
        # Standard HuggingFace model
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        return {
            "total_params": int(total_params),
            "trainable_params": int(trainable_params),
            "frozen_params": int(total_params - trainable_params)
        }
    else:
        # Adaptive model
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        # Count active heads and parameters
        active_heads = 0
        active_head_params = 0
        inactive_head_params = 0
        
        for block in model.blocks:
            attn_module = block["attn"]
            
            for head_idx in range(attn_module.num_heads):
                # Count params in this head
                head_params = 0
                for param in list(attn_module.W_q[head_idx].parameters()) + \
                           list(attn_module.W_k[head_idx].parameters()) + \
                           list(attn_module.W_v[head_idx].parameters()) + \
                           list(attn_module.W_o[head_idx].parameters()):
                    head_params += param.numel()
                
                if attn_module.gate[head_idx].item() > 0.1:
                    active_heads += 1
                    active_head_params += head_params
                else:
                    inactive_head_params += head_params
        
        # Count non-head parameters
        non_head_params = total_params - (active_head_params + inactive_head_params)
        
        return {
            "total_params": int(total_params),
            "trainable_params": int(trainable_params),
            "frozen_params": int(total_params - trainable_params),
            "active_heads": int(active_heads),
            "active_head_params": int(active_head_params),
            "inactive_head_params": int(inactive_head_params),
            "non_head_params": int(non_head_params)
        }

# Count parameters for both models
adaptive_params = count_parameters(adaptive_model)
baseline_params = count_parameters(baseline_model)

# Print parameter counts
print("Parameter counts:")
print(f"Baseline model: {baseline_params['total_params']:,} parameters")
print(f"Adaptive model: {adaptive_params['total_params']:,} parameters")
print(f"  Active heads: {adaptive_params.get('active_heads', 'N/A')}")
print(f"  Active head parameters: {adaptive_params.get('active_head_params', 'N/A'):,}")
print(f"  Inactive head parameters: {adaptive_params.get('inactive_head_params', 'N/A'):,}")
print(f"  Non-head parameters: {adaptive_params.get('non_head_params', 'N/A'):,}")

# Calculate efficiency
if 'active_head_params' in adaptive_params:
    effective_params = adaptive_params['active_head_params'] + adaptive_params['non_head_params']
    reduction = (baseline_params['total_params'] - effective_params) / baseline_params['total_params'] * 100
    print(f"\nEffective parameter reduction: {reduction:.2f}%")

# Visualize parameter distribution
if 'active_head_params' in adaptive_params:
    plt.figure(figsize=(10, 6))
    labels = ['Active Head Parameters', 'Inactive Head Parameters', 'Non-Head Parameters']
    sizes = [
        adaptive_params['active_head_params'],
        adaptive_params['inactive_head_params'],
        adaptive_params['non_head_params']
    ]
    colors = ['#2ca02c', '#d62728', '#1f77b4']
    
    plt.pie(sizes, labels=labels, colors=colors, autopct='%1.1f%%', startangle=90)
    plt.axis('equal')
    plt.title('Adaptive Model Parameter Distribution')
    plt.show()

## 9. Conclusion

In this notebook, we've demonstrated how to use the Adaptive Transformer with Sentinel Gates for text generation. We've shown:

1. Text generation with both the adaptive and baseline models
2. Visualization of attention patterns during generation
3. Tracking of gate values during generation
4. Analysis of head utilization and parameter efficiency
5. Comparison of generation performance

The adaptive model uses sentinel gates to dynamically control which attention heads are active during processing, potentially leading to both parameter efficiency and computation efficiency during inference.