In [10]:
import sys
import time
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Callable, Union

import os
import einops
import numpy as np
import torch as t
import torch.nn as nn
import wandb
# !pip install eindex
import eindex
from jaxtyping import Float, Int
from rich import print as rprint
from rich.table import Table
from torch import Tensor
# !pip install transformer_lens
from transformer_lens import HookedTransformer, utils
from transformer_lens.hook_points import HookPoint

# Make sure exercises are in the path
chapter = r"chapter2_rl"
exercises_dir = Path(f"{os.getcwd().split(chapter)[0]}/{chapter}/exercises").resolve()
section_dir = exercises_dir / "part4_rlhf"
if str(exercises_dir) not in sys.path: sys.path.append(str(exercises_dir))

import part4_rlhf.tests as tests
import part4_rlhf.solutions as solutions

device = t.device('mps' if t.backends.mps.is_available() else 'cuda' if t.cuda.is_available() else 'cpu')

MAIN = __name__ == "__main__"

LOW_GPU_MEM = False
BASE_MODEL = "gpt2-small" if LOW_GPU_MEM else "gpt2-medium"

  return torch._C._cuda_getDeviceCount() > 0


In [11]:
"""
First define self.base_model and self.value_head in your init step (reminder that you should use HookedTransformer.from_pretrained to load in a pretrained model). Then rewrite the forward method so that it outputs both the logits from a forward pass and the output of the value head.

The easiest and most direct way to get the output of the value head would be to add a hook to the residual stream before the unembedding matrix, which computes the output of the value head and stores it externally (or as a class attribute). You can review the material from section 1.2 if you don't remember how to use hooks, and you can refer to the diagram on the reference page for how to get the correct hook name.

Why do we need to add the hook after the layernorm? The answer is that the residual stream can often grow in magnitude over time. Our rewards will be normalized (see later exercise), and so we want to make sure the outputs of our value head (which are estimates of the reward) also start off normalized.
"""
class TransformerWithValueHead(nn.Module):
    """
    Defines a GPT model with a value head (the latter taking the last hidden state as input,
    post-layernorm).

    The value head is a simple MLP with one hidden layer, and scalar output:

        Linear(d_model -> 4*d_model)
        ReLU
        Linear(4*d_model -> 1)

    All linear layers have biases.
    """
    base_model: HookedTransformer
    value_head: nn.Sequential

    def __init__(self, base_model: str = BASE_MODEL):
        super().__init__()
        self.base_model = HookedTransformer.from_pretrained(base_model)
        d_model = self.base_model.cfg.d_model
        self.value_head = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.ReLU(),
            nn.Linear(4 * d_model, 1)
        )
        self.value_head_output = None

    def forward(self, input_ids: Int[Tensor, "batch seq"]) -> tuple[
        Float[Tensor, "batch seq d_vocab"],
        Int[Tensor, "batch seq"]
    ]:
        def calc_and_store_value_head_output(resid_post, hook):
            self.value_head_output = self.value_head(resid_post).squeeze(-1)
        logits = self.base_model.run_with_hooks(
            input_ids,
            return_type="logits",
            fwd_hooks=[(utils.get_act_name("normalized"), calc_and_store_value_head_output)]
        )
        return logits, self.value_head_output



# Define a reference model (we'll use this during RLHF)
model = TransformerWithValueHead().to(device)

# Test your value head's architecture
assert isinstance(model.base_model, HookedTransformer), "Your model should have a HookedTransformer as its `base_model` attribute."
assert isinstance(model.value_head, nn.Sequential), "Your model should have a `value_head` attribute that is a `nn.Sequential`."
d_model = model.base_model.cfg.d_model
assert len(model.value_head) == 3, "Your value head should be a `nn.Sequential` with 3 layers."
assert sum(p.numel() for p in model.value_head.parameters()) == (d_model+1)*4*d_model + (4*d_model+1), "Your value head should have the correct number of parameters."

# Test your class's forward pass
input_ids = t.randint(0, 1000, (1, 10)).to(device)
logits, values = model(input_ids)
assert logits.shape == (*input_ids.shape, model.base_model.cfg.d_vocab), "Your model's logits should have shape (batch, seq, d_vocab)."
assert values.shape == input_ids.shape, "Your model's value head should give you an output for every token in your input. Did you forget to squeeze the out_features=1 dim?"

print("All tests for `TransformerWithValueHead` passed!")

config.json:   0%|          | 0.00/718 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.52G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]



Loaded pretrained model gpt2-medium into HookedTransformer
All tests for `TransformerWithValueHead` passed!


In [12]:
@t.no_grad()
def get_samples(base_model: HookedTransformer, prompt: str, batch_size: int, gen_len: int, temperature: float):
    """
    Generates samples from the model, which will be fed into the reward model and evaluated.

    Inputs:
        gpt: the transformer to generate samples from (note we use gpt, not the model wrapper, cause we don't need value head)
        prompt: the initial prompt fed into the model
        batch_size: the number of samples to generate
        gen_len: the length of the generated samples (i.e. the number of *new* tokens to generate)

    Returns:
        sample_ids: the token ids of the generated samples (including initial prompt)
        samples: the generated samples (including initial prompt)
    """
    # Make sure we've passed in the base model (the bit we use for sampling)
    assert not isinstance(
        base_model, TransformerWithValueHead
    ), "Please pass in the base model, not the model wrapper."

    # Convert our prompt into tokens
    input_ids = base_model.to_tokens(prompt, prepend_bos=False).squeeze(0)

    # Generate samples (we repeat the input ids which is a bit wasteful but ¯\_(ツ)_/¯)
    input_ids = einops.repeat(input_ids, "seq -> batch seq", batch=batch_size)

    # Generate samples
    output_ids = base_model.generate(
        input_ids,
        max_new_tokens=gen_len,
        stop_at_eos=False,
        temperature=temperature,  # higher means more random completions
        verbose=False,
    )
    samples = base_model.to_string(output_ids)

    return output_ids.clone(), samples

In [21]:
"""
We'll start with a very basic reward function: counting the total number of periods in the sequence. For convenience, you should write your reward function to take in either a single sequence or a list of sequences (it will correspondingly return either a float or a tensor of floats).

An interesting thing to note about this reward function - it counts over all characters, but the episode length is defined in terms of tokens. This means that theoretically our model could reward hack by outputting tokens with more than one . character. This particular model's vocabulary happens to include the token '.' * 64, so rewards would be through the roof if this was ever generated! However, remember that RL is about performing actions, getting feedback on those actions, and using that feedback to influence your policy. The token '.' * 64 is so unlikely to ever be generated that it'll probably never be positively reinforced, and we avoid this problem.

"""

def reward_fn_char_count(generated_sample: str | list[str], char: str = '.') -> float | Float[Tensor, "batch"]:
    """
    Reward function, evaluated on the generated samples.

    In this case it's very simple: it just counts the number of instances of a particular character in
    the generated sample. It returns a tensor of rewards of dtype float the input is a list, or a single
    reward (float) if the input is a string.
    """
    if isinstance(generated_sample, list):
        return (
            t.tensor([reward_fn_char_count(item) for item in generated_sample]).float()
        )
    else:
        return float(generated_sample.count(char))


# Test your reward function
A = 'This is a test.'
B = '......'
C = 'Whatever'
assert isinstance(reward_fn_char_count(A), float)
assert reward_fn_char_count(A) == 1
assert reward_fn_char_count(B) == 6
assert reward_fn_char_count(C) == 0
assert reward_fn_char_count([A, B, C]).dtype == t.float
assert reward_fn_char_count([A, B, C]).tolist() == [1.0, 6.0, 0.0]

print('All tests for `reward_fn_char_count` passed!')

All tests for `reward_fn_char_count` passed!


In [28]:
"""
Following advice from Ziegler el al. (2019), it's important to normalize the reward function over each batch (i.e. subtract mean and divide by std dev). We've been able to get away with not doing this so far because our reward functions were usually nicely bounded, e.g. the reward was always zero or one in cartpole (and even in our reward shaping it was still in the zero-one range). But if we're working with reward functions that could be much higher variance such as the number of periods in a generated sequence, then we should normalize.

Note - we're not super strict about this function; the denominator being std + eps or (var + eps).sqrt() are both fine.
"""

def normalize_reward(reward: Float[Tensor, "batch_size"], eps=1e-5) -> Float[Tensor, "batch_size"]:
    """
    Normalizes the reward function values over the batch of sequences.
    """
    mean = reward.mean()
    std_dev = reward.std() + eps
    return (reward - mean) / std_dev


# Test your reward normalization function
reward = 10 + 5 * t.randn(10_000)
reward_normalized = normalize_reward(reward)
assert reward_normalized.mean().abs() < 1e-4
assert (reward_normalized.std() - 1).abs() < 1e-4
# Test edge case of zero reward
reward = t.zeros(5)
reward_normalized = normalize_reward(reward)
assert reward_normalized.abs().sum() < 1e-4

print('All tests for `normalize_reward` passed!')

All tests for `normalize_reward` passed!


In [29]:
@dataclass
class RLHFTrainingArgs():
    # Basic / global
    seed: int = 1
    cuda: bool = t.cuda.is_available()

    # Wandb / logging
    exp_name: str = "RLHF_Implementation"
    wandb_project_name: str | None = "ch2-day4-rlhf"
    wandb_entity: str | None = None  
    use_wandb: bool = False

    # Duration of different phases
    total_phases: int = 200
    batch_size: int = 256
    num_minibatches: int = 4
    batches_per_learning_phase: int = 2

    # Optimization hyperparameters
    base_learning_rate: float = 2e-5
    head_learning_rate: float = 5e-4
    max_grad_norm: float = 1.0
    warmup_steps: int = 20
    final_scale: float = 0.1

    # Computing other PPO loss functions
    clip_coef: float = 0.2
    vf_coef: float = 0.15
    ent_coef: float = 0.001

    # Base model & sampling arguments
    base_model: str = BASE_MODEL
    gen_len: int = 30
    temperature: float = 0.6
    prefix: str = "This is"

    # Extra stuff for RLHF
    kl_coef: float = 1.0
    reward_fn: Callable = reward_fn_char_count
    normalize_reward: bool = True

    def __post_init__(self):
        assert (
            self.batch_size % self.num_minibatches == 0
        ), "Batch size should be divisible by the number of minibatches."
        self.minibatch_size = self.batch_size // self.num_minibatches

In [37]:
@t.no_grad()
def compute_advantages(
    values: Float[Tensor, "minibatch_size seq_len"],
    rewards: Float[Tensor, "minibatch_size"],
    prefix_len: int,
) -> Float[Tensor, "minibatch_size gen_len"]:
    """
    Computes the advantages for the PPO loss function, i.e. A_pi(s, a) = Q_pi(s, a) - V_pi(s).

    In this formula we replace Q(s, a) with the 1-step Q estimates, and V(s) with the 0-step value estimates.

    Inputs:
        values:
            the value estimates for each token in the generated sequence
        rewards:
            the rewards for the entire generated sequence
        prefix_len:
            the length of the prefix (i.e. the length of the initial prompt)

    Returns:
        advantages:
            the advantages for each token in the generated sequence (not the entire sequence)
    """
    one_step_q_est = t.cat(
        [
            values[:, prefix_len:-1], #shape [minibatch_size, gen_len-1]
            rewards[:, None], #shape [minibatch_size, 1]
        ],
        dim=-1
    )

    zero_step_value_est = values[:, prefix_len -1 : -1] # shape [minibatch_size, gen_len]
    advantages = one_step_q_est - zero_step_value_est

    return advantages

# Test the function
tests.test_compute_advantages(compute_advantages)

All tests in `test_compute_advantages` passed!


In [44]:
@dataclass
class ReplayMinibatch:
    """
    Samples from the replay memory.
    """
    sample_ids: Float[Tensor, "minibatch_size seq_len"]
    logprobs: Float[Tensor, "minibatch_size seq_len"]
    advantages: Float[Tensor, "minibatch_size gen_len"]
    returns: Float[Tensor, "minibatch_size gen_len"]
    ref_logits: Float[Tensor, "minibatch_size seq_len d_vocab"]


class ReplayMemory:
    def __init__(
        self,
        args: RLHFTrainingArgs,
        sample_ids: Float[Tensor, "batch_size seq_len"],
        logprobs: Float[Tensor, "batch_size seq_len"],
        advantages: Float[Tensor, "batch_size gen_len"],
        values: Float[Tensor, "batch_size seq_len"],
        ref_logits: Float[Tensor, "batch_size seq_len d_vocab"],
    ):
        """
        Initializes the replay memory, with all the data generated from the rollout phase at once.

        The advantages are (batch_size, gen_len) because we only compute advantages for the generated
        tokens. The other tensors are (batch_size, seq_len) because they are computed for all tokens.
        """
        self.args = args
        self.sample_ids = sample_ids
        self.logprobs = logprobs
        self.advantages = advantages
        self.values = values
        self.ref_logits = ref_logits


    def get_minibatches(self):
        """
        Generates a list of minibatches by randomly sampling from the replay memory. Each sequence appears
        exactly `batches_per_learning_phase` times in total.
        """
        minibatches = []

        returns = self.advantages + self.values[:, -self.args.gen_len-1:-1]

        for _ in range(self.args.batches_per_learning_phase):

            idxs = t.randperm(self.args.batch_size).reshape(self.args.num_minibatches, self.args.minibatch_size)

            for idx in idxs:
                minibatches.append(
                    ReplayMinibatch(
                        sample_ids = self.sample_ids[idx],
                        logprobs = self.logprobs[idx],
                        advantages = self.advantages[idx],
                        returns = returns[idx],
                        ref_logits = self.ref_logits[idx],
                        )
                    )

        return minibatches

In [45]:
def calc_kl_penalty(
    logits: Float[Tensor, "minibatch_size seq_len d_vocab"],
    ref_logits: Float[Tensor, "minibatch_size seq_len d_vocab"],
    kl_coef: float,
    prefix_len: int,
) -> Float[Tensor, ""]:
    """
    Computes the KL divergence between the logits and the reference logits, scaled
    by the penalty function. This is used to stop the learned policy from diverging
    too much from the original reference model's policy.

    logits:
        The logits of the generated samples (under the new model).
    ref_logits:
        The logits of the generated samples (under the reference model).
    kl_coef:
        The coefficient of the KL penalty.
    prefix_len:
        The length of the prefix to ignore when computing the KL divergence.
    """
    # convert the logits to probs using softmax
    logprobs = logits.log_softmax(dim=-1)
    ref_logprobs = ref_logits.log_softmax(dim=-1)

    # compute the probabilities
    probs = logprobs.exp()

    # compute the KL divergence
    kl_div = (probs * (logprobs - ref_logprobs)).sum(dim=-1)

    kl_div = kl_div[:, prefix_len:]

    return kl_coef * kl_div.mean()



tests.test_calc_kl_penalty(calc_kl_penalty)
tests.test_calc_kl_penalty_stability(calc_kl_penalty)

Exception: Error: your kl_div=0.4819464087486267, kl_div_correct=0.48865893483161926. Did you take the correct slice over sequence positions?