In [None]:
# Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime -> Change runtime type -> GPU")

print(f"\nPython {sys.version.split()[0]}")
print(f"PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"Random seed set to {SEED}")

%matplotlib inline

# Case Study: Autoregressive Contract Clause Generation for Legal Document Drafting
## Implementation Notebook

---

**Scenario:** You are an ML engineer at **Lexis Draft AI**, a legal technology startup building a contract drafting assistant deployed across 18 mid-size law firms. Your current system uses TF-IDF retrieval over a library of 45,000 clauses, achieving a 62% clause relevance score and 38% first-draft acceptance rate.

**Your mission:** Build a GPT-style autoregressive transformer that generates novel, contextually appropriate contract clauses conditioned on deal parameters, previously drafted sections, and firm-specific style. Target: **85%+ clause relevance** and **65%+ first-draft acceptance rate**.

**Why GPT over retrieval?**
1. Retrieval cannot generate novel clauses for unprecedented deal structures (35% of queries).
2. Retrieval selects each clause independently, causing inconsistent defined terms across sections.
3. Retrieval cannot adapt to firm-specific drafting conventions.
4. An autoregressive model generates one token at a time conditioned on the full context -- naturally handling all three failure modes.

**Constraints:**
- Model must fit in 24 GB GPU memory (A10G) at inference time (~350M parameters in FP16)
- Clause generation latency < 5 seconds (100-300 tokens)
- All data must stay within each firm's cloud environment (no external APIs)
- Training data: ~405,000 contract clauses (~81M tokens)

---

# AI Teaching Assistant

Need help with this notebook? Open the **AI Teaching Assistant** -- it has already read this entire notebook and can help with concepts, code, and exercises.

**[Open AI Teaching Assistant](https://course-creator-brown.vercel.app/courses/build-llm/practice/2/assistant)**

*Tip: Open it in a separate tab and work through this notebook side-by-side.*

## 3.1 Data Acquisition and Preprocessing

In the Lexis Draft AI scenario, the real training data would consist of 405,000 contract clauses from 18 law firm clients and SEC EDGAR filings. For this notebook, we use a synthetic legal clause dataset that mirrors the structure and patterns of real contract language.

The techniques you learn here -- character-level tokenization, sequence chunking, and data preparation for autoregressive training -- transfer directly to production systems. The main difference in production would be using a BPE tokenizer (32,000 tokens) trained on the legal corpus, where terms like "indemnification" become single tokens.

In [None]:
# Install dependencies
!pip install matplotlib numpy -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from collections import Counter

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Synthetic legal clause corpus
# These clauses are representative of real contract language patterns
# across the clause categories that Lexis Draft AI handles

LEGAL_CLAUSES = [
    # Limitation of Liability
    "IN NO EVENT SHALL EITHER PARTY BE LIABLE TO THE OTHER PARTY FOR ANY INDIRECT INCIDENTAL SPECIAL CONSEQUENTIAL OR PUNITIVE DAMAGES ARISING OUT OF OR RELATED TO THIS AGREEMENT REGARDLESS OF WHETHER SUCH DAMAGES ARE BASED ON CONTRACT TORT NEGLIGENCE STRICT LIABILITY OR ANY OTHER THEORY EVEN IF SUCH PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES",
    "THE AGGREGATE LIABILITY OF LICENSOR UNDER THIS AGREEMENT SHALL NOT EXCEED THE TOTAL FEES PAID BY LICENSEE DURING THE TWELVE MONTH PERIOD IMMEDIATELY PRECEDING THE EVENT GIVING RISE TO SUCH LIABILITY",
    "NOTWITHSTANDING ANYTHING TO THE CONTRARY HEREIN THE LIMITATIONS SET FORTH IN THIS SECTION SHALL NOT APPLY TO A PARTY BREACH OF ITS CONFIDENTIALITY OBLIGATIONS OR A PARTY INDEMNIFICATION OBLIGATIONS UNDER THIS AGREEMENT",

    # Indemnification
    "LICENSEE SHALL INDEMNIFY DEFEND AND HOLD HARMLESS LICENSOR AND ITS OFFICERS DIRECTORS EMPLOYEES AND AGENTS FROM AND AGAINST ANY AND ALL CLAIMS DAMAGES LOSSES LIABILITIES COSTS AND EXPENSES INCLUDING REASONABLE ATTORNEYS FEES ARISING OUT OF OR RELATED TO LICENSEE USE OF THE SOFTWARE IN VIOLATION OF THIS AGREEMENT",
    "LICENSOR SHALL INDEMNIFY DEFEND AND HOLD HARMLESS LICENSEE FROM AND AGAINST ANY CLAIMS ALLEGING THAT THE SOFTWARE AS PROVIDED BY LICENSOR INFRINGES ANY UNITED STATES PATENT COPYRIGHT OR TRADE SECRET OF A THIRD PARTY",
    "THE INDEMNIFYING PARTY SHALL HAVE THE RIGHT TO CONTROL THE DEFENSE AND SETTLEMENT OF ANY CLAIM SUBJECT TO THE INDEMNIFIED PARTY CONSENT WHICH SHALL NOT BE UNREASONABLY WITHHELD",

    # Confidentiality
    "EACH PARTY AGREES TO HOLD IN STRICT CONFIDENCE ALL CONFIDENTIAL INFORMATION RECEIVED FROM THE OTHER PARTY AND SHALL NOT DISCLOSE SUCH INFORMATION TO ANY THIRD PARTY WITHOUT THE PRIOR WRITTEN CONSENT OF THE DISCLOSING PARTY",
    "CONFIDENTIAL INFORMATION SHALL NOT INCLUDE INFORMATION THAT IS OR BECOMES PUBLICLY AVAILABLE THROUGH NO FAULT OF THE RECEIVING PARTY WAS IN THE RECEIVING PARTY POSSESSION PRIOR TO DISCLOSURE OR IS INDEPENDENTLY DEVELOPED BY THE RECEIVING PARTY WITHOUT USE OF THE DISCLOSING PARTY CONFIDENTIAL INFORMATION",
    "THE OBLIGATIONS OF CONFIDENTIALITY SET FORTH HEREIN SHALL SURVIVE THE TERMINATION OR EXPIRATION OF THIS AGREEMENT FOR A PERIOD OF THREE YEARS",

    # Termination
    "EITHER PARTY MAY TERMINATE THIS AGREEMENT FOR CAUSE UPON THIRTY DAYS PRIOR WRITTEN NOTICE IF THE OTHER PARTY MATERIALLY BREACHES ANY PROVISION OF THIS AGREEMENT AND FAILS TO CURE SUCH BREACH WITHIN THE NOTICE PERIOD",
    "UPON TERMINATION OF THIS AGREEMENT FOR ANY REASON LICENSEE SHALL IMMEDIATELY CEASE ALL USE OF THE SOFTWARE AND SHALL RETURN OR DESTROY ALL COPIES OF THE SOFTWARE AND CONFIDENTIAL INFORMATION IN ITS POSSESSION",
    "THE FOLLOWING PROVISIONS SHALL SURVIVE ANY TERMINATION OR EXPIRATION OF THIS AGREEMENT CONFIDENTIALITY LIMITATION OF LIABILITY INDEMNIFICATION AND ANY PROVISIONS WHICH BY THEIR NATURE ARE INTENDED TO SURVIVE",

    # Representations and Warranties
    "LICENSOR REPRESENTS AND WARRANTS THAT IT HAS THE FULL RIGHT POWER AND AUTHORITY TO ENTER INTO THIS AGREEMENT AND TO GRANT THE LICENSES AND RIGHTS GRANTED HEREIN",
    "LICENSEE REPRESENTS AND WARRANTS THAT IT SHALL USE THE SOFTWARE IN COMPLIANCE WITH ALL APPLICABLE LAWS RULES AND REGULATIONS",
    "EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT LICENSOR MAKES NO WARRANTIES EXPRESS OR IMPLIED INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OF MERCHANTABILITY FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT",

    # Governing Law and Dispute Resolution
    "THIS AGREEMENT SHALL BE GOVERNED BY AND CONSTRUED IN ACCORDANCE WITH THE LAWS OF THE STATE OF DELAWARE WITHOUT REGARD TO ITS CONFLICT OF LAWS PRINCIPLES",
    "ANY DISPUTE ARISING OUT OF OR RELATING TO THIS AGREEMENT SHALL BE RESOLVED BY BINDING ARBITRATION ADMINISTERED BY THE AMERICAN ARBITRATION ASSOCIATION IN ACCORDANCE WITH ITS COMMERCIAL ARBITRATION RULES",
    "THE PARTIES AGREE THAT ANY LEGAL ACTION OR PROCEEDING ARISING UNDER THIS AGREEMENT SHALL BE BROUGHT EXCLUSIVELY IN THE FEDERAL OR STATE COURTS LOCATED IN WILMINGTON DELAWARE",

    # Assignment
    "NEITHER PARTY MAY ASSIGN OR TRANSFER THIS AGREEMENT OR ANY RIGHTS OR OBLIGATIONS HEREUNDER WITHOUT THE PRIOR WRITTEN CONSENT OF THE OTHER PARTY EXCEPT THAT EITHER PARTY MAY ASSIGN THIS AGREEMENT WITHOUT CONSENT IN CONNECTION WITH A MERGER ACQUISITION OR SALE OF ALL OR SUBSTANTIALLY ALL OF ITS ASSETS",

    # Force Majeure
    "NEITHER PARTY SHALL BE LIABLE FOR ANY FAILURE OR DELAY IN PERFORMING ITS OBLIGATIONS UNDER THIS AGREEMENT TO THE EXTENT SUCH FAILURE OR DELAY RESULTS FROM CIRCUMSTANCES BEYOND THE REASONABLE CONTROL OF SUCH PARTY INCLUDING BUT NOT LIMITED TO ACTS OF GOD NATURAL DISASTERS PANDEMIC GOVERNMENT ACTIONS WAR TERRORISM OR CIVIL UNREST",

    # Intellectual Property
    "ALL INTELLECTUAL PROPERTY RIGHTS IN AND TO THE SOFTWARE INCLUDING ALL MODIFICATIONS ENHANCEMENTS AND DERIVATIVE WORKS SHALL REMAIN THE EXCLUSIVE PROPERTY OF LICENSOR AND NOTHING IN THIS AGREEMENT SHALL BE CONSTRUED AS TRANSFERRING ANY OWNERSHIP RIGHTS TO LICENSEE",
]

# Repeat and augment to create a larger corpus
corpus_text = "\n\n".join(LEGAL_CLAUSES * 20)
print(f"Corpus size: {len(corpus_text):,} characters")
print(f"Number of clauses: {len(LEGAL_CLAUSES) * 20}")
print(f"Unique clause types: {len(LEGAL_CLAUSES)}")
print(f"\nSample clause (first 200 chars):")
print(corpus_text[:200] + "...")

### TODO 1: Build a Character-Level Tokenizer and Prepare Training Data

In production, Lexis Draft AI uses a BPE tokenizer with 32,000 tokens trained on the legal corpus, where domain-specific terms like "indemnification" and "notwithstanding" each become a single token. For this notebook, we use character-level tokenization -- it is simpler to implement and lets us focus on the model architecture.

The key insight is the same at both levels: we convert text into a sequence of integer IDs, then chop the corpus into fixed-length training sequences.

In [None]:
def build_tokenizer_and_data(corpus_text, max_seq_len=256):
    """
    Build a character-level tokenizer for the legal corpus and prepare
    training sequences.

    Args:
        corpus_text: string containing the full training corpus
        max_seq_len: maximum sequence length for training chunks

    Returns:
        - encode: function mapping string -> list[int]
        - decode: function mapping list[int] -> string
        - train_data: tensor of shape (num_sequences, max_seq_len)
        - vocab_size: int

    Steps:
        1. Collect all unique characters in corpus_text, sort them.
        2. Create char-to-id and id-to-char mappings (dicts).
        3. Define encode(text) that maps each character to its ID.
        4. Define decode(ids) that maps each ID back to a character
           and joins them into a string.
        5. Encode the full corpus into a 1D tensor of token IDs.
        6. Truncate to a multiple of max_seq_len, then reshape into
           non-overlapping chunks of shape (num_sequences, max_seq_len).
        7. Return (encode, decode, train_data, vocab_size).

    Hints:
        - chars = sorted(set(corpus_text))
        - char_to_id = {ch: i for i, ch in enumerate(chars)}
        - id_to_char = {i: ch for ch, i in char_to_id.items()}
        - encoded = torch.tensor([char_to_id[c] for c in corpus_text])
        - total = (len(encoded) // max_seq_len) * max_seq_len
        - train_data = encoded[:total].view(-1, max_seq_len)
    """
    # YOUR CODE HERE
    pass


# Build tokenizer and data
encode, decode, train_data, vocab_size = build_tokenizer_and_data(corpus_text)
print(f"Vocabulary size: {vocab_size}")
print(f"Training sequences: {train_data.shape[0]}")
print(f"Sequence length: {train_data.shape[1]}")
print(f"\nSample decoded sequence:")
print(decode(train_data[0].tolist())[:120] + "...")

In [None]:
# Verification: check your tokenizer
assert callable(encode), "encode must be a function"
assert callable(decode), "decode must be a function"
assert decode(encode("HELLO")) == "HELLO", "Round-trip encode/decode failed"
assert train_data.ndim == 2, f"train_data should be 2D, got {train_data.ndim}D"
assert train_data.shape[1] == 256, f"Sequence length should be 256, got {train_data.shape[1]}"
assert vocab_size > 0, "vocab_size must be positive"
print(f"All tokenizer checks passed.")
print(f"Vocabulary: {decode(list(range(min(vocab_size, 30))))}")

---

## 3.2 Model Architecture: Building a GPT from Scratch

Now we build the complete GPT model. The architecture has three main components:

1. **CausalSelfAttention** -- Multi-head attention with a causal mask so each token can only attend to tokens at or before its position. This is what makes the model autoregressive.

2. **TransformerBlock** -- Pre-LayerNorm architecture: `x = x + Attention(LN(x))` then `x = x + FFN(LN(x))`. The residual connections enable deep stacking.

3. **GPT** -- Token embeddings + positional embeddings, N Transformer blocks, final LayerNorm, and a linear projection to vocabulary logits.

**Mathematical foundation:**

The causal self-attention computation is:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}} + M\right)V$$

where $M_{ij} = 0$ if $i \geq j$ and $M_{ij} = -\infty$ if $i < j$. This ensures that generating token $t$ depends only on tokens $1$ through $t-1$.

### TODO 2: Implement Causal Self-Attention

In [None]:
class CausalSelfAttention(nn.Module):
    """
    Multi-head causal (masked) self-attention.

    Each token attends only to tokens at the same or earlier
    positions. Future positions are masked with -inf before softmax.

    Args:
        d_model: model embedding dimension
        n_heads: number of attention heads
        max_seq_len: maximum sequence length (for causal mask buffer)

    Forward:
        Input: x of shape (batch, seq_len, d_model)
        Output: out of shape (batch, seq_len, d_model)

    Implementation steps in __init__:
        1. Store n_heads and d_k = d_model // n_heads.
        2. Create a single Linear layer: d_model -> 3 * d_model
           (this produces Q, K, V in one matrix multiply).
        3. Create an output projection: Linear(d_model, d_model).
        4. Register a causal mask buffer:
           mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
           self.register_buffer('mask', mask)

    Implementation steps in forward(x):
        1. Get B, T, C from x.shape.
        2. Compute qkv = self.qkv(x), shape (B, T, 3*C).
        3. Reshape to (B, T, 3, n_heads, d_k) and permute to
           (3, B, n_heads, T, d_k). Then split into q, k, v.
        4. Compute attention scores: (q @ k.transpose(-2, -1)) / sqrt(d_k).
        5. Apply causal mask: att.masked_fill(self.mask[:T, :T], float('-inf')).
        6. Apply softmax along the last dimension.
        7. Multiply by v: (att @ v), shape (B, n_heads, T, d_k).
        8. Transpose and reshape back to (B, T, C).
        9. Apply output projection.
    """
    def __init__(self, d_model, n_heads, max_seq_len=256):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # YOUR CODE HERE: define self.qkv, self.proj, register mask buffer
        pass

    def forward(self, x):
        B, T, C = x.shape
        # YOUR CODE HERE
        pass

In [None]:
# Verification: test CausalSelfAttention
test_attn = CausalSelfAttention(d_model=64, n_heads=4, max_seq_len=32)
test_input = torch.randn(2, 16, 64)
test_output = test_attn(test_input)
assert test_output.shape == (2, 16, 64), f"Expected (2,16,64), got {test_output.shape}"

# Verify causality: changing a future token should not affect past outputs
x1 = torch.randn(1, 8, 64)
x2 = x1.clone()
x2[0, 5, :] = torch.randn(64)  # change token at position 5
out1 = test_attn(x1)
out2 = test_attn(x2)
# Positions 0-4 should be identical (they cannot see position 5)
assert torch.allclose(out1[0, :5], out2[0, :5], atol=1e-5), "Causal mask broken: future change affected past"
print("CausalSelfAttention: all checks passed (shape and causality).")

### TODO 3: Implement Transformer Block and Full GPT Model

The Transformer block uses the **Pre-LN** architecture (as in GPT-2), where LayerNorm is applied *before* the sub-layer rather than after. This improves training stability for deeper models.

The feed-forward network (FFN) expands the representation to 4x the model dimension and then compresses it back, using GELU activation. This expansion gives each token position a richer intermediate representation to work with.

In [None]:
class TransformerBlock(nn.Module):
    """
    Single GPT Transformer block with Pre-LN architecture.

    Architecture:
        x = x + Attention(LayerNorm(x))
        x = x + FFN(LayerNorm(x))

    FFN: Linear(d_model, 4*d_model) -> GELU -> Linear(4*d_model, d_model)

    __init__ components:
        - self.ln1: LayerNorm(d_model)
        - self.attn: CausalSelfAttention(d_model, n_heads, max_seq_len)
        - self.ln2: LayerNorm(d_model)
        - self.ffn: nn.Sequential(
              nn.Linear(d_model, 4 * d_model),
              nn.GELU(),
              nn.Linear(4 * d_model, d_model)
          )

    forward: apply the Pre-LN residual pattern.
    """
    def __init__(self, d_model, n_heads, max_seq_len=256):
        super().__init__()
        # YOUR CODE HERE
        pass

    def forward(self, x):
        # YOUR CODE HERE
        pass


class GPT(nn.Module):
    """
    Complete GPT language model for autoregressive generation.

    Architecture:
        1. Token embedding: nn.Embedding(vocab_size, d_model)
        2. Positional embedding: nn.Embedding(max_seq_len, d_model)
        3. Stack of N TransformerBlocks
        4. Final LayerNorm
        5. Linear head: d_model -> vocab_size (no bias)

    Forward pass:
        1. Look up token embeddings for input IDs.
        2. Create position indices 0..T-1 and look up positional embeddings.
        3. Add token + positional embeddings.
        4. Pass through all Transformer blocks sequentially.
        5. Apply final LayerNorm.
        6. Project to vocabulary logits with the linear head.

    Input: idx of shape (batch_size, seq_len) -- integer token IDs
    Output: logits of shape (batch_size, seq_len, vocab_size)
    """
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_seq_len=256):
        super().__init__()
        # YOUR CODE HERE
        pass

    def forward(self, idx):
        B, T = idx.shape
        # YOUR CODE HERE
        pass

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())

In [None]:
# Instantiate the model
model = GPT(
    vocab_size=vocab_size,
    d_model=128,
    n_heads=4,
    n_layers=4,
    max_seq_len=256,
).to(device)

print(f"Model parameters: {model.count_parameters():,}")
print(f"\nModel architecture:")
print(model)

# Verification: test forward pass
test_ids = torch.randint(0, vocab_size, (2, 32)).to(device)
test_logits = model(test_ids)
assert test_logits.shape == (2, 32, vocab_size), \
    f"Expected (2, 32, {vocab_size}), got {test_logits.shape}"
print(f"\nForward pass check: input {test_ids.shape} -> output {test_logits.shape}")
print("All model checks passed.")

**Thought Questions:**
1. Our model has ~800K parameters. The production Lexis Draft model would have ~350M. What architectural changes would you make to scale up? (Hint: increase d_model, n_heads, n_layers, and max_seq_len.)
2. Why does the GPT model use a **learned** positional embedding rather than the fixed sinusoidal encoding from the original Transformer paper?
3. Why is the output projection head (Linear to vocab_size) set to have no bias? (Hint: think about the interaction with the preceding LayerNorm.)

---

## 3.3 Training: Next-Token Prediction with Cross-Entropy Loss

The training objective is simple but powerful: given a sequence of tokens, predict the next token at every position. The model learns by minimizing cross-entropy loss:

$$\mathcal{L} = -\frac{1}{N}\sum_{i=1}^{N} \log P_\theta(t_i \mid t_1, t_2, \ldots, t_{i-1})$$

In the Lexis Draft scenario, we would apply the loss only to clause tokens (not prompt tokens) by masking prompt positions with label -100. For this notebook, we train on the full sequence.

### Training Configuration

In [None]:
# Training hyperparameters (scaled for Colab)
config = {
    'vocab_size': vocab_size,
    'd_model': 128,
    'n_heads': 4,
    'n_layers': 4,
    'max_seq_len': 256,
    'batch_size': 32,
    'learning_rate': 3e-4,
    'num_steps': 2000,
    'eval_interval': 100,
}

print("Training configuration:")
for k, v in config.items():
    print(f"  {k}: {v}")

### TODO 4: Implement the Training Loop

The training loop follows the standard four-step process:
1. **Forward pass:** compute logits from the model
2. **Loss computation:** cross-entropy between predicted next-token distribution and actual next token
3. **Backward pass:** compute gradients via backpropagation
4. **Weight update:** adjust parameters with the optimizer

In [None]:
def train_model(model, train_data, num_steps=2000, batch_size=32,
                learning_rate=3e-4, eval_interval=100, device='cpu'):
    """
    Train the GPT model on next-token prediction.

    At each training step:
        1. Sample a random batch of sequences from train_data.
        2. Create inputs x = batch[:, :-1]  (all tokens except the last).
           Create targets y = batch[:, 1:]   (all tokens except the first).
        3. Forward pass: logits = model(x), shape (B, T-1, vocab_size).
        4. Flatten logits to (B*(T-1), vocab_size) and targets to (B*(T-1),)
           then compute F.cross_entropy(logits_flat, targets_flat).
        5. optimizer.zero_grad(), loss.backward(), optimizer.step().
        6. Every eval_interval steps, print and record the loss.

    Args:
        model: GPT model instance
        train_data: tensor (num_sequences, seq_len)
        num_steps: total training iterations
        batch_size: batch size
        learning_rate: AdamW learning rate
        eval_interval: steps between loss logging
        device: torch device

    Returns:
        losses: list of (step, loss_value) tuples

    Hints:
        - Use torch.optim.AdamW(model.parameters(), lr=learning_rate)
        - Batch indices: torch.randint(0, len(train_data), (batch_size,))
        - Move batch to device: batch = train_data[idx].to(device)
        - vocab_size can be read from logits.shape[-1]
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    losses = []

    # YOUR CODE HERE: implement the training loop

    return losses


# Train the model
print("Starting training...")
start_time = time.time()
train_losses = train_model(
    model, train_data,
    num_steps=2000,
    batch_size=32,
    learning_rate=3e-4,
    eval_interval=100,
    device=device
)
elapsed = time.time() - start_time
print(f"\nTraining complete in {elapsed:.1f} seconds")

In [None]:
# Plot training loss
steps, losses = zip(*train_losses)

plt.figure(figsize=(10, 5))
plt.plot(steps, losses, alpha=0.3, color='blue')

# Smoothed curve
window = 5
smoothed = [np.mean(losses[max(0, i - window):i + 1]) for i in range(len(losses))]
plt.plot(steps, smoothed, color='blue', linewidth=2, label='Smoothed loss')

# Random baseline: loss if the model predicted uniformly at random
random_baseline = np.log(vocab_size)
plt.axhline(y=random_baseline, color='red', linestyle='--',
            label=f'Random baseline (ln({vocab_size}) = {random_baseline:.2f})')

plt.xlabel('Training Step')
plt.ylabel('Cross-Entropy Loss')
plt.title('GPT Training Loss on Legal Contract Clauses')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"Initial loss: {losses[0]:.4f}")
print(f"Final loss: {losses[-1]:.4f}")
print(f"Random baseline: {random_baseline:.4f}")
print(f"Loss reduction: {((losses[0] - losses[-1]) / losses[0]) * 100:.1f}%")

**Thought Questions:**
1. The random baseline loss is $\ln(V)$ where $V$ is the vocabulary size. Why is this the expected loss for a model that predicts uniformly at random?
2. If the final loss is significantly below the random baseline but still above zero, what does the remaining loss represent? (Hint: think about the inherent uncertainty in natural language.)
3. In the Lexis Draft production system, we apply label smoothing ($\epsilon = 0.1$). How would this change the training loss curve?

---

## 3.4 Generation: Autoregressive Sampling

Generation works by repeatedly:
1. Feeding the current sequence through the model
2. Taking the logits at the last position
3. Sampling from the resulting distribution
4. Appending the sampled token

The **temperature** parameter controls randomness: lower temperature makes the model more deterministic (sharper distribution), higher temperature makes it more random. **Top-k** filtering restricts sampling to the k most likely tokens, preventing the model from generating very unlikely characters.

In the Lexis Draft scenario, generation continues until an `[END_CLAUSE]` token is produced or 500 tokens are reached.

### TODO 5: Implement Autoregressive Text Generation

In [None]:
@torch.no_grad()
def generate(model, encode, decode, prompt, max_new_tokens=200,
             temperature=0.8, top_k=40, device='cpu'):
    """
    Generate text autoregressively from a prompt.

    Algorithm:
        1. Encode the prompt string to token IDs.
        2. Create idx tensor of shape (1, len(prompt_ids)) on device.
        3. For each new token to generate:
           a. If idx is longer than max_seq_len, crop to the last
              max_seq_len tokens.
           b. Forward pass: logits = model(idx_cropped).
           c. Take logits at the last position: logits[:, -1, :].
           d. Divide by temperature.
           e. If top_k > 0:
              - Find the top-k logit values.
              - Set all logits below the k-th largest to -inf.
           f. Apply softmax to get probabilities.
           g. Sample one token: torch.multinomial(probs, num_samples=1).
           h. Append to idx: idx = torch.cat([idx, next_token], dim=1).
        4. Decode the full sequence and return the string.

    Args:
        model: trained GPT model
        encode: tokenizer encode function
        decode: tokenizer decode function
        prompt: string to condition generation on
        max_new_tokens: maximum number of tokens to generate
        temperature: sampling temperature (default 0.8)
        top_k: top-k filtering (default 40; 0 = disabled)
        device: torch device

    Returns:
        generated_text: string (prompt + generated continuation)

    Hints:
        - model.eval() before generation
        - For top-k: v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
          logits[logits < v[:, [-1]]] = float('-inf')
    """
    model.eval()
    # YOUR CODE HERE
    pass


# Generate sample contract clauses
prompts = [
    "IN NO EVENT SHALL",
    "LICENSEE SHALL INDEMNIFY",
    "THIS AGREEMENT SHALL BE",
    "NEITHER PARTY MAY",
    "EACH PARTY AGREES",
]

print("=" * 70)
print("  GENERATED CONTRACT CLAUSES")
print("=" * 70)

for prompt in prompts:
    generated = generate(model, encode, decode, prompt,
                         max_new_tokens=200, temperature=0.7,
                         top_k=30, device=device)
    print(f"\nPrompt: '{prompt}'")
    print(f"Generated:\n{generated}")
    print("-" * 70)

**Thought Questions:**
1. Try temperature=0.3 and temperature=1.5. How does the output change? Which would be more appropriate for legal document generation, and why?
2. Why is top-k filtering important for generation quality? What happens without it?
3. In the Lexis Draft production system, generation is conditioned on `[CONTRACT_TYPE]`, `[GOVERNING_LAW]`, `[FIRM_STYLE]`, and `[CONTEXT]` tokens. How would you modify the generate function to handle this structured prompt?

---

## 3.5 Evaluation and Analysis

We evaluate our model on three axes:
1. **Perplexity** -- a standard language modeling metric: how surprised is the model by held-out text?
2. **Generation latency** -- can we generate a clause within the 5-second budget?
3. **Attention patterns** -- what does the model learn to attend to?

### TODO 6: Compute Perplexity, Latency, and Analyze Attention

In [None]:
def evaluate_model(model, train_data, encode, decode, vocab_size,
                   device='cpu', num_eval_batches=20, batch_size=32):
    """
    Evaluate the trained GPT model.

    Compute:
        1. Perplexity on held-out data:
           - Use the last 10% of train_data as eval data.
           - For each batch: x = batch[:, :-1], y = batch[:, 1:].
           - Compute cross-entropy loss and average across batches.
           - Perplexity = exp(mean_loss).

        2. Generation latency:
           - Generate 200 tokens from a prompt.
           - Measure wall-clock time.
           - Compute tokens_per_second.

        3. Next-character accuracy:
           - For evaluation batches, compute the fraction of positions
             where argmax(logits) == target.

    Args:
        model: trained GPT model
        train_data: full training data tensor
        encode, decode: tokenizer functions
        vocab_size: vocabulary size
        device: torch device
        num_eval_batches: number of batches for evaluation
        batch_size: batch size

    Returns:
        dict with 'perplexity', 'tokens_per_second', 'accuracy'

    Hints:
        - split_idx = int(0.9 * len(train_data))
        - eval_data = train_data[split_idx:]
        - Perplexity = torch.exp(torch.tensor(mean_loss)).item()
        - Use time.time() for latency measurement
    """
    model.eval()
    # YOUR CODE HERE
    pass


results = evaluate_model(model, train_data, encode, decode,
                         vocab_size, device=device)
print(f"\nEvaluation Results:")
print(f"  Perplexity: {results['perplexity']:.2f}")
print(f"  Tokens/second: {results['tokens_per_second']:.0f}")
print(f"  Next-char accuracy: {results['accuracy']:.1%}")

# Context: Lexis Draft production targets
print(f"\nProduction targets (for reference):")
print(f"  Target perplexity: < 12.0 (ours is character-level, so not directly comparable)")
print(f"  Target latency: < 5 sec for 100-300 tokens")
tokens_for_clause = 200  # typical clause length
estimated_clause_time = tokens_for_clause / results['tokens_per_second']
print(f"  Estimated clause generation time: {estimated_clause_time:.2f} sec for {tokens_for_clause} tokens")

In [None]:
# Attention weight visualization
def visualize_attention(model, encode, text, layer_idx=-1, head_idx=0,
                        device='cpu'):
    """
    Visualize attention weights for a given input text.

    Uses a forward hook to capture Q and K from the specified attention
    layer, computes the attention matrix, and plots it as a heatmap.
    """
    model.eval()
    attention_weights = []

    def hook_fn(module, input, output):
        x = input[0]
        B, T, C = x.shape
        qkv = module.qkv(x).reshape(B, T, 3, module.n_heads, module.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k = qkv[0], qkv[1]
        att = (q @ k.transpose(-2, -1)) / (module.d_k ** 0.5)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        att = att.masked_fill(mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        attention_weights.append(att.detach().cpu())

    # Register hook on the target layer's attention module
    target_layer = list(model.blocks)[layer_idx]
    hook = target_layer.attn.register_forward_hook(hook_fn)

    ids = torch.tensor([encode(text)], device=device)
    with torch.no_grad():
        model(ids)

    hook.remove()

    if attention_weights:
        attn = attention_weights[0][0, head_idx].numpy()
        chars = list(text[:ids.shape[1]])
        n = len(chars)

        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(attn[:n, :n], cmap='Blues')
        ax.set_xticks(range(n))
        ax.set_yticks(range(n))
        ax.set_xticklabels(chars, rotation=90, fontfamily='monospace', fontsize=7)
        ax.set_yticklabels(chars, fontfamily='monospace', fontsize=7)
        ax.set_xlabel('Key (attending to)')
        ax.set_ylabel('Query (attending from)')
        ax.set_title(f'Attention Weights (Layer {layer_idx}, Head {head_idx})')
        plt.colorbar(im)
        plt.tight_layout()
        plt.show()


# Visualize attention on a sample clause beginning
sample_text = "IN NO EVENT SHALL EITHER PARTY BE"
visualize_attention(model, encode, sample_text, layer_idx=-1, head_idx=0,
                    device=device)

# Try a different head to see different attention patterns
visualize_attention(model, encode, sample_text, layer_idx=-1, head_idx=1,
                    device=device)

**Thought Questions:**
1. What patterns do you see in the attention heatmap? Do different heads attend to different things?
2. In legal text, certain words like "SHALL" and "NOT" fundamentally change the meaning of a clause. Do you see these words receiving high attention?
3. The causal mask creates the triangular pattern. Why is this critical for autoregressive generation but not needed for models like BERT?

---

## 3.6 Temperature and Sampling Strategy Analysis

For legal document generation, the choice of sampling strategy directly affects output quality. Too deterministic and the model produces repetitive text. Too random and it generates nonsensical legal language.

### TODO 7: Compare Sampling Strategies

In [None]:
def compare_sampling_strategies(model, encode, decode, prompt,
                                 device='cpu'):
    """
    Compare different sampling strategies for legal text generation.

    Generate text from the same prompt using:
        1. Greedy decoding (temperature=0.01, top_k=1)
        2. Low temperature (temperature=0.3, top_k=0)
        3. Medium temperature (temperature=0.7, top_k=30)
        4. High temperature (temperature=1.2, top_k=0)
        5. Top-k only (temperature=1.0, top_k=10)

    For each strategy, generate 200 tokens and report:
        - The generated text
        - Unique character ratio (num_unique_chars / total_chars)
        - Repetition score: fraction of 4-grams that appear more than once

    Then create a bar chart comparing the repetition scores.

    Args:
        model: trained GPT model
        encode, decode: tokenizer functions
        prompt: string prompt
        device: torch device

    Hints:
        - To compute repetition: extract all 4-character substrings,
          count occurrences, and compute fraction with count > 1
        - Legal drafting benefits from moderate temperature (0.5-0.8)
          with top-k filtering
    """
    strategies = [
        ('Greedy', 0.01, 1),
        ('Low temp (0.3)', 0.3, 0),
        ('Medium temp (0.7) + top-k', 0.7, 30),
        ('High temp (1.2)', 1.2, 0),
        ('Top-k only (k=10)', 1.0, 10),
    ]

    # YOUR CODE HERE
    pass


compare_sampling_strategies(model, encode, decode,
                             "LICENSEE SHALL INDEMNIFY",
                             device=device)

**Thought Questions:**
1. Which strategy produces the most realistic legal language? Why?
2. Greedy decoding often produces repetitive text ("the the the..."). Why does this happen in autoregressive models?
3. For Lexis Draft's production system, what sampling strategy would you recommend and why? Consider that attorneys need both quality and diversity.

---

## 3.7 Scaling Analysis: From Notebook to Production

Our notebook model has ~800K parameters. The production model for Lexis Draft AI would have ~350M parameters. Let us analyze how the key metrics scale.

### TODO 8: Parameter Count and Memory Analysis

In [None]:
def scaling_analysis(vocab_size):
    """
    Analyze how model size, memory, and compute scale with architecture choices.

    Compute and plot the following:

    1. Parameter count for different configurations:
       - Notebook:  d=128,  heads=4,  layers=4,  vocab=vocab_size
       - Small:     d=512,  heads=8,  layers=8,  vocab=32000
       - Medium:    d=768,  heads=12, layers=12, vocab=32000
       - Production:d=1024, heads=16, layers=24, vocab=32000
       - Large:     d=1600, heads=25, layers=48, vocab=32000

       Parameter count formula:
       - Embeddings: vocab * d + max_seq * d
       - Per block: 4*d*d (QKV+proj) + 8*d*d (FFN) + 4*d (layernorms)
       - Head: d * vocab
       - Total: embeddings + layers * per_block + d (final LN) + head

    2. FP16 memory for model weights (params * 2 bytes).

    3. KV-cache memory for a single sequence:
       - Per layer: 2 * seq_len * d * 2 bytes (FP16)
       - Total: layers * per_layer

    4. Create a table and bar chart of the results.

    Args:
        vocab_size: notebook model vocab size

    Hints:
        - Use 2048 as max_seq_len for all non-notebook configs
        - The production config should have ~350M parameters
        - Plot with log scale on y-axis for parameter counts
    """
    configs = [
        ('Notebook',   128,   4,  4,  vocab_size, 256),
        ('Small',      512,   8,  8,  32000, 2048),
        ('Medium',     768,  12, 12,  32000, 2048),
        ('Production', 1024, 16, 24,  32000, 2048),
        ('Large',      1600, 25, 48,  32000, 2048),
    ]

    # YOUR CODE HERE
    pass


scaling_analysis(vocab_size)

**Thought Questions:**
1. Does the production model (~350M params) fit within the 24 GB A10G GPU constraint? How much room is left for the KV-cache?
2. The KV-cache grows linearly with sequence length. At what sequence length would the KV-cache exceed the remaining GPU memory?
3. Lexis Draft uses LoRA adapters for firm-specific style. If each adapter adds 2-4M parameters, how does this compare to storing 18 separate full models?

---

## 3.8 Production Considerations

### Output Validation Pipeline

In the Lexis Draft production system, every generated clause passes through a validation pipeline before being shown to the attorney. Let us implement a simplified version.

In [None]:
def validate_clause(generated_text, context_terms=None):
    """
    Validate a generated contract clause.

    Checks:
    1. Defined term consistency: all capitalized multi-word terms
       (e.g., LICENSOR, LICENSEE, AGREEMENT) should be in the
       context_terms set.
    2. Prohibited language: flag any deprecated legal terms.
    3. Length check: clause should be between 50 and 500 tokens.
    4. Structural check: clause should not end mid-sentence.

    Returns:
        dict with 'passed', 'warnings', 'errors'
    """
    if context_terms is None:
        context_terms = {
            'LICENSOR', 'LICENSEE', 'AGREEMENT', 'SOFTWARE',
            'PARTY', 'PARTIES', 'SECTION'
        }

    warnings = []
    errors = []

    # 1. Check defined terms
    words = generated_text.split()
    for word in words:
        clean = word.strip('.,;:()')
        if clean.isupper() and len(clean) > 2 and clean not in context_terms:
            # Check if it's a common legal term we expect
            common_terms = {'THE', 'AND', 'FOR', 'ANY', 'ALL', 'SHALL',
                          'NOT', 'ITS', 'BUT', 'NOR', 'YET', 'HAS',
                          'MAY', 'SUCH', 'FROM', 'WITH', 'THAT',
                          'THIS', 'UNDER', 'UPON', 'INTO', 'THAN'}
            if clean not in common_terms:
                warnings.append(f"Undefined term: {clean}")

    # 2. Prohibited language
    prohibited = ['HERETOFORE', 'WITNESSETH', 'WHEREAS']
    for term in prohibited:
        if term in generated_text:
            warnings.append(f"Deprecated term: {term}")

    # 3. Length check
    if len(generated_text) < 50:
        errors.append("Clause too short (< 50 characters)")
    if len(generated_text) > 2000:
        warnings.append("Clause unusually long (> 2000 characters)")

    # 4. Structural check
    stripped = generated_text.strip()
    if stripped and stripped[-1] not in '.;':
        warnings.append("Clause does not end with proper punctuation")

    passed = len(errors) == 0
    return {'passed': passed, 'warnings': warnings, 'errors': errors}


# Validate our generated clauses
print("Clause Validation Results:")
print("=" * 50)
for prompt in prompts:
    generated = generate(model, encode, decode, prompt,
                         max_new_tokens=200, temperature=0.7,
                         top_k=30, device=device)
    result = validate_clause(generated)
    status = "PASS" if result['passed'] else "FAIL"
    print(f"\n[{status}] Prompt: '{prompt}'")
    if result['errors']:
        for e in result['errors']:
            print(f"  ERROR: {e}")
    if result['warnings']:
        for w in result['warnings'][:3]:  # Show first 3 warnings
            print(f"  WARNING: {w}")
    if not result['errors'] and not result['warnings']:
        print(f"  No issues found.")

---

## Summary

In this notebook, you built a complete GPT-style autoregressive language model for legal contract clause generation:

In [None]:
# Final summary
print("=" * 70)
print("  CASE STUDY SUMMARY: GPT from Scratch")
print("  Autoregressive Contract Clause Generation for Lexis Draft AI")
print("=" * 70)
print()
print("Model Architecture:")
print(f"  Type: Decoder-only Transformer (GPT)")
print(f"  Parameters: {model.count_parameters():,}")
print(f"  Layers: 4, Heads: 4, d_model: 128")
print(f"  Tokenization: Character-level ({vocab_size} tokens)")
print(f"  Max sequence length: 256")
print()
print("Training:")
print(f"  Data: {len(train_data)} sequences of legal contract clauses")
print(f"  Objective: Next-token prediction (cross-entropy loss)")
print(f"  Optimizer: AdamW (lr=3e-4)")
print(f"  Steps: 2,000")
print()
print("Key Findings:")
print(f"  1. Training loss converged well below random baseline")
print(f"     (random baseline = ln({vocab_size}) = {np.log(vocab_size):.2f})")
print(f"  2. Model generates coherent legal language following")
print(f"     contract drafting conventions")
print(f"  3. Causal attention ensures each token conditions on all")
print(f"     previous context, maintaining clause consistency")
print(f"  4. Temperature and top-k filtering control the quality/")
print(f"     diversity tradeoff for legal text")
print()
print("Production Scaling (Lexis Draft AI):")
print("  - Scale to ~350M parameters (d=1024, heads=16, layers=24)")
print("  - Use BPE tokenizer (32K vocab) trained on legal corpus")
print("  - Add LoRA adapters for per-firm style adaptation")
print("  - Implement KV-cache for O(T) generation (vs O(T^2))")
print("  - Deploy with vLLM for continuous batching")
print("  - Output validation: defined terms, cross-references,")
print("     prohibited language, confidence thresholds")
print()
print("Business Impact (Target):")
print("  - Clause relevance: 62% -> 85%")
print("  - First-draft acceptance: 38% -> 65%")
print("  - Attorney editing time: 3.2 hrs -> 1.5 hrs per contract")
print("  - Annual savings: ~$18.4M across 18 client firms")

---

## Next Steps

For a deeper understanding of the production system design, read **Section 4** of the case study document, which covers:

1. **Multi-firm model architecture** -- base model + LoRA adapters for firm-specific style
2. **KV-cache** -- how to make generation efficient enough for interactive use
3. **Data pipeline** -- attorney feedback collection and continuous learning
4. **Monitoring and guardrails** -- defined term checks, cross-reference validation, confidence thresholds
5. **Ethical considerations** -- attorney-client privilege, bias, and the boundary between drafting assistance and legal advice