In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Building a Complete Inference Engine

*Part 3 of the Vizuara series on Inference & Scaling*
*Estimated time: 60 minutes*

## 1. Why Does This Matter?

In Notebook 1, we built the KV cache to make generation fast. In Notebook 2, we built sampling strategies to make generation good. Now we put them together into a single, complete inference engine -- the kind of system that powers every LLM chatbot in the world.

This is not just engineering glue. There are important design decisions at the intersection:
- How do we handle the **prefill phase** (processing the prompt) separately from the **decode phase** (generating tokens)?
- How do we manage **stopping conditions** (end-of-sequence token, max length)?
- What are the **real bottlenecks** in practice?

In this notebook, we will:
1. Build a **complete generation pipeline** with KV cache + top-p sampling
2. Implement proper **prefill/decode separation**
3. Add **streaming output** (token-by-token delivery)
4. **Benchmark** the full system: throughput, latency, and memory
5. Profile where time is actually spent

In [None]:
# Setup -- run this cell first
!pip install -q torch matplotlib numpy

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time

%matplotlib inline

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Building Intuition

Think of an LLM inference request as a restaurant order. There are two distinct phases:

**Prefill** (the kitchen prepares your main dish): The model processes the entire prompt at once, computing attention over all prompt tokens in parallel. This is compute-bound -- the GPU does a lot of math on many tokens simultaneously.

**Decode** (the waiter brings courses one at a time): The model generates tokens sequentially, one at a time, using the KV cache. This is memory-bound -- the GPU spends most of its time loading cached KV vectors from memory rather than doing math.

Understanding this distinction is critical for optimizing inference:
- **Prefill optimization**: batch prompts, use tensor parallelism, maximize GPU compute utilization
- **Decode optimization**: minimize memory access, use quantized KV caches, maximize memory bandwidth utilization

Most production systems separate these two phases entirely, running them on different hardware configurations.

## 3. The Mathematics

### Prefill Phase

Given a prompt of $n$ tokens, the prefill phase processes all tokens in parallel through the transformer. For each layer $l$:

$$Q^{(l)}, K^{(l)}, V^{(l)} = W_Q^{(l)} H^{(l)}, \; W_K^{(l)} H^{(l)}, \; W_V^{(l)} H^{(l)}$$

where $H^{(l)} \in \mathbb{R}^{n \times d}$ is the hidden state matrix for all $n$ tokens.

Attention is computed in parallel over all $n$ query positions with causal masking:

$$\text{Attn}^{(l)} = \text{softmax}\left(\frac{Q^{(l)} {K^{(l)}}^T}{\sqrt{d_k}} + M_{\text{causal}}\right) V^{(l)}$$

**Compute cost**: $O(n^2 \cdot d)$ per layer for attention, plus $O(n \cdot d^2)$ for projections and FFN.

**Output**: The KV cache is populated with $K^{(l)}$ and $V^{(l)}$ for all layers. We also get the logits for the last token to begin generation.

### Decode Phase

At each decode step $t$, we process a single new token:

1. Compute $q_t, k_t, v_t$ from the new token's hidden state (cost: $O(d^2)$ per layer)
2. Append $k_t, v_t$ to the cache
3. Compute attention: $q_t$ against all $(n + t)$ cached keys (cost: $O((n+t) \cdot d_k)$ per layer per head)

**Total cost per decode step**: $O(d^2 + (n+t) \cdot d)$ per layer

The key insight: for large models ($d = 4096+$), the projection cost $O(d^2)$ dominates when the sequence is short, but the attention cost $O((n+t) \cdot d)$ dominates for long sequences.

### Arithmetic Intensity

The **arithmetic intensity** (FLOPs per byte of memory accessed) determines whether a workload is compute-bound or memory-bound:

$$\text{AI}_{\text{prefill}} = O(n) \quad \text{(high -- compute-bound)}$$
$$\text{AI}_{\text{decode}} = O(1) \quad \text{(low -- memory-bound)}$$

Prefill processes $n$ tokens and reuses the weight matrices across all tokens -- high reuse, high arithmetic intensity. Decode processes 1 token -- low reuse, low arithmetic intensity. This is why decode is almost always memory-bandwidth-bound.

## 4. Let's Build It -- Component by Component

### 4.1 The Model (with KV Cache Support)

In [None]:
class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x, kv_cache=None, start_pos=0):
        B, T, _ = x.shape
        qkv = self.qkv(x)
        q, k, v = qkv.split(self.d_model, dim=-1)

        q = q.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(B, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.d_k).transpose(1, 2)

        if kv_cache is not None:
            k_prev, v_prev = kv_cache
            k = torch.cat([k_prev, k], dim=2)
            v = torch.cat([v_prev, v], dim=2)

        new_cache = (k, v)
        T_total = k.shape[2]

        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)

        # Causal mask
        causal = torch.triu(
            torch.ones(T, T_total, device=x.device, dtype=torch.bool),
            diagonal=T_total - T + 1
        )
        scores.masked_fill_(causal.unsqueeze(0).unsqueeze(0), float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().view(B, T, self.d_model)
        return self.out_proj(out), new_cache


class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.attn = CausalSelfAttention(d_model, n_heads)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model),
        )

    def forward(self, x, kv_cache=None, start_pos=0):
        h = self.ln1(x)
        attn_out, new_cache = self.attn(h, kv_cache=kv_cache, start_pos=start_pos)
        x = x + attn_out
        x = x + self.ffn(self.ln2(x))
        return x, new_cache


class InferenceGPT(nn.Module):
    def __init__(self, vocab_size, d_model=128, n_heads=4, n_layers=4, max_len=512):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.layers = nn.ModuleList([TransformerBlock(d_model, n_heads) for _ in range(n_layers)])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)
        self.eos_token_id = None  # Set after tokenizer is defined

    def forward(self, tokens, kv_caches=None, start_pos=0):
        B, T = tokens.shape
        pos = torch.arange(start_pos, start_pos + T, device=tokens.device)
        x = self.tok_emb(tokens) + self.pos_emb(pos)

        new_caches = []
        for i, layer in enumerate(self.layers):
            cache = kv_caches[i] if kv_caches is not None else None
            x, new_cache = layer(x, kv_cache=cache, start_pos=start_pos)
            new_caches.append(new_cache)

        logits = self.head(self.ln_f(x))
        return logits, new_caches


# Create model
vocab_size = 128  # ASCII characters
model = InferenceGPT(vocab_size).to(device)
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

### 4.2 The Complete Generation Engine

In [None]:
class GenerationEngine:
    """
    A complete LLM inference engine with KV cache and configurable sampling.
    Separates prefill and decode phases explicitly.
    """

    def __init__(self, model, eos_token_id=None):
        self.model = model
        self.eos_token_id = eos_token_id
        self.model.eval()

    @torch.no_grad()
    def generate(self, prompt_tokens, max_new_tokens=100,
                 temperature=0.8, top_p=0.9, stream=False):
        """
        Generate tokens with KV cache and top-p sampling.

        Args:
            prompt_tokens: (1, seq_len) input token ids
            max_new_tokens: maximum tokens to generate
            temperature: sampling temperature
            top_p: nucleus sampling threshold
            stream: if True, yield tokens one at a time
        """
        # Phase 1: Prefill
        prefill_start = time.time()
        logits, kv_caches = self.model(prompt_tokens, kv_caches=None, start_pos=0)
        prefill_time = time.time() - prefill_start

        # Sample first token
        next_logits = logits[:, -1, :]
        next_token = self._sample(next_logits, temperature, top_p)
        generated = [next_token]

        if stream:
            yield next_token.item(), {'phase': 'decode', 'step': 0}

        cur_pos = prompt_tokens.shape[1]

        # Phase 2: Decode
        decode_start = time.time()
        for step in range(1, max_new_tokens):
            logits, kv_caches = self.model(
                next_token, kv_caches=kv_caches, start_pos=cur_pos
            )
            cur_pos += 1
            next_logits = logits[:, -1, :]
            next_token = self._sample(next_logits, temperature, top_p)
            generated.append(next_token)

            if stream:
                yield next_token.item(), {
                    'phase': 'decode',
                    'step': step,
                }

            if self.eos_token_id is not None and next_token.item() == self.eos_token_id:
                break

        decode_time = time.time() - decode_start

        if not stream:
            all_tokens = torch.cat([prompt_tokens] + generated, dim=1)
            stats = {
                'prefill_time': prefill_time,
                'decode_time': decode_time,
                'prefill_tokens': prompt_tokens.shape[1],
                'decode_tokens': len(generated),
                'prefill_tps': prompt_tokens.shape[1] / prefill_time if prefill_time > 0 else 0,
                'decode_tps': len(generated) / decode_time if decode_time > 0 else 0,
            }
            return all_tokens, stats

    def _sample(self, logits, temperature=1.0, top_p=1.0):
        """Top-p sampling with temperature."""
        if temperature <= 0:
            return torch.argmax(logits, dim=-1, keepdim=True)

        logits = logits / temperature

        if top_p < 1.0:
            sorted_logits, sorted_indices = torch.sort(logits, descending=True, dim=-1)
            cumprobs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            mask = cumprobs - F.softmax(sorted_logits, dim=-1) >= top_p
            sorted_logits[mask] = float('-inf')
            probs = F.softmax(sorted_logits, dim=-1)
            idx = torch.multinomial(probs, num_samples=1)
            return sorted_indices.gather(-1, idx)
        else:
            probs = F.softmax(logits, dim=-1)
            return torch.multinomial(probs, num_samples=1)


# Create engine
engine = GenerationEngine(model)

# Quick test
prompt = torch.randint(0, vocab_size, (1, 32), device=device)
output, stats = engine.generate(prompt, max_new_tokens=50, temperature=0.8, top_p=0.9)

print(f"Generated {stats['decode_tokens']} tokens")
print(f"Prefill: {stats['prefill_time']*1000:.1f}ms ({stats['prefill_tps']:.0f} tok/s)")
print(f"Decode:  {stats['decode_time']*1000:.1f}ms ({stats['decode_tps']:.0f} tok/s)")

### Visualization Checkpoint: Prefill vs Decode Timing

In [None]:
# Profile prefill vs decode across different prompt lengths
prompt_lengths = [16, 32, 64, 128, 256]
gen_length = 100
n_trials = 3

prefill_times = []
decode_times = []
prefill_tps_list = []
decode_tps_list = []

for plen in prompt_lengths:
    p_times = []
    d_times = []
    for _ in range(n_trials):
        prompt = torch.randint(0, vocab_size, (1, plen), device=device)
        _, stats = engine.generate(prompt, max_new_tokens=gen_length)
        p_times.append(stats['prefill_time'])
        d_times.append(stats['decode_time'])

    prefill_times.append(np.mean(p_times))
    decode_times.append(np.mean(d_times))
    prefill_tps_list.append(np.mean([plen / t for t in p_times]))
    decode_tps_list.append(np.mean([gen_length / t for t in d_times]))

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# Stacked bar: prefill vs decode time
ax = axes[0]
x = np.arange(len(prompt_lengths))
ax.bar(x, [t*1000 for t in prefill_times], label='Prefill', color='#3498db', alpha=0.8)
ax.bar(x, [t*1000 for t in decode_times], bottom=[t*1000 for t in prefill_times],
       label='Decode', color='#e74c3c', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(prompt_lengths)
ax.set_xlabel('Prompt Length')
ax.set_ylabel('Time (ms)')
ax.set_title('Total Latency Breakdown')
ax.legend()

# Throughput comparison
ax = axes[1]
ax.plot(prompt_lengths, prefill_tps_list, 'bo-', linewidth=2, label='Prefill tok/s')
ax.plot(prompt_lengths, decode_tps_list, 'ro-', linewidth=2, label='Decode tok/s')
ax.set_xlabel('Prompt Length')
ax.set_ylabel('Tokens per Second')
ax.set_title('Throughput: Prefill vs Decode')
ax.legend()
ax.grid(True, alpha=0.3)

# Fraction of time in decode
ax = axes[2]
decode_fracs = [d / (p + d) * 100 for p, d in zip(prefill_times, decode_times)]
ax.bar(x, decode_fracs, color='#e74c3c', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(prompt_lengths)
ax.set_xlabel('Prompt Length')
ax.set_ylabel('% of Total Time')
ax.set_title('Fraction of Time in Decode Phase')
ax.axhline(y=50, color='gray', linestyle='--', alpha=0.5)

plt.tight_layout()
plt.show()

## 5. Your Turn

### TODO 1: Implement Streaming Generation

The engine supports a `stream=True` mode that yields tokens one at a time. Use it to build a function that prints tokens as they are generated, simulating the "typing" effect you see in ChatGPT.

In [None]:
# TODO: Use the streaming mode to print tokens as they arrive.
# The engine.generate() function is a generator when stream=True.
#
# YOUR CODE HERE
# def stream_generate(engine, prompt_tokens, max_new_tokens=50,
#                     temperature=0.8, top_p=0.9):
#     """Print tokens as they are generated, with timing info."""
#     print("Generating: ", end="", flush=True)
#     token_times = []
#
#     for token_id, info in engine.generate(
#         prompt_tokens, max_new_tokens=max_new_tokens,
#         temperature=temperature, top_p=top_p, stream=True
#     ):
#         # Convert token_id to character and print
#         char = chr(token_id) if 32 <= token_id < 127 else '?'
#         print(char, end="", flush=True)
#         token_times.append(time.time())
#
#     print("\n")
#
#     # Calculate inter-token latencies
#     if len(token_times) > 1:
#         latencies = [token_times[i+1] - token_times[i]
#                      for i in range(len(token_times)-1)]
#         print(f"Mean inter-token latency: {np.mean(latencies)*1000:.1f}ms")
#         print(f"P50 latency: {np.percentile(latencies, 50)*1000:.1f}ms")
#         print(f"P99 latency: {np.percentile(latencies, 99)*1000:.1f}ms")
#
# # Test it
# prompt = torch.randint(0, vocab_size, (1, 16), device=device)
# stream_generate(engine, prompt)

### TODO 2: KV Cache Memory Calculator

Build a function that calculates the exact KV cache memory for any model configuration and prints a formatted report.

In [None]:
# TODO: Implement a KV cache memory calculator.
#
# def kv_cache_memory(d_model, n_heads, n_layers, seq_len,
#                     dtype_bytes=2, batch_size=1):
#     """
#     Calculate KV cache memory in bytes.
#
#     Memory = 2 (K and V) * n_layers * seq_len * d_model * dtype_bytes * batch_size
#
#     Returns: dict with memory in bytes, MB, and GB
#     """
#     # YOUR CODE HERE
#     pass
#
# # Test with common model sizes
# models = {
#     "GPT-2 (124M)": {"d_model": 768, "n_heads": 12, "n_layers": 12},
#     "LLaMA-7B": {"d_model": 4096, "n_heads": 32, "n_layers": 32},
#     "LLaMA-70B": {"d_model": 8192, "n_heads": 64, "n_layers": 80},
# }
#
# print(f"{'Model':<20} {'Seq=2K':<12} {'Seq=8K':<12} {'Seq=32K':<12}")
# print("-" * 56)
# for name, cfg in models.items():
#     mem_2k = kv_cache_memory(**cfg, seq_len=2048)
#     mem_8k = kv_cache_memory(**cfg, seq_len=8192)
#     mem_32k = kv_cache_memory(**cfg, seq_len=32768)
#     print(f"{name:<20} {mem_2k['gb']:<12.2f} {mem_8k['gb']:<12.2f} {mem_32k['gb']:<12.2f}")

## 6. Putting It All Together

Let us train the model on real text and run the complete engine.

In [None]:
# Train on Shakespeare-like text
corpus = """
To be or not to be that is the question
Whether tis nobler in the mind to suffer
The slings and arrows of outrageous fortune
Or to take arms against a sea of troubles
And by opposing end them To die to sleep
No more and by a sleep to say we end
The heartache and the thousand natural shocks
That flesh is heir to Tis a consummation
Devoutly to be wished To die to sleep
To sleep perchance to dream aye there is the rub
""".strip() * 20

chars = sorted(set(corpus))
c2i = {c: i for i, c in enumerate(chars)}
i2c = {i: c for c, i in c2i.items()}
v = len(chars)

# Rebuild model with correct vocab
model = InferenceGPT(v, d_model=128, n_heads=4, n_layers=4).to(device)
model.eos_token_id = None

data = torch.tensor([c2i[c] for c in corpus], device=device)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-3)

seq_len = 64
model.train()
for epoch in range(200):
    total_loss, n = 0, 0
    for i in range(0, len(data) - seq_len - 1, seq_len):
        x = data[i:i+seq_len].unsqueeze(0)
        y = data[i+1:i+seq_len+1].unsqueeze(0)
        logits, _ = model(x, start_pos=0)
        loss = F.cross_entropy(logits.view(-1, v), y.view(-1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        n += 1
    if (epoch + 1) % 50 == 0:
        print(f"Epoch {epoch+1}: loss = {total_loss/n:.4f}")

engine = GenerationEngine(model)

## 7. Training and Results

In [None]:
# Generate with the trained model using different settings
model.eval()
prompt_text = "To be or not"
prompt_ids = torch.tensor([[c2i[c] for c in prompt_text]], device=device)

settings = [
    {"temperature": 0.0, "top_p": 1.0, "label": "Greedy (T=0)"},
    {"temperature": 0.5, "top_p": 0.9, "label": "Conservative (T=0.5, p=0.9)"},
    {"temperature": 0.8, "top_p": 0.9, "label": "Balanced (T=0.8, p=0.9)"},
    {"temperature": 1.2, "top_p": 0.95, "label": "Creative (T=1.2, p=0.95)"},
]

print("=" * 70)
print(f"Prompt: '{prompt_text}'")
print("=" * 70)

for s in settings:
    output, stats = engine.generate(
        prompt_ids, max_new_tokens=80,
        temperature=s['temperature'], top_p=s['top_p']
    )
    text = ''.join([i2c[t.item()] for t in output[0]])
    print(f"\n[{s['label']}]")
    print(f"  {text}")
    print(f"  Prefill: {stats['prefill_time']*1000:.1f}ms | "
          f"Decode: {stats['decode_time']*1000:.1f}ms | "
          f"Throughput: {stats['decode_tps']:.0f} tok/s")

print("\n" + "=" * 70)

## 8. Final Output

In [None]:
# Comprehensive performance dashboard
gen_lengths = [25, 50, 100, 200]
prompt_len = 32

results = []
for gl in gen_lengths:
    prompt = torch.tensor([[c2i[c] for c in "To be or not to be that is"[:prompt_len]]], device=device)
    # Pad if needed
    if prompt.shape[1] < prompt_len:
        pad = torch.randint(0, v, (1, prompt_len - prompt.shape[1]), device=device)
        prompt = torch.cat([prompt, pad], dim=1)

    times = []
    for _ in range(5):
        _, stats = engine.generate(prompt, max_new_tokens=gl)
        times.append(stats)

    avg_prefill = np.mean([t['prefill_time'] for t in times])
    avg_decode = np.mean([t['decode_time'] for t in times])
    avg_tps = np.mean([t['decode_tps'] for t in times])

    results.append({
        'gen_len': gl,
        'prefill_ms': avg_prefill * 1000,
        'decode_ms': avg_decode * 1000,
        'total_ms': (avg_prefill + avg_decode) * 1000,
        'tokens_per_sec': avg_tps,
    })

# Summary table
print(f"\n{'Gen Length':<12} {'Prefill (ms)':<14} {'Decode (ms)':<14} {'Total (ms)':<14} {'Tok/s':<10}")
print("-" * 64)
for r in results:
    print(f"{r['gen_len']:<12} {r['prefill_ms']:<14.1f} {r['decode_ms']:<14.1f} "
          f"{r['total_ms']:<14.1f} {r['tokens_per_sec']:<10.0f}")

# Performance plot
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

gls = [r['gen_len'] for r in results]
ax1.plot(gls, [r['total_ms'] for r in results], 'bo-', linewidth=2, markersize=8, label='Total')
ax1.plot(gls, [r['prefill_ms'] for r in results], 'g^--', linewidth=1.5, label='Prefill')
ax1.plot(gls, [r['decode_ms'] for r in results], 'rs--', linewidth=1.5, label='Decode')
ax1.set_xlabel('Generation Length (tokens)')
ax1.set_ylabel('Latency (ms)')
ax1.set_title('End-to-End Latency')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.plot(gls, [r['tokens_per_sec'] for r in results], 'ko-', linewidth=2, markersize=8)
ax2.set_xlabel('Generation Length (tokens)')
ax2.set_ylabel('Tokens per Second')
ax2.set_title('Decode Throughput')
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## 9. Reflection and Next Steps

### What We Learned

1. **Prefill and decode are fundamentally different**: Prefill processes many tokens in parallel (compute-bound), while decode processes one token at a time (memory-bound). In production, they are often handled separately.

2. **The complete generation loop** combines KV cache management, positional encoding offsets, sampling strategy, and stopping conditions into a single pipeline.

3. **Streaming generation** delivers tokens to the user as they are produced, hiding the per-token latency behind the reading speed of the human user.

4. **Performance characteristics**: Decode throughput is relatively constant regardless of generation length (until KV cache memory becomes a bottleneck). Prefill latency scales roughly quadratically with prompt length due to attention.

### Key Metrics for Production Systems

| Metric | What It Measures | Typical Values |
|--------|-----------------|----------------|
| Time to First Token (TTFT) | Prefill latency | 50-500ms |
| Inter-Token Latency (ITL) | Per-token decode time | 10-50ms |
| Throughput (tok/s) | Decode speed | 30-150 tok/s per user |
| KV Cache Memory | Memory per user | 0.5-4 GB |

### What is Next

With efficient, high-quality generation working, the next notebook tackles a different problem entirely: **how to make the model better at specific tasks without retraining it from scratch**. We will implement LoRA fine-tuning from scratch.