# Arithmetic Intensity & The Roofline Model

---

## What You'll Learn

1. **Arithmetic Intensity** - the ops:byte ratio that determines if you're compute-bound or memory-bound
2. **The Roofline Model** - a visual framework for GPU performance analysis
3. **Where prefill and decode land** on the roofline (and why they're fundamentally different)
4. **Detailed attention arithmetic intensity calculation** step by step
5. **How batching shifts you from memory-bound to compute-bound**
6. **Interactive exploration** of batch size and sequence length effects

---

### The Core Problem

GPUs have two fundamental resources:
- **Compute**: How many operations per second (FLOPS)
- **Memory bandwidth**: How many bytes per second can move between HBM and compute cores

The ratio of these defines the **ops:byte ratio** (also called the "machine balance point"):

$$\text{ops:byte ratio} = \frac{\text{Peak FLOPS}}{\text{Peak Bandwidth}}$$

If your workload needs **more ops per byte** than this ratio, you're **compute-bound** (GPU cores are the bottleneck).  
If your workload needs **fewer ops per byte**, you're **memory-bound** (memory bandwidth is the bottleneck).

In [None]:
# Install dependencies
!pip install numpy matplotlib plotly ipywidgets -q

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import plotly.graph_objects as go
from plotly.subplots import make_subplots

print("All imports ready!")

## Part 1: GPU Specs - The Machine Balance Point

Let's define the specs for modern GPUs. The **ops:byte ratio** tells us the crossover point between compute-bound and memory-bound workloads.

In [None]:
# GPU specifications
gpus = {
    'A100 (FP16)': {
        'flops': 312e12,       # 312 TFLOPS
        'bandwidth': 2.0e12,   # 2.0 TB/s
        'memory': 80,          # GB HBM3
        'color': '#3498db'
    },
    'H100 (FP16)': {
        'flops': 989.5e12,     # ~990 TFLOPS
        'bandwidth': 3.35e12,  # 3.35 TB/s
        'memory': 80,          # GB HBM3
        'color': '#2ecc71'
    },
    'H100 (FP8)': {
        'flops': 1979e12,      # ~1979 TFLOPS
        'bandwidth': 3.35e12,  # 3.35 TB/s
        'memory': 80,          # GB HBM3
        'color': '#27ae60'
    },
    'B200 (FP16)': {
        'flops': 2250e12,      # 2.25 PFLOPS
        'bandwidth': 8.0e12,   # 8.0 TB/s
        'memory': 192,         # GB HBM3e
        'color': '#e74c3c'
    },
    'B200 (FP8)': {
        'flops': 4500e12,      # 4.5 PFLOPS
        'bandwidth': 8.0e12,   # 8.0 TB/s
        'memory': 192,         # GB HBM3e
        'color': '#c0392b'
    },
}

print(f"{'GPU':<20} {'FLOPS':>12} {'Bandwidth':>12} {'Ops:Byte Ratio':>15}")
print("=" * 65)
for name, specs in gpus.items():
    ratio = specs['flops'] / specs['bandwidth']
    specs['ops_byte_ratio'] = ratio
    print(f"{name:<20} {specs['flops']/1e12:>9.1f} TF  {specs['bandwidth']/1e12:>8.1f} TB/s  {ratio:>12.1f}")

print("\n=> The ops:byte ratio is the 'knee' of the roofline.")
print("   H100 FP16: ~295 ops per byte read from memory")
print("   B200 FP16: ~281 ops per byte")

## Part 2: Arithmetic Intensity of Common Operations

**Arithmetic Intensity** = (number of arithmetic operations) / (number of bytes transferred)

Let's calculate this for common operations in inference.

In [None]:
def matmul_arithmetic_intensity(M, N, K, bytes_per_element=2):
    """
    Calculate arithmetic intensity of matrix multiplication C = A @ B
    A: (M, K), B: (K, N), C: (M, N)
    
    FLOPs: 2*M*N*K (multiply + add for each output element)
    Bytes: (M*K + K*N + M*N) * bytes_per_element (read A, read B, write C)
    """
    flops = 2 * M * N * K
    bytes_transferred = (M * K + K * N + M * N) * bytes_per_element
    ai = flops / bytes_transferred
    return ai, flops, bytes_transferred


# Example: different matrix sizes
print(f"{'Operation':<35} {'FLOPs':>12} {'Bytes':>12} {'AI':>8} {'Bound':>12}")
print("=" * 85)

h100_ratio = gpus['H100 (FP16)']['ops_byte_ratio']

cases = [
    ('Vector dot (1x1024 @ 1024x1)', 1, 1, 1024),
    ('MV multiply (1x4096 @ 4096x4096)', 1, 4096, 4096),
    ('Small matmul (32x4096 @ 4096x4096)', 32, 4096, 4096),
    ('Medium matmul (256x4096 @ 4096x4096)', 256, 4096, 4096),
    ('Large matmul (1024x4096 @ 4096x4096)', 1024, 4096, 4096),
    ('Huge matmul (4096x4096 @ 4096x4096)', 4096, 4096, 4096),
]

for name, M, N, K in cases:
    ai, flops, bytes_t = matmul_arithmetic_intensity(M, N, K)
    bound = 'COMPUTE' if ai > h100_ratio else 'MEMORY'
    print(f"{name:<35} {flops:>12,} {bytes_t:>12,} {ai:>8.1f} {bound:>12}")

print(f"\nH100 FP16 ops:byte ratio = {h100_ratio:.1f}")
print("Operations with AI < ratio are MEMORY-BOUND")
print("Operations with AI > ratio are COMPUTE-BOUND")

### Key Insight: Batch Size = 1 is Memory-Bound!

When batch size is 1 (the common case for single-user inference), the matrix multiply becomes a **matrix-vector multiply**, which has terrible arithmetic intensity. This is exactly why **decode is memory-bound**.

In [None]:
# Show how batch size affects arithmetic intensity
d_model = 4096
batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]

ais = []
for bs in batch_sizes:
    ai, _, _ = matmul_arithmetic_intensity(bs, d_model, d_model)
    ais.append(ai)

fig, ax = plt.subplots(figsize=(12, 6))
ax.plot(batch_sizes, ais, 'bo-', linewidth=2, markersize=8)

# Add H100 and B200 lines
ax.axhline(y=gpus['H100 (FP16)']['ops_byte_ratio'], color='green', 
           linestyle='--', linewidth=2, label=f"H100 FP16 ({gpus['H100 (FP16)']['ops_byte_ratio']:.0f})")
ax.axhline(y=gpus['B200 (FP16)']['ops_byte_ratio'], color='red', 
           linestyle='--', linewidth=2, label=f"B200 FP16 ({gpus['B200 (FP16)']['ops_byte_ratio']:.0f})")

# Shade regions
ax.fill_between(batch_sizes, 0, gpus['H100 (FP16)']['ops_byte_ratio'], 
                alpha=0.1, color='red', label='Memory-bound region (H100)')

ax.set_xscale('log', base=2)
ax.set_xlabel('Batch Size', fontsize=13)
ax.set_ylabel('Arithmetic Intensity (ops/byte)', fontsize=13)
ax.set_title('Arithmetic Intensity vs Batch Size\n(Linear layer: batch x 4096 @ 4096 x 4096)',
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nNotice: You need batch size ~128+ to become compute-bound on H100!")
print("Single-request decode (batch=1) is deeply memory-bound.")

## Part 3: The Roofline Model

The roofline model plots **achievable performance** (FLOPS) vs **arithmetic intensity** (ops/byte).

Two constraints form the "roofline":
1. **Memory bandwidth ceiling**: Performance = bandwidth x arithmetic_intensity
2. **Compute ceiling**: Performance = peak_FLOPS

Your actual performance is limited by whichever ceiling you hit first.

In [None]:
def plot_roofline(gpu_name, gpu_specs, workloads=None):
    """Plot the roofline model for a GPU."""
    peak_flops = gpu_specs['flops']
    bandwidth = gpu_specs['bandwidth']
    ops_byte = gpu_specs['ops_byte_ratio']
    
    # X axis: arithmetic intensity
    ai_range = np.logspace(-1, 4, 1000)
    
    # Memory-bound region: perf = bandwidth * AI
    mem_bound = bandwidth * ai_range
    
    # Compute-bound region: perf = peak_flops
    compute_bound = np.full_like(ai_range, peak_flops)
    
    # Roofline = min of both
    roofline = np.minimum(mem_bound, compute_bound)
    
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Plot roofline
    ax.loglog(ai_range, roofline, 'b-', linewidth=3, label='Roofline')
    
    # Shade regions
    ax.fill_between(ai_range, roofline, 1e8, where=(ai_range < ops_byte),
                    alpha=0.15, color='red')
    ax.fill_between(ai_range, roofline, 1e8, where=(ai_range >= ops_byte),
                    alpha=0.15, color='green')
    
    # Mark the knee point
    ax.axvline(x=ops_byte, color='gray', linestyle=':', linewidth=1.5)
    ax.annotate(f'Ops:Byte = {ops_byte:.0f}', xy=(ops_byte, peak_flops * 0.3),
                fontsize=11, ha='center', fontweight='bold',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8))
    
    # Labels
    ax.text(1, peak_flops * 0.05, 'MEMORY\nBOUND', fontsize=14, 
            color='red', fontweight='bold', ha='center', alpha=0.7)
    ax.text(ops_byte * 5, peak_flops * 0.05, 'COMPUTE\nBOUND', fontsize=14,
            color='green', fontweight='bold', ha='center', alpha=0.7)
    
    # Plot workloads
    if workloads:
        for name, ai, achieved_flops, marker, color in workloads:
            ax.plot(ai, achieved_flops, marker, markersize=15, color=color, 
                    label=name, markeredgecolor='black', markeredgewidth=1.5, zorder=5)
    
    ax.set_xlabel('Arithmetic Intensity (FLOPS/Byte)', fontsize=13)
    ax.set_ylabel('Achievable Performance (FLOPS)', fontsize=13)
    ax.set_title(f'Roofline Model: {gpu_name}\n'
                 f'Peak: {peak_flops/1e12:.0f} TFLOPS | BW: {bandwidth/1e12:.1f} TB/s',
                 fontsize=14, fontweight='bold')
    ax.legend(fontsize=11, loc='lower right')
    ax.grid(True, alpha=0.3, which='both')
    ax.set_xlim(0.1, 5000)
    ax.set_ylim(1e10, peak_flops * 3)
    
    plt.tight_layout()
    return fig, ax


# Plot H100 roofline
fig, ax = plot_roofline('H100 (FP16)', gpus['H100 (FP16)'])
plt.show()

## Part 4: Where Prefill and Decode Land

This is the crucial insight for inference engineering:

- **Prefill** (processing the entire prompt): Large matrix multiplies with many tokens -> **compute-bound**
- **Decode** (generating one token at a time): Matrix-vector products -> **memory-bound**

In [None]:
# Calculate arithmetic intensity for prefill and decode
d_model = 4096
n_layers = 32
seq_len_prefill = 1024

# Prefill: processing seq_len tokens through a linear layer
# Matrix multiply: (seq_len, d_model) @ (d_model, d_model)
ai_prefill, flops_prefill, bytes_prefill = matmul_arithmetic_intensity(
    seq_len_prefill, d_model, d_model)

# Decode: processing 1 token through a linear layer
# Matrix-vector: (1, d_model) @ (d_model, d_model)
ai_decode, flops_decode, bytes_decode = matmul_arithmetic_intensity(
    1, d_model, d_model)

print(f"Prefill (seq_len={seq_len_prefill}):")
print(f"  AI = {ai_prefill:.1f} ops/byte")
print(f"  FLOPs = {flops_prefill:,.0f}")
print(f"  Bytes = {bytes_prefill:,.0f}")
print(f"  -> {'COMPUTE' if ai_prefill > h100_ratio else 'MEMORY'}-bound on H100")

print(f"\nDecode (batch=1):")
print(f"  AI = {ai_decode:.1f} ops/byte")
print(f"  FLOPs = {flops_decode:,.0f}")
print(f"  Bytes = {bytes_decode:,.0f}")
print(f"  -> {'COMPUTE' if ai_decode > h100_ratio else 'MEMORY'}-bound on H100")

print(f"\nH100 ops:byte ratio = {h100_ratio:.1f}")
print(f"Prefill AI ({ai_prefill:.1f}) > {h100_ratio:.1f} -> Compute-bound")
print(f"Decode AI ({ai_decode:.1f}) < {h100_ratio:.1f} -> Memory-bound")

In [None]:
# Plot roofline with prefill and decode marked
h100_specs = gpus['H100 (FP16)']
bw = h100_specs['bandwidth']
peak = h100_specs['flops']

# Estimated achievable performance (assuming good utilization)
prefill_perf = min(bw * ai_prefill, peak) * 0.7  # 70% utilization
decode_perf = min(bw * ai_decode, peak) * 0.6     # 60% utilization

workloads = [
    ('Prefill (1024 tokens)', ai_prefill, prefill_perf, 's', '#2ecc71'),
    ('Decode (batch=1)', ai_decode, decode_perf, 'o', '#e74c3c'),
    ('Decode (batch=32)', 32 * ai_decode / 2, decode_perf * 8, '^', '#f39c12'),
    ('Decode (batch=128)', 128 * ai_decode / 2, min(decode_perf * 25, peak * 0.65), 'D', '#3498db'),
]

fig, ax = plot_roofline('H100 (FP16)', h100_specs, workloads)
plt.show()

print("\nKey insight:")
print("- Prefill is compute-bound: GPU cores are the bottleneck")
print("- Decode (batch=1) is memory-bound: HBM bandwidth is the bottleneck")
print("- Batching decode moves it toward compute-bound")

## Part 5: Detailed Attention Arithmetic Intensity Calculation

Let's walk through the attention calculation step by step, exactly as described in the inference engineering analysis.

For a single attention head with:
- $d = 128$ (head dimension)
- $N = 4096$ (sequence length)

The three steps of attention:
1. $S = QK^T$ (score computation)
2. $P = \text{softmax}(S)$ (probability computation)
3. $O = PV$ (output computation)

In [None]:
def attention_arithmetic_intensity(d, N, bytes_per_element=2):
    """
    Calculate arithmetic intensity for a single attention head.
    
    Args:
        d: head dimension (typically 128)
        N: sequence length
        bytes_per_element: 2 for FP16, 1 for FP8
    
    Returns:
        dict with detailed breakdown
    """
    bpe = bytes_per_element
    
    # ==========================================
    # Step 1: S = Q @ K^T
    # Q: (N, d), K^T: (d, N) -> S: (N, N)
    # ==========================================
    s1_read_bytes = (N * d + N * d) * bpe     # Read Q and K
    s1_write_bytes = (N * N) * bpe             # Write S
    s1_flops = 2 * N * N * d                   # Matrix multiply
    s1_total_bytes = s1_read_bytes + s1_write_bytes
    
    # ==========================================
    # Step 2: P = softmax(S)
    # Read S: (N, N), Write P: (N, N)
    # Operations: exp, sum, divide for each row
    # ==========================================
    s2_read_bytes = (N * N) * bpe              # Read S
    s2_write_bytes = (N * N) * bpe             # Write P
    s2_flops = 5 * N * N                       # ~5 ops per element (exp, sum, div, max, sub)
    s2_total_bytes = s2_read_bytes + s2_write_bytes
    
    # ==========================================
    # Step 3: O = P @ V
    # P: (N, N), V: (N, d) -> O: (N, d)
    # ==========================================
    s3_read_bytes = (N * N + N * d) * bpe      # Read P and V
    s3_write_bytes = (N * d) * bpe             # Write O
    s3_flops = 2 * N * N * d                   # Matrix multiply
    s3_total_bytes = s3_read_bytes + s3_write_bytes
    
    # ==========================================
    # Totals
    # ==========================================
    total_flops = s1_flops + s2_flops + s3_flops
    total_bytes = s1_total_bytes + s2_total_bytes + s3_total_bytes
    total_ai = total_flops / total_bytes
    
    return {
        'd': d, 'N': N, 'bpe': bpe,
        'step1': {'flops': s1_flops, 'bytes': s1_total_bytes, 'ai': s1_flops/s1_total_bytes},
        'step2': {'flops': s2_flops, 'bytes': s2_total_bytes, 'ai': s2_flops/s2_total_bytes},
        'step3': {'flops': s3_flops, 'bytes': s3_total_bytes, 'ai': s3_flops/s3_total_bytes},
        'total_flops': total_flops,
        'total_bytes': total_bytes,
        'total_ai': total_ai
    }


# Calculate for d=128, N=4096
result = attention_arithmetic_intensity(d=128, N=4096, bytes_per_element=2)

print(f"Attention Arithmetic Intensity Calculation")
print(f"d = {result['d']}, N = {result['N']}, FP16 ({result['bpe']} bytes/element)")
print("=" * 70)

for step_name, step_label in [('step1', 'S = Q @ K^T'), ('step2', 'P = softmax(S)'), ('step3', 'O = P @ V')]:
    s = result[step_name]
    print(f"\n{step_label}:")
    print(f"  FLOPs:  {s['flops']:>15,}")
    print(f"  Bytes:  {s['bytes']:>15,}")
    print(f"  AI:     {s['ai']:>15.1f} ops/byte")

print(f"\n{'='*70}")
print(f"TOTAL:")
print(f"  FLOPs:  {result['total_flops']:>15,}")
print(f"  Bytes:  {result['total_bytes']:>15,}")
print(f"  AI:     {result['total_ai']:>15.1f} ops/byte")
print(f"\nH100 FP16 ops:byte ratio = {h100_ratio:.1f}")
print(f"Attention AI ({result['total_ai']:.1f}) < H100 ratio ({h100_ratio:.1f})")
print(f"=> Standard attention is MEMORY-BOUND on H100!")

In [None]:
# Visualize the breakdown
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

steps = ['S = QK^T', 'P = softmax(S)', 'O = PV']
step_keys = ['step1', 'step2', 'step3']
colors = ['#3498db', '#e74c3c', '#2ecc71']

# FLOPs breakdown
flops_vals = [result[k]['flops'] for k in step_keys]
axes[0].bar(steps, [f/1e9 for f in flops_vals], color=colors, edgecolor='black')
axes[0].set_ylabel('GFLOPs', fontsize=12)
axes[0].set_title('FLOPs per Step', fontsize=13, fontweight='bold')
for i, v in enumerate(flops_vals):
    axes[0].text(i, v/1e9 + 0.1, f'{v/1e9:.1f}G', ha='center', fontweight='bold')

# Bytes breakdown
bytes_vals = [result[k]['bytes'] for k in step_keys]
axes[1].bar(steps, [b/1e6 for b in bytes_vals], color=colors, edgecolor='black')
axes[1].set_ylabel('MB', fontsize=12)
axes[1].set_title('Bytes Transferred per Step', fontsize=13, fontweight='bold')
for i, v in enumerate(bytes_vals):
    axes[1].text(i, v/1e6 + 0.5, f'{v/1e6:.1f}M', ha='center', fontweight='bold')

# AI breakdown
ai_vals = [result[k]['ai'] for k in step_keys]
bars = axes[2].bar(steps, ai_vals, color=colors, edgecolor='black')
axes[2].axhline(y=h100_ratio, color='red', linestyle='--', linewidth=2, 
                label=f'H100 ratio ({h100_ratio:.0f})')
axes[2].set_ylabel('Arithmetic Intensity', fontsize=12)
axes[2].set_title('AI per Step', fontsize=13, fontweight='bold')
axes[2].legend(fontsize=10)
for i, v in enumerate(ai_vals):
    axes[2].text(i, v + 2, f'{v:.1f}', ha='center', fontweight='bold')

plt.suptitle(f'Attention Breakdown (d={result["d"]}, N={result["N"]})',
             fontsize=15, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nSoftmax (Step 2) has the LOWEST AI - it's the most memory-bound step.")
print("This is why FlashAttention fuses these operations to avoid writing S and P to HBM!")

## Part 6: How Sequence Length Affects Arithmetic Intensity

As sequence length grows, the N x N attention matrix dominates, changing the compute vs memory balance.

In [None]:
# Vary sequence length
d = 128
seq_lengths = [256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536]

ais_by_seq = []
for N in seq_lengths:
    r = attention_arithmetic_intensity(d, N)
    ais_by_seq.append(r['total_ai'])

fig, ax = plt.subplots(figsize=(12, 6))
ax.semilogx(seq_lengths, ais_by_seq, 'bo-', linewidth=2, markersize=10, label='Attention AI')

# GPU lines
ax.axhline(y=gpus['H100 (FP16)']['ops_byte_ratio'], color='green', linestyle='--', 
           linewidth=2, label=f"H100 FP16 ({gpus['H100 (FP16)']['ops_byte_ratio']:.0f})")
ax.axhline(y=gpus['A100 (FP16)']['ops_byte_ratio'], color='blue', linestyle='--', 
           linewidth=2, label=f"A100 FP16 ({gpus['A100 (FP16)']['ops_byte_ratio']:.0f})")

ax.set_xlabel('Sequence Length (N)', fontsize=13)
ax.set_ylabel('Arithmetic Intensity (ops/byte)', fontsize=13)
ax.set_title(f'Attention AI vs Sequence Length (d={d})\nLonger sequences -> higher AI -> closer to compute-bound',
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

# Annotate values
for i, (n, ai) in enumerate(zip(seq_lengths, ais_by_seq)):
    if i % 2 == 0:
        ax.annotate(f'{ai:.0f}', (n, ai), textcoords="offset points", 
                    xytext=(0, 12), ha='center', fontsize=9)

plt.tight_layout()
plt.show()

print(f"\nAs N grows, AI approaches d/2 = {d/2} (asymptotically)")
print("Standard attention stays memory-bound on H100 for typical sequence lengths.")
print("This is why FlashAttention is so important!")

## Part 7: How Batching Changes Everything

Batching multiple requests together increases arithmetic intensity because you **reuse the same model weights** for multiple inputs.

In [None]:
def decode_step_ai(batch_size, d_model, n_layers, seq_len_kv, n_heads, d_head, bpe=2):
    """
    Calculate arithmetic intensity for a full decode step.
    
    During decode, for each token:
    - Linear projections: (batch, 1, d_model) @ (d_model, d_model) x4 per layer (Q, K, V, O)
    - Attention: (batch, 1, d_head) @ (batch, d_head, seq_len) per head
    - FFN: (batch, 1, d_model) @ (d_model, 4*d_model) and back
    """
    B = batch_size
    
    total_flops = 0
    total_bytes = 0
    
    for _ in range(n_layers):
        # QKV projections: 3 matmuls of (B, d_model) @ (d_model, d_model)
        # Weight matrix is read ONCE, shared across batch
        qkv_flops = 3 * 2 * B * d_model * d_model
        qkv_bytes = 3 * (d_model * d_model * bpe) + 3 * (B * d_model * bpe) * 2  # weights + in/out
        
        # Attention: for each head, (B, 1, d_head) @ (B, d_head, seq_len_kv)
        # KV cache is per-request, not shared
        attn_flops = n_heads * 2 * B * seq_len_kv * d_head * 2  # QK^T and PV
        attn_bytes = n_heads * (B * seq_len_kv * d_head * bpe * 2 + B * d_head * bpe * 3)  # KV cache + Q + O
        
        # Output projection: (B, d_model) @ (d_model, d_model)
        o_flops = 2 * B * d_model * d_model
        o_bytes = d_model * d_model * bpe + B * d_model * bpe * 2
        
        # FFN: two linear layers (d_model -> 4*d_model -> d_model)
        ffn_flops = 2 * (2 * B * d_model * 4 * d_model)
        ffn_bytes = 2 * (d_model * 4 * d_model * bpe) + B * d_model * bpe * 3
        
        total_flops += qkv_flops + attn_flops + o_flops + ffn_flops
        total_bytes += qkv_bytes + attn_bytes + o_bytes + ffn_bytes
    
    return total_flops / total_bytes, total_flops, total_bytes


# Parameters for a ~7B model
d_model = 4096
n_layers = 32
n_heads = 32
d_head = 128
seq_len_kv = 2048

batch_sizes = [1, 2, 4, 8, 16, 32, 64, 128, 256]

print(f"Decode Step AI for ~7B Model (d={d_model}, L={n_layers}, KV len={seq_len_kv})")
print("=" * 60)

decode_ais = []
for bs in batch_sizes:
    ai, flops, bytes_t = decode_step_ai(bs, d_model, n_layers, seq_len_kv, n_heads, d_head)
    decode_ais.append(ai)
    bound = 'COMPUTE' if ai > h100_ratio else 'MEMORY'
    print(f"  Batch {bs:>4}: AI = {ai:>6.1f} ops/byte -> {bound}-bound")

print(f"\nH100 FP16 ops:byte ratio = {h100_ratio:.0f}")

In [None]:
# Plot batching effect
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# AI vs batch size
axes[0].semilogx(batch_sizes, decode_ais, 'ro-', linewidth=2, markersize=10, 
                  label='Decode AI', base=2)
axes[0].axhline(y=h100_ratio, color='green', linestyle='--', linewidth=2, 
                label=f'H100 FP16 ({h100_ratio:.0f})')
axes[0].set_xlabel('Batch Size', fontsize=13)
axes[0].set_ylabel('Arithmetic Intensity', fontsize=13)
axes[0].set_title('Decode AI vs Batch Size\n(~7B model, KV len=2048)', 
                   fontsize=13, fontweight='bold')
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)
axes[0].fill_between(batch_sizes, 0, h100_ratio, alpha=0.1, color='red')

# Throughput estimate
# Memory-bound: throughput = bandwidth / bytes_per_token
# Compute-bound: throughput = flops / flops_per_token
h100_bw = gpus['H100 (FP16)']['bandwidth']
h100_peak = gpus['H100 (FP16)']['flops']

throughputs = []
for bs in batch_sizes:
    ai, flops, bytes_t = decode_step_ai(bs, d_model, n_layers, seq_len_kv, n_heads, d_head)
    # Time = max(flops/peak, bytes/bandwidth)
    time_compute = flops / (h100_peak * 0.7)  # 70% utilization
    time_memory = bytes_t / (h100_bw * 0.8)   # 80% utilization
    time_total = max(time_compute, time_memory)
    tokens_per_second = bs / time_total
    throughputs.append(tokens_per_second)

axes[1].semilogx(batch_sizes, [t/1000 for t in throughputs], 'bs-', 
                  linewidth=2, markersize=10, base=2)
axes[1].set_xlabel('Batch Size', fontsize=13)
axes[1].set_ylabel('Throughput (K tokens/s)', fontsize=13)
axes[1].set_title('Estimated Decode Throughput\n(H100, ~7B model)', 
                   fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("\nKey takeaway: Batching converts wasted memory bandwidth into useful throughput.")
print("This is why inference servers batch requests together!")

## Part 8: Comparing Multiple GPUs on the Roofline

In [None]:
# Multi-GPU roofline comparison
fig, ax = plt.subplots(figsize=(14, 8))

ai_range = np.logspace(-1, 4, 1000)

for name, specs in gpus.items():
    peak = specs['flops']
    bw = specs['bandwidth']
    
    mem_line = bw * ai_range
    compute_line = np.full_like(ai_range, peak)
    roofline = np.minimum(mem_line, compute_line)
    
    ax.loglog(ai_range, roofline, linewidth=2.5, color=specs['color'], 
              label=f"{name} (peak: {peak/1e12:.0f} TF)")

# Mark some workloads
workload_points = [
    ('Decode\n(batch=1)', 2, 1e11, 'v', 'red'),
    ('Decode\n(batch=32)', 25, 3e12, '^', 'orange'),
    ('Prefill\n(1K tokens)', 500, 5e14, 's', 'green'),
]

for label, ai, perf, marker, color in workload_points:
    ax.plot(ai, perf, marker, markersize=15, color=color, 
            markeredgecolor='black', markeredgewidth=2, zorder=5)
    ax.annotate(label, (ai, perf), textcoords="offset points",
                xytext=(15, -5), fontsize=10, fontweight='bold')

ax.set_xlabel('Arithmetic Intensity (FLOPS/Byte)', fontsize=13)
ax.set_ylabel('Performance (FLOPS)', fontsize=13)
ax.set_title('Roofline Comparison: A100 vs H100 vs B200',
             fontsize=15, fontweight='bold')
ax.legend(fontsize=10, loc='lower right')
ax.grid(True, alpha=0.3, which='both')
ax.set_xlim(0.1, 5000)
ax.set_ylim(1e10, 1e16)

plt.tight_layout()
plt.show()

print("B200 has both higher compute AND higher bandwidth.")
print("But the ops:byte ratio is similar, so the bound classification doesn't change much.")

## Part 9: Interactive - Vary Batch Size and Sequence Length

In [None]:
# Create interactive-style visualization (works without widgets too)

# Compute a grid of AI values
batch_grid = [1, 2, 4, 8, 16, 32, 64, 128]
seq_grid = [256, 512, 1024, 2048, 4096, 8192]

ai_matrix = np.zeros((len(batch_grid), len(seq_grid)))

for i, bs in enumerate(batch_grid):
    for j, sl in enumerate(seq_grid):
        ai, _, _ = decode_step_ai(bs, d_model, n_layers, sl, n_heads, d_head)
        ai_matrix[i, j] = ai

fig, ax = plt.subplots(figsize=(12, 8))

import matplotlib.colors as mcolors

# Custom colormap: red for memory-bound, green for compute-bound
cmap = plt.cm.RdYlGn
norm = mcolors.TwoSlopeNorm(vmin=0, vcenter=h100_ratio, vmax=h100_ratio*3)

im = ax.imshow(ai_matrix, cmap=cmap, norm=norm, aspect='auto')

ax.set_xticks(range(len(seq_grid)))
ax.set_xticklabels(seq_grid)
ax.set_yticks(range(len(batch_grid)))
ax.set_yticklabels(batch_grid)

ax.set_xlabel('KV Cache Sequence Length', fontsize=13)
ax.set_ylabel('Batch Size', fontsize=13)
ax.set_title(f'Decode Arithmetic Intensity (ops/byte)\n~7B model on H100 FP16 (threshold: {h100_ratio:.0f})',
             fontsize=14, fontweight='bold')

# Add text annotations
for i in range(len(batch_grid)):
    for j in range(len(seq_grid)):
        val = ai_matrix[i, j]
        color = 'white' if val < h100_ratio * 0.3 or val > h100_ratio * 2 else 'black'
        label = f'{val:.0f}\n{"MEM" if val < h100_ratio else "COMP"}'
        ax.text(j, i, label, ha='center', va='center', fontsize=8, 
                color=color, fontweight='bold')

plt.colorbar(im, ax=ax, label='Arithmetic Intensity', shrink=0.8)
plt.tight_layout()
plt.show()

print("Red = Memory-bound | Green = Compute-bound")
print(f"Threshold (H100 FP16): {h100_ratio:.0f} ops/byte")
print("\nLarger batch + shorter KV -> more compute-bound")
print("Smaller batch + longer KV -> more memory-bound")

In [None]:
# Show the effect on estimated latency per token
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Vary batch size (fixed seq_len=2048)
bs_range = [1, 2, 4, 8, 16, 32, 64, 128, 256]
latencies_bs = []
for bs in bs_range:
    ai, flops, bytes_t = decode_step_ai(bs, d_model, n_layers, 2048, n_heads, d_head)
    time_compute = flops / (h100_peak * 0.7)
    time_memory = bytes_t / (h100_bw * 0.8)
    latency_ms = max(time_compute, time_memory) * 1000
    latencies_bs.append(latency_ms / bs)  # Per-request latency

axes[0].plot(bs_range, latencies_bs, 'ro-', linewidth=2, markersize=8)
axes[0].set_xscale('log', base=2)
axes[0].set_xlabel('Batch Size', fontsize=13)
axes[0].set_ylabel('Latency per token per request (ms)', fontsize=12)
axes[0].set_title('Per-Request Latency vs Batch Size\n(seq_len=2048)', 
                   fontsize=13, fontweight='bold')
axes[0].grid(True, alpha=0.3)

# Vary seq_len (fixed batch=1)
sl_range = [256, 512, 1024, 2048, 4096, 8192, 16384]
latencies_sl = []
for sl in sl_range:
    ai, flops, bytes_t = decode_step_ai(1, d_model, n_layers, sl, n_heads, d_head)
    time_compute = flops / (h100_peak * 0.7)
    time_memory = bytes_t / (h100_bw * 0.8)
    latency_ms = max(time_compute, time_memory) * 1000
    latencies_sl.append(latency_ms)

axes[1].plot(sl_range, latencies_sl, 'bs-', linewidth=2, markersize=8)
axes[1].set_xscale('log', base=2)
axes[1].set_xlabel('KV Cache Sequence Length', fontsize=13)
axes[1].set_ylabel('Latency per token (ms)', fontsize=12)
axes[1].set_title('Decode Latency vs Sequence Length\n(batch=1)', 
                   fontsize=13, fontweight='bold')
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print("Left: Batching amortizes weight loading -> lower per-request latency")
print("Right: Longer sequences -> more KV cache to read -> higher latency")

## Part 10: Practical Implications for Inference Engineering

Let's summarize the practical implications of arithmetic intensity analysis.

In [None]:
# Summary table
print("ARITHMETIC INTENSITY CHEAT SHEET")
print("=" * 70)
print()
print("OPERATION              AI (approx)    BOUND (H100 FP16)  OPTIMIZATION")
print("-" * 70)
print("Decode (batch=1)       ~2             Memory-bound       Quantize weights")
print("Decode (batch=32)      ~25            Memory-bound       Batch more requests")
print("Decode (batch=128)     ~80            Memory-bound*      Near threshold")
print("Prefill (N=1024)       ~500           Compute-bound      Use FP8/INT8")
print("Attention (N=4096)     ~62            Memory-bound       FlashAttention")
print("FFN layer              ~batch-dep     Varies             Depends on batch")
print()
print(f"H100 FP16 ops:byte ratio = {h100_ratio:.0f}")
print()
print("KEY RULES:")
print("1. Memory-bound -> reduce bytes (quantization, pruning)")
print("2. Compute-bound -> reduce FLOPs (lower precision, fewer layers)")
print("3. Batching is the #1 way to improve throughput for memory-bound ops")
print("4. FlashAttention avoids materializing the N x N matrix (kernel fusion)")
print("5. Higher GPU bandwidth (B200) helps memory-bound more than compute-bound")

---

## Key Takeaways

1. **Arithmetic Intensity (AI)** = FLOPs / Bytes transferred. It tells you whether your operation is compute-bound or memory-bound.

2. **The Roofline Model** has two ceilings:
   - Memory bandwidth ceiling (slope = bandwidth)
   - Compute ceiling (flat = peak FLOPS)
   - The crossover point is the **ops:byte ratio**

3. **Prefill is compute-bound**: Processing many tokens through matrix multiplies gives high AI (~500+).

4. **Decode is memory-bound**: Single-token generation has AI of ~2, far below the H100's ratio of ~295.

5. **Attention AI (~62 for d=128, N=4096)** is below the H100 threshold. FlashAttention helps by fusing operations to reduce memory traffic.

6. **Batching is the primary lever** for improving memory-bound operations. Each doubling of batch size roughly doubles throughput until you become compute-bound.

7. **Optimization strategy depends on the bound**:
   - Memory-bound -> reduce data movement (quantization, caching, fusion)
   - Compute-bound -> reduce arithmetic (lower precision, pruning)

---

*Next: We'll explore prefix caching - a technique that reduces redundant KV cache computation for requests with shared prefixes.*