In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Case Study: Autoregressive Contract Clause Generation
## Implementation Notebook

---

**Scenario:** You are an ML engineer at Lexis Draft AI, a legal technology startup building a contract drafting assistant for law firms. Your task is to build a GPT-style autoregressive model that generates contextually appropriate contract clauses conditioned on deal parameters and previously drafted sections.

**Current system:** A retrieval-based pipeline with 62% clause relevance score and 38% first-draft acceptance rate. Your target: 85%+ clause relevance and 65%+ first-draft acceptance.

**Why GPT:** The retrieval system cannot generate novel clauses for unprecedented deal structures, cannot maintain defined-term consistency across sections, and cannot adapt to firm-specific drafting style. A GPT-style autoregressive model generates one token at a time conditioned on the full context, naturally handling all three failure modes.

---

## 3.1 Data Acquisition and Preprocessing

We use a synthetic legal clause dataset for this notebook. In the Lexis Draft AI scenario, the real training data would consist of 405,000 contract clauses from 18 law firm clients and SEC EDGAR filings. The synthetic dataset mirrors the structure and patterns of real contract language while being freely usable.

In [None]:
# Install dependencies
!pip install matplotlib numpy -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import time
from collections import Counter

torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Synthetic legal clause corpus
# These clauses are representative of real contract language patterns
LEGAL_CLAUSES = [
    # Limitation of Liability
    "IN NO EVENT SHALL EITHER PARTY BE LIABLE TO THE OTHER PARTY FOR ANY INDIRECT INCIDENTAL SPECIAL CONSEQUENTIAL OR PUNITIVE DAMAGES ARISING OUT OF OR RELATED TO THIS AGREEMENT REGARDLESS OF WHETHER SUCH DAMAGES ARE BASED ON CONTRACT TORT NEGLIGENCE STRICT LIABILITY OR ANY OTHER THEORY EVEN IF SUCH PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES",
    "THE AGGREGATE LIABILITY OF LICENSOR UNDER THIS AGREEMENT SHALL NOT EXCEED THE TOTAL FEES PAID BY LICENSEE DURING THE TWELVE MONTH PERIOD IMMEDIATELY PRECEDING THE EVENT GIVING RISE TO SUCH LIABILITY",
    "NOTWITHSTANDING ANYTHING TO THE CONTRARY HEREIN THE LIMITATIONS SET FORTH IN THIS SECTION SHALL NOT APPLY TO A PARTY BREACH OF ITS CONFIDENTIALITY OBLIGATIONS OR A PARTY INDEMNIFICATION OBLIGATIONS UNDER THIS AGREEMENT",

    # Indemnification
    "LICENSEE SHALL INDEMNIFY DEFEND AND HOLD HARMLESS LICENSOR AND ITS OFFICERS DIRECTORS EMPLOYEES AND AGENTS FROM AND AGAINST ANY AND ALL CLAIMS DAMAGES LOSSES LIABILITIES COSTS AND EXPENSES INCLUDING REASONABLE ATTORNEYS FEES ARISING OUT OF OR RELATED TO LICENSEE USE OF THE SOFTWARE IN VIOLATION OF THIS AGREEMENT",
    "LICENSOR SHALL INDEMNIFY DEFEND AND HOLD HARMLESS LICENSEE FROM AND AGAINST ANY CLAIMS ALLEGING THAT THE SOFTWARE AS PROVIDED BY LICENSOR INFRINGES ANY UNITED STATES PATENT COPYRIGHT OR TRADE SECRET OF A THIRD PARTY",
    "THE INDEMNIFYING PARTY SHALL HAVE THE RIGHT TO CONTROL THE DEFENSE AND SETTLEMENT OF ANY CLAIM SUBJECT TO THE INDEMNIFIED PARTY CONSENT WHICH SHALL NOT BE UNREASONABLY WITHHELD",

    # Confidentiality
    "EACH PARTY AGREES TO HOLD IN STRICT CONFIDENCE ALL CONFIDENTIAL INFORMATION RECEIVED FROM THE OTHER PARTY AND SHALL NOT DISCLOSE SUCH INFORMATION TO ANY THIRD PARTY WITHOUT THE PRIOR WRITTEN CONSENT OF THE DISCLOSING PARTY",
    "CONFIDENTIAL INFORMATION SHALL NOT INCLUDE INFORMATION THAT IS OR BECOMES PUBLICLY AVAILABLE THROUGH NO FAULT OF THE RECEIVING PARTY WAS IN THE RECEIVING PARTY POSSESSION PRIOR TO DISCLOSURE OR IS INDEPENDENTLY DEVELOPED BY THE RECEIVING PARTY WITHOUT USE OF THE DISCLOSING PARTY CONFIDENTIAL INFORMATION",
    "THE OBLIGATIONS OF CONFIDENTIALITY SET FORTH HEREIN SHALL SURVIVE THE TERMINATION OR EXPIRATION OF THIS AGREEMENT FOR A PERIOD OF THREE YEARS",

    # Termination
    "EITHER PARTY MAY TERMINATE THIS AGREEMENT FOR CAUSE UPON THIRTY DAYS PRIOR WRITTEN NOTICE IF THE OTHER PARTY MATERIALLY BREACHES ANY PROVISION OF THIS AGREEMENT AND FAILS TO CURE SUCH BREACH WITHIN THE NOTICE PERIOD",
    "UPON TERMINATION OF THIS AGREEMENT FOR ANY REASON LICENSEE SHALL IMMEDIATELY CEASE ALL USE OF THE SOFTWARE AND SHALL RETURN OR DESTROY ALL COPIES OF THE SOFTWARE AND CONFIDENTIAL INFORMATION IN ITS POSSESSION",
    "THE FOLLOWING PROVISIONS SHALL SURVIVE ANY TERMINATION OR EXPIRATION OF THIS AGREEMENT CONFIDENTIALITY LIMITATION OF LIABILITY INDEMNIFICATION AND ANY PROVISIONS WHICH BY THEIR NATURE ARE INTENDED TO SURVIVE",

    # Representations and Warranties
    "LICENSOR REPRESENTS AND WARRANTS THAT IT HAS THE FULL RIGHT POWER AND AUTHORITY TO ENTER INTO THIS AGREEMENT AND TO GRANT THE LICENSES AND RIGHTS GRANTED HEREIN",
    "LICENSEE REPRESENTS AND WARRANTS THAT IT SHALL USE THE SOFTWARE IN COMPLIANCE WITH ALL APPLICABLE LAWS RULES AND REGULATIONS",
    "EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT LICENSOR MAKES NO WARRANTIES EXPRESS OR IMPLIED INCLUDING WITHOUT LIMITATION ANY IMPLIED WARRANTIES OF MERCHANTABILITY FITNESS FOR A PARTICULAR PURPOSE OR NONINFRINGEMENT",

    # Governing Law and Dispute Resolution
    "THIS AGREEMENT SHALL BE GOVERNED BY AND CONSTRUED IN ACCORDANCE WITH THE LAWS OF THE STATE OF DELAWARE WITHOUT REGARD TO ITS CONFLICT OF LAWS PRINCIPLES",
    "ANY DISPUTE ARISING OUT OF OR RELATING TO THIS AGREEMENT SHALL BE RESOLVED BY BINDING ARBITRATION ADMINISTERED BY THE AMERICAN ARBITRATION ASSOCIATION IN ACCORDANCE WITH ITS COMMERCIAL ARBITRATION RULES",
    "THE PARTIES AGREE THAT ANY LEGAL ACTION OR PROCEEDING ARISING UNDER THIS AGREEMENT SHALL BE BROUGHT EXCLUSIVELY IN THE FEDERAL OR STATE COURTS LOCATED IN WILMINGTON DELAWARE",

    # Assignment
    "NEITHER PARTY MAY ASSIGN OR TRANSFER THIS AGREEMENT OR ANY RIGHTS OR OBLIGATIONS HEREUNDER WITHOUT THE PRIOR WRITTEN CONSENT OF THE OTHER PARTY EXCEPT THAT EITHER PARTY MAY ASSIGN THIS AGREEMENT WITHOUT CONSENT IN CONNECTION WITH A MERGER ACQUISITION OR SALE OF ALL OR SUBSTANTIALLY ALL OF ITS ASSETS",

    # Force Majeure
    "NEITHER PARTY SHALL BE LIABLE FOR ANY FAILURE OR DELAY IN PERFORMING ITS OBLIGATIONS UNDER THIS AGREEMENT TO THE EXTENT SUCH FAILURE OR DELAY RESULTS FROM CIRCUMSTANCES BEYOND THE REASONABLE CONTROL OF SUCH PARTY INCLUDING BUT NOT LIMITED TO ACTS OF GOD NATURAL DISASTERS PANDEMIC GOVERNMENT ACTIONS WAR TERRORISM OR CIVIL UNREST",

    # Intellectual Property
    "ALL INTELLECTUAL PROPERTY RIGHTS IN AND TO THE SOFTWARE INCLUDING ALL MODIFICATIONS ENHANCEMENTS AND DERIVATIVE WORKS SHALL REMAIN THE EXCLUSIVE PROPERTY OF LICENSOR AND NOTHING IN THIS AGREEMENT SHALL BE CONSTRUED AS TRANSFERRING ANY OWNERSHIP RIGHTS TO LICENSEE",
]

# Repeat and augment to create a larger corpus
corpus_text = "\n\n".join(LEGAL_CLAUSES * 20)
print(f"Corpus size: {len(corpus_text)} characters")
print(f"Sample:\n{corpus_text[:200]}...")

### TODO 1: Build a Character-Level Tokenizer and Prepare Training Data

In [None]:
def build_tokenizer_and_data(corpus_text, max_seq_len=256):
    """
    Build a character-level tokenizer for the legal corpus and prepare
    training sequences.

    Args:
        corpus_text: string containing the full training corpus
        max_seq_len: maximum sequence length for training chunks

    Returns:
        - encode: function mapping string -> list[int]
        - decode: function mapping list[int] -> string
        - train_data: tensor of shape (num_sequences, max_seq_len)
        - vocab_size: int

    Steps:
        1. Collect all unique characters in corpus_text, sort them.
        2. Create char-to-id and id-to-char mappings.
        3. Define encode(text) and decode(ids) functions.
        4. Encode the full corpus into a 1D tensor.
        5. Reshape into non-overlapping chunks of max_seq_len.
           Discard any leftover tokens that do not fill a complete chunk.
        6. Return the four values.

    Hints:
        - chars = sorted(set(corpus_text))
        - char_to_id = {ch: i for i, ch in enumerate(chars)}
        - To chunk: total_tokens = (len(encoded) // max_seq_len) * max_seq_len
          then encoded[:total_tokens].view(-1, max_seq_len)
    """
    # YOUR CODE HERE
    pass


# Build tokenizer and data
encode, decode, train_data, vocab_size = build_tokenizer_and_data(corpus_text)
print(f"Vocabulary size: {vocab_size}")
print(f"Training sequences: {train_data.shape[0]}")
print(f"Sequence length: {train_data.shape[1]}")
print(f"\nSample decoded sequence:\n{decode(train_data[0].tolist())[:100]}...")

## 3.2 Model Architecture: Building a GPT from Scratch

We build the complete GPT model following the architecture from the article: token embeddings, positional embeddings, N Transformer blocks (each with causal multi-head self-attention, layer normalization, and a feed-forward network), a final layer norm, and a linear projection to the vocabulary.

### TODO 2: Implement Causal Self-Attention

In [None]:
class CausalSelfAttention(nn.Module):
    """
    Multi-head causal self-attention.

    Each token attends only to tokens at the same or earlier
    positions. Future positions are masked with -inf before softmax.

    Args:
        d_model: model embedding dimension
        n_heads: number of attention heads
        max_seq_len: maximum sequence length

    Forward:
        Input: x of shape (batch, seq_len, d_model)
        Output: out of shape (batch, seq_len, d_model)

    Implementation steps:
        1. Project input to Q, K, V using a single linear layer
           (d_model -> 3 * d_model), then split into Q, K, V.
        2. Reshape Q, K, V to (batch, n_heads, seq_len, d_k)
           where d_k = d_model // n_heads.
        3. Compute attention scores: (Q @ K^T) / sqrt(d_k).
        4. Apply causal mask: set upper-triangle entries to -inf.
        5. Apply softmax to get attention weights.
        6. Multiply attention weights by V.
        7. Reshape back to (batch, seq_len, d_model).
        8. Apply output projection (d_model -> d_model).

    Hints:
        - Register the causal mask as a buffer (not a parameter)
          using self.register_buffer
        - torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
          creates the mask
        - Use .masked_fill(mask, float('-inf')) before softmax
    """
    def __init__(self, d_model, n_heads, max_seq_len=256):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        # YOUR CODE HERE: define self.qkv, self.proj, and register causal mask buffer
        pass

    def forward(self, x):
        B, T, C = x.shape
        # YOUR CODE HERE
        pass

### TODO 3: Implement Transformer Block and GPT Model

In [None]:
class TransformerBlock(nn.Module):
    """
    Single GPT Transformer block (Pre-LN architecture).

    Architecture:
        x = x + Attention(LayerNorm(x))
        x = x + FFN(LayerNorm(x))

    FFN: Linear(d_model, 4*d_model) -> GELU -> Linear(4*d_model, d_model)

    Args:
        d_model: model dimension
        n_heads: number of attention heads
        max_seq_len: maximum sequence length
    """
    def __init__(self, d_model, n_heads, max_seq_len=256):
        super().__init__()
        # YOUR CODE HERE
        pass

    def forward(self, x):
        # YOUR CODE HERE
        pass


class GPT(nn.Module):
    """
    Complete GPT language model.

    Architecture:
        1. Token embedding table (vocab_size x d_model)
        2. Positional embedding table (max_seq_len x d_model)
        3. Stack of N TransformerBlocks
        4. Final LayerNorm
        5. Linear head projecting d_model -> vocab_size (no bias)

    Forward pass:
        1. Look up token embeddings
        2. Add positional embeddings for positions 0..T-1
        3. Pass through all Transformer blocks
        4. Apply final LayerNorm
        5. Project to vocabulary logits

    Args:
        vocab_size, d_model, n_heads, n_layers, max_seq_len
    """
    def __init__(self, vocab_size, d_model, n_heads, n_layers, max_seq_len=256):
        super().__init__()
        # YOUR CODE HERE
        pass

    def forward(self, idx):
        B, T = idx.shape
        # YOUR CODE HERE
        pass

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters())


# Instantiate the model
model = GPT(
    vocab_size=vocab_size,
    d_model=128,
    n_heads=4,
    n_layers=4,
    max_seq_len=256,
).to(device)

print(f"Model parameters: {model.count_parameters():,}")
print(f"\nModel architecture:")
print(model)

## 3.3 Training: Loss and Backpropagation

The training loop implements the exact four-step process from the article: forward pass (compute logits), compute cross-entropy loss, backward pass (compute gradients), and weight update (optimizer step).

### TODO 4: Implement the Training Loop

In [None]:
def train_model(model, train_data, num_steps=2000, batch_size=32,
                learning_rate=3e-4, eval_interval=200, device='cpu'):
    """
    Train the GPT model on next-token prediction.

    Training step:
        1. Sample a random batch of sequences from train_data.
        2. Inputs x = batch[:, :-1] (all tokens except last).
           Targets y = batch[:, 1:] (all tokens except first).
        3. Forward pass: logits = model(x).
        4. Compute loss: F.cross_entropy on flattened logits and targets.
        5. optimizer.zero_grad(), loss.backward(), optimizer.step().

    Args:
        model: GPT model
        train_data: tensor of shape (N, seq_len)
        num_steps: number of training iterations
        batch_size: batch size
        learning_rate: AdamW learning rate
        eval_interval: steps between loss logging
        device: torch device

    Returns:
        list of (step, train_loss) tuples

    Hints:
        - Use torch.optim.AdamW
        - Sample batch indices: torch.randint(0, len(train_data), (batch_size,))
        - Flatten logits: logits.view(-1, vocab_size)
        - Flatten targets: y.view(-1)
    """
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    losses = []

    # YOUR CODE HERE: implement the training loop

    return losses


# Train the model
print("Starting training...")
start_time = time.time()
train_losses = train_model(
    model, train_data,
    num_steps=2000,
    batch_size=32,
    learning_rate=3e-4,
    eval_interval=100,
    device=device
)
elapsed = time.time() - start_time
print(f"\nTraining complete in {elapsed:.1f} seconds")

In [None]:
# Plot training loss
steps, losses = zip(*train_losses)
plt.figure(figsize=(10, 5))
plt.plot(steps, losses, alpha=0.3, color='blue')

# Smoothed curve
window = 5
smoothed = [np.mean(losses[max(0,i-window):i+1]) for i in range(len(losses))]
plt.plot(steps, smoothed, color='blue', linewidth=2, label='Smoothed loss')

random_baseline = np.log(vocab_size)
plt.axhline(y=random_baseline, color='red', linestyle='--',
            label=f'Random baseline ({random_baseline:.2f})')
plt.xlabel('Training Step')
plt.ylabel('Cross-Entropy Loss')
plt.title('GPT Training Loss on Legal Contract Clauses')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

## 3.4 Generation: Sampling from the Trained Model

With the model trained, we generate contract clauses autoregressively. Each token is sampled from the model's predicted distribution, conditioned on all previously generated tokens.

### TODO 5: Implement Autoregressive Generation

In [None]:
@torch.no_grad()
def generate(model, encode, decode, prompt, max_new_tokens=200,
             temperature=0.8, top_k=40, device='cpu'):
    """
    Generate text autoregressively from a prompt.

    Algorithm:
        1. Encode the prompt string to token IDs.
        2. At each step:
           a. Feed the current sequence through the model.
           b. Take logits at the last position.
           c. Divide by temperature.
           d. If top_k > 0, zero out all logits below the k-th largest.
           e. Apply softmax to get probabilities.
           f. Sample one token from the distribution.
           g. Append to the sequence.
        3. Decode the full sequence back to text.

    Args:
        model: trained GPT model
        encode: tokenizer encode function
        decode: tokenizer decode function
        prompt: string prompt to condition on
        max_new_tokens: maximum tokens to generate
        temperature: sampling temperature
        top_k: top-k filtering (0 = disabled)
        device: torch device

    Returns:
        generated_text: string (prompt + generated continuation)

    Hints:
        - model.eval() before generation
        - Crop sequence to max_seq_len if it exceeds the model's limit
        - For top-k: v, _ = torch.topk(logits, top_k)
          then logits[logits < v[..., [-1]]] = float('-inf')
        - torch.multinomial(probs, num_samples=1) for sampling
    """
    model.eval()
    # YOUR CODE HERE
    pass


# Generate sample clauses
prompts = [
    "IN NO EVENT SHALL",
    "LICENSEE SHALL INDEMNIFY",
    "THIS AGREEMENT SHALL BE",
    "NEITHER PARTY MAY",
    "EACH PARTY AGREES",
]

print("=" * 60)
print("  GENERATED CONTRACT CLAUSES")
print("=" * 60)

for prompt in prompts:
    generated = generate(model, encode, decode, prompt,
                        max_new_tokens=200, temperature=0.7,
                        top_k=30, device=device)
    print(f"\nPrompt: '{prompt}'")
    print(f"Generated:\n{generated}")
    print("-" * 60)

## 3.5 Evaluation and Analysis

### TODO 6: Compute Perplexity and Analyze Generation Quality

In [None]:
def evaluate_model(model, train_data, encode, decode, vocab_size,
                   device='cpu', num_eval_batches=20, batch_size=32):
    """
    Evaluate the trained model.

    Compute:
        1. Perplexity on a held-out portion of the data:
           perplexity = exp(average cross-entropy loss)
        2. Generation latency: tokens per second on GPU/CPU
        3. Character-level accuracy: percentage of correctly predicted
           next characters across evaluation data

    Args:
        model: trained GPT model
        train_data: full training data tensor
        encode, decode: tokenizer functions
        vocab_size: vocabulary size
        device: torch device
        num_eval_batches: number of batches for evaluation
        batch_size: batch size

    Returns:
        dict with 'perplexity', 'tokens_per_second', 'accuracy'

    Hints:
        - Use the last 10% of train_data as eval data
        - Perplexity = exp(mean_loss)
        - For latency, generate 200 tokens and measure wall-clock time
        - Accuracy: (logits.argmax(-1) == targets).float().mean()
    """
    model.eval()
    # YOUR CODE HERE
    pass


results = evaluate_model(model, train_data, encode, decode,
                         vocab_size, device=device)
print(f"\nEvaluation Results:")
print(f"  Perplexity: {results['perplexity']:.2f}")
print(f"  Tokens/second: {results['tokens_per_second']:.0f}")
print(f"  Next-char accuracy: {results['accuracy']:.1%}")

In [None]:
# Attention weight visualization
def visualize_attention(model, encode, text, layer_idx=-1, head_idx=0,
                        device='cpu'):
    """
    Visualize attention weights for a given input text.

    Extracts attention weights from the specified layer and head,
    and plots them as a heatmap showing which tokens attend to which.

    Args:
        model: GPT model (must be modified to return attention weights,
               or use forward hooks)
        encode: tokenizer encode function
        text: input string to visualize
        layer_idx: which transformer block (-1 = last)
        head_idx: which attention head
        device: torch device
    """
    model.eval()

    # Register a forward hook to capture attention weights
    attention_weights = []

    def hook_fn(module, input, output):
        # Capture the attention weights before they are used
        # This requires modifying forward to also return weights,
        # or computing them here from the Q, K we can access
        x = input[0]
        B, T, C = x.shape
        qkv = module.qkv(x).reshape(B, T, 3, module.n_heads, module.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k = qkv[0], qkv[1]
        att = (q @ k.transpose(-2, -1)) / (module.d_k ** 0.5)
        mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        att = att.masked_fill(mask, float('-inf'))
        att = F.softmax(att, dim=-1)
        attention_weights.append(att.detach().cpu())

    # Register hook on the target layer
    target_layer = list(model.blocks)[layer_idx]
    hook = target_layer.attn.register_forward_hook(hook_fn)

    # Forward pass
    ids = torch.tensor([encode(text)], device=device)
    with torch.no_grad():
        model(ids)

    hook.remove()

    # Plot
    if attention_weights:
        attn = attention_weights[0][0, head_idx].numpy()
        chars = list(text[:ids.shape[1]])

        fig, ax = plt.subplots(figsize=(10, 8))
        im = ax.imshow(attn[:len(chars), :len(chars)], cmap='Blues')
        ax.set_xticks(range(len(chars)))
        ax.set_yticks(range(len(chars)))
        ax.set_xticklabels(chars, rotation=90, fontfamily='monospace', fontsize=8)
        ax.set_yticklabels(chars, fontfamily='monospace', fontsize=8)
        ax.set_xlabel('Key (attending to)')
        ax.set_ylabel('Query (attending from)')
        ax.set_title(f'Attention Weights (Layer {layer_idx}, Head {head_idx})')
        plt.colorbar(im)
        plt.tight_layout()
        plt.show()


# Visualize attention on a sample clause
sample_text = "IN NO EVENT SHALL EITHER PARTY BE"
visualize_attention(model, encode, sample_text, layer_idx=-1, head_idx=0,
                    device=device)

In [None]:
# Summary of results
print("=" * 60)
print("  CASE STUDY SUMMARY")
print("=" * 60)
print()
print("Model: GPT-style autoregressive transformer")
print(f"Parameters: {model.count_parameters():,}")
print(f"Architecture: {4} layers, {4} heads, d_model={128}")
print(f"Training data: {len(train_data)} sequences of legal clauses")
print()
print("Key findings:")
print(f"  1. Final training loss significantly below random baseline")
print(f"     (random baseline = log({vocab_size}) = {np.log(vocab_size):.2f})")
print(f"  2. Model generates coherent legal language that follows")
print(f"     contract drafting conventions")
print(f"  3. Causal attention allows each token to condition on all")
print(f"     previous context, maintaining clause consistency")
print(f"  4. The same architecture scales from this notebook demo")
print(f"     to production models with 350M+ parameters")
print()
print("Production considerations:")
print("  - Scale to 350M parameters for production quality")
print("  - Use BPE tokenizer trained on legal corpus")
print("  - Add LoRA adapters for firm-specific style")
print("  - Implement KV-cache for efficient generation")
print("  - Deploy with vLLM for continuous batching")