# 🚀 Pure Transformer - Interactive Model Exploration

This notebook provides interactive visualizations and analysis of the Pure Transformer architecture.

In [None]:
import torch
import torch.nn as nn
from pure_transformer.model import TransformerLM, TransformerConfig
from pure_transformer.configs import TINY_CONFIG, SMALL_CONFIG, MEDIUM_CONFIG
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sns.set_style('whitegrid')
%matplotlib inline

## 1. Model Instantiation

Let's create a model and explore its structure.

In [None]:
# Create a tiny model for exploration
model = TransformerLM(TINY_CONFIG)
print(f"Model created with {model.count_parameters():,} parameters")
print(f"\nModel configuration:")
print(TINY_CONFIG)

## 2. Parameter Distribution

Visualize where parameters are distributed across the model.

In [None]:
# Analyze parameter distribution
param_counts = {}
for name, module in model.named_modules():
    if len(list(module.children())) == 0:  # Leaf module
        params = sum(p.numel() for p in module.parameters())
        if params > 0:
            module_type = module.__class__.__name__
            param_counts[module_type] = param_counts.get(module_type, 0) + params

# Plot
plt.figure(figsize=(12, 6))
colors = plt.cm.viridis(np.linspace(0, 1, len(param_counts)))
plt.bar(param_counts.keys(), param_counts.values(), color=colors)
plt.xticks(rotation=45, ha='right')
plt.ylabel('Number of Parameters')
plt.title('Parameter Distribution by Module Type')
plt.tight_layout()
plt.show()

print("\nParameter breakdown:")
for module_type, count in sorted(param_counts.items(), key=lambda x: x[1], reverse=True):
    print(f"{module_type:30s}: {count:12,} ({100*count/model.count_parameters():.1f}%)")

## 3. Attention Pattern Visualization

Visualize attention patterns for a sample input.

In [None]:
# Create sample input
sample_input = torch.randint(0, 1000, (1, 32))

# Forward pass
with torch.no_grad():
    logits = model(sample_input)

print(f"Input shape: {sample_input.shape}")
print(f"Output shape: {logits.shape}")
print(f"\nTop 5 predicted tokens: {logits[0, -1].topk(5).indices.tolist()}")

## 4. Model Size Comparison

Compare different model configurations.

In [None]:
configs = {
    'TINY': TINY_CONFIG,
    'SMALL': SMALL_CONFIG,
    'MEDIUM': MEDIUM_CONFIG,
}

comparison = {}
for name, config in configs.items():
    temp_model = TransformerLM(config)
    comparison[name] = {
        'Parameters': temp_model.count_parameters(),
        'Layers': config.num_layers,
        'Hidden Size': config.hidden_size,
        'Heads': config.num_heads,
    }

# Plot comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
fig.suptitle('Model Configuration Comparison', fontsize=16)

for idx, (metric, ax) in enumerate(zip(['Parameters', 'Layers', 'Hidden Size', 'Heads'], axes.flat)):
    values = [comparison[name][metric] for name in configs.keys()]
    ax.bar(configs.keys(), values, color=['#4CAF50', '#2196F3', '#FF9800'])
    ax.set_title(metric)
    ax.set_ylabel('Value')
    if metric == 'Parameters':
        ax.set_ylabel('Parameters (Millions)')
        ax.set_yticklabels([f'{int(y/1e6)}M' for y in ax.get_yticks()])

plt.tight_layout()
plt.show()

## 5. Training Metrics Simulation

Simulate and visualize training metrics.

In [None]:
# Simulate training curve
steps = np.arange(0, 10000, 100)
warmup_steps = 1000
max_lr = 3e-4
min_lr = 3e-5

def cosine_schedule(step, warmup_steps, total_steps, max_lr, min_lr):
    if step < warmup_steps:
        return max_lr * step / warmup_steps
    progress = (step - warmup_steps) / (total_steps - warmup_steps)
    return min_lr + 0.5 * (max_lr - min_lr) * (1 + np.cos(np.pi * progress))

lrs = [cosine_schedule(s, warmup_steps, 10000, max_lr, min_lr) for s in steps]

# Plot learning rate schedule
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(steps, lrs, linewidth=2, color='#667eea')
plt.axvline(warmup_steps, color='red', linestyle='--', alpha=0.5, label='End of Warmup')
plt.xlabel('Training Step')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule (Cosine with Warmup)')
plt.legend()
plt.grid(alpha=0.3)

# Simulate loss curve
plt.subplot(1, 2, 2)
loss = 8 - 5 * (1 - np.exp(-steps/2000)) + np.random.normal(0, 0.1, len(steps))
plt.plot(steps, loss, linewidth=2, color='#764ba2', alpha=0.7)
plt.xlabel('Training Step')
plt.ylabel('Loss')
plt.title('Training Loss Curve')
plt.grid(alpha=0.3)

plt.tight_layout()
plt.show()

## 6. Component Interaction Map

Visualize how different components interact.

In [None]:
print("Component Dependencies:")
print("\nTransformerLM depends on:")
print("  ├─ Token Embeddings")
print("  ├─ RoPE Cache (precompute_rope_cache)")
print("  ├─ Transformer Blocks (24×)")
print("  │   ├─ RMSNorm")
print("  │   ├─ Attention")
print("  │   │   ├─ Q/K/V Projections")
print("  │   │   ├─ apply_rotary_emb")
print("  │   │   ├─ Flash Attention / SDPA")
print("  │   │   └─ Output Projection")
print("  │   ├─ RMSNorm")
print("  │   └─ SwiGLUMLP")
print("  │       ├─ Gate Projection")
print("  │       ├─ Up Projection")
print("  │       └─ Down Projection")
print("  ├─ Final RMSNorm")
print("  └─ LM Head")