# Lab 2.4.5: Architecture Comparison - SOLUTIONS

Complete solutions for the architecture comparison exercises.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from dataclasses import dataclass
from typing import Dict, List

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Exercise Solution: Add Another Model

In [None]:
# Extended model configuration
MODELS_TO_COMPARE = {
    'mamba-2.8b': {
        'name': 'state-spaces/mamba-2.8b-hf',
        'architecture': 'Mamba',  # State Space Models
        'active_ratio': 1.0,
    },
    'phi-2': {
        'name': 'microsoft/phi-2',
        'architecture': 'Transformer',
        'active_ratio': 1.0,
    },
    'qwen-moe': {
        'name': 'Qwen/Qwen1.5-MoE-A2.7B',
        'architecture': 'MoE',
        'active_ratio': 0.19,  # 2.7B / 14.3B
    },
    # Add TinyLlama for smaller comparison
    'tinyllama': {
        'name': 'TinyLlama/TinyLlama-1.1B-Chat-v1.0',
        'architecture': 'Transformer',
        'active_ratio': 1.0,
    },
}

print('Extended model comparison:')
for name, config in MODELS_TO_COMPARE.items():
    print(f'  {name}: {config["architecture"]} ({config["name"]})')

In [None]:
# Simulated benchmark results for visualization
# (Replace with actual benchmarks when running with models)

@dataclass
class BenchmarkResult:
    model_name: str
    architecture: str
    total_params_b: float
    active_params_b: float
    memory_gb: float
    perplexity: float
    tokens_per_second: Dict[int, float]
    peak_memory: Dict[int, float]

# Simulated results
results = [
    BenchmarkResult(
        model_name='mamba-2.8b',
        architecture='Mamba',  # State Space Models
        total_params_b=2.8,
        active_params_b=2.8,
        memory_gb=5.6,
        perplexity=8.2,
        tokens_per_second={1024: 45, 4096: 42, 16384: 38},
        peak_memory={1024: 6.0, 4096: 6.2, 16384: 6.5},
    ),
    BenchmarkResult(
        model_name='phi-2',
        architecture='Transformer',
        total_params_b=2.7,
        active_params_b=2.7,
        memory_gb=5.4,
        perplexity=6.8,
        tokens_per_second={1024: 55, 4096: 35, 16384: 15},
        peak_memory={1024: 6.0, 4096: 12.0, 16384: 28.0},
    ),
    BenchmarkResult(
        model_name='qwen-moe',
        architecture='MoE',
        total_params_b=14.3,
        active_params_b=2.7,
        memory_gb=28.6,
        perplexity=7.5,
        tokens_per_second={1024: 40, 4096: 32, 16384: 18},
        peak_memory={1024: 30.0, 4096: 36.0, 16384: 48.0},
    ),
    BenchmarkResult(
        model_name='tinyllama',
        architecture='Transformer',
        total_params_b=1.1,
        active_params_b=1.1,
        memory_gb=2.2,
        perplexity=9.5,
        tokens_per_second={1024: 80, 4096: 60, 16384: 25},
        peak_memory={1024: 2.5, 4096: 5.0, 16384: 12.0},
    ),
]

In [None]:
# Visualize extended comparison
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Architecture colors: Mamba (State Space Models), Transformer, MoE
colors = {'Mamba': '#27AE60', 'Transformer': '#E74C3C', 'MoE': '#3498DB'}
contexts = [1024, 4096, 16384]

# Speed comparison
ax = axes[0, 0]
for r in results:
    ctx = sorted(r.tokens_per_second.keys())
    speeds = [r.tokens_per_second[c] for c in ctx]
    ax.plot(ctx, speeds, 'o-', label=r.model_name, 
           color=colors.get(r.architecture, '#9B59B6'), linewidth=2)
ax.set_xlabel('Context Length')
ax.set_ylabel('Tokens/Second')
ax.set_title('Generation Speed', fontweight='bold')
ax.legend()
ax.set_xscale('log', base=2)
ax.grid(True, alpha=0.3)

# Memory comparison
ax = axes[0, 1]
for r in results:
    ctx = sorted(r.peak_memory.keys())
    mems = [r.peak_memory[c] for c in ctx]
    ax.plot(ctx, mems, 's-', label=r.model_name,
           color=colors.get(r.architecture, '#9B59B6'), linewidth=2)
ax.axhline(y=128, color='gray', linestyle='--', label='DGX Spark')
ax.set_xlabel('Context Length')
ax.set_ylabel('Peak Memory (GB)')
ax.set_title('Memory Usage', fontweight='bold')
ax.legend()
ax.set_xscale('log', base=2)
ax.grid(True, alpha=0.3)

# Parameter efficiency
ax = axes[1, 0]
models = [r.model_name for r in results]
x = np.arange(len(models))
width = 0.35
bars1 = ax.bar(x - width/2, [r.total_params_b for r in results], width, label='Total', color='#3498DB')
bars2 = ax.bar(x + width/2, [r.active_params_b for r in results], width, label='Active', color='#27AE60')
ax.set_ylabel('Parameters (Billions)')
ax.set_title('Total vs Active Parameters', fontweight='bold')
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=15)
ax.legend()

# Perplexity
ax = axes[1, 1]
ax.bar(models, [r.perplexity for r in results], 
       color=[colors.get(r.architecture, '#9B59B6') for r in results])
ax.set_ylabel('Perplexity (lower = better)')
ax.set_title('Model Quality', fontweight='bold')
ax.set_xticklabels(models, rotation=15)

plt.tight_layout()
plt.show()

# Summary table
print('\n Summary Table:')
print('=' * 80)
for r in results:
    print(f'{r.model_name:<15} | {r.architecture:<12} | '
          f'{r.total_params_b:.1f}B total | {r.active_params_b:.1f}B active | '
          f'PPL: {r.perplexity:.1f}')