# Reward Model Experiment: Clean vs Unfiltered Data

**Goal:** Train two reward models on HH-RLHF preference data — one on quality-filtered (clean) data, one on unfiltered data — and compare accuracy on a held-out test set.

**Setup:** Free T4 GPU on Google Colab. ~25 minutes total runtime.

## 1. Install Dependencies

In [None]:
!pip install -q transformers datasets torch tqdm matplotlib

## 2. Imports & Device Setup

In [None]:
import json
import re
from typing import Any

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Reward Model Definition

In [None]:
class RewardModel(nn.Module):
    """Bradley-Terry reward model: backbone -> [CLS] -> Linear -> scalar."""

    def __init__(self, model_name: str = "distilbert-base-uncased") -> None:
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden_size = self.backbone.config.hidden_size
        self.reward_head = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.backbone(input_ids=input_ids, attention_mask=attention_mask)
        cls_embedding = outputs.last_hidden_state[:, 0, :]
        reward = self.reward_head(cls_embedding).squeeze(-1)
        return reward

## 4. Conversation Parsing & Dataset

In [None]:
_SPEAKER_RE = re.compile(r"^(Human|Assistant):\s*", re.MULTILINE)


def last_turn(text: str, role: str) -> str:
    """Extract the last turn for a given role from an HH-RLHF conversation."""
    parts = _SPEAKER_RE.split(text)
    result = ""
    for i in range(1, len(parts), 2):
        if parts[i].lower() == role and i + 1 < len(parts):
            result = parts[i + 1].strip()
    return result


class PreferencePairDataset(Dataset):
    """Dataset of (prompt, chosen, rejected) tokenized preference pairs."""

    def __init__(self, prompts, chosen_responses, rejected_responses, tokenizer, max_length=256):
        self.prompts = prompts
        self.chosen_responses = chosen_responses
        self.rejected_responses = rejected_responses
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.prompts)

    def __getitem__(self, idx):
        prompt = self.prompts[idx]
        chosen_enc = self.tokenizer(
            prompt, self.chosen_responses[idx],
            truncation=True, max_length=self.max_length,
            padding="max_length", return_tensors="pt",
        )
        rejected_enc = self.tokenizer(
            prompt, self.rejected_responses[idx],
            truncation=True, max_length=self.max_length,
            padding="max_length", return_tensors="pt",
        )
        return {
            "chosen_input_ids": chosen_enc["input_ids"].squeeze(0),
            "chosen_attention_mask": chosen_enc["attention_mask"].squeeze(0),
            "rejected_input_ids": rejected_enc["input_ids"].squeeze(0),
            "rejected_attention_mask": rejected_enc["attention_mask"].squeeze(0),
        }

## 5. Load HH-RLHF & Parse Conversations

In [None]:
print("Loading HH-RLHF dataset...")
hh_train = load_dataset("Anthropic/hh-rlhf", split="train")
hh_test = load_dataset("Anthropic/hh-rlhf", split="test")
print(f"Train: {len(hh_train)} examples, Test: {len(hh_test)} examples")

def parse_split(ds):
    """Parse a HuggingFace split into (prompts, chosen, rejected) lists."""
    prompts, chosen, rejected = [], [], []
    for row in ds:
        p = last_turn(row["chosen"], "human")
        c = last_turn(row["chosen"], "assistant")
        r = last_turn(row["rejected"], "assistant")
        if p and c and r:
            prompts.append(p)
            chosen.append(c)
            rejected.append(r)
    return prompts, chosen, rejected

# Parse all conversations
print("Parsing train split...")
all_prompts, all_chosen, all_rejected = parse_split(hh_train)
print(f"Parsed {len(all_prompts)} valid train examples")

print("Parsing test split...")
test_prompts, test_chosen, test_rejected = parse_split(hh_test)
print(f"Parsed {len(test_prompts)} valid test examples")

## 6. Upload flagged_indices.json & Split Data

In [None]:
# Upload flagged_indices.json (exported from local PostgreSQL)
from google.colab import files
uploaded = files.upload()  # Select flagged_indices.json

with open("flagged_indices.json") as f:
    flagged_indices = set(json.load(f))

print(f"Loaded {len(flagged_indices)} flagged indices")

In [None]:
# Build index mapping: we need to know which parsed examples correspond to
# which original dataset indices (some get skipped if parsing fails)
valid_to_original = []  # valid_idx -> original dataset index
for i, row in enumerate(hh_train):
    p = last_turn(row["chosen"], "human")
    c = last_turn(row["chosen"], "assistant")
    r = last_turn(row["rejected"], "assistant")
    if p and c and r:
        valid_to_original.append(i)

# Split into clean vs unfiltered
clean_prompts, clean_chosen, clean_rejected = [], [], []
for valid_idx in range(len(all_prompts)):
    original_idx = valid_to_original[valid_idx]
    if original_idx not in flagged_indices:
        clean_prompts.append(all_prompts[valid_idx])
        clean_chosen.append(all_chosen[valid_idx])
        clean_rejected.append(all_rejected[valid_idx])

print(f"Unfiltered: {len(all_prompts)} examples")
print(f"Clean (filtered): {len(clean_prompts)} examples")
print(f"Removed: {len(all_prompts) - len(clean_prompts)} flagged examples")

## 7. Create Datasets

In [None]:
MODEL_NAME = "distilbert-base-uncased"
MAX_LENGTH = 256
BATCH_SIZE = 8
LR = 2e-5
EPOCHS = 1

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

clean_dataset = PreferencePairDataset(clean_prompts, clean_chosen, clean_rejected, tokenizer, MAX_LENGTH)
unfiltered_dataset = PreferencePairDataset(all_prompts, all_chosen, all_rejected, tokenizer, MAX_LENGTH)
test_dataset = PreferencePairDataset(test_prompts, test_chosen, test_rejected, tokenizer, MAX_LENGTH)

print(f"Clean dataset: {len(clean_dataset)} pairs")
print(f"Unfiltered dataset: {len(unfiltered_dataset)} pairs")
print(f"Test dataset: {len(test_dataset)} pairs")

## 8. Training Functions

In [None]:
def train_reward_model(model, dataset, epochs=1, batch_size=8, lr=2e-5, device="cuda"):
    """Train with Bradley-Terry pairwise loss."""
    model = model.to(device)
    model.train()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    history = {"epoch_loss": [], "epoch_accuracy": []}

    for epoch in range(epochs):
        total_loss, correct, total = 0.0, 0, 0
        for batch in tqdm(loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            chosen_ids = batch["chosen_input_ids"].to(device)
            chosen_mask = batch["chosen_attention_mask"].to(device)
            rejected_ids = batch["rejected_input_ids"].to(device)
            rejected_mask = batch["rejected_attention_mask"].to(device)

            reward_chosen = model(chosen_ids, chosen_mask)
            reward_rejected = model(rejected_ids, rejected_mask)

            loss = -F.logsigmoid(reward_chosen - reward_rejected).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * chosen_ids.size(0)
            correct += (reward_chosen > reward_rejected).sum().item()
            total += chosen_ids.size(0)

        epoch_loss = total_loss / total
        epoch_acc = correct / total
        history["epoch_loss"].append(epoch_loss)
        history["epoch_accuracy"].append(epoch_acc)
        print(f"  Epoch {epoch + 1}: loss={epoch_loss:.4f}, accuracy={epoch_acc:.4f}")

    return history


def evaluate_reward_model(model, dataset, batch_size=8, device="cuda"):
    """Evaluate accuracy on a preference pair dataset."""
    model = model.to(device)
    model.eval()
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

    correct, total = 0, 0
    total_gap = 0.0

    with torch.no_grad():
        for batch in tqdm(loader, desc="Evaluating"):
            chosen_ids = batch["chosen_input_ids"].to(device)
            chosen_mask = batch["chosen_attention_mask"].to(device)
            rejected_ids = batch["rejected_input_ids"].to(device)
            rejected_mask = batch["rejected_attention_mask"].to(device)

            reward_chosen = model(chosen_ids, chosen_mask)
            reward_rejected = model(rejected_ids, rejected_mask)

            gap = reward_chosen - reward_rejected
            correct += (gap > 0).sum().item()
            total_gap += gap.sum().item()
            total += chosen_ids.size(0)

    return {
        "accuracy": correct / total,
        "correct": correct,
        "total": total,
        "avg_reward_gap": total_gap / total,
    }

## 9. Train Model A (Clean Data)

In [None]:
print("=" * 60)
print("Training Model A (CLEAN data)")
print("=" * 60)
model_a = RewardModel(MODEL_NAME)
history_a = train_reward_model(model_a, clean_dataset, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, device=device)

## 10. Train Model B (Unfiltered Data)

In [None]:
print("=" * 60)
print("Training Model B (UNFILTERED data)")
print("=" * 60)
model_b = RewardModel(MODEL_NAME)
history_b = train_reward_model(model_b, unfiltered_dataset, epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR, device=device)

## 11. Evaluate Both Models on Test Set

In [None]:
print("\nEvaluating Model A (clean) on test set...")
results_a = evaluate_reward_model(model_a, test_dataset, batch_size=BATCH_SIZE, device=device)

print("\nEvaluating Model B (unfiltered) on test set...")
results_b = evaluate_reward_model(model_b, test_dataset, batch_size=BATCH_SIZE, device=device)

## 12. Results Comparison

In [None]:
# Print comparison table
print("\n" + "=" * 60)
print("RESULTS: Reward Model Comparison")
print("=" * 60)
print(f"{'Metric':<25} {'Clean (A)':<15} {'Unfiltered (B)':<15}")
print("-" * 55)
print(f"{'Training examples':<25} {len(clean_dataset):<15,} {len(unfiltered_dataset):<15,}")
print(f"{'Train accuracy':<25} {history_a['epoch_accuracy'][-1]:<15.4f} {history_b['epoch_accuracy'][-1]:<15.4f}")
print(f"{'Train loss':<25} {history_a['epoch_loss'][-1]:<15.4f} {history_b['epoch_loss'][-1]:<15.4f}")
print(f"{'Test accuracy':<25} {results_a['accuracy']:<15.4f} {results_b['accuracy']:<15.4f}")
print(f"{'Test correct/total':<25} {results_a['correct']}/{results_a['total']:<10} {results_b['correct']}/{results_b['total']:<10}")
print(f"{'Avg reward gap':<25} {results_a['avg_reward_gap']:<15.4f} {results_b['avg_reward_gap']:<15.4f}")
print("-" * 55)

diff = results_a["accuracy"] - results_b["accuracy"]
winner = "Clean (A)" if diff > 0 else "Unfiltered (B)"
print(f"\nAccuracy difference: {abs(diff):.4f} ({abs(diff)*100:.2f}%)")
print(f"Winner: {winner}")
if diff > 0:
    print("Data filtering IMPROVED reward model quality.")
else:
    print("Data filtering did NOT improve reward model quality.")

In [None]:
# Bar chart comparison
labels = ["Clean (A)", "Unfiltered (B)"]
accuracies = [results_a["accuracy"] * 100, results_b["accuracy"] * 100]
colors = ["#2ecc71", "#e74c3c"]

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Test accuracy comparison
bars = axes[0].bar(labels, accuracies, color=colors, edgecolor="black", linewidth=0.5)
axes[0].set_ylabel("Accuracy (%)")
axes[0].set_title("Test Accuracy: Clean vs Unfiltered")
axes[0].set_ylim(45, max(accuracies) + 5)
axes[0].axhline(y=50, color="gray", linestyle="--", alpha=0.5, label="Random baseline")
axes[0].legend()
for bar, acc in zip(bars, accuracies):
    axes[0].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5,
                 f"{acc:.2f}%", ha="center", va="bottom", fontweight="bold")

# Reward gap comparison
gaps = [results_a["avg_reward_gap"], results_b["avg_reward_gap"]]
bars2 = axes[1].bar(labels, gaps, color=colors, edgecolor="black", linewidth=0.5)
axes[1].set_ylabel("Avg Reward Gap")
axes[1].set_title("Avg Reward Gap (chosen - rejected)")
axes[1].axhline(y=0, color="gray", linestyle="--", alpha=0.5)
for bar, gap in zip(bars2, gaps):
    axes[1].text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.01,
                 f"{gap:.4f}", ha="center", va="bottom", fontweight="bold")

plt.tight_layout()
plt.savefig("reward_model_comparison.png", dpi=150, bbox_inches="tight")
plt.show()
print("Chart saved to reward_model_comparison.png")