# Reinforcement Learning with Human Feedback

[Large Language Models](https://en.wikipedia.org/wiki/Large_language_model) (LLMs) have been at the
forefront of the present AI summer, making the generative AI field explode with innovation in recent
years.

LLMs are trained on vast amount of data in a self-supervised fashion (i.e., next token prediction),
but they require lots of fine-tuning for them to be "user-friendly" and behave / respond in a way
that humans consider appropriate (e.g., do not
[hallucinate](<https://en.wikipedia.org/wiki/Hallucination_(artificial_intelligence)>)).

Fine tuning LLMs generally happens in two stages:

- **Supervised Fined-Tuning (SFT):** The pre-trained model is fine-tuned on a smaller, high-quality
  dataset of labeled prompt-response pairs that demonstrate desired behaviors for specific tasks.
  This helps the model become more useful and follow instructions.
- **RLHF:** Human evaluators rank or provide preferences for different model outputs. The feedback
  is used to train a "reward model," which guides the LLM to generate responses that are highly
  preferred by humans (e.g., addressing safety, nuance, creativity, etc.) Here is a
  [seminal paper](https://arxiv.org/abs/1909.08593) on the topic.

### RLHF Steps

The starting point is to collect a "preference dataset", where human raters pick one LLM output
preferred over another (or multiple ones). The preference dataset has tuples like:

```
(prompt, output_1, output_2, human_preference)
```

Then we use the preference dataset to train a _Reward Model_. The reward model is generally another
LLM that takes a prompt and a completion, and outputs a _score_. Training the reward model is a hard
task in and of itself, and in this notebook we will "fake" it using a hand-crafted reward function.

At this point, we can introduce the reinforcement learning loop to fine-tune our original LLM! We
need a second dataset containing many prompts (no completions!) which are analogous to "environment
steps". Then, we feed the prompts to the model, which produces a completion, which is then scored
via the reward model, and then weights are updated via PPO (or similar RL algorithm).

![RLHF diagram](../assets/14_RLHF_loop.excalidraw.png) <br><small>Reinforcement learning loop for
LLM fine-tuning.</small>

In this notebook, we will fine-tune a small language model to prefer telling stories about the
animal kingdom!


In [10]:
# Import necessary classes from Hugging Face transformers library
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    AutoModelForSequenceClassification,
    PreTrainedModel,
)
from transformers.tokenization_utils import PreTrainedTokenizer

# Import necessary classes from PEFT (Parameter-Efficient Fine-Tuning) library for LoRA
from peft import get_peft_model, LoraConfig, TaskType

import torch
import torch.optim as optim
import torch.nn.functional as F

import random
import copy

from util.gymnastics import RLHF, DEVICE, init_random

We will use a [TinyStories](https://arxiv.org/abs/2305.07759) pretrained model that you can find on
[HuggingFace](https://huggingface.co/roneneldan/TinyStories-33M). Moreover, we limit the output of
the model to 60 tokens, to keep the generation relatively short and more efficient.


In [11]:
# Language model used in this notebook.
MODEL_NAME = "roneneldan/TinyStories-33M"

# Output length for training and sampling.
OUTPUT_LEN = 60

# Sample prompt for tests and consistent for comparison across epochs.
SAMPLE_PROMPT = "Once upon a time"

## Models

In order to be able to fine-tune a language-model, we need to understand the very basics of its
architecture. In particular, at the foundation of modern LLMs is the _transformer architecture_,
described in the popular paper ["Attention Is All You Need"](https://arxiv.org/abs/1706.03762).

We will use the [Hugging Face transformers API](https://huggingface.co/docs/transformers/index).
Reading the documentation before proceeding with the notebook is highly-recommended, because the
notebook assumes familiarity with such APIs and concepts.

First thing first, let's write a function to get a tokenizer for our model. A _tokenizer_ transforms
text into the corresponding tensor tokens' numerical IDs.


In [12]:
def make_tokenizer(model_name=MODEL_NAME) -> PreTrainedTokenizer:
    # Load the tokenizer associated with the chosen model. The tokenizer converts text to numerical
    # IDs and back.
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    # Some models don't have a default padding token; set it to the end-of-sequence token in case.
    # This is important for batching (even if don't batch in this notebook) and consistent sequence
    # handling.
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
    return tokenizer

We don't want to retrain the entire model for performance reasons. So we will use the HF
[PEFT](https://huggingface.co/docs/transformers/en/peft) (Parameter Efficient Fine Tuning) library
to create
[LoRA adapters](https://huggingface.co/docs/peft/conceptual_guides/adapter#low-rank-adaptation-lora).

Low-Rank Adaptation (LoRA) is one of the most popular PEFT methods, and allows to train a smaller
subset of parameters added to the model's fixed parameters to tweak its behavior.

Because we will use PPO for training, we will have both an actor and a critic model. Hence, we will
create two LoRA configurations.


In [13]:
def make_lora_configs(r=16, alpha=32, dropout=0.05) -> tuple[LoraConfig, LoraConfig]:
    """LoRA (Low-Rank Adaptation) allows fine-tuning only a small number of parameters, saving
    memory and compute.

    Args:
        r (int): Rank of the LoRA decomposition matrices. Higher rank means more trainable
                 parameters (more capacity but slower).
        alpha (int): LoRA scaling factor, often set to 2*R. Balances the influence of LoRA
                     adapters vs base weights.
        dropout (float): Dropout probability within LoRA layers to prevent overfitting.
    """
    # Specify which layers / modules within the base model to apply LoRA adapters to.
    # Targeting attention projection layers (query, key, value) is common and effective.
    target_layers = ["q_proj", "k_proj", "v_proj"]

    # LoRA configuration specifically for the actor model (Causal Language Model task)
    actor_lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        target_modules=target_layers,
    )

    # LoRA configuration specifically for the critic model (Sequence Classification task)
    # Note: Use SEQ_CLS task type b/c we load the critic using AutoModelForSequenceClassification.
    # Also, we assume the base architecture shares "targetable" layer names with the CausalLM.
    critic_lora_config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=r,
        lora_alpha=alpha,
        lora_dropout=dropout,
        target_modules=target_layers,
    )
    return actor_lora_config, critic_lora_config

Finally, we can create the models used for training!


In [14]:
def make_training_models(
    tokenizer: PreTrainedTokenizer,
    model_name=MODEL_NAME,
) -> tuple[PreTrainedModel, PreTrainedModel, PreTrainedModel]:
    """Loads the pre-trained language models to use for RLHF fine-tuning.

    Returns:
        tuple: A tuple containing:
            - PreTrainedModel: the frozen reference model, used for KL-divergence computation and
                               reference during training.
            - PreTrainedModel: the "actor" model, i.e., the language model that is fine-tuned.
            - PreTrainedModel: the "critic" model, i.e., the model for classification used to
                               determine the "values" of the actor outputs.
    """
    # Load the base pre-trained language model for the Actor (policy network).
    base_actor_model = AutoModelForCausalLM.from_pretrained(model_name).to(DEVICE)

    # Load the base pre-trained model, but configure it for sequence classification to act as the
    # critic (value network). `num_labels=1` makes it output a single continuous value (regression),
    # suitable for predicting expected reward (value). This is a simplification; a more standard
    # critic might have a custom value head on the CausalLM base.
    base_critic_model = AutoModelForSequenceClassification.from_pretrained(
        model_name, num_labels=1
    ).to(DEVICE)

    # Ensure critic model knows the correct padding token ID, important for attention mechanisms.
    base_critic_model.config.pad_token_id = tokenizer.pad_token_id

    # Create a deep copy of the original *base* actor model as the reference model, before applying
    # LoRA. This model serves as a fixed reference point for the KL divergence calculation.
    ref_model = copy.deepcopy(base_actor_model).to(DEVICE)
    # Set the reference model to evaluation mode (disables dropout, etc.)
    ref_model.eval()
    # Freeze all parameters of the reference model, so they don't get updated during training.
    for param in ref_model.parameters():
        param.requires_grad = False

    # Wrap the base actor and critic models with LoRA adapters using the defined configurations.
    # `get_peft_model` modifies the model to insert LoRA layers and freezes the original weights.
    actor_lora_config, critic_lora_config = make_lora_configs()
    actor_model = get_peft_model(base_actor_model, actor_lora_config)
    critic_model = get_peft_model(base_critic_model, critic_lora_config)

    return ref_model, actor_model, critic_model

Let's now put some pieces together and create a function that invokes a model and generates a
completion to an input prompt.


In [None]:
@torch.no_grad()
def generate(
    tokenizer: PreTrainedTokenizer,
    llm_model: PreTrainedModel,
    prompt: str,
    max_length: int = OUTPUT_LEN,
) -> tuple[torch.Tensor, str]:
    """Generates an output from the LLM model.

    Returns:
        tuple: A tuple containing:
            - Tensor: the model output token IDs.
            - str: The decoded text.
    """
    # Tokenize the current prompt text into numerical IDs for the model, truncating long prompts to
    # avoid exceeding our specified max_length.
    tokenized_prompt = tokenizer(
        prompt, return_tensors="pt", truncation=True, max_length=max_length // 2
    ).to(DEVICE)
    # Get token IDs and attention mask (which indicates which tokens are real vs padding).
    input_ids = tokenized_prompt.input_ids
    attention_mask = tokenized_prompt.attention_mask
    prompt_len = input_ids.shape[1]  # Store the length of the prompt section

    # Generate a response using the actor model (with LoRA). Use torch.no_grad() to avoid any
    # gradients calculation on generation (saving memory during inference).
    with torch.no_grad():
        outputs = llm_model.generate(
            input_ids,
            attention_mask=attention_mask,
            max_length=max_length,  # Maximum length of prompt + response
            do_sample=True,  # Enable sampling (essential for exploration in PPO)
            pad_token_id=tokenizer.pad_token_id,  # Specify padding token ID
            eos_token_id=tokenizer.eos_token_id,  # Specify end-of-sequence token ID
            top_k=50,  # Sampling: consider only top 50 probable tokens
            top_p=0.95,  # Sampling: consider tokens summing up to 95% prob (nucleus sampling)
            # Sampling temperature: controls randomness (lower=more focused, higher=more random)
            temperature=0.7,
        )
    # Extract only the generated token IDs (excluding the prompt).
    response_ids = outputs[0][prompt_len:]
    # Decode the response token IDs back into human-readable text.
    response_text = tokenizer.decode(response_ids, skip_special_tokens=True)
    # Return response_ids with first being the batch (of size 1 in this simple notebook).
    return tokenized_prompt, response_ids.unsqueeze(0), response_text

Let's create the tokenizer and generate an example.


In [7]:
tokenizer = make_tokenizer()
ref_model, actor_model, critic_model = make_training_models(tokenizer)

Some weights of GPTNeoForSequenceClassification were not initialized from the model checkpoint at roneneldan/TinyStories-33M and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
_, _, output_text = generate(tokenizer, ref_model, prompt=SAMPLE_PROMPT)
print(f"{SAMPLE_PROMPT}{output_text}")

Once upon a time, there was a little girl named Lily. She loved to play with her toys and her friends. One day, Lily's mommy gave her a new toy. It was a big, shiny doll that could talk. Lily was so happy and she hugged her new toy. 


## Reward Model

As previously mentioned, we are going to use an hard-coded function as our "pretend reward model"
for simplicity and to avoid to train a full-fledged reward model. While this works in our limited
educational RLHF example, it is certainly not effective for proper fine-tuning and alignment:
[reward hacking](https://en.wikipedia.org/wiki/Reward_hacking) being one of the typical reasons.

But for now, let's write a reward function that encourages telling stories about animals!


In [15]:
def animals_stories_reward(response: str) -> torch.Tensor:
    """Defines the reward signal used to guide the RL training.

    This simple function rewards mentioning animal words and penalizes mentioning human words.
    Note: Simple keyword-based rewards are easy for education but prone to "reward hacking".
    """
    # Preprocess the generated response.
    text_words = response.lower().split()
    response_length = len(text_words)

    # Count occurrences of animal and human words.
    animal_mentions = [word for word in text_words if word in RLHF.ANIMAL_WORDS]
    animal_score = len(animal_mentions)
    human_score = sum(1 for word in text_words if word in RLHF.HUMAN_WORDS)

    # Penalize very short responses.
    length_penalty = 0 if response_length > 4 else -1.0

    # Penalize excessive repetition of animal words.
    animal_threshold = max(5, int(0.15 * response_length))
    excessive_animal_penalty = 0
    if animal_score > animal_threshold:
        excessive_animal_penalty = -0.5 * (animal_score - animal_threshold)

    # Encourage a diversity of animal words.
    unique_animal_bonus = 0
    if animal_score > 0:
        unique_animal_bonus = 0.5 * len(set(animal_mentions))

    # Calculate final reward.
    reward = (
        (1.2 * animal_score)
        - (0.8 * human_score)
        + length_penalty
        + excessive_animal_penalty
        + unique_animal_bonus
    )
    return torch.tensor(reward, dtype=torch.float32, device=DEVICE)

In [10]:
def test_animals_stories_reward():
    response1 = "The quick brown fox and a lion met a boy"
    expected_reward1 = 2.6
    actual_reward1 = animals_stories_reward(response1)
    assert torch.isclose(
        actual_reward1, torch.tensor(expected_reward1)
    ), f"FAILED Case 1: Expected {expected_reward1}, but got {actual_reward1.item()}"
    print("✅ Test Case 1 (Standard) Passed!")

    response2 = "a cat and dog"
    expected_reward2 = 2.4
    actual_reward2 = animals_stories_reward(response2)
    assert torch.isclose(
        actual_reward2, torch.tensor(expected_reward2)
    ), f"FAILED Case 2: Expected {expected_reward2}, but got {actual_reward2.item()}"
    print("✅ Test Case 2 (Short Response) Passed!")

    response3 = "cat dog lion tiger bear fox cat dog lion tiger bear fox"  # 12 words
    expected_reward3 = 13.9
    actual_reward3 = animals_stories_reward(response3)
    assert torch.isclose(
        actual_reward3, torch.tensor(expected_reward3)
    ), f"FAILED Case 3: Expected {expected_reward3}, but got {actual_reward3.item()}"
    print("✅ Test Case 3 (Excessive Animals) Passed!")

    response4 = "This is a simple sentence about nothing special"
    expected_reward4 = 0.0
    actual_reward4 = animals_stories_reward(response4)
    assert torch.isclose(
        actual_reward4, torch.tensor(expected_reward4)
    ), f"FAILED Case 4: Expected {expected_reward4}, but got {actual_reward4.item()}"
    print("✅ Test Case 4 (No Keywords) Passed!")


test_animals_stories_reward()

✅ Test Case 1 (Standard) Passed!
✅ Test Case 2 (Short Response) Passed!
✅ Test Case 3 (Excessive Animals) Passed!
✅ Test Case 4 (No Keywords) Passed!


## RLHF Fine Tuning w/ PPO

Let's start by writing a helper function that computes the actual probabilities of the model output,
with respect to the actual response. We will use this repeatedly in the training loop.


In [None]:
def actual_log_probs(
    logits_all: torch.Tensor, response_ids: torch.Tensor, prompt_len: int
) -> torch.Tensor:
    """Compute the log probabilities of the model output logits w.r.t. the *actual* response.

    Args:
        logits_all (torch.Tensor): The model output logits of prompt + response. The shape should
                                   be [batch, tokens, logits], where batch is always 1 in this
                                   notebook, tokens are prompt+response tokens, logits ~50k (i.e.,
                                   total number of tokens).
        response_ids (torch.Tensor): The actual response token IDs. Shape should be:
                                     [batch, response_tokens], in this notebook [1, 100].
        prompt_len (int): How long is the prompt.
    """
    log_probs_all = F.log_softmax(logits_all, dim=-1)
    # Define the start and end indices corresponding to the *generated response* tokens within the
    # full sequence logits. We need logits from position `prompt_len - 1` to predict tokens at
    # position `prompt_len`, up to the end.
    gen_start = prompt_len - 1
    # Index of the second-to-last token (predicting the last token).
    gen_end = logits_all.shape[1] - 1
    response_log_probs = log_probs_all[:, gen_start:gen_end, :]
    # The generated `response_ids` tensor needs correct shape for `gather`.
    # Shape: (1, gen_len) -> (1, gen_len, 1)
    gather_index = response_ids.unsqueeze(-1)
    # Use `gather` to select the log probabilities corresponding to the actual tokens generated by
    # the actor (remove last dim).
    # (1, 100, 50k) gathered by (1, 100, 1) -> selects (i, j, index[i, j, 1]), i.e., the actual
    # token logprob for each token.
    actual_log_probs = torch.gather(response_log_probs, 2, gather_index).squeeze(-1)
    return actual_log_probs

Here is an AI-generated unit-test :)


In [None]:
def test_actual_log_probs():
    batch_size = 1
    prompt_len = 3
    response_len = 2
    total_len = prompt_len + response_len

    logits_all = torch.tensor(
        [
            [  # Batch 0
                [0, 0, 0, 0, 0],  # Logits for token 1 (in prompt)
                [0, 0, 0, 0, 0],  # Logits for token 2 (in prompt)
                [1, 2, 3, 4, 5],  # Logits used to predict the 1st response token
                [5, 4, 3, 2, 1],  # Logits used to predict the 2nd response token
                [0, 0, 0, 0, 0],  # Logits for token after response (will be ignored)
            ]
        ],
        dtype=torch.float32,
    )

    response_ids = torch.tensor([[2, 0]], dtype=torch.int64)
    logits_for_response = logits_all[:, prompt_len - 1 : total_len - 1, :]
    expected_log_probs_slice = F.log_softmax(logits_for_response, dim=-1)
    log_prob_token1 = expected_log_probs_slice[0, 0, 2]
    log_prob_token2 = expected_log_probs_slice[0, 1, 0]

    expected_output = torch.tensor([[log_prob_token1, log_prob_token2]])

    actual_output = actual_log_probs(
        logits_all=logits_all,
        response_ids=response_ids,
        prompt_len=prompt_len,
    )

    assert actual_output.shape == (
        batch_size,
        response_len,
    ), f"FAILED: Expected shape {(batch_size, response_len)}, but got {actual_output.shape}"

    assert torch.allclose(
        actual_output, expected_output
    ), f"FAILED: Expected values {expected_output}, but got {actual_output}"

    print("✅ Test passed!")


test_actual_log_probs()

✅ Test passed!


And finally... let's write our RLHF training loop using PPO! You should already by familiar with
PPO's general training structure and loss function - feel free to take a peek at the PPO notebook
solution to refresh your memory!

One important detail about training is that the reward (that comes from the reward model) is
counter-balanced by preventing the new model distribution to steer away too much from the original
model distribution. Intuitively, we still want our language model to produce text in a similar way
it currently does (and not finding ways to maximize reward that would instead produce unintelligible
text).

In order to do that, we measure the difference between the "reference" model distribution and the
model under training via
[KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence).

$$
D_{\text{KL}}(P \| Q) = \sum_i P(i) \log \frac{P(i)}{Q(i)}
$$

Which can be [efficiently approximated](http://joschu.net/blog/kl-approx.html) as:

$$
D_{\text{KL}}(P \| Q) = E_{X \sim P}\left[ \log P(X) - \log Q(X) \right]
$$

Understanding the mathematics behind these formulas would definitely be beneficial, but it is not
necessary to still appreciate the overall RLHF training loop.


In [None]:
def rlhf_fine_tuning(
    actor_model: PreTrainedModel,
    critic_model: PreTrainedModel,
    num_epochs=50,
    n_epoch_updates=2,
    max_length=OUTPUT_LEN,
    actor_lr=5e-5,
    critic_lr=1e-4,
    kl_beta=0.1,
    epsilon=0.2,
    grad_clip_norm=1.0,
    seed=42,
):
    """Fine-tune an LLM using Reinforcement Learning w/ Human Feedback.

    Args:
        num_epochs (int): Number of times the training loop iterates over all prompts.
        max_length (int): Maximum sequence length for generated responses (including prompt).
                          Limits computational cost and output length.
        actor_lr (float): Learning rate for the actor (policy) model optimizer.
        critic_lr (float): Learning rate for the critic (value) model optimizer.
        kl_beta (float): Strength of the KL divergence penalty.
        epsilon (float): PPO clipping parameter; limits how much the policy changes in one update.
        grad_clip_norm (float): Maximum gradients norm to avoid exploding gradients.
        seed (int): Seed for deterministic training.
    """
    init_random(seed=seed)

    # Optimizers on the trainable parameters, i.e., LoRA weights.
    actor_optimizer = optim.Adam(actor_model.parameters(), lr=actor_lr)
    critic_optimizer = optim.Adam(critic_model.parameters(), lr=critic_lr)

    for epoch in range(num_epochs):
        # Dictionary of metrics to track during training.
        metrics = {
            "raw_scores": [],
            "rewards": [],
            "kl_divs": [],
            "ppo_losses": [],
            "critic_losses": [],
        }

        # Set models to training mode (enables dropout, LoRA updates, etc.)
        actor_model.train()
        critic_model.train()

        # Shuffle prompts randomly each epoch for better generalization.
        epoch_prompts = RLHF.PROMPTS.copy()
        random.shuffle(epoch_prompts)

        for prompt in epoch_prompts:
            # Generate Response (no grad).
            tokenized_prompt, response_ids, response_text = generate(
                tokenizer, actor_model, prompt, max_length
            )
            # Store the length of the response section.
            response_len = response_ids.shape[1]
            # Store the length of the prompt section.
            prompt_len = tokenized_prompt.input_ids.shape[1]

            # Concatenate prompt and response IDs to form the full sequence for model input.
            full_ids = torch.cat([tokenized_prompt.input_ids, response_ids], dim=1)
            # Create the corresponding attention mask for the full sequence (prompt mask + ones for
            # response b/c response has no padding).
            full_attention_mask = torch.cat(
                [
                    tokenized_prompt.attention_mask,
                    torch.ones(1, response_len, device=DEVICE, dtype=torch.long),
                ],
                dim=1,
            )

            # Calculate Metrics (Reward, KL, Value). Perform these calculations without tracking
            # gradients as they are inputs to the loss functions.
            with torch.no_grad():
                # Get logits (raw model outputs before activation) from the actor (PEFT model).
                # The actor model output has two fields: `logits` and `past_key_values` (for caching
                # attention computations).
                actor_logits = actor_model(full_ids, attention_mask=full_attention_mask).logits
                # Calculate the log probabilities using the response_ids.
                actor_log_probs = actual_log_probs(actor_logits, response_ids, prompt_len)

                # Get logits from the frozen reference model (original base model).
                ref_logits = ref_model(full_ids, attention_mask=full_attention_mask).logits
                # Calculate the log probabilities using the response_ids.
                ref_log_probs = actual_log_probs(ref_logits, response_ids, prompt_len)

                # Calculate the log probability of the generated sequence under the old policy
                # (current actor state). Sum the log probabilities across the sequence dimension
                # (dim=1). Squeeze to get a scalar tensor.
                log_probs_old = actor_log_probs.sum(dim=1).squeeze()

                # Get the value prediction from the critic (PEFT model) using the full sequence.
                critic_outputs = critic_model(full_ids, attention_mask=full_attention_mask)
                # The critic outputs logits; squeeze to get the scalar value prediction V(s).
                critic_value = critic_outputs.logits.squeeze()

                # Calculate the KL divergence between the actor and reference model distributions
                # for the generated sequence. KL(P || Q) approx E[log P - log Q] (sample)
                # Sum difference over sequence, get scalar.
                kl_div = (actor_log_probs - ref_log_probs).sum(dim=1).squeeze()

                # Get the base reward based on the generated text using the custom reward function.
                base_reward = animals_stories_reward(response_text)
                # Calculate the final reward used for PPO update: base reward minus the KL penalty.
                total_reward = base_reward - kl_beta * kl_div

            # Calculate the advantage A(s,a) = R - V(s) using the final reward and the value
            # estimate (scalar). Detach it (to be used as a constant target below).
            advantage = (total_reward - critic_value).detach()

            for _ in range(n_epoch_updates):
                # Perform a forward pass through the current actor model again with gradients
                # enabled to get the log probabilities of the generated sequence under the
                # current policy state.
                actor_logits_new = actor_model(full_ids, attention_mask=full_attention_mask).logits
                actor_log_probs_new = actual_log_probs(actor_logits_new, response_ids, prompt_len)
                log_probs_new = actor_log_probs_new.sum(dim=1).squeeze()

                # Calculate the probability ratio: ratio = exp(log_prob_new - log_prob_old).
                # Use the detached log_probs_old as it represents the fixed policy for this step.
                ratio = torch.exp(log_probs_new - log_probs_old.detach())

                # Clip the ratio to the range [1 - epsilon, 1 + epsilon].
                clipped_ratio = torch.clamp(ratio, 1 - epsilon, 1 + epsilon)

                # The PPO loss is the negative of the surrogate objective: the minimum between
                # ratio * advantage and clipped_ratio * advantage.
                ppo_loss = -torch.min(ratio * advantage, clipped_ratio * advantage)

                # Update Actor (Policy Network)
                actor_optimizer.zero_grad()
                ppo_loss.backward()
                torch.nn.utils.clip_grad_norm_(actor_model.parameters(), max_norm=grad_clip_norm)
                actor_optimizer.step()

                # Update Critic (Value Network)
                # Calculate the target value for the critic. Return = Advantage + V(s), which should
                # approximate the total_reward.
                returns = advantage + critic_value.detach()

                # Perform a forward pass through the current critic with gradients enabled to get
                # value prediction for the current state.
                critic_outputs_for_loss = critic_model(full_ids, attention_mask=full_attention_mask)
                critic_value_for_loss = critic_outputs_for_loss.logits.squeeze()  # Scalar tensor

                # Calculate the critic loss, MSE between predicted value and target return.
                critic_loss = F.mse_loss(critic_value_for_loss, returns)

                critic_optimizer.zero_grad()
                critic_loss.backward()
                torch.nn.utils.clip_grad_norm_(critic_model.parameters(), max_norm=grad_clip_norm)
                critic_optimizer.step()

            # Store metrics for this step.
            metrics["raw_scores"].append(base_reward)
            metrics["rewards"].append(total_reward.item())
            metrics["kl_divs"].append(kl_div.item())
            metrics["ppo_losses"].append(ppo_loss.item())
            metrics["critic_losses"].append(critic_loss.item())

        # Epoch average metrics.
        avg_raw_score = sum(metrics["raw_scores"]) / len(metrics["raw_scores"])
        avg_final_reward = sum(metrics["rewards"]) / len(metrics["rewards"])
        avg_kl_div = sum(metrics["kl_divs"]) / len(metrics["kl_divs"])
        avg_ppo_loss = sum(metrics["ppo_losses"]) / len(metrics["ppo_losses"])
        avg_critic_loss = sum(metrics["critic_losses"]) / len(metrics["critic_losses"])
        print(f"[Epoch {epoch+1}/{num_epochs}] ", end="")
        print(
            f"Reward: {avg_raw_score:.2f} | Reward (w/ KL): {avg_final_reward:.2f} | "
            + f"KL Div: {avg_kl_div:.4f} | PPO Loss: {avg_ppo_loss:.4f} | "
            + f"Critic Loss: {avg_critic_loss:.4f}"
        )

        # Sample response for qualitative progress assessment.
        actor_model.eval() # eval mode (e.g., disable dropouts)
        sample_prompt = SAMPLE_PROMPT
        _, _, sample_response = generate(tokenizer, actor_model, sample_prompt, max_length)
        print(f"[SAMPLE] {sample_prompt}{sample_response.replace('\n', ' ')}\n")

    print("Training finished!")

## Run Training!

During training we should be monitoring the KL divergence, which should be positive but not too
large. Also the losses should trend downwards, but they can be very noisy.


In [None]:
# Takes ~15 minutes on a modern CPU/GPU (RTX 4090).
rlhf_fine_tuning(actor_model, critic_model)

[Epoch 1/50] Reward: -0.59 | Reward (w/ KL): -0.59 | KL Div: -0.0093 | PPO Loss: 0.2125 | Critic Loss: 10.6719
[SAMPLE] Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, bright rainbow in the sky. It was so pretty!  Lily ran to her friend, a little boy named Max, and said, "

[Epoch 2/50] Reward: -1.06 | Reward (w/ KL): -0.98 | KL Div: -0.7448 | PPO Loss: -0.1894 | Critic Loss: 12.3777
[SAMPLE] Once upon a time, there was a little girl named Lily. She loved to play outside in her backyard. One day, she saw a big swing hanging from a tree. She ran over to it and started to swing back and forth. It was so much fun!  Suddenly, her mom

[Epoch 3/50] Reward: -0.47 | Reward (w/ KL): -0.35 | KL Div: -1.2162 | PPO Loss: -0.2504 | Critic Loss: 16.1696
[SAMPLE] Once upon a time there was a little boy named Jack. He was 3 years old and liked to explore. One day he went on a walk with his mom and dad.  Jack was so excited to ex

Let's now see how our model generates stories: hopefully, we will get a strong preference towards
animals' stories!


In [None]:
_, _, sample_response = generate(tokenizer, actor_model, SAMPLE_PROMPT)
print(f"[SAMPLE] {SAMPLE_PROMPT}{sample_response.replace('\n', ' ')}")

[SAMPLE] Once upon a time, there was a little bear named Tim. Tim was very hungry. He went to the woods to find some food. Tim saw a rabbit in the woods. The rabbit looked hungry too. Tim said, "Hi, rabbit! Do you want to share some food with me?"


## Pseudo-statistic

Finally, let's compute an approximate statistic about how often the fine-tuned model uses words
related to animals compared to the reference model. We should see some bias :)


In [None]:
ref_model_count, ref_model_total, fine_tuned_count, fine_tuned_total = 0, 0, 0, 0

for i in range(100):
    print(f"\rSample: {i} ...", end="")
    _, _, ref_model_text = generate(tokenizer, ref_model, SAMPLE_PROMPT)
    _, _, fine_tuned_text = generate(tokenizer, actor_model, SAMPLE_PROMPT)
    ref_model_text_words = ref_model_text.split()
    fine_tuned_text_words = fine_tuned_text.split()
    ref_model_total += len(ref_model_text_words)
    fine_tuned_total += len(fine_tuned_text_words)
    for word in RLHF.ANIMAL_WORDS:
        ref_model_count += ref_model_text_words.count(word)
        fine_tuned_count += fine_tuned_text_words.count(word)

print(f"\rFined tuned: {(fine_tuned_count / fine_tuned_total) * 100:.1f}%")
print(f"Reference model: {(ref_model_count / ref_model_total) * 100:.1f}%")

Fined tuned: 6.5%
Reference model: 0.5%


## Other / Most Recent Techniques

Most recent techniques that represent good alternatives to RLHF are
[DPO](https://arxiv.org/abs/2305.18290) and [GRPO](https://arxiv.org/abs/2402.03300). In particular,
when you have verifiable tasks (either via script, or via another LLM!), GRPO works great and
employs reward functions similar to this notebook's example. Deeplearning.AI has a great
[introductory course on GRPO](https://learn.deeplearning.ai/courses/reinforcement-fine-tuning-llms-grpo),
totally recommended!
