# Day 17: RLHF Components - Part 2

This notebook continues our exploration of RLHF components. In this part, we'll focus on:

1. Training the reward model
2. Implementing a simplified PPO algorithm
3. Optimizing our policy model using PPO
4. Evaluating the results

In [None]:
# Import necessary libraries (if running this notebook independently)
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random
import matplotlib.pyplot as plt
from tqdm import tqdm
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer
)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## 5. Training the Reward Model

Now we'll train our reward model to predict human preferences. The model should assign higher rewards to preferred responses.

In [None]:
# Define the reward model training loop
def train_reward_model(model, train_dataset, val_dataset, epochs=3, batch_size=4, lr=1e-5):
    """Train the reward model on preference data."""
    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    
    # Create data loaders
    train_loader = torch.utils.data.DataLoader(
        train_dataset, 
        batch_size=batch_size, 
        shuffle=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_dataset, 
        batch_size=batch_size
    )
    
    # Training loop
    train_losses = []
    val_accuracies = []
    
    for epoch in range(epochs):
        # Training
        model.train()
        epoch_loss = 0
        
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            # Get batch data
            chosen_input_ids = batch["chosen_input_ids"].to(device)
            chosen_attention_mask = batch["chosen_attention_mask"].to(device)
            rejected_input_ids = batch["rejected_input_ids"].to(device)
            rejected_attention_mask = batch["rejected_attention_mask"].to(device)
            
            # Forward pass for chosen responses
            chosen_rewards = model(chosen_input_ids, chosen_attention_mask)
            
            # Forward pass for rejected responses
            rejected_rewards = model(rejected_input_ids, rejected_attention_mask)
            
            # Compute loss (chosen should have higher reward than rejected)
            loss = -torch.log(torch.sigmoid(chosen_rewards - rejected_rewards)).mean()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            epoch_loss += loss.item()
        
        avg_loss = epoch_loss / len(train_loader)
        train_losses.append(avg_loss)
        
        # Validation
        model.eval()
        correct = 0
        total = 0
        
        with torch.no_grad():
            for batch in val_loader:
                # Get batch data
                chosen_input_ids = batch["chosen_input_ids"].to(device)
                chosen_attention_mask = batch["chosen_attention_mask"].to(device)
                rejected_input_ids = batch["rejected_input_ids"].to(device)
                rejected_attention_mask = batch["rejected_attention_mask"].to(device)
                
                # Forward pass
                chosen_rewards = model(chosen_input_ids, chosen_attention_mask)
                rejected_rewards = model(rejected_input_ids, rejected_attention_mask)
                
                # Count correct predictions (chosen should have higher reward)
                correct += (chosen_rewards > rejected_rewards).sum().item()
                total += chosen_rewards.size(0)
        
        val_accuracy = correct / total
        val_accuracies.append(val_accuracy)
        
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
    
    return train_losses, val_accuracies

# Train the reward model
# Note: In a real implementation, you would train for more epochs and with more data
train_losses, val_accuracies = train_reward_model(
    reward_model, 
    tokenized_train_dataset, 
    tokenized_val_dataset,
    epochs=2,  # Using a small number for demonstration
    batch_size=2,
    lr=1e-5
)

In [None]:
# Plot training results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.title('Training Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')

plt.subplot(1, 2, 2)
plt.plot(val_accuracies)
plt.title('Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')

plt.tight_layout()
plt.show()

## 6. Testing the Reward Model

Let's test our trained reward model on some examples to see if it can predict human preferences.

In [None]:
def get_reward(model, prompt, response):
    """Get reward for a prompt-response pair."""
    input_text = f"Prompt: {prompt}\nResponse: {response}"
    inputs = reward_tokenizer(
        input_text,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt"
    ).to(device)
    
    with torch.no_grad():
        reward = model(inputs["input_ids"], inputs["attention_mask"]).item()
    
    return reward

# Test the reward model on some examples
test_examples = [
    {
        "prompt": "Explain the concept of machine learning.",
        "good_response": "Machine learning is a branch of artificial intelligence that enables computers to learn from data and improve their performance on a task without being explicitly programmed. It works by identifying patterns in data and using those patterns to make predictions or decisions.",
        "bad_response": "Machine learning is when computers do stuff with data."
    },
    {
        "prompt": "Write a poem about the moon.",
        "good_response": "Silver orb in midnight sky,\nSilent guardian watching high.\nCasting light on dreams below,\nAncient wisdom you bestow.\nCraters tell your timeless tale,\nAs you wax and as you wane.",
        "bad_response": "Moon is big. Moon is bright. Moon is in the sky at night."
    }
]

# Evaluate the reward model
reward_model.eval()
for example in test_examples:
    prompt = example["prompt"]
    good_response = example["good_response"]
    bad_response = example["bad_response"]
    
    good_reward = get_reward(reward_model, prompt, good_response)
    bad_reward = get_reward(reward_model, prompt, bad_response)
    
    print(f"Prompt: {prompt}")
    print(f"Good response reward: {good_reward:.4f}")
    print(f"Bad response reward: {bad_reward:.4f}")
    print(f"Correctly identified better response: {good_reward > bad_reward}")
    print("-" * 50)

## 7. Implementing a Simplified PPO Algorithm

Now we'll implement a simplified version of the PPO algorithm for language models. In a real RLHF pipeline, this would be much more complex, but this simplified version will help us understand the key components.

In [None]:
class SimplePPO:
    """Simplified PPO implementation for language models."""
    
    def __init__(self, policy_model, ref_model, reward_model, tokenizer, device):
        self.policy_model = policy_model
        self.ref_model = ref_model
        self.reward_model = reward_model
        self.tokenizer = tokenizer
        self.device = device
        
        # PPO hyperparameters
        self.kl_coef = 0.1  # KL penalty coefficient
        self.clip_param = 0.2  # PPO clip parameter
        
    def generate_responses(self, prompts, max_length=100):
        """Generate responses from the policy model."""
        responses = []
        log_probs = []
        
        for prompt in prompts:
            inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
            input_ids = inputs["input_ids"]
            
            # Generate from policy model
            with torch.no_grad():
                output = self.policy_model.generate(
                    input_ids,
                    max_length=max_length,
                    do_sample=True,
                    temperature=0.7,
                    top_p=0.9,
                    pad_token_id=self.tokenizer.pad_token_id,
                    return_dict_in_generate=True,
                    output_scores=True
                )
            
            # Extract generated tokens and scores
            generated_ids = output.sequences[0]
            scores = output.scores
            
            # Compute log probabilities (simplified)
            response_log_prob = 0.0
            for i, score in enumerate(scores):
                token_id = generated_ids[input_ids.size(1) + i].item()
                token_log_prob = F.log_softmax(score[0], dim=-1)[token_id].item()
                response_log_prob += token_log_prob
            
            # Decode response (remove prompt)
            response = self.tokenizer.decode(generated_ids[input_ids.size(1):], skip_special_tokens=True)
            
            responses.append(response)
            log_probs.append(response_log_prob)
        
        return responses, log_probs
    
    def compute_rewards(self, prompts, responses):
        """Compute rewards for prompt-response pairs."""
        rewards = []
        
        for prompt, response in zip(prompts, responses):
            reward = get_reward(self.reward_model, prompt, response)
            rewards.append(reward)
        
        return rewards
    
    def compute_kl_penalty(self, prompts, responses):
        """Compute KL divergence between policy and reference model (simplified)."""
        kl_penalties = []
        
        for prompt, response in zip(prompts, responses):
            # This is a very simplified KL calculation
            # In a real implementation, you would compute token-by-token KL
            input_text = prompt + response
            inputs = self.tokenizer(input_text, return_tensors="pt").to(self.device)
            
            with torch.no_grad():
                # Get logits from both models
                policy_outputs = self.policy_model(inputs["input_ids"])
                ref_outputs = self.ref_model(inputs["input_ids"])
                
                policy_logits = policy_outputs.logits
                ref_logits = ref_outputs.logits
                
                # Compute KL divergence (simplified)
                policy_probs = F.softmax(policy_logits, dim=-1)
                ref_probs = F.softmax(ref_logits, dim=-1)
                
                kl = F.kl_div(
                    F.log_softmax(policy_logits, dim=-1),
                    ref_probs,
                    reduction="batchmean"
                )
                
                kl_penalties.append(kl.item())
        
        return kl_penalties
    
    def ppo_step(self, prompts, batch_size=4, lr=1e-6):
        """Perform one PPO update step."""
        # Set policy model to training mode
        self.policy_model.train()
        
        # Create optimizer
        optimizer = torch.optim.AdamW(self.policy_model.parameters(), lr=lr)
        
        # Generate responses and compute rewards
        responses, old_log_probs = self.generate_responses(prompts)
        rewards = self.compute_rewards(prompts, responses)
        kl_penalties = self.compute_kl_penalty(prompts, responses)
        
        # Compute advantages (reward - KL penalty)
        advantages = [r - self.kl_coef * kl for r, kl in zip(rewards, kl_penalties)]
        
        # PPO update (simplified)
        for i in range(0, len(prompts), batch_size):
            batch_prompts = prompts[i:i+batch_size]
            batch_responses = responses[i:i+batch_size]
            batch_old_log_probs = old_log_probs[i:i+batch_size]
            batch_advantages = advantages[i:i+batch_size]
            
            # This is a simplified PPO update
            # In a real implementation, you would compute token-by-token probabilities
            # and apply the PPO objective properly
            
            # For demonstration purposes, we'll just print the values
            for j in range(len(batch_prompts)):
                print(f"Prompt: {batch_prompts[j]}")
                print(f"Response: {batch_responses[j]}")
                print(f"Reward: {rewards[i+j]:.4f}")
                print(f"KL Penalty: {kl_penalties[i+j]:.4f}")
                print(f"Advantage: {batch_advantages[j]:.4f}")
                print("-" * 50)
        
        # In a real implementation, you would:
        # 1. Compute new log probs for the same responses
        # 2. Compute ratio = exp(new_log_probs - old_log_probs)
        # 3. Compute clipped objective: min(ratio * advantage, clip(ratio, 1-ε, 1+ε) * advantage)
        # 4. Optimize this objective
        
        # For demonstration, we'll just return the average reward
        return np.mean(rewards)

## 8. Running the PPO Algorithm

Now let's run our simplified PPO algorithm to optimize the policy model.

In [None]:
# Create a copy of the SFT model to serve as the reference model
ref_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
ref_model.eval()  # Reference model is frozen

# Initialize PPO
ppo = SimplePPO(
    policy_model=sft_model,
    ref_model=ref_model,
    reward_model=reward_model,
    tokenizer=tokenizer,
    device=device
)

# Sample prompts for PPO training
ppo_prompts = [
    "Explain the concept of reinforcement learning.",
    "Write a short story about a robot learning to feel emotions."
]

# Run one PPO step (in a real implementation, you would run many steps)
avg_reward = ppo.ppo_step(ppo_prompts)

## 9. Comparing SFT and RLHF Models

Finally, let's compare the outputs of our original SFT model and the RLHF-optimized model.

In [None]:
def generate_response(model, prompt, max_length=100):
    """Generate a response from the model given a prompt."""
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    
    with torch.no_grad():
        output = model.generate(
            inputs["input_ids"],
            max_length=max_length,
            do_sample=True,
            temperature=0.7,
            top_p=0.9,
            pad_token_id=tokenizer.pad_token_id
        )
    
    response = tokenizer.decode(output[0], skip_special_tokens=True)
    response = response[len(prompt):].strip()
    
    return response

# Test prompts
test_prompts = [
    "Explain quantum computing to a high school student.",
    "Write a poem about artificial intelligence.",
    "What are three ways to improve productivity?"
]

# Compare SFT and RLHF models
for prompt in test_prompts:
    print(f"Prompt: {prompt}\n")
    
    # Generate from SFT model (reference model)
    ref_response = generate_response(ref_model, prompt)
    print(f"SFT Model Response:\n{ref_response}\n")
    
    # Generate from RLHF model (policy model)
    policy_response = generate_response(sft_model, prompt)
    print(f"RLHF Model Response:\n{policy_response}\n")
    
    # Get rewards
    ref_reward = get_reward(reward_model, prompt, ref_response)
    policy_reward = get_reward(reward_model, prompt, policy_response)
    
    print(f"SFT Model Reward: {ref_reward:.4f}")
    print(f"RLHF Model Reward: {policy_reward:.4f}")
    print(f"Improvement: {policy_reward - ref_reward:.4f}")
    print("-" * 80)

## 10. Conclusion

In this notebook, we've implemented simplified versions of the key components of the RLHF pipeline:

1. **Reward Model Training**: We trained a model to predict human preferences between pairs of responses.
2. **PPO Algorithm**: We implemented a simplified version of PPO for language models.
3. **Policy Optimization**: We used PPO to optimize our policy model based on the reward model's feedback.

In a real RLHF implementation, each of these components would be much more complex and would require more data, compute resources, and careful tuning. However, this simplified implementation helps us understand the core concepts and how they fit together.

Key takeaways:
- RLHF builds on SFT by optimizing for human preferences rather than just imitating examples
- The reward model is crucial for guiding the policy optimization
- PPO helps balance between maximizing reward and staying close to the original model
- RLHF can lead to more helpful, harmless, and honest responses compared to SFT alone

Next steps:
- Explore more efficient alternatives to RLHF, such as Direct Preference Optimization (DPO)
- Learn about safety considerations and evaluation methods for aligned models