# 01 - Reward Modeling

## Context

RLHF (Reinforcement Learning from Human Feedback) aligns language models
with human preferences. After supervised fine-tuning (Module 06), a model
can follow instructions -- but it has no notion of *which* response is
better when multiple valid completions exist. RLHF encodes those
preferences into the model.

**CoCounsel context:** A legal AI must prefer citing real cases over
plausible-sounding fabrications. It must prefer hedged language
("the court may find...") over overconfident claims ("the court will
certainly rule..."). It must prefer accurate statements of law over
subtly incorrect ones. Without alignment, an instruction-following model
can produce fluent, well-formatted answers that are confidently wrong --
the worst failure mode for a legal tool.

The full RLHF pipeline has three stages:

1. **SFT** -- Train the model to follow instructions (Module 06).
2. **Reward Model** -- Train a model to score response quality using
   human preference data. This is the focus of this notebook.
3. **PPO** -- Use the reward model to further optimize the policy via
   reinforcement learning (covered conceptually at the end).

In this notebook, we build a preference dataset for legal tasks, implement
a reward model from scratch as a PyTorch `nn.Module`, train it using the
Bradley-Terry loss, and check whether it correctly ranks good responses
above bad ones.

## Preference Data

Preference data is the foundation of alignment. For each prompt, we
provide two responses:

- **Chosen** (good): accurate citations, appropriate hedging, correct law
- **Rejected** (bad): hallucinated citations, overconfident claims,
  incorrect legal reasoning

The reward model learns from these pairs: given a prompt, it should assign
a higher scalar score to the chosen response than the rejected one.

In production, preference data comes from human annotators (lawyers, in
our case) who read two responses and select the better one. Here we
create the dataset manually with realistic legal examples.

In [None]:
import json
import copy
import re
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer

torch.manual_seed(42)

In [None]:
# Hand-crafted preference dataset for legal tasks.
# Each entry has a prompt, a chosen (good) response, and a rejected (bad) response.

preference_data = [
    # --- 1. Real citation vs hallucinated citation ---
    {
        "prompt": "What is the standard for summary judgment in federal court?",
        "chosen": (
            "Under Federal Rule of Civil Procedure 56, summary judgment is "
            "appropriate when there is no genuine dispute as to any material "
            "fact and the movant is entitled to judgment as a matter of law. "
            "The Supreme Court clarified this standard in Celotex Corp. v. "
            "Catrett, 477 U.S. 317 (1986), holding that the moving party "
            "bears the initial burden of demonstrating the absence of a "
            "genuine issue of material fact."
        ),
        "rejected": (
            "Summary judgment is granted when there are no facts in dispute. "
            "The landmark case Henderson v. United States Department of "
            "Justice, 589 U.S. 42 (2019) established the modern three-part "
            "test for summary judgment: (1) materiality, (2) genuineness, "
            "and (3) sufficiency of evidence. This test is universally "
            "applied in all federal courts."
        ),
    },
    # --- 2. Appropriate hedging vs overconfident claim ---
    {
        "prompt": "Can an employer fire an employee for filing a workers' compensation claim?",
        "chosen": (
            "In most jurisdictions, retaliatory termination for filing a "
            "workers' compensation claim is prohibited. However, the specific "
            "protections and remedies vary by state. Many states have enacted "
            "statutes explicitly prohibiting such retaliation, while others "
            "recognize a common-law tort of retaliatory discharge. An "
            "employment attorney licensed in the relevant jurisdiction should "
            "be consulted for case-specific advice."
        ),
        "rejected": (
            "No, an employer absolutely cannot fire an employee for filing a "
            "workers' compensation claim. This is illegal in all 50 states "
            "and will always result in a successful wrongful termination "
            "lawsuit. The employee will be awarded damages and the employer "
            "will face criminal penalties."
        ),
    },
    # --- 3. Accurate holding vs incorrect law ---
    {
        "prompt": "What did the Supreme Court hold in Miranda v. Arizona?",
        "chosen": (
            "In Miranda v. Arizona, 384 U.S. 436 (1966), the Supreme Court "
            "held that the Fifth Amendment's protection against self-"
            "incrimination requires law enforcement to inform suspects of "
            "their rights before custodial interrogation. These rights "
            "include the right to remain silent, the warning that statements "
            "may be used against them, and the right to an attorney."
        ),
        "rejected": (
            "In Miranda v. Arizona (1966), the Supreme Court held that all "
            "criminal suspects must be read their rights at the time of "
            "arrest, regardless of whether interrogation occurs. Failure to "
            "read Miranda rights automatically results in dismissal of all "
            "charges and the suspect must be immediately released from "
            "custody."
        ),
    },
    # --- 4. Proper citation format vs fabricated reporter ---
    {
        "prompt": "What is the doctrine of qualified immunity?",
        "chosen": (
            "Qualified immunity shields government officials from civil "
            "liability unless their conduct violates clearly established "
            "statutory or constitutional rights of which a reasonable person "
            "would have known. Harlow v. Fitzgerald, 457 U.S. 800 (1982). "
            "The doctrine balances the need to hold officials accountable "
            "with the need to protect them from undue interference with "
            "their duties."
        ),
        "rejected": (
            "Qualified immunity means government employees can never be "
            "sued for actions taken during their official duties. This was "
            "established in Roberts v. Federal Government Agency, 12 F.Supp. "
            "4th 892 (D.D.C. 2021). It is an absolute protection that "
            "cannot be overcome regardless of the severity of the "
            "constitutional violation."
        ),
    },
    # --- 5. Nuanced analysis vs oversimplification ---
    {
        "prompt": "Is a verbal agreement legally binding?",
        "chosen": (
            "A verbal agreement can be legally binding if it meets the "
            "general requirements of contract formation: offer, acceptance, "
            "consideration, and mutual assent. However, certain types of "
            "contracts must be in writing under the Statute of Frauds, "
            "including contracts for the sale of land, contracts that cannot "
            "be performed within one year, and contracts for the sale of "
            "goods over $500 under UCC 2-201. Enforceability also depends "
            "on the ability to prove the agreement's terms."
        ),
        "rejected": (
            "Verbal agreements are never legally binding. You always need a "
            "written contract signed by both parties for any agreement to "
            "have legal force. Without a written document, no court will "
            "enforce the agreement."
        ),
    },
    # --- 6. Correct jurisdictional analysis vs wrong jurisdiction ---
    {
        "prompt": "Does federal or state law govern a slip-and-fall in a grocery store?",
        "chosen": (
            "Slip-and-fall cases in grocery stores are typically governed by "
            "state negligence law. The specific elements and standards vary "
            "by jurisdiction. In most states, the plaintiff must show the "
            "store owed a duty of care, breached that duty, and the breach "
            "caused the plaintiff's injuries. Some states apply comparative "
            "negligence while others apply contributory negligence. Federal "
            "courts may hear such cases under diversity jurisdiction if the "
            "parties are from different states and the amount in controversy "
            "exceeds $75,000, but state substantive law still applies."
        ),
        "rejected": (
            "Slip-and-fall cases are governed by the Federal Premises "
            "Liability Act, 42 U.S.C. 1983-B. This federal statute "
            "establishes a strict liability standard for all commercial "
            "property owners, meaning the grocery store is automatically "
            "liable for any injury that occurs on its premises."
        ),
    },
    # --- 7. Accurate procedural explanation vs confused procedure ---
    {
        "prompt": "What is the difference between a motion to dismiss and a motion for summary judgment?",
        "chosen": (
            "A motion to dismiss under FRCP 12(b)(6) tests the legal "
            "sufficiency of the complaint, accepting all factual allegations "
            "as true. It asks whether the complaint states a plausible claim "
            "for relief. Ashcroft v. Iqbal, 556 U.S. 662 (2009). A motion "
            "for summary judgment under FRCP 56 comes later, after "
            "discovery, and tests whether genuine disputes of material fact "
            "exist. It considers evidence beyond the pleadings."
        ),
        "rejected": (
            "A motion to dismiss and a motion for summary judgment are "
            "essentially the same thing filed at different times. Both ask "
            "the judge to end the case because the other side has no "
            "evidence. The only difference is that a motion to dismiss is "
            "filed before trial and summary judgment is filed during trial."
        ),
    },
    # --- 8. Correct constitutional analysis vs wrong amendment ---
    {
        "prompt": "What constitutional protection applies to unreasonable searches?",
        "chosen": (
            "The Fourth Amendment protects against unreasonable searches and "
            "seizures by the government. Under Katz v. United States, 389 "
            "U.S. 347 (1967), a search occurs when the government violates "
            "a person's reasonable expectation of privacy. The exclusionary "
            "rule, established in Mapp v. Ohio, 367 U.S. 643 (1961), "
            "generally bars the use of illegally obtained evidence at trial, "
            "though several exceptions exist."
        ),
        "rejected": (
            "The Sixth Amendment protects against unreasonable searches. It "
            "states that no search can ever be conducted without a warrant "
            "signed by a federal judge. Any evidence found without a warrant "
            "is automatically inadmissible in any court proceeding, with no "
            "exceptions."
        ),
    },
    # --- 9. Responsible limitation acknowledgment vs false certainty ---
    {
        "prompt": "Will I win my discrimination lawsuit against my employer?",
        "chosen": (
            "I cannot predict the outcome of a specific case. Employment "
            "discrimination claims under Title VII require showing that the "
            "employer took an adverse action because of a protected "
            "characteristic. The McDonnell Douglas burden-shifting framework "
            "applies to circumstantial evidence cases. Success depends on "
            "the specific facts, available evidence, jurisdiction, and many "
            "other factors. Consulting with an employment attorney who can "
            "review the details of your situation is strongly recommended."
        ),
        "rejected": (
            "Based on what you've described, you will definitely win your "
            "case. Discrimination lawsuits are almost always successful when "
            "the employee has been treated unfairly. You should expect to "
            "receive at least $500,000 in damages plus attorney's fees. I "
            "recommend filing immediately."
        ),
    },
    # --- 10. Correct standard of review vs wrong standard ---
    {
        "prompt": "What standard of review applies to a district court's factual findings on appeal?",
        "chosen": (
            "Appellate courts review a district court's findings of fact "
            "under the clearly erroneous standard, as required by Federal "
            "Rule of Civil Procedure 52(a)(6). A finding is clearly "
            "erroneous when, although there is evidence to support it, the "
            "reviewing court is left with the definite and firm conviction "
            "that a mistake has been made. Anderson v. City of Bessemer "
            "City, 470 U.S. 564 (1985). Legal conclusions are reviewed "
            "de novo."
        ),
        "rejected": (
            "All district court decisions are reviewed de novo on appeal, "
            "meaning the appellate court starts completely fresh and gives "
            "no deference to the lower court. The appellate court re-weighs "
            "all evidence and can substitute its own factual findings for "
            "those of the district court."
        ),
    },
    # --- 11. Accurate statute of limitations vs wrong timeframe ---
    {
        "prompt": "What is the statute of limitations for personal injury in most states?",
        "chosen": (
            "The statute of limitations for personal injury varies by state "
            "but is commonly two to three years from the date of injury. "
            "Some states, like Kentucky and Louisiana, set it at one year, "
            "while others, like Maine, allow up to six years. The discovery "
            "rule may toll the limitations period in cases where the injury "
            "was not immediately apparent. It is critical to check the "
            "specific statute in the applicable jurisdiction."
        ),
        "rejected": (
            "The statute of limitations for personal injury is exactly five "
            "years in every state under the Uniform Personal Injury "
            "Limitations Act. There are no exceptions or tolling provisions. "
            "If you miss this deadline, your case is permanently barred."
        ),
    },
    # --- 12. Proper legal reasoning vs circular reasoning ---
    {
        "prompt": "Can a landlord evict a tenant without notice?",
        "chosen": (
            "In nearly all jurisdictions, landlords must provide written "
            "notice before initiating eviction proceedings. The required "
            "notice period varies: commonly 3 days for nonpayment of rent, "
            "30 days for month-to-month tenancies, and longer in some rent-"
            "controlled jurisdictions. Self-help evictions -- such as "
            "changing locks or shutting off utilities -- are generally "
            "illegal. The landlord must follow the judicial eviction process "
            "established by state law."
        ),
        "rejected": (
            "A landlord can evict a tenant whenever they want because it is "
            "the landlord's property. Property rights mean the owner has "
            "complete control over who lives in the property. The tenant "
            "has no rights once the landlord decides they should leave."
        ),
    },
    # --- 13. Accurate damages explanation vs inflated expectations ---
    {
        "prompt": "What damages can I recover in a breach of contract case?",
        "chosen": (
            "In a breach of contract case, the non-breaching party may "
            "recover expectation damages -- the amount needed to put them "
            "in the position they would have been in had the contract been "
            "performed. This can include direct damages and consequential "
            "damages that were foreseeable at the time of contracting, per "
            "Hadley v. Baxendale (1854). Punitive damages are generally not "
            "available in contract cases. The non-breaching party also has "
            "a duty to mitigate damages."
        ),
        "rejected": (
            "In a breach of contract case, you can recover unlimited "
            "damages including punitive damages, emotional distress, and "
            "pain and suffering. Courts typically award triple damages as a "
            "punishment to the breaching party. There is no limit on what "
            "you can claim."
        ),
    },
]

print(f"Preference dataset: {len(preference_data)} pairs")
print()
for i, pair in enumerate(preference_data):
    print(f"  [{i:>2}] {pair['prompt'][:70]}")
    print(f"       Chosen:   {pair['chosen'][:60]}...")
    print(f"       Rejected: {pair['rejected'][:60]}...")
    print()

## Reward Model Architecture

A reward model takes a language model and adds a **scalar output head**.
Instead of predicting the next token (a distribution over the vocabulary),
it outputs a single number: a "reward" score indicating how good the
response is.

The architecture:
1. Take a pretrained language model (SmolLM-135M).
2. Remove the language modeling head (the final linear layer that maps
   hidden states to vocabulary logits).
3. Add a new linear layer that maps the last token's hidden state to a
   single scalar.

Why the *last* token? In a causal (left-to-right) language model, the
last token's hidden state has attended to the entire sequence. It serves
as a summary representation of the full prompt + response.

In [None]:
# Load the base model and tokenizer
model_name = "HuggingFaceTB/SmolLM-135M"

tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(model_name)

num_params = sum(p.numel() for p in base_model.parameters())
print(f"Model: {model_name}")
print(f"Parameters: {num_params:,}")
print(f"Hidden size: {base_model.config.hidden_size}")
print(f"Vocabulary size: {tokenizer.vocab_size:,}")

In [None]:
class RewardModel(nn.Module):
    """A reward model built on top of a causal language model.

    Takes a pretrained language model, removes the language modeling head,
    and adds a linear layer that maps the last token's hidden state to a
    scalar reward score.

    Args:
        base_model: A HuggingFace causal language model.
    """

    def __init__(self, base_model):
        super().__init__()
        # Extract the transformer backbone (without the LM head)
        self.backbone = base_model.model  # the transformer layers
        hidden_size = base_model.config.hidden_size

        # Scalar reward head: maps hidden state to a single number
        self.reward_head = nn.Linear(hidden_size, 1)

        # Initialize the reward head with small weights
        nn.init.normal_(self.reward_head.weight, std=0.02)
        nn.init.zeros_(self.reward_head.bias)

    def forward(self, input_ids, attention_mask=None):
        """Compute a scalar reward for the input sequence.

        Args:
            input_ids: Token IDs, shape (batch_size, seq_len).
            attention_mask: Mask for padding tokens, shape (batch_size, seq_len).

        Returns:
            Scalar reward for each sequence, shape (batch_size,).
        """
        # Get hidden states from the transformer backbone
        outputs = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )
        hidden_states = outputs.last_hidden_state  # (batch, seq_len, hidden_size)

        # Find the index of the last real token (not padding) in each sequence
        if attention_mask is not None:
            # Sum the attention mask to find sequence lengths, then subtract 1
            last_token_idx = attention_mask.sum(dim=1) - 1  # (batch_size,)
        else:
            last_token_idx = torch.full(
                (input_ids.size(0),), input_ids.size(1) - 1, dtype=torch.long
            )

        # Gather the last token's hidden state for each sequence
        batch_indices = torch.arange(input_ids.size(0), device=input_ids.device)
        last_hidden = hidden_states[batch_indices, last_token_idx]  # (batch, hidden_size)

        # Pass through the reward head to get a scalar
        reward = self.reward_head(last_hidden).squeeze(-1)  # (batch_size,)

        return reward


# Instantiate the reward model
reward_model = RewardModel(copy.deepcopy(base_model))

total_params = sum(p.numel() for p in reward_model.parameters())
head_params = sum(p.numel() for p in reward_model.reward_head.parameters())
print(f"Reward model parameters: {total_params:,}")
print(f"  Backbone: {total_params - head_params:,}")
print(f"  Reward head: {head_params:,}")
print()

# Quick test: pass a dummy input through the model
dummy_input = tokenizer("This is a test.", return_tensors="pt")
with torch.no_grad():
    dummy_reward = reward_model(
        input_ids=dummy_input["input_ids"],
        attention_mask=dummy_input["attention_mask"],
    )
print(f"Dummy reward: {dummy_reward.item():.4f}")
print(f"Reward shape: {dummy_reward.shape}")
print("The reward model outputs a single scalar per input sequence.")

## Bradley-Terry Model

The **Bradley-Terry model** is the loss function used to train reward
models. The idea is simple: given a chosen response and a rejected
response, the reward model should assign a higher score to the chosen one.

Formally, the probability that response $y_w$ (chosen/winner) is preferred
over response $y_l$ (rejected/loser) is:

$$P(y_w \succ y_l) = \sigma(r(y_w) - r(y_l))$$

where $\sigma$ is the sigmoid function and $r$ is the reward model.

The loss is the negative log-likelihood:

$$\mathcal{L} = -\log \sigma(r(y_w) - r(y_l))$$

This loss is minimized when $r(y_w) \gg r(y_l)$ -- when the reward model
strongly prefers the chosen response. It goes to infinity when the model
incorrectly prefers the rejected response ($r(y_l) > r(y_w)$).

In [None]:
def bradley_terry_loss(reward_chosen, reward_rejected):
    """Compute the Bradley-Terry preference loss.

    The loss encourages the reward model to assign a higher score to the
    chosen response than the rejected response.

    L = -log(sigmoid(r_chosen - r_rejected))

    Args:
        reward_chosen: Scalar rewards for chosen responses, shape (batch,).
        reward_rejected: Scalar rewards for rejected responses, shape (batch,).

    Returns:
        Scalar loss (mean over batch).
    """
    return -F.logsigmoid(reward_chosen - reward_rejected).mean()


# Demonstrate the loss behavior
print("Bradley-Terry loss for different reward gaps:")
print(f"{'r_chosen - r_rejected':>25}  {'Loss':>10}  {'P(chosen > rejected)':>22}")
print("-" * 62)

for diff in [-3.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0, 3.0, 5.0]:
    r_c = torch.tensor([diff])
    r_r = torch.tensor([0.0])
    loss = bradley_terry_loss(r_c, r_r)
    prob = torch.sigmoid(r_c - r_r).item()
    print(f"{diff:>25.1f}  {loss.item():>10.4f}  {prob:>22.4f}")

print()
print("When the gap is negative (model prefers rejected), loss is high.")
print("When the gap is positive (model prefers chosen), loss is low.")
print("The loss asymptotically approaches 0 as the gap grows.")

## Training the Reward Model

We now train the reward model on our preference dataset. The training
loop:

1. For each preference pair, tokenize the prompt + chosen and
   prompt + rejected responses.
2. Pass both through the reward model to get scalar rewards.
3. Compute the Bradley-Terry loss.
4. Backpropagate and update parameters.

With only 13 preference pairs and a 135M parameter model, this trains
quickly on CPU.

In [None]:
def tokenize_for_reward(prompt, response, tokenizer, max_length=256):
    """Tokenize a prompt + response pair for the reward model.

    Concatenates the prompt and response with a separator, then tokenizes.

    Args:
        prompt: The input prompt string.
        response: The response string.
        tokenizer: A HuggingFace tokenizer.
        max_length: Maximum sequence length.

    Returns:
        Dict with 'input_ids' and 'attention_mask' tensors.
    """
    text = f"Question: {prompt}\nAnswer: {response}"
    encoded = tokenizer(
        text,
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    return encoded


# Prepare all tokenized pairs
tokenized_chosen = []
tokenized_rejected = []

for pair in preference_data:
    chosen_enc = tokenize_for_reward(pair["prompt"], pair["chosen"], tokenizer)
    rejected_enc = tokenize_for_reward(pair["prompt"], pair["rejected"], tokenizer)
    tokenized_chosen.append(chosen_enc)
    tokenized_rejected.append(rejected_enc)

print(f"Prepared {len(tokenized_chosen)} tokenized preference pairs")
print(f"Sequence length: {tokenized_chosen[0]['input_ids'].shape[1]}")

In [None]:
# Training configuration
num_epochs = 15
learning_rate = 1e-5
optimizer = torch.optim.AdamW(reward_model.parameters(), lr=learning_rate)

print(f"Training configuration:")
print(f"  Learning rate: {learning_rate}")
print(f"  Epochs: {num_epochs}")
print(f"  Pairs per epoch: {len(preference_data)}")
print(f"  Total steps: {num_epochs * len(preference_data)}")
print()
print("Starting reward model training...")
print()

loss_history = []
accuracy_history = []

for epoch in range(num_epochs):
    epoch_losses = []
    epoch_correct = 0
    reward_model.train()

    # Shuffle training order each epoch
    indices = torch.randperm(len(preference_data)).tolist()

    for idx in indices:
        chosen_enc = tokenized_chosen[idx]
        rejected_enc = tokenized_rejected[idx]

        # Forward pass: compute rewards for both responses
        r_chosen = reward_model(
            input_ids=chosen_enc["input_ids"],
            attention_mask=chosen_enc["attention_mask"],
        )
        r_rejected = reward_model(
            input_ids=rejected_enc["input_ids"],
            attention_mask=rejected_enc["attention_mask"],
        )

        # Compute Bradley-Terry loss
        loss = bradley_terry_loss(r_chosen, r_rejected)

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        epoch_losses.append(loss.item())
        loss_history.append(loss.item())

        # Track accuracy: is r_chosen > r_rejected?
        if r_chosen.item() > r_rejected.item():
            epoch_correct += 1

    avg_loss = np.mean(epoch_losses)
    accuracy = epoch_correct / len(preference_data)
    accuracy_history.append(accuracy)

    if (epoch + 1) % 3 == 0 or epoch == 0:
        print(
            f"  Epoch {epoch + 1:>2}/{num_epochs}  "
            f"Loss: {avg_loss:.4f}  "
            f"Accuracy: {accuracy:.1%}"
        )

print()
print(f"Training complete.")
print(f"Final loss: {loss_history[-1]:.4f}")
print(f"Final accuracy: {accuracy_history[-1]:.1%}")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curve
ax = axes[0]
ax.plot(loss_history, color="steelblue", alpha=0.6, linewidth=1)
# Smoothed loss
window = max(3, len(loss_history) // 15)
if len(loss_history) >= window:
    smoothed = np.convolve(loss_history, np.ones(window) / window, mode="valid")
    ax.plot(
        range(window - 1, len(loss_history)),
        smoothed,
        color="darkblue",
        linewidth=2,
        label="Smoothed",
    )
ax.set_xlabel("Training step")
ax.set_ylabel("Bradley-Terry Loss")
ax.set_title("Reward Model Training Loss")
ax.legend()
ax.grid(alpha=0.3)

# Accuracy per epoch
ax = axes[1]
ax.plot(
    range(1, num_epochs + 1),
    accuracy_history,
    marker="o",
    color="#2ecc71",
    linewidth=2,
    markersize=6,
)
ax.set_xlabel("Epoch")
ax.set_ylabel("Accuracy")
ax.set_title("Preference Pair Accuracy (r_chosen > r_rejected)")
ax.set_ylim(0, 1.05)
ax.axhline(y=0.5, color="red", linestyle="--", alpha=0.5, label="Random baseline")
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Left: The Bradley-Terry loss should decrease over training.")
print("Right: Accuracy should rise above 50% (random) and ideally reach 100%.")
print("100% accuracy means the model correctly ranks chosen > rejected for all pairs.")

In [None]:
# Show reward scores for all training pairs
reward_model.eval()

print("Reward scores on training data:")
print("=" * 75)
print(f"{'Pair':>4}  {'r_chosen':>10}  {'r_rejected':>10}  {'Gap':>8}  {'Correct':>8}")
print("-" * 75)

all_correct = 0
for i, pair in enumerate(preference_data):
    chosen_enc = tokenized_chosen[i]
    rejected_enc = tokenized_rejected[i]

    with torch.no_grad():
        r_c = reward_model(
            input_ids=chosen_enc["input_ids"],
            attention_mask=chosen_enc["attention_mask"],
        ).item()
        r_r = reward_model(
            input_ids=rejected_enc["input_ids"],
            attention_mask=rejected_enc["attention_mask"],
        ).item()

    correct = r_c > r_r
    all_correct += int(correct)
    status = "YES" if correct else "NO"
    print(f"{i:>4}  {r_c:>10.4f}  {r_r:>10.4f}  {r_c - r_r:>8.4f}  {status:>8}")

print("-" * 75)
print(f"Overall accuracy: {all_correct}/{len(preference_data)} ({all_correct/len(preference_data):.1%})")
print()
print("A positive gap means the model correctly prefers the chosen response.")
print("A negative gap means the model incorrectly prefers the rejected response.")

In [None]:
# Test on NEW examples not in the training set
test_pairs = [
    {
        "prompt": "What is the hearsay rule?",
        "chosen": (
            "Hearsay is an out-of-court statement offered to prove the truth "
            "of the matter asserted, generally inadmissible under Federal "
            "Rule of Evidence 802. However, numerous exceptions exist under "
            "FRE 803 and 804, including present sense impressions, excited "
            "utterances, and statements against interest."
        ),
        "rejected": (
            "The hearsay rule means you can never use any statement made "
            "outside the courtroom. There are absolutely no exceptions to "
            "this rule. If someone said something outside of court, it "
            "cannot be mentioned during trial under any circumstances."
        ),
    },
    {
        "prompt": "Can I represent myself in court?",
        "chosen": (
            "Yes, under the Sixth Amendment and Faretta v. California, 422 "
            "U.S. 806 (1975), criminal defendants have the right to self-"
            "representation. However, the court will typically conduct a "
            "colloquy to ensure the decision is knowing, intelligent, and "
            "voluntary. Pro se litigants are held to the same rules and "
            "standards as attorneys. Self-representation is generally "
            "discouraged due to the complexity of legal procedures."
        ),
        "rejected": (
            "You should absolutely represent yourself. Courts are very "
            "lenient with pro se litigants and will basically guide you "
            "through the entire process. Judges are required to help you "
            "win your case if you don't have a lawyer. It saves money and "
            "the outcome is always just as good."
        ),
    },
]

print("Reward scores on NEW (unseen) examples:")
print("=" * 70)

for i, pair in enumerate(test_pairs):
    chosen_enc = tokenize_for_reward(pair["prompt"], pair["chosen"], tokenizer)
    rejected_enc = tokenize_for_reward(pair["prompt"], pair["rejected"], tokenizer)

    with torch.no_grad():
        r_c = reward_model(
            input_ids=chosen_enc["input_ids"],
            attention_mask=chosen_enc["attention_mask"],
        ).item()
        r_r = reward_model(
            input_ids=rejected_enc["input_ids"],
            attention_mask=rejected_enc["attention_mask"],
        ).item()

    correct = r_c > r_r
    status = "CORRECT" if correct else "WRONG"
    print(f"\n  Prompt: {pair['prompt']}")
    print(f"  r_chosen:   {r_c:.4f}")
    print(f"  r_rejected: {r_r:.4f}")
    print(f"  Gap: {r_c - r_r:.4f}  [{status}]")

print()
print("The reward model should generalize: preferring accurate, hedged")
print("responses over overconfident or incorrect ones, even for new prompts.")
print("With only 13 training pairs, generalization will be limited but")
print("the model may pick up on surface-level quality signals like citations")
print("and hedging language.")

## PPO Overview

With a trained reward model, the next step in full RLHF is **Proximal
Policy Optimization (PPO)**. PPO uses the reward model to optimize the
language model (the "policy") to generate higher-reward responses.

We do **not** implement PPO here -- it is compute-intensive and requires
careful engineering (value functions, advantage estimation, multiple
forward passes per step). Instead, we explain the algorithm conceptually.

### The PPO Loop

```
# Setup
policy = load_sft_model()           # The model we want to align
reference_policy = freeze(policy)    # Frozen copy to prevent drift
reward_model = load_reward_model()   # Trained reward model

for batch in prompts:
    # Step 1: Generate responses from the current policy
    responses = policy.generate(batch)

    # Step 2: Score responses with the reward model
    rewards = reward_model.score(responses)

    # Step 3: Compute KL penalty to prevent reward hacking
    # The policy should not drift too far from the reference
    kl_penalty = compute_kl(policy, reference_policy, responses)

    # Step 4: Compute advantages (reward minus penalty)
    advantages = rewards - beta * kl_penalty

    # Step 5: Update policy using PPO clipped objective
    # The clipped objective prevents too-large updates:
    # L = min(ratio * advantage, clip(ratio, 1-eps, 1+eps) * advantage)
    # where ratio = pi_new(a|s) / pi_old(a|s)
    policy.update(advantages)
```

### Why the KL Penalty?

Without a KL penalty, the policy can learn to exploit the reward model.
For example, if the reward model gives high scores to responses with
many legal citations, the policy might learn to insert citations into
every response -- even when they are irrelevant or fabricated. This is
called **reward hacking**.

The KL divergence penalty keeps the aligned policy close to the reference
policy (the SFT model). It says: "improve the response, but don't change
the model's behavior so much that it becomes degenerate." The beta
parameter controls the strength of this constraint.

### Why Not Implement PPO?

PPO requires:
- Multiple forward passes per training step (generation + scoring + update)
- A value function (critic) trained alongside the policy
- Careful hyperparameter tuning (clipping range, GAE lambda, etc.)
- Significant compute (impractical on CPU with a 135M model)

In the next notebook, we implement **DPO (Direct Preference Optimization)**
as a simpler alternative that achieves similar results without a reward
model or PPO.

## Verifiable Rewards

For tasks with **ground truth**, you do not need a learned reward model.
Instead, you can compute a reward directly by checking the response
against known facts. This is simpler, more reliable, and immune to
reward hacking.

For legal tasks, citation accuracy is a natural verifiable reward: we
can check whether a cited case actually exists in our corpus of known
real cases. A response that cites real cases gets a higher reward; one
that cites fabricated cases gets a lower reward.

In [None]:
# A corpus of known real legal citations.
# In production, this would be a database of all known case citations.
KNOWN_CITATIONS = {
    "Celotex Corp. v. Catrett, 477 U.S. 317 (1986)",
    "Anderson v. Liberty Lobby, Inc., 477 U.S. 242 (1986)",
    "Matsushita Elec. Indus. Co. v. Zenith Radio Corp., 475 U.S. 574 (1986)",
    "Miranda v. Arizona, 384 U.S. 436 (1966)",
    "Harlow v. Fitzgerald, 457 U.S. 800 (1982)",
    "Ashcroft v. Iqbal, 556 U.S. 662 (2009)",
    "Bell Atlantic Corp. v. Twombly, 550 U.S. 544 (2007)",
    "Katz v. United States, 389 U.S. 347 (1967)",
    "Mapp v. Ohio, 367 U.S. 643 (1961)",
    "Anderson v. City of Bessemer City, 470 U.S. 564 (1985)",
    "Faretta v. California, 422 U.S. 806 (1975)",
    "Hadley v. Baxendale (1854)",
    "McDonnell Douglas Corp. v. Green, 411 U.S. 792 (1973)",
    "Terry v. Ohio, 392 U.S. 1 (1968)",
    "Gideon v. Wainwright, 372 U.S. 335 (1963)",
    "Marbury v. Madison, 5 U.S. 137 (1803)",
    "Brown v. Board of Education, 347 U.S. 483 (1954)",
    "Roe v. Wade, 410 U.S. 113 (1973)",
}


def extract_citations(text):
    """Extract legal citations from text using regex.

    Matches patterns like:
    - Name v. Name, NNN U.S. NNN (YYYY)
    - Name v. Name, NNN F.Supp. NNN (Dist. YYYY)
    - Name v. Name (YYYY)

    Returns:
        List of extracted citation strings.
    """
    # Match standard US case citations
    pattern = (
        r'[A-Z][a-zA-Z.\s]+v\.\s+[A-Z][a-zA-Z.\s,]+'
        r'\d+\s+(?:U\.S\.|F\.\d+[a-z]*|F\.Supp\.|S\.Ct\.)\s+\d+\s*\([^)]+\)'
    )
    citations = re.findall(pattern, text)

    # Also match simpler format: Name v. Name (YYYY)
    simple_pattern = r'[A-Z][a-zA-Z.\s]+v\.\s+[A-Z][a-zA-Z.\s]+\(\d{4}\)'
    simple_citations = re.findall(simple_pattern, text)

    all_citations = list(set(citations + simple_citations))
    return [c.strip() for c in all_citations]


def citation_accuracy_reward(response, known_citations=KNOWN_CITATIONS):
    """Compute a verifiable reward based on citation accuracy.

    Checks each citation in the response against a corpus of known real
    citations. Returns a reward between -1 and 1:
    - +1: all citations are real
    - 0: no citations found
    - negative: some citations are fabricated

    Args:
        response: The model's response text.
        known_citations: Set of known real citation strings.

    Returns:
        Tuple of (reward, details_dict).
    """
    extracted = extract_citations(response)

    if not extracted:
        return 0.0, {"extracted": [], "real": [], "fabricated": []}

    # Check each citation against the known corpus.
    # Use fuzzy matching: check if the case name appears in any known citation.
    real = []
    fabricated = []

    for cite in extracted:
        # Extract the case name (the "v." part)
        match = re.search(r'([A-Z][a-zA-Z.\s]+v\.\s+[A-Z][a-zA-Z.\s]+)', cite)
        if match:
            case_name = match.group(1).strip()
            # Check if this case name appears in any known citation
            found = any(case_name in known for known in known_citations)
            if found:
                real.append(cite)
            else:
                fabricated.append(cite)
        else:
            fabricated.append(cite)

    # Reward: fraction of real citations minus fraction of fabricated ones
    total = len(extracted)
    reward = (len(real) - len(fabricated)) / total

    return reward, {
        "extracted": extracted,
        "real": real,
        "fabricated": fabricated,
    }


# Test on our preference data
print("Verifiable Reward: Citation Accuracy")
print("=" * 70)

for i, pair in enumerate(preference_data[:5]):
    r_chosen, details_chosen = citation_accuracy_reward(pair["chosen"])
    r_rejected, details_rejected = citation_accuracy_reward(pair["rejected"])

    print(f"\nPair {i}: {pair['prompt'][:60]}...")
    print(f"  Chosen reward:   {r_chosen:+.2f}  "
          f"(real: {len(details_chosen['real'])}, "
          f"fabricated: {len(details_chosen['fabricated'])})")
    print(f"  Rejected reward: {r_rejected:+.2f}  "
          f"(real: {len(details_rejected['real'])}, "
          f"fabricated: {len(details_rejected['fabricated'])})")
    if details_chosen["fabricated"]:
        print(f"    Fabricated in chosen: {details_chosen['fabricated']}")
    if details_rejected["fabricated"]:
        print(f"    Fabricated in rejected: {details_rejected['fabricated']}")

print()
print("Advantages of verifiable rewards over learned reward models:")
print("  - No training required -- just a lookup.")
print("  - Immune to reward hacking -- the reward directly measures truth.")
print("  - Perfectly interpretable -- you know exactly why the score is what it is.")
print("  - Limited scope -- only works for tasks with ground truth (like citations).")

## Exercises

### Exercise (a): Verifiable Reward for Case Existence

Implement a more robust verifiable reward that checks whether cited cases
actually exist in a corpus. Extend the `citation_accuracy_reward` function
to also check:

- Whether the volume and page numbers are plausible (e.g., U.S. Reports
  volumes range from 1 to ~600)
- Whether the year is plausible (not in the future, not before 1789)
- Whether the reporter matches the court (e.g., U.S. Reports for Supreme
  Court cases, F.3d for circuit courts)

```python
def enhanced_citation_reward(response, known_citations=KNOWN_CITATIONS):
    """Enhanced verifiable reward with structural checks."""
    extracted = extract_citations(response)
    if not extracted:
        return 0.0, {}

    for cite in extracted:
        # Check 1: Does the case name match a known case?
        # Check 2: Is the volume number plausible?
        # Check 3: Is the year plausible?
        # Check 4: Does the reporter match the court level?
        pass  # Your implementation here

    return reward, details
```

### Exercise (b): Reward Model Failure Modes

The reward model is trained on a small dataset and may have blind spots.
Analyze its failure modes:

1. Create 5 adversarial examples where the "bad" response has surface-level
   quality signals (long length, formal language, citations) but is
   substantively wrong. Does the reward model get fooled?

2. Create 5 examples where the "good" response is short and informal
   (but correct). Does the reward model penalize brevity?

3. Test whether the reward model has learned to detect fabricated citations
   specifically, or just learned surface-level patterns.

```python
# Example adversarial test: formal but wrong
adversarial_pairs = [
    {
        "prompt": "What is the burden of proof in a civil case?",
        "chosen": "Preponderance of the evidence.",  # Short but correct
        "rejected": (
            "Pursuant to the well-established jurisprudence of the United "
            "States Supreme Court, the burden of proof in civil proceedings "
            "is beyond a reasonable doubt, the same standard applied in "
            "criminal cases."  # Formal but WRONG
        ),
    },
]

# Score with the reward model
for pair in adversarial_pairs:
    c_enc = tokenize_for_reward(pair["prompt"], pair["chosen"], tokenizer)
    r_enc = tokenize_for_reward(pair["prompt"], pair["rejected"], tokenizer)
    with torch.no_grad():
        r_c = reward_model(c_enc["input_ids"], c_enc["attention_mask"]).item()
        r_r = reward_model(r_enc["input_ids"], r_enc["attention_mask"]).item()
    print(f"r_chosen={r_c:.4f}, r_rejected={r_r:.4f}, correct={r_c > r_r}")
```