# Reward Model Testing

This notebook is used to test the trained reward model

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

In [None]:
# Model path
MODEL_PATH = "Qwen2.5-1.5B-Instruct-ultrafeedback_binarized-reward-num_labels_1_wo_filter"

# Global variables to store the model
tokenizer = None
model = None
device = None

## 1. Load Model

In [None]:
def load_model():
    """Load reward model"""
    global tokenizer, model, device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Loading model from {MODEL_PATH}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
    model = AutoModelForSequenceClassification.from_pretrained(
        MODEL_PATH,
        num_labels=1,
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
    )
    model.to(device)
    model.eval()
    print(f"Model loaded on {device}")

# Load model
load_model()

## 2. Define Scoring Function

In [None]:
def get_reward_score(prompt, response):
    """
    Input prompt and response, return reward score
    
    Args:
        prompt: User input question
        response: Model response
    
    Returns:
        float: reward score
    """
    # Use chat template to format input
    messages = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": response}
    ]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
    
    # Tokenize
    inputs = tokenizer(text, return_tensors="pt", truncation=False, max_length=None).to(device)
    
    # Get reward score
    with torch.no_grad():
        outputs = model(**inputs)
        reward = outputs.logits[0, 0].item()
    
    return reward

## 3. Test Examples

In [None]:
# Define test cases
test_cases = [
    {
        "prompt": "How do I learn programming?",
        "good_response": "To learn programming, start with the basics: 1. Choose a beginner-friendly language like Python. 2. Systematically learn syntax and fundamentals. 3. Practice with exercises and projects. 4. Read high-quality code. 5. Join communities to learn and share.",
        "bad_response": "I don't know, Google it yourself."
    },
    {
        "prompt": "What is Artificial Intelligence?",
        "good_response": "Artificial Intelligence (AI) is a branch of computer science focused on building systems that simulate human intelligence. It encompasses machine learning, deep learning, and natural language processing, enabling applications in image recognition, speech processing, autonomous driving, and more.",
        "bad_response": "It's just machines."
    },
    {
        "prompt": "Recommend a good book.",
        "good_response": "I recommend 'The Three-Body Problem', a sci-fi trilogy by Liu Cixin. It explores the contact between humanity and extraterrestrial civilization with profound philosophy and vivid imagination. A Hugo Award winner, it's a masterpiece of Chinese science fiction.",
        "bad_response": "Read whatever you find."
    }
]

In [None]:
# Run tests
print("\n" + "="*80)
print("Start Testing Reward Model")
print("="*80)

for i, case in enumerate(test_cases, 1):
    prompt = case["prompt"]
    good_response = case["good_response"]
    bad_response = case["bad_response"]
    
    good_score = get_reward_score(prompt, good_response)
    bad_score = get_reward_score(prompt, bad_response)
    
    print(f"\n[Test Case {i}]")
    print(f"Question: {prompt}")
    print(f"\nGood Response: {good_response}")
    print(f"Score: {good_score:.4f}")
    print(f"\nBad Response: {bad_response}")
    print(f"Score: {bad_score:.4f}")
    print(f"\nDiff: {good_score - bad_score:.4f}")
    print(f"Order Correct: {'✓' if good_score > bad_score else '✗'}")
    print("-" * 80)

## 4. Custom Test

In [None]:
# Test single example
custom_prompt = "How to read a file in Python?"
custom_response = "In Python, you can read a file using the open() function. For example: with open('file.txt', 'r') as f: content = f.read()"

score = get_reward_score(custom_prompt, custom_response)
print(f"Question: {custom_prompt}")
print(f"Response: {custom_response}")
print(f"Score: {score:.4f}")

In [None]:
# You can test your own examples here
my_prompt = "Your question"
my_response = "Your response"

my_score = get_reward_score(my_prompt, my_response)
print(f"Score: {my_score:.4f}")

## 5. Test Accuracy on Training Set

In [None]:
from datasets import load_dataset

# Load dataset
print("Loading ultrafeedback_binarized training set...")
dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
print(f"Dataset size: {len(dataset)}")

print(f"\nStart testing accuracy on all samples...")
print("="*80)

correct_count = 0
total_count = 0

for i, example in enumerate(dataset):
    prompt = example['chosen'][0]["content"]
    chosen = example["chosen"][-1]["content"]  # Get the last assistant message
    rejected = example["rejected"][-1]["content"]
    
    # Calculate scores
    chosen_score = get_reward_score(prompt, chosen)
    rejected_score = get_reward_score(prompt, rejected)
    
    # Check correctness
    if chosen_score > rejected_score:
        correct_count += 1
        result = "✓"
    else:
        result = "✗"
    
    total_count += 1
    
    # Print progress every 1000 samples
    if (i + 1) % 1000 == 0:
        current_acc = correct_count / total_count * 100
        print(f"Progress: {i+1}/{len(dataset)}, Current Accuracy: {current_acc:.2f}%")

# Calculate final accuracy
accuracy = correct_count / total_count * 100

print("="*80)
print(f"\nFinal Result:")
print(f"Correct Count: {correct_count}/{total_count}")
print(f"Accuracy: {accuracy:.2f}%")

## 6. Analyze Dataset Context Length Distribution

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Collect token lengths for all samples
print("Analyzing dataset context length distribution...")
print(f"Total samples in dataset: {len(dataset)}")

# Store lengths for different parts
prompt_lengths = []
chosen_response_lengths = []
rejected_response_lengths = []
full_chosen_lengths = []
full_rejected_lengths = []

# Analyze entire dataset
for i, example in enumerate(dataset):
    prompt = example['chosen'][0]["content"]
    chosen = example["chosen"][-1]["content"]
    rejected = example["rejected"][-1]["content"]
    
    # Count prompt length
    prompt_inputs = tokenizer(prompt, return_tensors="pt", truncation=False)
    prompt_lengths.append(prompt_inputs['input_ids'].shape[1])
    
    # Count chosen response length
    chosen_inputs = tokenizer(chosen, return_tensors="pt", truncation=False)
    chosen_response_lengths.append(chosen_inputs['input_ids'].shape[1])
    
    # Count rejected response length
    rejected_inputs = tokenizer(rejected, return_tensors="pt", truncation=False)
    rejected_response_lengths.append(rejected_inputs['input_ids'].shape[1])
    
    # Count full conversation length (prompt + response)
    messages_chosen = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": chosen}
    ]
    text_chosen = tokenizer.apply_chat_template(messages_chosen, tokenize=False, add_generation_prompt=False)
    inputs_chosen = tokenizer(text_chosen, return_tensors="pt", truncation=False)
    full_chosen_lengths.append(inputs_chosen['input_ids'].shape[1])
    
    messages_rejected = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": rejected}
    ]
    text_rejected = tokenizer.apply_chat_template(messages_rejected, tokenize=False, add_generation_prompt=False)
    inputs_rejected = tokenizer(text_rejected, return_tensors="pt", truncation=False)
    full_rejected_lengths.append(inputs_rejected['input_ids'].shape[1])
    
    # Print progress every 1000 samples
    if (i + 1) % 1000 == 0:
        print(f"Processed: {i+1}/{len(dataset)} samples")

# Convert to numpy arrays
prompt_lengths = np.array(prompt_lengths)
chosen_response_lengths = np.array(chosen_response_lengths)
rejected_response_lengths = np.array(rejected_response_lengths)
full_chosen_lengths = np.array(full_chosen_lengths)
full_rejected_lengths = np.array(full_rejected_lengths)
all_full_lengths = np.concatenate([full_chosen_lengths, full_rejected_lengths])

# Print statistics
print("\n" + "="*80)
print("Context Length Statistics:")
print("="*80)

print("\n[Prompt Length Statistics]")
print(f"  Max Length: {np.max(prompt_lengths)}")
print(f"  Min Length: {np.min(prompt_lengths)}")
print(f"  Mean Length: {np.mean(prompt_lengths):.2f}")
print(f"  Median: {np.median(prompt_lengths):.2f}")
print(f"  95th Percentile: {np.percentile(prompt_lengths, 95):.2f}")

print("\n[Chosen Response Length Statistics]")
print(f"  Max Length: {np.max(chosen_response_lengths)}")
print(f"  Min Length: {np.min(chosen_response_lengths)}")
print(f"  Mean Length: {np.mean(chosen_response_lengths):.2f}")
print(f"  Median: {np.median(chosen_response_lengths):.2f}")
print(f"  95th Percentile: {np.percentile(chosen_response_lengths, 95):.2f}")

print("\n[Rejected Response Length Statistics]")
print(f"  Max Length: {np.max(rejected_response_lengths)}")
print(f"  Min Length: {np.min(rejected_response_lengths)}")
print(f"  Mean Length: {np.mean(rejected_response_lengths):.2f}")
print(f"  Median: {np.median(rejected_response_lengths):.2f}")
print(f"  95th Percentile: {np.percentile(rejected_response_lengths, 95):.2f}")

print("\n[Full Conversation Length Statistics (Prompt + Response)]")
print(f"  Total Conversations: {len(all_full_lengths)}")
print(f"  Max Length: {np.max(all_full_lengths)}")
print(f"  Min Length: {np.min(all_full_lengths)}")
print(f"  Mean Length: {np.mean(all_full_lengths):.2f}")
print(f"  Median: {np.median(all_full_lengths):.2f}")
print(f"  95th Percentile: {np.percentile(all_full_lengths, 95):.2f}")
print("="*80)

# Plot distributions
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. Prompt Length Distribution
axes[0, 0].hist(prompt_lengths, bins=80, edgecolor='black', alpha=0.7, color='skyblue')
axes[0, 0].axvline(np.mean(prompt_lengths), color='red', linestyle='--', linewidth=2, 
                    label=f'Mean: {np.mean(prompt_lengths):.2f}')
axes[0, 0].axvline(np.median(prompt_lengths), color='green', linestyle='--', linewidth=2, 
                    label=f'Median: {np.median(prompt_lengths):.2f}')
axes[0, 0].set_xlabel('Token Length', fontsize=12)
axes[0, 0].set_ylabel('Frequency', fontsize=12)
axes[0, 0].set_title('Prompt Length Distribution', fontsize=14, fontweight='bold')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# 2. Response Length Distribution Comparison
axes[0, 1].hist(chosen_response_lengths, bins=80, alpha=0.6, color='green', label='Chosen', edgecolor='black')
axes[0, 1].hist(rejected_response_lengths, bins=80, alpha=0.6, color='red', label='Rejected', edgecolor='black')
axes[0, 1].set_xlabel('Token Length', fontsize=12)
axes[0, 1].set_ylabel('Frequency', fontsize=12)
axes[0, 1].set_title('Response Length Distribution Comparison', fontsize=14, fontweight='bold')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# 3. Full Conversation Length Distribution
axes[1, 0].hist(all_full_lengths, bins=100, edgecolor='black', alpha=0.7, color='steelblue')
axes[1, 0].axvline(np.mean(all_full_lengths), color='red', linestyle='--', linewidth=2, 
                    label=f'Mean: {np.mean(all_full_lengths):.2f}')
axes[1, 0].axvline(np.median(all_full_lengths), color='green', linestyle='--', linewidth=2, 
                    label=f'Median: {np.median(all_full_lengths):.2f}')
axes[1, 0].axvline(512, color='orange', linestyle='--', linewidth=2, label='512 tokens')
axes[1, 0].set_xlabel('Token Length', fontsize=12)
axes[1, 0].set_ylabel('Frequency', fontsize=12)
axes[1, 0].set_title('Full Conversation Length Distribution (Prompt + Response)', fontsize=14, fontweight='bold')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# 4. Boxplot Comparison
box_data = [prompt_lengths, chosen_response_lengths, rejected_response_lengths]
bp = axes[1, 1].boxplot(box_data, labels=['Prompt', 'Chosen\nResponse', 'Rejected\nResponse'], 
                         patch_artist=True, showmeans=True)
colors = ['skyblue', 'lightgreen', 'lightcoral']
for patch, color in zip(bp['boxes'], colors):
    patch.set_facecolor(color)
axes[1, 1].set_ylabel('Token Length', fontsize=12)
axes[1, 1].set_title('Length Comparison Boxplot', fontsize=14, fontweight='bold')
axes[1, 1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

# Show percentage of samples in different length ranges
print("\n[Full Conversation Length Distribution]")
print("-" * 80)
ranges = [(0, 128), (128, 256), (256, 512), (512, 1024), (1024, 2048), (2048, float('inf'))]
for start, end in ranges:
    if end == float('inf'):
        count = np.sum(all_full_lengths > start)
        print(f"> {start} tokens: {count} ({count/len(all_full_lengths)*100:.2f}%)")
    else:
        count = np.sum((all_full_lengths > start) & (all_full_lengths <= end))
        print(f"{start}-{end} tokens: {count} ({count/len(all_full_lengths)*100:.2f}%)")