# RLHF From Scratch on LLMs

In this notebook, I will start with history of RLHF, the importance of RLHF in LLMs, then go into the architectures TRPO, PPO, GRPO and DPO. Each of the technique's explanation will have the math, code and explanations on how it's done, finally in the end we'll experiment these techniques on one of prebuilt LLMs (the llms are not built from scratch, since i've already done that in my [llm-from-scratch repository](https://github.com/ashworks1706/llm-from-scratch)). If you're new to RL, check out [dqn-from-scratch](https://github.com/ashworks1706/dqn-from-scratch) where i've explained RL indepth from the core

## Brief History of RLHF



RLHF emerged around 2017-2018 when researchers at OpenAI developed techniques to incorporate human preferences into reinforcement learning systems. The seminal paper "Deep Reinforcement Learning from Human Preferences" by Christiano et al. (2017) introduced the core concept of using human comparisons between pairs of outputs to train a reward model that could guide RL agents toward preferred behaviors. While initially applied to simpler tasks and robotics, the technique remained relatively specialized until recent years. The technique gained mainstream attention in 2022 when OpenAI used it to create ChatGPT from GPT-3.5, dramatically improving output quality by aligning the model with human preferences. This breakthrough demonstrated RLHF's potential to transform raw language model capabilities into systems that better align with human intent and values. Since then, RLHF has become a standard component in developing advanced language models like GPT-4, Claude, and Llama 2, with each iteration refining the techniques to achieve better alignment.

#### Why RLHF Matters for LLMs

<img src="assets/rlhf-vs-finetune.png" width=300>

Large Language Models trained solely on next-token prediction are just models with knowledge, they don't know how to answer properly. You have model trained on shakespear work, great! but how do you make it to answer questions in the way we want (in the way humans talk)? These models optimize for predicting the next token based on training data distribution, which doesn't necessarily correlate with producing helpful, harmless, or honest responses. Traditional LLMs may generate toxic, harmful, or misleading content because they're simply trying to produce statistically likely continuations without understanding human values or preferences. They lack an inherent mechanism to distinguish between content that is statistically probable and content that is actually desirable according to human standards. RLHF addresses these issues by creating a feedback loop where human preferences explicitly guide the model's learning process, steering it toward outputs that humans find more helpful, honest, and aligned with their intent. This alignment process transforms a powerful but directionless prediction engine into a system that can better understand and respect nuanced human values and follow complex instructions in ways that maximize utility for users.




## The Birds Eye View

<img src="assets/workflow.png" >


Before delving into the complexities of RLHF, it's essential to understand the overal workflow of what actually happens in a typical RLHF based projects. When a large language model is initially trained on vast internet text corpora, it develops remarkable capabilities to predict text and acquire factual knowledge, but this training alone doesn't prepare it to be helpful in specific contexts or respond appropriately to human instructions. Consider a language model trained on extensive educational materials, including university canvas modules, academic papers, and textbooks. This model would possess substantial knowledge about various academic subjects, pedagogical approaches, and educational concepts. However, if asked to "Explain the concept of photosynthesis to a 10-year-old," it might produce a technically accurate but overly complex explanation filled with academic jargon that would confuse rather than enlighten a young student. The model hasn't been optimized to serve as an effective tutor - it simply predicts what text might follow in educational materials. 

The Supervised Fine-Tuning stage addresses this gap by training the model on demonstrations of desired behavior. For our hypothetical educational assistant, SFT would involve collecting thousands of examples showing how skilled human tutors respond to student questions: simplifying complex concepts, using age-appropriate language, providing relevant examples, checking for understanding, and offering encouragement. These demonstrations are formatted as input-output pairs (prompt and ideal response), and the model is fine-tuned to minimize the difference between its outputs and these human-generated "gold standard" responses. Through this process, the model learns the patterns that characterize helpful tutoring: breaking down complex concepts into simpler components, using analogies relevant to younger audiences, avoiding unnecessary technical terms, and adopting a supportive tone. After SFT, when asked to explain photosynthesis to a 10-year-old, the model is much more likely to respond with an explanation involving plants "eating sunlight" and "breathing in carbon dioxide to make food," rather than discussing electron transport chains and ATP synthesis. The model hasn't gained new knowledge, but it has learned a new way to present its existing knowledge that better aligns with the specific goal of being an effective tutor for younger students. However, SFT alone has significant limitations. First, it can only learn from the specific examples it's shown, leaving gaps in how to handle the infinite variety of possible user requests. Second, the demonstrations might not cover the full range of desirable behaviors or edge cases where special handling is needed. Third, the quality of the SFT model depends entirely on the quality and consistency of the demonstration data. Finally, there's no mechanism for the model to understand why certain responses are better than others - it simply learns to mimic patterns without a deeper understanding of the preferences that make one response superior. These limitations are precisely what RLHF is designed to address in the subsequent stages of the alignment process.

Following Supervised Fine-Tuning, the RLHF workflow progresses to Human Preference Collection - a crucial stage that fundamentally changes how model improvement occurs. In this phase, rather than providing gold-standard demonstrations, human evaluators compare and rank different model responses to the same prompt. For our educational assistant, this might involve presenting evaluators with pairs of explanations for the same scientific concept and asking them which better achieves the goal of teaching a young student. One explanation might be more engaging and use more appropriate analogies, while another might be technically accurate but still too complex. By explicitly choosing the better response, humans provide preference signals that capture nuanced quality distinctions beyond what demonstration data alone can convey. These comparisons generate valuable datasets where each entry contains a prompt and two responses, with a label indicating which response humans preferred. The collection process typically gathers thousands or even millions of such comparative judgments, creating a rich dataset that embodies human preferences about what constitutes a high-quality response across diverse scenarios.

The third stage, Reward Model Training, transforms these human preferences into a quantifiable reward function that can guide further optimization. This reward model takes a prompt and response as input and outputs a scalar score representing how well the response aligns with human preferences. Technically, it's trained to predict which of two responses humans would prefer by maximizing the likelihood of the observed preference data. For our educational tutor, the reward model learns to assign higher scores to explanations that successfully simplify complex concepts without sacrificing accuracy, use age-appropriate analogies, maintain an encouraging tone, and check for understanding. This model becomes a computational proxy for human judgment, capable of evaluating millions of potential responses far beyond what human evaluators could manually assess. The quality of this reward model is critical, as it effectively defines what "good" means for all subsequent optimization.

With a trained reward model in place, the final stage applies Reinforcement Learning techniques to optimize the language model toward maximizing the predicted reward. The most common approach is Proximal Policy Optimization (PPO), which iteratively improves the model by adjusting its parameters to generate responses that receive higher reward scores. However, simply maximizing reward can lead to degenerate outputs that exploit loopholes in the reward model or diverge too far from natural language patterns. To prevent this, the optimization includes a "KL divergence" penalty that constrains how much the optimized model can deviate from the SFT model, preserving fluency and knowledge while improving alignment. For our educational tutor, this process might result in a model that maintains scientific accuracy while consistently finding creative, age-appropriate analogies and explanations across a much broader range of topics than were covered in the original demonstration data. The entire RLHF pipeline is often iterative, with new preference data collected from the improved model, leading to refined reward models and further optimization cycles. This continuous feedback loop progressively aligns the language model with human values and preferences, addressing the fundamental limitations of training on prediction alone or even on demonstration data without comparative preference signals.


## 1. Getting a pretrained LLM

<img src="assets/workflow1.png">


Now the first step is to have a fresh pretrained LLM right off the top. We'll be using huggingface library transformers library for our transformer components

In [None]:
# %pip install transformers peft datasets tqdm wandb rouge-score
# PEFT is a technique to fine tune LLMs without modifying all of their parameters. it's efficient for our tutorial.

In [None]:
# Import from library and setup the model class 
 
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel, PeftConfig
import os

class PretrainedLLM:
    def __init__(self, model_name="facebook/opt-350m", device=None):
        """
        Initializing a (No SFT RLHF) pretrained language model for RLHF experiment
        
        Args:
            model_name: HuggingFace model identifier (default: OPT-350M, a relatively small but capable model)
            device: Computing device (will auto-detect if None)
        """
        self.model_name = model_name
        
        # this code is just for detecting if you have Nvidia CUDA driver or not
        if device is None:
            self.device = "cuda" if torch.cuda.is_available() else "cpu"
        else:
            self.device = device
            
        print(f"Loading {model_name} on {self.device}...")
        
        # Load model and tokenizer (for full guide on implementing llm from scratch check out https://github.com/ashworks1706/llm-from-scratch
        self.tokenizer = AutoTokenizer.from_pretrained(model_name) 
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name, 
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32,
            low_cpu_mem_usage=True
        )
        
        # distributed training for better GPU utilization
        self.model.to(self.device)
        print(f"Model loaded successfully with {sum(p.numel() for p in self.model.parameters())/1e6:.1f}M parameters")
        
    def generate(self, prompt, max_new_tokens=100, temperature=0.7, top_p=0.9):
        """
        Generate text from the model given a prompt (no RLHF)
        
        Args:
            prompt: Input text to generate from
            max_new_tokens: Maximum number of tokens to generate
            temperature: Sampling temperature (lower = more deterministic)
            top_p: Nucleus sampling parameter (lower = more focused)
            
        Returns:
            Generated text as string
        """
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)
        
        # Generate with sampling, no fancy tuning required
        with torch.no_grad():
            outputs = self.model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                top_p=top_p,
                do_sample=True
            )
        
        # Decode and remove the prompt from the generated text
        full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        generated_text = full_text[len(self.tokenizer.decode(inputs.input_ids[0], skip_special_tokens=True)):]
        
        return generated_text
    
    def save_checkpoint(self, path):
        """Save model checkpoint to the specified path"""
        self.model.save_pretrained(path)
        self.tokenizer.save_pretrained(path)
        print(f"Model saved to {path}")
        
    def load_adapter(self, adapter_path):
        """Load a PEFT adapter for efficient fine-tuning"""
        self.model = PeftModel.from_pretrained(
            self.model,
            adapter_path,
            torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
        )
        print(f"Loaded adapter from {adapter_path}")



In [None]:
# Example usage
if __name__ == "__main__":
    # Initialize with a small model for experimentation
    llm = PretrainedLLM(model_name="facebook/opt-350m")
    
    # Test generation
    prompt = "Explain quantum computing to a 10-year-old:"
    print(f"(No SFT RLHF) Prompt: {prompt}")
    print(f"(NO SFT RLHF) Response: {llm.generate(prompt, max_new_tokens=150)}")

---
(No SFT RLHF) *Prompt*: Explain quantum computing to a 10-year-old:

(No SFT RLHF) *Response*: Quantum computing is a type of computation that harnesses the collective properties of quantum states, such as superposition, interference, and entanglement, to perform calculations. The basic unit of quantum computation is the quantum bit or qubit. Unlike classical bits, which can only be in states of 0 or 1, qubits can exist in a superposition of states, meaning they can represent both 0 and 1 simultaneously. This property allows quantum computers to process a vast number of possibilities simultaneously.

The field emerged from quantum mechanics, which is a fundamental theory in physics that provides a description of the physical properties of nature at the atomic and subatomic scales. Quantum computing leverages quantum mechanical phenomena to perform operations on data. This can potentially solve certain problems much faster than classical computers.

For instance, Shor's algorithm demonstrates that a quantum computer can factor large integers exponentially faster than the best known algorithm running on a classical computer. This has significant implications for cryptography systems that rely on the difficulty of factoring large numbers.

---

In this example, the user anticipates the model to give explanation releavant to a 10 year old. Instead, the model produces a very technical response with an academic tone without any examples. A better response could have been (with SFT):

---

(No RLHF) *Prompt*: Explain quantum computing to a 10-year-old:

(No RLHF) *Response*: Imagine your regular computer is like playing a board game where you can only be on one space at a time. You have to move step by step to get to the end.

Quantum computers are like magical game pieces that can be in lots of different spaces at the same time! They use special bits called "qubits" instead of regular computer bits.

Regular computer bits are like light switches - they can only be ON or OFF. But qubits are special because they can be ON and OFF at the same time, kind of like how a spinning coin is neither heads nor tails until it stops.

This special power lets quantum computers solve some really tricky puzzles much faster than regular computers. Scientists are excited about them because they might help us discover new medicines, understand weather better, and solve problems that would take regular computers thousands of years!

The quantum computers we have today are still small and make mistakes, but engineers are working hard to make them bigger and better.

---


Here as you can see that the SFT model has learned to use simple language and analogies, breakdown complex concepts, avoid technical jargon like superposition and slightly adatpive to the age level of the audience

This is our goal response with SFT and RLHF :

---

(SFT RLHF) *Prompt*: Explain quantum computing to a 10-year-old:

(SFT RLHF) *Response*: Hey there! Imagine you have a super cool magic coin. When you flip a normal coin, it lands on either heads or tails, right? But this magic coin can spin and be a little bit heads and a little bit tails at the same time while it's spinning!

Quantum computers use special things called "qubits" that work like these magic coins. While normal computers use regular bits that can only be 0 OR 1 (like a coin that's only heads or tails), quantum computers use qubits that can be 0 AND 1 at the same time (like our spinning magic coin).

Why is this so awesome? Well, imagine you lost your favorite toy somewhere in your house. A normal computer would have to look in each room one by one. But a quantum computer is like having a bunch of friends who can check all the rooms at the exact same time! That makes solving certain problems MUCH faster.

Scientists are still learning how to build better quantum computers because they're tricky to make. The qubits are super sensitive - kind of like trying to keep a hundred coins spinning without letting any of them fall over!

Does that make sense? What do you think would be the coolest thing to use a quantum computer for?

---


in this example, we can clearly see the difference, it's more conversational, user friendly and uses analogies to make explanation easier like a human

## 2. Supervised Fine Tuning

<img src="assets/workflow2.png" >


### Understanding the difference between SFT and RLHF



But what if I just keep model with SFT and not RLHF? Or what if I just skip to RLHF instead of SFT?

These are excellent questions that get to the heart of why the complete RLHF pipeline exists. Supervised Fine-Tuning (SFT) and Reinforcement Learning from Human Feedback (RLHF) serve different but complementary roles in aligning language models with human expectations and preferences.

If you only implement SFT without RLHF, you'll have a model that can follow basic patterns demonstrated in your training examples, but it will struggle to generalize beyond them. As we saw in our quantum computing example, SFT can teach a model to use simpler language and appropriate analogies, but it's limited by the specific demonstrations provided. The model learns to mimic patterns without developing a deeper understanding of why certain responses are better than others. When faced with novel queries or edge cases not covered in the training data, an SFT-only model often fails to maintain the same quality of responses. Additionally, SFT can only optimize for whatever patterns exist in your demonstration data - if that data contains subtle biases or inconsistencies, those will be faithfully reproduced by the model.

Conversely, if you attempt to skip SFT and go directly to RLHF, you're likely to encounter significant challenges. RLHF works by refining an already somewhat aligned model through preference optimization. Starting with a raw pretrained model would make this process extremely inefficient and potentially unstable. The preference learning and reinforcement stages need a reasonable starting point where the model can already produce somewhat appropriate responses that humans can meaningfully compare and rank. Without SFT, the initial responses might be so far from helpful that the preference signals become too noisy or the optimization process becomes prohibitively difficult. It would be like trying to teach advanced painting techniques to someone who hasn't yet learned to hold a brush - the feedback would be overwhelming and difficult to incorporate.

The full RLHF pipeline with SFT followed by preference learning and reinforcement creates a progressively refined alignment. SFT provides the foundation by teaching the model basic response patterns and formats through demonstration. RLHF then builds on this foundation by teaching the model to distinguish between good and better responses through comparative feedback, allowing it to generalize beyond specific examples to broader human preferences. As we observed in our examples, the SFT model improved basic comprehensibility and appropriateness, while the RLHF model further enhanced engagement, conversational tone, and subtle aspects of helpfulness that are difficult to capture through demonstrations alone. This complementary relationship explains why major AI systems like ChatGPT and Claude use both techniques in sequence rather than choosing one over the other. The complete alignment process transforms raw predictive power into carefully balanced helpful assistance that respects complex human values and preferences.

### Components of SFT

To perform Supervised Fine-Tuning (SFT) on our pretrained LLM, we need high-quality demonstration data consisting of prompt-response pairs showing the desired behavior, typically thousands of examples created by experts. We also need a data preprocessing pipeline to format this data consistently, including tokenization and special tokens to distinguish between prompts and responses. SFT requires careful configuration of hyperparameters like learning rate, batch size, and optimization methods, with techniques such as warmup and decay schedules for training stability. Rather than fine-tuning all parameters, we'll use PEFT methods like LoRA that add small trainable modules while keeping most of the model frozen, making training more efficient. We'll implement a training loop for forward passes, loss calculation, backpropagation, and parameter updates, along with evaluation metrics such as perplexity and ROUGE scores to assess performance. Finally, our existing PretrainedLLM class already supports checkpointing and adapter saving, which we'll use to periodically save the model state during training. This SFT process will transform our raw model into one that can follow instructions and communicate appropriately, serving as the foundation for subsequent RLHF stages.

High-quality demonstration data: Thousands of prompt-response pairs created by experts

Data preprocessing pipeline: Consistent formatting, tokenization, and special tokens

Fine-tuning configuration: Learning rate, batch size, warmup/decay schedules

PEFT implementation: Using LoRA to add trainable modules while freezing most parameters

Training loop: Forward passes, loss calculation, backpropagation, parameter updates

Evaluation metrics: Perplexity, ROUGE scores to assess performance

Checkpointing: Saving model state during training using our existing functionality


For dataset, we're using the Databricks Dolly-15k dataset, which is a high-quality instruction-following dataset specifically designed for fine-tuning language models. This dataset contains 15,000 human-generated prompt/response pairs across various instruction categories including creative writing, classification, information extraction, open QA, brainstorming, and summarization. 

The Dolly dataset was created by Databricks employees who manually wrote both the prompts and high-quality responses, making it particularly valuable for instruction-tuning. Unlike some other datasets which may be generated or filtered from existing sources, Dolly's samples are purpose-built for teaching models to follow instructions in a helpful manner.

In [None]:
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer, 
    TrainingArguments, 
    Trainer,
    DataCollatorForLanguageModeling
)
from peft import (
    get_peft_model,
    LoraConfig,
    TaskType,
    prepare_model_for_kbit_training
)
import os
import numpy as np
from tqdm import tqdm

class SupervisedFineTuner:
    def __init__(self, base_model, dataset_name="databricks/dolly-15k", max_seq_length=512):
        """
        Initializing SFT framework
        
        Args:
            base_model: The PretrainedLLM instance to fine-tune
            dataset_name: HuggingFace dataset identifier containing instruction/response pairs
            max_seq_length: Maximum sequence length for inputs
        """
        self.llm = base_model
        self.tokenizer = base_model.tokenizer
        self.model = base_model.model
        self.device = base_model.device
        self.max_seq_length = max_seq_length # max length of sequences that model will process
        self.dataset_name = dataset_name
        
        # If tokenizer doesn't have padding token, set it to eos token
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
            
        print(f"Loading dataset {dataset_name}...")
        self.raw_dataset = load_dataset(dataset_name)
        print(f"Dataset loaded with {len(self.raw_dataset['train'])} training examples")
        
    def prepare_data(self):
        """Process the dataset into the format needed for instruction fine-tuning"""
        
        def format_instruction(example):
            """Format an example into a prompt-response pair with special tokens"""
            # Different datasets have different column names, they might call different call different labels
            if 'instruction' in example and 'response' in example:
                prompt = example['instruction']
                response = example['response']
            elif 'prompt' in example and 'completion' in example:
                prompt = example['prompt']
                response = example['completion']
            else:
                # Fallback for other dataset formats
                prompt = str(example['input']) if 'input' in example else ""
                response = str(example['output']) if 'output' in example else ""
            
            # Format with special tokens
            formatted_text = f"User: {prompt.strip()}\n\nAssistant: {response.strip()}"
            return {"formatted_text": formatted_text}
        
        print("Formatting dataset...")
        self.processed_dataset = self.raw_dataset.map(format_instruction)
        
        def tokenize_function(examples):
            """Tokenize the examples and prepare for training"""
            texts = examples["formatted_text"]
            
            # Tokenize with padding and truncation
            tokenized = self.tokenizer(
                texts,
                padding="max_length",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
            
            # Create labels (for causal LM, labels are the same as input_ids)
            tokenized["labels"] = tokenized["input_ids"].clone()
            
            # Mask padding tokens in the labels to -100 so they're not included in loss
            tokenized["labels"][tokenized["input_ids"] == self.tokenizer.pad_token_id] = -100
            
            return tokenized
        
        print("Tokenizing dataset...")
        self.tokenized_dataset = self.processed_dataset.map(
            tokenize_function,
            batched=True,
            remove_columns=self.processed_dataset["train"].column_names
        )
        
        return self.tokenized_dataset
    
    def setup_peft(self, r=16, lora_alpha=32, lora_dropout=0.05):
        """Set up Parameter-Efficient Fine-Tuning using LoRA"""
        
        print("Setting up LoRA for efficient fine-tuning...")
        # Configure LoRA
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM,
            r=r,  # Rank of the update matrices
            lora_alpha=lora_alpha,  # Scaling factor
            lora_dropout=lora_dropout,
            target_modules=["q_proj", "v_proj"],  # Which modules to apply LoRA to
            bias="none",
            inference_mode=False
        )
        
        # Prepare model for training
        self.model = prepare_model_for_kbit_training(self.model)
        
        # Apply LoRA
        self.model = get_peft_model(self.model, peft_config)
        
        # Display trainable parameters
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%} of total)")
        
        return self.model
    
    def train(self, output_dir="sft_model", num_epochs=3, batch_size=8, learning_rate=2e-5):
        """Train the model using the prepared dataset"""
        
        print("Setting up training arguments...")
        # Training arguments
        training_args = TrainingArguments(
            output_dir=output_dir,                           # Directory to save model checkpoints
            num_train_epochs=num_epochs,                     # Number of times to iterate through the dataset
            per_device_train_batch_size=batch_size,          # Batch size per GPU/CPU for training
            gradient_accumulation_steps=4,                   # Number of updates steps to accumulate gradients for
            warmup_ratio=0.1,                               # Percentage of steps for learning rate warmup
            weight_decay=0.01,                              # L2 regularization weight
            learning_rate=learning_rate,                     # Initial learning rate
            logging_steps=10,                               # How often to log training metrics
            save_steps=200,                                 # How often to save model checkpoints
            save_total_limit=3,                             # Maximum number of checkpoints to keep
            fp16=True if self.device == "cuda" else False,   # Whether to use 16-bit floating point precision
            report_to="none"                                # Disable external reporting services
        )
        
        # Create data collator
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=self.tokenizer,
            mlm=False  # We're doing causal LM, not masked LM
        )
        
        print("Creating trainer...")
        # Initialize the trainer
        trainer = Trainer(
            model=self.model,
            args=training_args,
            train_dataset=self.tokenized_dataset["train"],
            eval_dataset=self.tokenized_dataset["test"] if "test" in self.tokenized_dataset else None,
            data_collator=data_collator
        )
        
        print("Starting training...")
        # Finally Train the model
        trainer.train()
        
        # Save the adapter
        adapter_path = os.path.join(output_dir, "adapter")
        self.model.save_pretrained(adapter_path) # save the model to the path
        self.tokenizer.save_pretrained(adapter_path)
        print(f"Saved LoRA adapter to {adapter_path}")
        
        return adapter_path
    
    def evaluate(self, evaluation_prompts):
        """Evaluate the model on a list of prompts"""
        print("Evaluating model...")
        
        for prompt in evaluation_prompts:
            print(f"Prompt: {prompt}")
            
            # Generate with the original model
            base_response = self.llm.generate(prompt, max_new_tokens=200)
            print(f"Base model response: {base_response}\n")
            
            # Format prompt for the fine-tuned model
            formatted_prompt = f"User: {prompt}\n\nAssistant: "
            inputs = self.tokenizer(formatted_prompt, return_tensors="pt").to(self.device)
            
            # Generate with the fine-tuned model
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=200,
                    temperature=0.7,
                    top_p=0.9,
                    do_sample=True
                )
                
            # Decode and display
            full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            sft_response = full_text[len(formatted_prompt):]
            print(f"SFT model response: {sft_response}\n")
            print("-" * 50)



As we can see, we have first a base model (our pretrained LLM), a dataset name, and a maximum sequence length parameter that controls how much text the model processes at once. It handles tokenizer configuration by ensuring the padding token is properly set, which is crucial for consistent batch processing during training. Next, the prepare_data method loads the Dolly-15k dataset and transforms each example by extracting the prompt-response pairs and formatting them with special tokens that help the model distinguish between user input and expected output. The formatting includes adding "User:" and "Assistant:" prefixes that teach the model the proper conversation structure. After formatting, it tokenizes all examples, handling padding and truncation to ensure consistent lengths, and sets up special label handling where padding tokens are masked from loss calculations.

The setup_peft method is particularly innovative, implementing Parameter-Efficient Fine-Tuning using Low-Rank Adaptation (LoRA). Rather than updating all model weights—which would be computationally expensive—LoRA adds small trainable matrices to key attention components while keeping most parameters frozen. The method configures LoRA with appropriate rank and scaling parameters, then applies it to the query and value projection matrices of the transformer architecture. This approach dramatically reduces the number of trainable parameters to often less than 1% of the total, making fine-tuning feasible on hardware. At the core of LoRA's efficiency is its mathematical insight about weight updates. It first decomposes weight updates into products of two smaller matrices by leveraging the observation that during fine-tuning, the weight updates often have a low "intrinsic rank" - meaning they can be approximated by low-rank matrices without significant loss of information. For example, in a transformer model where a weight matrix W might be `768×768` (containing `589,824` parameters), LoRA replaces the full update with two matrices B `(768×16)` and A `(16×768)`, requiring only `24,576` parameters - a 96% reduction. These matrices are initialized with careful scaling: B starts with random Gaussian values while A begins at zero, ensuring training begins from the original model's behavior. The implementation in the code uses `r=16` for the rank hyperparameter, which determines this compression ratio. The `lora_alpha=32` parameter controls scaling during inference, effectively determining how strongly the adaptation affects the original weights. The `target_modules=["q_proj", "v_proj"]` parameter specifically targets the query and value projection matrices in the attention mechanism, which are particularly influential for language understanding and generation while leaving other components untouched.

The train method handles the actual training process by configuring optimization parameters including learning rate, batch size, and gradient accumulation steps. It sets up a training pipeline with appropriate arguments for supervised learning, including warmup schedules and weight decay for regularization. After training completes, it saves just the LoRA adapter rather than the full model, making the fine-tuned version extremely portable at just a fraction of the full model size. Finally, the evaluation method provides a convenient way to compare the base model against the fine-tuned version using the same prompts. 

In [None]:
if __name__ == "__main__":
    
    # Initialize base model
    base_llm = PretrainedLLM(model_name="facebook/opt-350m")
    
    # Create SFT trainer
    sft = SupervisedFineTuner(base_llm, dataset_name="databricks/dolly-15k")
    
    # Prepare data
    processed_data = sft.prepare_data()
    
    # Setup PEFT
    peft_model = sft.setup_peft()
    
    # Train the model
    adapter_path = sft.train(output_dir="sft_model", num_epochs=1)  # Reduced for demo
    
    # Load the adapter into the base model
    base_llm.load_adapter(adapter_path)
    
    # Test the model
    evaluation_prompts = [
        "Explain quantum computing to a 10-year-old:",
        "Write a short story about a robot learning to feel emotions:",
        "How do I bake a chocolate cake?"
    ]
    
    sft.evaluate(evaluation_prompts)

EXAMPLE

and there we go! we have a successfully supervised fine tuned llm!

## 3. Reinforcement Learning with Human Feedback

#### Understand the relation between LLMs and Reinforcement Learning

Before diving into the human feedback mechanisms, it's crucial to understand how Reinforcement Learning fundamentals apply to language models. This connection isn't immediately obvious, as LLMs seem quite different from traditional RL scenarios like game-playing agents or robotic control systems.

Reinforcement Learning is a machine learning paradigm where an agent learns to make decisions by interacting with an environment to maximize cumulative rewards. The core components of any RL system include:

- **Agent**: The decision-maker (our language model) that observes the environment and takes actions
- **Environment**: The context in which the agent operates (the conversation or text generation task)
- **State (s)**: The current situation or context that the agent observes (the prompt and any previously generated text)
- **Action (a)**: The decisions the agent can make (choosing the next token/word to generate)
- **Policy (π)**: The strategy that maps states to actions (how the model decides what token to generate next)
- **Reward (r)**: The feedback signal indicating how good an action was (human preference scores or task-specific metrics)
- **Value Function (V)**: Estimates the expected cumulative reward from a given state (how "good" a particular context is)
- **Trajectory**: A sequence of states, actions, and rewards over time (the complete generation process from prompt to final response)

Let's return to our quantum computing explanation scenario to see how these RL concepts apply to language models. Imagine our model is explaining quantum computing to a 10-year-old:

- **Environment**: The conversation context where a child has asked "Explain quantum computing to a 10-year-old"
- **Initial State**: The prompt "Explain quantum computing to a 10-year-old:" (this is what the model "observes")
- **Agent**: Our language model that must generate an appropriate response
- **Action Space**: The model's vocabulary (approximately 50,000 possible tokens it could choose from at each step)
- **Policy**: The model's current strategy for choosing words - initially just next-token prediction, later optimized for human preferences

At each generation step, the model is in a specific state (the prompt plus all previously generated tokens) and must choose an action (the next token). For example:

- **State 1**: "Explain quantum computing to a 10-year-old:"
- **Action 1**: Choose token "Hey" (starting conversationally)
- **State 2**: "Explain quantum computing to a 10-year-old: Hey"
- **Action 2**: Choose token "there!" (continuing the friendly tone)
- **State 3**: "Explain quantum computing to a 10-year-old: Hey there!"
- **Action 3**: Choose token "Imagine" (starting an analogy)

And so forth, until the model generates an end-of-sequence token, completing the trajectory.


The genius of applying RL to language models lies in reframing text generation as a sequential decision-making process. Instead of just predicting the statistically most likely next token (as in standard pretraining), we can now optimize for much more sophisticated objectives:

**Traditional Approach**: "What word typically comes next in internet text?"

**RL Approach**: "What word choice will lead to a response that humans find helpful, harmless, and honest?"

In our quantum explanation example, the traditional model might continue with technical terms because they frequently appear after "quantum computing" in training data. But an RL-optimized model learns that choosing words like "imagine" or "think of it like" leads to better human ratings for child-appropriate explanations.

**State Representation in Language Models**: Unlike board games where the state is clearly defined (piece positions), in language models the state is the entire context window - the prompt plus all previously generated tokens. This creates a vast and complex state space where the same prompt can lead to exponentially many possible conversation paths.

**Action Space Complexity**: At each step, the model chooses from thousands of possible tokens. The sequence of these choices determines whether we get a technical academic explanation or a child-friendly analogy with magic coins and spinning objects.

**Reward Signal**: This is where human feedback becomes crucial. Instead of immediate rewards (like points in a game), language model rewards typically come at the end of generation when humans evaluate the complete response. This creates what's called a "sparse reward" problem - the model must learn which early token choices led to the eventual positive feedback.

**Policy Evolution**: The model's policy starts as simple next-token prediction but evolves through RLHF to incorporate human preferences. It learns that certain patterns (like using analogies, asking engaging questions, maintaining appropriate tone) consistently lead to higher rewards.

This reframing transforms language model training from pattern matching to goal-oriented behavior, enabling models to optimize for nuanced human values rather than just statistical likelihood. The subsequent RLHF stages (preference collection, reward modeling, and policy optimization) all build on this fundamental RL foundation to create AI systems that are not just knowledgeable, but genuinely helpful and aligned with human intentions.



### Human Preference Collection

#### But Why Do We Even Need Human Preference Collection?

You might wonder: "We already have an SFT model that can explain quantum computing nicely to children. Why do we need this additional step?" This is a fundamental question that gets to the heart of what makes RLHF so powerful.

The limitation of SFT becomes clear when we consider what it actually teaches the model. SFT is essentially sophisticated pattern matching - it learns to mimic the style and structure of the demonstration data, but it doesn't develop a deeper understanding of *why* certain responses are better than others. It's like teaching someone to paint by having them copy masterpieces stroke by stroke - they might reproduce the paintings accurately, but they haven't learned the principles that make great art.

Consider our quantum computing example. An SFT model might learn that responses for children should use simple language and analogies, but it lacks nuanced understanding of what makes one analogy better than another. It might consistently use the "spinning coin" analogy because it appeared in training data, but it won't know that sometimes a "magic treasure hunt" analogy might be more engaging for a particular child, or that asking follow-up questions creates better learning experiences.

Human Preference Collection addresses these limitations by teaching the model to understand *relative quality* - not just how to generate appropriate responses, but how to distinguish between good, better, and best responses. This comparative learning is fundamentally different from demonstration learning and unlocks much more sophisticated behavior.


<img src="assets/workflow3.png" >



Human Preference Collection is essentially a massive, systematic comparison exercise where human evaluators help the model learn what "better" means in countless different scenarios. Here's how it works in practice:

#### Step 1: Response Generation
First, we use our SFT model to generate multiple responses to the same prompt. For our quantum computing example, we might generate several different explanations:

**Prompt**: "Explain quantum computing to a 10-year-old"

**Response A**: "Quantum computers use qubits instead of regular bits. Think of regular bits like light switches that can only be on or off. Qubits are special because they can be on and off at the same time, like a spinning coin that's both heads and tails until it stops spinning."

**Response B**: "Hey there! Imagine you have a magic coin that can be heads and tails at the same time while it's spinning! Quantum computers use these special 'qubits' that work just like our magic coin. This lets them solve puzzles much faster than regular computers. What's your favorite kind of puzzle?"

**Response C**: "Quantum computing is like having a super-fast helper that can check every room in your house for your lost toy at the same time, instead of checking one room after another like a regular computer would do."

#### Step 2: Human Evaluation
Human evaluators (often experts in education, communication, or the relevant domain) are presented with pairs of these responses and asked to choose which one better fulfills the criteria. The evaluation interface might look like:

```
Prompt: Explain quantum computing to a 10-year-old

Response A: [Response A text]
Response B: [Response B text]

Which response better explains quantum computing to a 10-year-old?
□ Response A is significantly better
□ Response A is slightly better  
□ Response B is slightly better
□ Response B is significantly better
□ They're about the same quality
```

The evaluators consider multiple factors:
- **Age-appropriateness**: Does it use language a 10-year-old would understand?
- **Engagement**: Would this capture and maintain a child's interest?
- **Accuracy**: Is the explanation scientifically sound (even if simplified)?
- **Clarity**: Would a child actually understand this after reading it?
- **Interactivity**: Does it encourage questions or further learning?

#### Step 3: Preference Data Creation
Each comparison creates a preference data point. If evaluators consistently prefer Response B over Response A, this creates a training signal that Response B's approach (conversational tone, asking questions, using engaging analogies) should be valued more highly than Response A's approach (more technical, less interactive).


Let's extend our quantum computing analogy to understand preference collection itself:

Imagine you're teaching a young student not just about quantum computers, but about how to *evaluate* different explanations of quantum computers. Instead of just showing them one "correct" way to explain it (like SFT), you show them pairs of explanations and ask: "Which of these would help you understand better?"

**Comparison 1**: Technical explanation vs. Magic coin analogy
→ Student prefers: Magic coin analogy

**Comparison 2**: Magic coin analogy vs. Treasure hunt analogy  
→ Student prefers: Treasure hunt analogy (more engaging)

**Comparison 3**: Treasure hunt analogy vs. Interactive treasure hunt with questions
→ Student prefers: Interactive version (encourages participation)

Through hundreds of these comparisons, the student (our model) learns not just individual explanations, but the *principles* that make explanations effective: engagement beats technicality, interactivity beats passivity, relatable analogies beat abstract concepts.


In real RLHF implementations, this process happens at massive scale:

- **Thousands of prompts** across diverse topics and scenarios
- **Multiple responses** generated for each prompt (typically 2-4)
- **Multiple evaluators** rating each comparison (to ensure reliability)
- **Hundreds of thousands** of preference comparisons collected
- **Quality control** measures to ensure consistent evaluation standards

For our quantum explanation example, the preference data might reveal patterns like:
- Responses with questions score higher than statements
- Analogies to familiar objects (coins, toys) beat abstract concepts
- Conversational tone ("Hey there!") beats formal tone
- Explanations that acknowledge the topic's complexity score higher than oversimplifications

#### Why This Works Better Than More SFT Data

You might think: "Why not just collect more demonstration data instead?" The key insight is that preference collection captures information that demonstration data cannot:

**Demonstration data tells us**: "This is a good response"

**Preference data tells us**: "This response is better than that response *because*..."

The "because" part is crucial. Through comparative evaluation, the model learns the underlying principles that make responses effective, not just specific examples of effective responses. This enables much better generalization to new scenarios and more nuanced quality judgments.

In our quantum computing scenario, instead of learning "use the magic coin analogy," the model learns "use analogies that relate to a child's everyday experience, and frame them in an engaging, interactive way." This principle can then be applied to explaining any complex topic to children, not just quantum computing.

This preference data becomes the foundation for the next stage: training a reward model that can automatically evaluate response quality and guide the final optimization process. The human preferences, collected at scale, teach the AI system not just what to say, but what makes some ways of saying things better than others.


### Reward Model Training



<img src="assets/workflow4.png" >


dive into reward model and how they play a role in this why is it needed and how it works, if possible also dive into math of it

### Optimization Algorithms

#### Why these algorithms were needed?

start with basic policy optimization introduciton and about, how gradient policy optimzation works generally and state its limitations briefly, off policy and on policy learning intro

##### Trust Region Policy Optimization

<img src="assets/trpo.png" width=900>

Now in this one, explain a lot, what it is, use our quantum 5 year old analogies how it works laydown the components and how it was groundbreaking, dive into the math side of it a lot we wont be doing code here since we're mainly focusing on PPO and DPO for this notebook

In policy gradient algorithms, update steps computed at any specific policy $\pi_{\theta}$ are only really "predictive" in the neighborhood of $\theta_t$. That is, it is probable that updates outside of this neighborhood may not contain any predictive value at all. Intuitively, you may then think of constraining updates so that they do stay in the vicinity of our current policy.

Since the policy is a probability distribution over (state, action) pairs, we refer to it as the "policy space". Trust region methods (Kakade, 2001; Schulman et al., 2015a; 2017) aim to restrict how far successive policies are allowed to deviate. 

In Trust Region Policy Optimization(Schulman et al., 2015a), this is achieved by minimizing the KL divergence between successive policies on the optimization trajectory. Let $\hat{A_{\pi}}(s_t, a_t)$ be the estimated advantage function where the agent picks action $a$ at time $t$ given he is at state $s$ while following a policy $\pi_{\theta}$. Let $\pi(a_t|s_t)$ be the old policy where we take action $a$ given we are in state $s$ at time $t$. $\pi_{\theta}(a_t|s_t)$ is our current (parameterized) policy which we seek to update.

Define the optimization problem as:
$$\underset{\theta}{\max}\ \mathbb{E}_{(s_t, a_t)\sim \pi} \left[\frac{\pi_{\theta}(a_t|s_t)}{\pi(a_t|s_t)}\hat{A_{\pi}}(s_t, a_t)\right]\\ \\
\text{subject to } D_{KL}(\pi_{\theta}(\cdot|s)||\pi_{\theta}(\cdot|s))\leq\delta, \ \forall s
$$
The "subject to" line is essentially an assumption that we will have to improve. It is saying that our optimization problem will abide by the goal of having the KL divergence between the two policies becoming less than some small $\delta$.

This theoretical update is not easy to compute. Thus, TRPO makes approximations by reformulating both the loss function $\mathcal{L}(\theta_k, \theta)$ and the KL divergence $D_KL(\theta||\theta_k)$ to give an easier-to-compute approximation of the objective. Furthermore, this approximate objective is then able to be solved using Lagrangian duality to yield the update:
$$\theta_{k+1}=\theta_k+\alpha^j\sqrt{\frac{2\delta}{g^TH^{-1}g}}H^{-1}g$$
Since the Hessian inverse $H^{-1}$ is expensive to compute, TRPO utilizes the conjugate gradient algorithm to solve $Hx=g$ (or $x=H^{-1}g$) which requires a function for computing $Hx$ instead of computing and storing the entire matrix $H$.

##### Proximal Policy Optimization

explain what PPO is how it was better than TRPOm define its components, explain it with the math and how it works one by one indepth

<img src="assets/ppo.png">

Let's start off by making a simple `PPOModel` class such that we can perform a `forward()` pass. 

In PPO, we refer to the policy model as the "actor" and the value model as the "critic".

The value model outputs scalar values (scores) just like the reward model. Meanwhile the policy model outputs probability distributions (take the log and you get logits).

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Tuple
from transformers import Trainer

class PPOModel(nn.Module):
    def __init__(self, actor_model, critic_model):
        super().__init__()
        self.actor_model = actor_model
        self.critic_model = critic_model

    def forward(self, sequences, extra_inputs=None):
        # fetch logits from actor, fetch scalar reward from reference
        actor_logits = self.actor_model(**sequences, return_dict=True).logits
        critic_values = self.critic_model(**sequences)[-1]
        
        if extra_inputs is not None:
            extra_loss = self.actor_model(**extra_inputs, return_dict=True).loss
        else:
            extra_loss = 0.0
        return actor_logits, critic_values, extra_loss

Now, navigate to the `src/ppo` folder to find `ppo_trainer.py`. This will contain the main code for our PPO algorithm.

We will now go over each component of the `PPOTrainer` class.

The KL divergence cannot be computed in its original form. 

The action space (token space) is far too vast for us to sum/integrate over all $x$.
Furthermore, when we train, we don't store full probability distributions but rather log-probabilities of the tokens.
Thus, it doesn't make sense to waste GPU memory on something we can avoid. 
So, it is necessary to estimate it (with Monte Carlo, for example). However, we have a plethora of KL estimators to choose from.

Choosing the most "optimal" estimator is out of the scope of this notebook. We only implement the most popular estimators which have been tried and tested. 

Suppose $r=\frac{\pi (\theta)}{\pi_{\text{old}}(\theta)}$ is the ratio between the current policy and the old policy.
Let $k$ denote an ambiguous KL estimator. Then, we have the following:

$$k_3 = (r - 1) - \log r$$
$$k_{\text{abs}}=|\log r|$$
$$k_{\text{MSE}}=\frac{1}{2}(\log r)^2$$

This might be up for debate, but the best KL estimator is $k_{\text{MSE}}$. 

In [None]:
def compute_rewards_with_kl_penalty(self, ref_values, actor_log_probs, ref_log_probs, responses_mask):
    """
    Computes rewards with the KL divergence penalty.
    Includes implementations of KL estimators since we can't compute it exactly.
    - k_3 is a popular KL estimator proposed here: http://joschu.net/blog/kl-approx.html

    Args:
        ref_values

        actor_log_probs: torch.Tensor
            log probabilities from our actor model

        ref_log_probs: torch.Tensor
            log probabilities from our reference model

        responses_mask: torch.Tensor
    """
    masks = responses_mask[:, 1:] 
    rewards_score = self.get_last_reward_score(ref_values, responses_mask)
    
    batch_size = rewards_score.shape[0]
    rewards_with_kl_penalty, kl_penalty_all = [], []

    for i in range(batch_size):
        mask = masks[i]
        lp_a = actor_log_probs[i][mask] # masked actor logprobs
        lp_r = ref_log_probs[i][mask] # masked reference logprobs


        # in my equations below, r is simply the ratio: pi(y) / pi_ref(y)
        if self.args.kl_penalty_method == 'k_3': # equation: (r - 1) - log r
            lp_diff = lp_a - lp_r
            ratio = torch.exp(lp_diff)
            kl_est = (ratio - 1.0) - lp_diff

        elif self.args.kl_penalty_method == 'abs': # equation: |log r|
            kl_est = torch.abs(lp_a - lp_r)

        elif self.args.kl_penalty_method == 'mse': # equation: 1/2 * (log r)^2
            kl_est = 0.5 * (lp_a - lp_r) ** 2 
        else:
            raise ValueError(f"Unknown kl_penalty_method: {self.args.kl_penalty_method}")

            
        kl_penalty = - self.args.kl_penalty_beta * kl_est
        kl_penalty_all.append(kl_penalty)

        if self.args.reward_score_clip is not None:
            rewards_score[i] = torch.clamp(rewards_score[i], -self.args.reward_score_clip, self.args.reward_score_clip)
        
        end_index = mask.nonzero()[-1].detach().item()
        kl_penalty[end_index] += rewards_score[i]

        rewards_with_kl_penalty.append(kl_penalty)
    return torch.stack(rewards_with_kl_penalty), torch.stack(kl_penalty_all), rewards_score 

Now, we cannot compute the advantage function $A_t$ exactly, so we estimate it. 

The algorithm used in PPO to compute the estimate of the advantage $\hat{A}_t$ is Generalized Advantage Estimation (GAE).
One key thing here is that we do not just naively estimate $A_t$; we have to do it carefully to mitigate a large variance.

If $T$ is our trajectory length, then GAE computes this estimate by:

$$\hat{A}_t=\sum^{T-t-1}_{l=0}(\gamma\lambda)^l \delta_{t+l}$$

where $\gamma$ is the discount factor, $\lambda\in[0,1]$ is the GAE parameter, and $\delta_t=R(s_t,a_t)+\gamma V(s_{t+1})-V(s_t)$ which is the Temporal-Difference (TD) error. Here, $R(s_t, a_t)$ is the reward at time $t$, and $V(s_t)$ is the value function at time $t$. 

Variance reduction is a common theme in reinforcement learning as long-horizon estimates are prone to deviating from our expected value.
What GAE does is it performs multi-step "bootstrapping" to reduce the variance. Bootstrapping is a term which refers to updating an estimate value using (one or more) of the **same kind of estimate values**.

Without bootstrapping, our variance would explode, leading to longer episode trajectories (which we don't want). We want to keep these trajectories as short as possible; that is, to converge as fast as possible. In GAE, this bootstrapping can be seen in the TD error $\delta_t=R(s_t,a_t)+\gamma V(s_{t+1})-V(s_t)$.

However, we also need to consider bias. Remember the bias-variance tradeoff? Here, as we progress in the trajectory, our estimates accumulate positive bias at each step. This is the price to be paid for variance reduction at each step. There are proposed methods (such as VAPO https://arxiv.org/pdf/2504.05118) which aim to achieve both low variance while mitigating high bias. This is a bit more advanced, though.

In the context of applying this to language models, $T$ can be viewed as the maximum sequence length of a model's output. In GAE, the $t$-th step would be the $t$-th token that is currently being sampled from our model. The difference $T-t$ is just how far away this token is from the end of the sentence (EOS). The only nonzero reward $R(s_T)$ is calculated at the `<EOS>` token. Thus, all prior tokens are just propagating it backwards (with discounting). 

Given this discounting, as we increase the difference $T-t=2, T-t=3,..., T-t=k$, the reward signal grows weaker. Eventually, the last token to obtain this initial reward might obtain a zero value! Yikes. Thankfully, there are methods which attempt to address this decaying reward problem (such as VC-PPO https://arxiv.org/pdf/2503.01491).

In [None]:
def compute_gae_advantage_return(self, rewards, values, mask, gamma, lam):
    """
    Computes the Generalized Advantage Estimation via Temporal-Difference with parameter lambda.

    Args:
        rewards: torch.Tensor
            shape: (bs, response_length)
        values: torch.Tensor
            shape: (bs, response_length)
        mask: torch.Tensor
            shape: (bs, response_length)
        gamma: float
            discount factor
        lam: float
            lambda parameter for GAE algorithm
    """
    B, T = rewards.shape # B is batch size, T is response_length

    # here, we bootstrap the value model updates with Temporal-Difference (parameter \lambda)
    with torch.no_grad():
        advantages_reversed = []
        lastgaelam = 0

        for t in reversed(range(T)):
            # for long sequences with T - t >> 1, discounting reduces the reward signal to near zero
            next_values = values[:, t + 1] if t < T - 1 else 0.0
            delta = rewards[:, t] + gamma * next_values - values[:, t]
            lastgaelam = (delta + gamma * lam * lastgaelam) * mask[:, t] 
            advantages_reversed.append(lastgaelam)
        advantages = torch.stack(advantages_reversed[::-1], dim=1)

        returns = advantages + values
        # 
        if use_advantage_norm:
            advantages = masked_whiten(advantages, eos_mask)
    return advantages, returns


Implementing PPO (which is an Actor-Critic algorithm) requires a lot of design choices, where you can easily lose yourself trying to read them all. However, the topmost choice deals with the actor and critic themselves. We have two choices: 

1. Learn two largely separate models. One for the actor and another for the critic.
2. Learn one large model, and apply two shallow heads at the end to format our output (one for the actor, another for the critic)

The first approach is usually more popular because the singular model we learn is usually a deep transformer which contains rich, hidden representations. As such, our two heads only need to be shallow because they serve as "mappings" from our representations to scalar values. The large transformer will have already done the grunt work during its intense training runs (and perhaps, any additional supervised fine-tuning).

In usual implementations of PPO (which is an Actor-Critic algorithm), we train a single large model with two heads on top. One is the actor (policy) head and the other is the critic (value) head.
We optimize this single model by minimizing a three-part loss function. 

That is, we add the $L^{CLIP}$ clipped surrogate loss as proposed in the original PPO paper. Next, we subtract an MSE error loss, $L^{VF}$ for computing the value function. Finally, we add the entropy of the policy, $H(\pi_{\theta}(\cdot | s))$ which is multiplied by some constant term, $c_1$. Usually, we also multiply $L^{VF}$ by a constant $c_2$ as well. So we have:
$$L=L^{CLIP}-c_2 L^{VF} + c_1 H(\pi_{\theta}(\cdot | s))$$

In the following sections, we will go over each term and their theoretical significances.

# *=================================BREAK================================*

The actor has an entropy term which is effectively a regularizer that promotes exploration. Entropy in our case is (very loosely) a measure of our uncertainty in the policy. As such, it can tell us how confident a policy is at choosing an action given that they are in some state. When this entropy $H(\pi)$ is low, we have a high confidence, whereas a high $H(\pi)$ implies low confidence. 

We want our policy to diversify itself instead of picking the "best choice" each time. Thus, if we want to push the policy to "explore more", we need to keep some amount of uncertainty. We would effectively be "spreading" the probability distribution over multiple actions rather than letting a high confidence confine it to only a few actions. So, natually, we'd want to maximize the entropy $H(\pi)$. 

Another thing is that we don't always know how to compute $H(\pi)$ exactly, so we estimate it by taking the log-probability of our policy and averaging over the batch. If $\mathcal{B}$ is our batch of state, action pairs, then we have:
$$H(\pi)\approx -\frac{1}{\mathcal{|B|}}\sum_{a,s\in\mathcal{B}} (-\log \pi(a|s))=\frac{1}{\mathcal{|B|}}\sum_{a,s\in\mathcal{B}} (\log \pi(a|s))$$
Here, the negative in the front is by the definition of entropy $H(x)=-\sum_x p(x)\log(p(x))$, but since we are estimating it we can omit the $p(x)$ term so we have something like $H(x)\approx -\sum_x \log(p(x))$. The negative in front of $\log \pi(a|s)$ is because we are collecting negative log-probabilities. It is usually easier to compute log-probabilities than the full probabilities, hence this estimation. 

So, our entropy $H(\pi)$ comes out as positive. Furthermore, we add a constant $c_1$ which we can tweak to regularize the effect of $H(\pi)$ on the overall loss $L$.
$$c_1 H(\pi(\cdot|s))$$
However, we can afford computing the entropy in its entire form (which is what the code does below):
$$H(\pi)= -\frac{1}{\mathcal{|B|}}\sum_{a,s\in\mathcal{B}} (-\pi(a|s)\log \pi(a|s))=\frac{1}{\mathcal{|B|}}\sum_{a,s\in\mathcal{B}} \pi(a|s)\log \pi(a|s)$$

In [None]:
def get_entropy(self, logits, mask):
        """
        Computes the entropy of the policy, which incentivizes exploration.

        Args:
            logits: torch.Tensor
                unnormalized log-probabilities from a model
            mask: torch.Tensor
                mask to be applied to the computed entropy
        """
        probs = F.softmax(logits, dim=-1)
        log_probs = F.log_softmax(logits, dim=-1)
        entropy = self.masked_mean(-torch.sum(probs * log_probs, dim=-1), mask)
        return entropy

The PPO-Clip surrogate objective is defined as:
$$ L^{\text{CLIP}} = \min\left(\frac{\pi(\theta)}{\pi_{\text{old}}(\theta)}A_t, \text{clip}\left( \frac{\pi(\theta)}{\pi_{\text{old}}(\theta)} , 1-\epsilon, 1+\epsilon \right)A_t \right) $$

Notice how we do not have a `min` operation here? Notice how we actually use `torch.max()`? That seems counterintuitive. However, further note that we added a negative sign to `loss1` and `loss2`, before taking their `max`. 

This is effectively equivalent to the `min()` operation but done in reverse to ensure numerical stability. 

In [None]:
def compute_policy_loss(self, old_log_prob, log_prob, advantages, eos_mask, epsilon):
        """
        Computes the policy gradient loss function for PPO.

        Args:
            old_log_prob: torch.Tensor
                log probabilities from the old policy
            log_prob: torch.Tensor
                log probabilities from the current policy 
            advantages: torch.Tensor
                Computed advantages via advantage estimation
            eos_mask: torch.Tensor
            
            epsilon: float
        """

        # from log domain -> real domain
        ratio = torch.exp(log_prob - old_log_prob)
        loss1 = -advantages * ratio
        loss2 = -advantages * torch.clamp(ratio, 1.0 - epsilon, 1.0 + epsilon)

        loss = masked_mean(torch.max(loss1, loss2), eos_mask)

As detailed in PPO-Clip algorithm, we compute the value function by regressing on an MSE loss. 

For some reason this is not explicitly mentioned in the original PPO paper. It is, however, mentioned in OpenAI's Spinning Up documentation entry on PPO. For the answer, we need to do some digging back to the old Barto & Sutton book (you can not escape RL!!!)

Recall that the on-policy value function is defined as the expected discounted return that the agent can get given he starts in some state $s$ and acts according to the policy $\pi$:
$$v_{\pi}(s)=\mathbb{E}_{\pi}[R_{t+1}+\gamma R_{t+2}+\cdots |S_t=s]$$
Now, define the discounted return as $G_t=\sum_{t'=t}^T \gamma^{t'-t}r_{t'}$. If we reindex with $k=t'-t$ then this becomes $G_t=\sum_{k=0}^T \gamma^{k}r_{t}$.
So, the value function becomes:
$$v_{\pi}(s)=\mathbb{E}_{\pi}[G_t|S_t=s]=\mathbb{E}_{\pi}[\sum_{k=0}^T \gamma^{k}r_{t}|S_t=s]$$
which makes a lot more sense now! We see that $v_{\pi}(s)$ is just the expected discounted return! 


It is found in Chapter 9.2 of the book. Here, the "natural objective function" for value-based algorithms is to minimize the MSE between the true value and observed returns.
That is, our value error is defined as:
$$\overline{VE}(\mathbf{w})=\mathbb{E}[v_{\pi}(s)-\hat{v}(s,\mathbf{w})]^2$$
where the expectation $\mathbb{E}[\cdot]$ is taken over the state distribution $\mu(s)$ such that $\mu(s)\geq 0$ and $\sum_s \mu(s)=1$.

Dissecting this equation, we can see that the expected discounted return $v_{\pi}(s)$ can be interpreted as the empirical return $G_t$. However, this is not the exactly the $G_t$ we defined earlier. In practice, we do not compute the full return, so we usually use something like a Monte Carlo estimate. It can be shown that this Monte Carlo estimate is an unbiased sample of $v_{\pi}(s)$, so they are not quite equal (but we can get close enough).

Next, $\hat{v}(s, \mathbf{w})$ is defined as a joint distribution with states $s$ and weights $\mathbf{w}$ (which we learn). This is effectively our learned value function, $V(s_t)$ which can be computed using a multi-layer neural network. 

Thus, in our PPO-Clip algorithm we have the MSE loss:
$$L_v = \sum_{t=1}^T \frac{1}{2}(G_t - V(S_t))^2$$
which tells us how far our observed returns (value estimates), $V(s_t)$, deviate from the "true" empirical returns, $G_t$.

For a further dissection of On-Policy Prediction with Approximation, I recommend reading Chapter 9 of Sutton & Barto.

In [None]:
def compute_value_loss(self, value_preds, returns, values, eos_mask, epsilon):
    """
    Fits the value function by regression on the MSE loss. 

    Args:
        value_preds: torch.Tensor
            the predictions from our value model
        returns: torch.Tensor
            shape: (bs, response_length)
        values: torch.Tensor
            
        eos_mask: torch.Tensor

        epsilon: torch.Tensor
    """
    
    # this keeps our value predictions within some epsilon-distance
    # mostly for numerical stability before performing regression
    clip_value_preds = torch.clamp(value_preds, values - epsilon, values + epsilon)   

    # thus, we have two errors, but it doesn't matter because we take the maximum (one)
    values_error = (value_preds - returns) ** 2
    clip_values_error = (clip_value_preds - returns) ** 2
    
    # this is essentially the inner sum for one trajectory t \in D_k where D_k is set of trajectories
    loss = 0.5 * masked_mean(torch.max(values_error, clip_values_error), eos_mask)
    return loss, values_error

Notice that in the above code cell, we called `masked_mean()`? What is this? Well, we're now getting into implementation-specific "tricks" associated with PPO. Stuff you might not find in the original literature.



#### Group Relative Policy Optimization

<img src="assets/grpo.png">

Now in this one, explain a lot, what it is, use our quantum 5 year old analogies how it works laydown the components and how it was groundbreaking, dive into the math side of it a lot we wont be doing code here since we're mainly focusing on PPO and DPO for this notebook

#### Direct Preference Optimization

<img src="assets/dpo.png">

Now in this one, explain a lot, what it is, use our quantum 5 year old analogies how it works laydown the components and how it was groundbreaking, dive into the math side of it a lot, since we're focusing on dpo and ppo

#### (Bonus) Test Time Preference Optimization

Now in this one, explain a lot, what it is, use our quantum 5 year old analogies how it works laydown the components and how it was groundbreaking, dive into the math side of it a lot we wont be doing code here since we're mainly focusing on PPO and DPO for this notebook

## Citations