# RLHF End-to-End (SFT ➜ Reward Model ➜ PPO RLHF)

This single notebook combines the three steps from the original notebooks:
- Supervised Fine-Tuning (SFT)
- Reward Model (RM) training
- PPO-based RLHF training

It deduplicates setup, reuses the same tokenizer and base model, and uses Windows/VS Code friendly steps (no Colab-specific shell commands).

## 0) Setup

In [None]:
%pip install -q transformers datasets==3.5.0 scikit-learn

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/491.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m15.5 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/183.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m183.9/183.9 kB[0m [31m15.1 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.0 requires fsspec==2025.3.0, but you have fsspec 2024.12.0 which is incompatible.[0m[31m
[0m

In [None]:
import torch
from torch import nn
import numpy as np
import random
from copy import deepcopy

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer, DataCollatorForLanguageModeling, DataCollatorWithPadding
from torch.utils.data import DataLoader

model_name = 'gpt2'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)

# Ensure pad token is set for GPT-2 flows
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

---
## 1) Supervised Fine-Tuning (SFT)
We train the base model on SST-2 sentences as a language modeling task to produce an SFT checkpoint used later as the reference policy.

### 1.1 Tokenizer quick check

In [None]:
text = 'Hello, this is the first step of RLHF training.'
tokens = tokenizer(text)
print(tokens)
print(tokenizer.decode(tokens['input_ids']))

{'input_ids': [15496, 11, 428, 318, 262, 717, 2239, 286, 45715, 29567, 3047, 13], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Hello, this is the first step of RLHF training.


### 1.2 Load and inspect dataset (SST-2)

In [None]:
dataset_name = 'sst2'
ds = load_dataset(dataset_name)
ds_train, ds_val = ds['train'], ds['validation']
ds

README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

data/validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 872
    })
    test: Dataset({
        features: ['idx', 'sentence', 'label'],
        num_rows: 1821
    })
})

### 1.3 Tokenize dataset for causal LM

In [None]:
# Function to tokenize the sentences for Supervised Fine-Tuning (SFT)
def tokenize_sft(batch):
    # Use the tokenizer to encode the sentences in the batch
    return tokenizer(batch['sentence'])

# Keyword arguments for the map function to process the dataset in batches
map_kwargs = {
    'batched': True,  # Process the dataset in batches
    'batch_size': 512, # Define the size of each batch
    'remove_columns': ['idx', 'sentence', 'label'] # Remove original columns after tokenization
}

# Apply the tokenization function to the training and validation datasets
tokenized_sft_train = ds_train.map(tokenize_sft, **map_kwargs)
tokenized_sft_val = ds_val.map(tokenize_sft, **map_kwargs)

# Filter out sequences that are too short (less than or equal to 5 tokens)
tokenized_sft_train = tokenized_sft_train.filter(lambda x: len(x['input_ids']) > 5)
tokenized_sft_val = tokenized_sft_val.filter(lambda x: len(x['input_ids']) > 5)

# Set the format of the datasets to 'torch' tensors for PyTorch compatibility
tokenized_sft_train.set_format(type='torch')
tokenized_sft_val.set_format(type='torch')

# Print the number of examples in the filtered training and validation datasets
len(tokenized_sft_train), len(tokenized_sft_val)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/67349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/872 [00:00<?, ? examples/s]

(49401, 867)

### 1.4 Dataloaders with padding

In [None]:
# Initialize a DataCollator for Language Modeling.
# This collator will be used to pad the tokenized sequences to the same length
# within each batch and create the necessary attention masks for the model.
# mlm=False indicates that we are not using Masked Language Model (MLM) objective,
# but rather a Causal Language Model (CLM) objective.
data_collator_lm = DataCollatorForLanguageModeling(tokenizer, mlm=False)

# Create a DataLoader for the tokenized SFT training dataset.
# The DataLoader shuffles the data and organizes it into batches for training.
# collate_fn specifies how to batch the data, using the data_collator_lm for padding.
train_loader_sft = DataLoader(tokenized_sft_train, batch_size=32, collate_fn=data_collator_lm)

# Create a DataLoader for the tokenized SFT validation dataset.
# collate_fn specifies how to batch the data, using the data_collator_lm for padding.
val_loader_sft = DataLoader(tokenized_sft_val, batch_size=32, collate_fn=data_collator_lm)

# Print the number of batches in the training DataLoader.
# This gives an idea of the training loop's iteration count per epoch.
len(train_loader_sft)

1544

### 1.5 Train SFT

In [None]:
# Initialize the AdamW optimizer with the model's parameters and a learning rate.
# AdamW is a popular optimizer that includes weight decay.
optimizer_sft = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Define the number of training epochs for SFT. An epoch is one full pass through the training dataset.
num_epochs_sft = 1

# Define a function to evaluate the model on the validation set.
def validate_sft(epoch):
    # Set the model to evaluation mode. This disables dropout and batch normalization updates.
    model.eval()
    total_loss = 0.0
    # Iterate over the validation data loader.
    for batch in val_loader_sft:
        # Move the batch tensors to the appropriate device (GPU or CPU).
        batch = {k: v.to(device) for k, v in batch.items()}
        # Disable gradient calculation during validation to save memory and speed up computation.
        with torch.no_grad():
            # Pass the batch through the model to get the outputs (including the loss).
            outputs = model(**batch)
            # Add the loss of the current batch to the total loss.
            total_loss += outputs.loss.item()
    # Print the average validation loss for the current epoch.
    print(f'[SFT] val_loss at epoch {epoch}:', total_loss / max(1, len(val_loader_sft)))

# Run validation before starting the training to see the initial performance.
validate_sft(0)

# Start the training loop for the specified number of epochs.
for epoch in range(num_epochs_sft):
    # Set the model back to training mode. This enables dropout and batch normalization updates.
    model.train()
    # Iterate over the training data loader.
    for batch in train_loader_sft:
        # Move the batch tensors to the appropriate device (GPU or CPU).
        batch = {k: v.to(device) for k, v in batch.items()}
        # Pass the batch through the model to get the outputs (including the loss).
        outputs = model(**batch)
        # Get the loss from the model outputs.
        loss = outputs.loss
        # Zero out the gradients from the previous training step.
        optimizer_sft.zero_grad()
        # Perform backpropagation to calculate the gradients of the loss with respect to the model's parameters.
        loss.backward()
        # Update the model's parameters using the calculated gradients and the optimizer.
        optimizer_sft.step()
    # Run validation after each training epoch to monitor performance.
    validate_sft(epoch + 1)

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


[SFT] val_loss at epoch 0: 5.181761639458792
[SFT] val_loss at epoch 1: 3.9878671084131514


### 1.6 Save SFT model

In [None]:
sft_dir = './sft_model_epoch_1'
model.save_pretrained(sft_dir)

(Optional) Zip the saved SFT model directory into a .zip (cross-platform).

In [None]:
import shutil
shutil.make_archive('sft_model_epoch_1', 'zip', sft_dir)

'/content/sft_model_epoch_1.zip'

---
## 2) Reward Model (RM) Training
We add a small reward head on top of GPT-2 and train it to predict sentiment (as a proxy) at the final token position.

In [None]:
REWARD_TOKEN_ID = tokenizer.eos_token_id
REWARD_TOKEN_ID

50256

In [None]:
def tokenize_rm(batch):
    outputs = tokenizer(batch['sentence'])
    outputs['score'] = [0] * len(outputs['input_ids'])
    outputs['score_index'] = [0] * len(outputs['input_ids'])
    for i in range(len(outputs['input_ids'])):
        outputs['input_ids'][i].append(REWARD_TOKEN_ID)
        outputs['attention_mask'][i].append(1)
        outputs['score'][i] = float(batch['label'][i])
        outputs['score_index'][i] = len(outputs['input_ids'][i]) - 1
    return outputs

map_kwargs_rm = {
    'batched': True,
    'batch_size': 512,
    'remove_columns': ['idx', 'sentence', 'label']
}

tokenized_rm_train = ds_train.map(tokenize_rm, **map_kwargs_rm)
tokenized_rm_val = ds_val.map(tokenize_rm, **map_kwargs_rm)

# format and filter
tokenized_rm_train.set_format(type='torch')
tokenized_rm_val.set_format(type='torch')
tokenized_rm_train = tokenized_rm_train.filter(lambda x: len(x['input_ids']) > 6)
tokenized_rm_val = tokenized_rm_val.filter(lambda x: len(x['input_ids']) > 6)
len(tokenized_rm_train), len(tokenized_rm_val)

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Filter:   0%|          | 0/67349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/872 [00:00<?, ? examples/s]

(49401, 867)

In [None]:
class RewardHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.reward = nn.Linear(self.hidden_size, 1)
        nn.init.normal_(self.reward.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.reward.bias)

    def forward(self, hidden_states):
        return self.reward(hidden_states)

class GPT2RewardModel(nn.Module):
    def __init__(self, model_name):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_name)
        self.reward_head = RewardHead(self.llm.config)

    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.llm.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        )
        last_hidden_state = transformer_outputs.hidden_states[-1]
        reward = self.reward_head(last_hidden_state).squeeze(-1)
        return torch.sigmoid(reward)

In [None]:
rm_model = GPT2RewardModel(model_name).to(device)
data_collator_rm = DataCollatorWithPadding(tokenizer)
train_loader_rm = DataLoader(tokenized_rm_train, batch_size=64, shuffle=True, collate_fn=data_collator_rm)
val_loader_rm = DataLoader(tokenized_rm_val, batch_size=64, shuffle=True, collate_fn=data_collator_rm)
batch_rm = next(iter(train_loader_rm))
list(batch_rm.keys())

['input_ids', 'attention_mask', 'score', 'score_index']

### 2.1 Train Reward Model

In [None]:
optimizer_rm = torch.optim.AdamW(rm_model.parameters(), lr=1e-4)
criterion_rm = nn.BCELoss()
num_epochs_rm = 1

def validate_rm():
    rm_model.eval()
    total_loss = 0.0
    for batch in val_loader_rm:
        inputs = {k: v.to(device) for k, v in batch.items()}
        model_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        with torch.no_grad():
            scores = rm_model(**model_inputs)
            b_idx = torch.arange(scores.shape[0], device=device)
            score = scores[b_idx, inputs['score_index']]
            target = inputs['score']
            loss = criterion_rm(score, target)
        total_loss += loss.item()
    print('[RM] validation loss:', total_loss / max(1, len(val_loader_rm)))

validate_rm()
for epoch in range(num_epochs_rm):
    rm_model.train()
    for batch in train_loader_rm:
        inputs = {k: v.to(device) for k, v in batch.items()}
        model_inputs = {
            'input_ids': inputs['input_ids'],
            'attention_mask': inputs['attention_mask']
        }
        scores = rm_model(**model_inputs)
        b_idx = torch.arange(scores.shape[0], device=device)
        score = scores[b_idx, inputs['score_index']]
        target = inputs['score']
        loss = criterion_rm(score, target)
        optimizer_rm.zero_grad()
        loss.backward()
        optimizer_rm.step()
    validate_rm()

[RM] validation loss: 4.810159223420279
[RM] validation loss: 0.34349739285452024


### 2.2 Save Reward Model

In [None]:
torch.save(rm_model.state_dict(), 'reward_model.pt')

(Optional) Quick confusion-matrix style check (threshold 0.5).

In [None]:
from sklearn.metrics import confusion_matrix
rm_model.eval()
all_predictions, all_labels = [], []
for batch in val_loader_rm:
    inputs = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        scores = rm_model(inputs['input_ids'], inputs['attention_mask'])
        b_idx = torch.arange(scores.shape[0], device=device)
        score = scores[b_idx, inputs['score_index']]
        target = inputs['score']
    predictions = (score > 0.5).int()
    all_predictions.extend(predictions.detach().cpu().numpy())
    all_labels.extend(target.detach().cpu().numpy())
confusion_matrix(all_labels, all_predictions)

array([[348,  76],
       [ 26, 417]])

---
## 3) PPO RLHF Training
We use the SFT model as the reference and train a value-head model with PPO-style updates using the Reward Model for feedback.

In [None]:
# Load Reward Model weights
reward_model = GPT2RewardModel(model_name).to(device)
reward_model.load_state_dict(torch.load('reward_model.pt', map_location=device))
reward_model.eval()

GPT2RewardModel(
  (llm): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=768, out_feature

In [None]:
class ValueHead(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.value = nn.Linear(self.hidden_size, 1)
        nn.init.normal_(self.value.weight, std=(1.0 / np.sqrt(self.hidden_size + 1)))
        nn.init.zeros_(self.value.bias)
    def forward(self, hidden_states):
        return self.value(hidden_states)

class ModelForCausalLMWithValueHead(nn.Module):
    def __init__(self, model_path):
        super().__init__()
        self.llm = AutoModelForCausalLM.from_pretrained(model_path)
        self.v_head = ValueHead(self.llm.config)
    def forward(self, input_ids, attention_mask):
        transformer_outputs = self.llm.forward(
            input_ids=input_ids, attention_mask=attention_mask, output_hidden_states=True
        )
        lm_logits = transformer_outputs.logits
        last_hidden_state = transformer_outputs.hidden_states[-1]
        value = self.v_head(last_hidden_state).squeeze(-1)
        return lm_logits, value
    def generate(self, *args, **kwargs):
        return self.llm.generate(*args, **kwargs)

In [None]:
from sklearn.metrics import confusion_matrix
rm_model.eval()
all_predictions, all_labels = [], []
for batch in val_loader_rm:
    inputs = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        scores = rm_model(inputs['input_ids'], inputs['attention_mask'])
        b_idx = torch.arange(scores.shape[0], device=device)
        score = scores[b_idx, inputs['score_index']]
        target = inputs['score']
    predictions = (score > 0.5).int()
    all_predictions.extend(predictions.detach().cpu().numpy())
    all_labels.extend(target.detach().cpu().numpy())
confusion_matrix(all_labels, all_predictions)

array([[348,  76],
       [ 26, 417]])

### 3.1 Prepare RLHF dataset (queries)

In [None]:
# Reuse original SST-2 split but filter for longer sentences (for more interesting generations)
ds_train_rlhf = ds_train.filter(lambda x: len(x['sentence'].split(' ')) > 8)
ds_val_rlhf = ds_val.filter(lambda x: len(x['sentence'].split(' ')) > 8)

input_min_token_length, input_max_token_length = 2, 8
input_token_length_range = list(range(input_min_token_length, input_max_token_length))

def tokenize_query(sample):
    size = random.choice(input_token_length_range)
    sample['input_ids'] = tokenizer.encode(sample['sentence'])[:size]
    sample['attention_mask'] = [1] * len(sample['input_ids'])
    sample['query'] = tokenizer.decode(sample['input_ids'])
    return sample

map_kwargs_q = { 'batched': False, 'remove_columns': ['idx', 'sentence', 'label'] }
tokenized_train_rlhf = ds_train_rlhf.map(tokenize_query, **map_kwargs_q)
tokenized_val_rlhf = ds_val_rlhf.map(tokenize_query, **map_kwargs_q)
tokenized_train_rlhf.set_format(type='torch')
tokenized_val_rlhf.set_format(type='torch')

def list_collator(batch):
    return {k: [d[k] for d in batch] for k in batch[0]}

batch_size = 32
train_loader_rlhf = DataLoader(tokenized_train_rlhf, batch_size=batch_size, shuffle=True, collate_fn=list_collator)
val_loader_rlhf = DataLoader(tokenized_val_rlhf, batch_size=batch_size, shuffle=True, collate_fn=list_collator)
len(tokenized_train_rlhf), len(tokenized_val_rlhf)

Filter:   0%|          | 0/67349 [00:00<?, ? examples/s]

Filter:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/31105 [00:00<?, ? examples/s]

Map:   0%|          | 0/807 [00:00<?, ? examples/s]

(31105, 807)

### 3.2 Generation and scoring utilities

In [None]:
output_min_length, output_max_length = 5, 16
generation_kwargs = {
    'min_length': -1,
    'top_k': 0.0,
    'top_p': 1.0,
    'do_sample': True,
    'pad_token_id': tokenizer.pad_token_id
}

In [None]:
def batch_generate_and_score(batch):
    query_tensors = batch['input_ids']
    query_attention_masks = batch['attention_mask']
    response_tensors, query_response_tensors, score_tensors = [], [], []
    for i, query in enumerate(query_tensors):
        query = query.to(device)
        query_attention_mask = query_attention_masks[i].to(device)
        new_tokens = random.choice(list(range(output_min_length, output_max_length)))
        generation_kwargs['max_new_tokens'] = new_tokens
        query_response = ppo_model.generate(
            input_ids=query.unsqueeze(0),
            attention_mask=query_attention_mask.unsqueeze(0),
            **generation_kwargs
        ).squeeze(0)
        response_len = len(query_response) - len(query)
        response_tensors.append(query_response[-response_len:])
        query_response_tensors.append(query_response)
        with torch.no_grad():
            qr_score_ids = torch.cat([query_response, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(qr_score_ids, dtype=torch.long)
            score = reward_model(qr_score_ids.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
        score_tensors.append(score)
    batch['response'] = [tokenizer.decode(r) for r in response_tensors]
    return query_tensors, response_tensors, query_response_tensors, score_tensors

### 3.3 Compute rewards, advantages, returns

In [None]:
# Utilities to compute rewards and advantages
from transformers import DataCollatorWithPadding

data_collator_pad = DataCollatorWithPadding(tokenizer=tokenizer)

def compute_rewards(input_data, query_tensors, response_tensors, score_tensors):
    with torch.no_grad():
        logits, values = ppo_model(**input_data)  # b, seq, vocab
        ref_logits, _ = sft_reference(**input_data)
        logp = torch.nn.functional.log_softmax(logits[:, :-1, :], dim=-1)
        ref_logp = torch.nn.functional.log_softmax(ref_logits[:, :-1, :], dim=-1)

        labels = input_data['input_ids'][:, 1:]  # b, seq-1
        logp = torch.gather(logp, 2, labels.unsqueeze(-1)).squeeze(-1)       # b, seq-1
        ref_logp = torch.gather(ref_logp, 2, labels.unsqueeze(-1)).squeeze(-1)  # b, seq-1

        kl = logp - ref_logp
        beta = 0.2
        rewards = - beta * kl
        attention_mask = input_data['attention_mask']
        masks = torch.zeros_like(attention_mask[:, 1:])
        masks[:, :] = attention_mask[:, 1:]
        for j in range(len(query_tensors)):
            start = len(query_tensors[j]) - 1
            end = start + len(response_tensors[j])
            masks[j, :start] = 0
            masks[j, end:] = 0
            rewards[j, end - 1] += score_tensors[j]
            rewards[j, :] *= masks[j, :]
            values[j, :-1] *= masks[j, :]

    return logp, rewards, values[:, :-1], masks

def masked_mean(values, mask):
    return (values * mask).sum() / (mask.sum() + 1e-8)

def masked_var(values, mask):
    mean = masked_mean(values, mask)
    centred = values - mean
    return masked_mean(centred ** 2, mask)

def masked_whiten(values, mask):
    mean, var = masked_mean(values, mask), masked_var(values, mask)
    whitened = (values - mean) * torch.rsqrt(var + 1e-8)
    whitened += mean
    return whitened

def compute_advantage(rewards, values, masks):
    lastgae = 0.0
    advantage_reversed = []
    seq_length = rewards.shape[-1]
    gamma, lam = 1.0, 0.95
    for t in reversed(range(seq_length)):
        nextvalues = values[:, t + 1] if t < seq_length - 1 else 0.0
        delta = rewards[:, t] + gamma * nextvalues - values[:, t]
        lastgae = delta + gamma * lam * lastgae
        advantage_reversed.append(lastgae)
    advantages = torch.stack(advantage_reversed[::-1], dim=1)
    advantages = masked_whiten(advantages, masks)
    returns = advantages + values
    return advantages, returns

# Safety: make sure core variables exist even if cells above weren't run in order
if 'ppo_model' not in globals():
    model_path = './sft_model_epoch_1'
    ppo_model = ModelForCausalLMWithValueHead(model_path).to(device)
if 'sft_reference' not in globals():
    from copy import deepcopy
    sft_reference = deepcopy(ppo_model).to(device)

if 'reward_model' not in globals():
    reward_model = GPT2RewardModel(model_name).to(device)
    reward_model.load_state_dict(torch.load('reward_model.pt', map_location=device))
    reward_model.eval()

if 'REWARD_TOKEN_ID' not in globals():
    REWARD_TOKEN_ID = tokenizer.eos_token_id

if 'generation_kwargs' not in globals():
    output_min_length, output_max_length = 5, 16
    generation_kwargs = {
        'min_length': -1,
        'top_k': 0.0,
        'top_p': 1.0,
        'do_sample': True,
        'pad_token_id': tokenizer.pad_token_id,
    }

val_gen_lengths = [random.choice(list(range(output_min_length, output_max_length))) for _ in range(len(tokenized_val_rlhf))]

def validate_rlhf():
    scores = []
    for b, batch in enumerate(val_loader_rlhf):
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs['max_new_tokens'] = new_tokens
            qr = ppo_model.generate(input_ids=query.unsqueeze(0), attention_mask=query_attention_mask.unsqueeze(0), **generation_kwargs).squeeze(0)
            qr_score_ids = torch.cat([qr, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(qr_score_ids, dtype=torch.long)
            score = reward_model(qr_score_ids.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('avg score:', sum(scores) / max(1, len(scores)))

validate_rlhf()

avg score: 0.2491551961892895


### 3.4 PPO mini-batch training

In [None]:
learning_rate = 1e-5
optimizer_ppo = torch.optim.AdamW(ppo_model.parameters(), lr=learning_rate)
mini_batch_size = 4
ppo_epochs = 4
cliprange_ratio = 0.2
v_loss_coeff = 0.1
ratio_threshold = 10.0

def compute_loss(old_logprobs, values, logprobs, vpreds, masks, advantages, returns):
    ratio = torch.exp(logprobs - old_logprobs)
    pg_loss1 = - ratio * advantages
    pg_loss2 = - torch.clamp(ratio, 1 - cliprange_ratio, 1 + cliprange_ratio) * advantages
    pg_loss = masked_mean(torch.max(pg_loss1, pg_loss2), masks)
    v_loss = masked_mean((vpreds - returns) ** 2, masks)
    loss = pg_loss + v_loss_coeff * v_loss
    avg_ratio = masked_mean(ratio, masks)
    if avg_ratio > ratio_threshold:
        pg_loss = pg_loss * 0.0
        v_loss = v_loss * 0.0
        loss = loss * 0.0
    return loss, v_loss

def mini_batch_train(input_data, logprobs, values, masks, advantages, returns):
    for _ in range(ppo_epochs):
        batch_inds = np.random.permutation(input_data['input_ids'].shape[0])
        for start in range(0, len(batch_inds), mini_batch_size):
            end = start + mini_batch_size
            mb_inds = batch_inds[start:end]
            mb_inputs = {
                'input_ids': input_data['input_ids'][mb_inds],
                'attention_mask': input_data['attention_mask'][mb_inds]
            }
            mb_logits, mb_vpreds = ppo_model(**mb_inputs)
            mb_logits = torch.nn.functional.log_softmax(mb_logits[:, :-1, :], dim=-1)
            mb_logprobs = torch.gather(mb_logits, 2, mb_inputs['input_ids'][:, 1:].unsqueeze(-1)).squeeze(-1)
            loss, _ = compute_loss(
                logprobs[mb_inds], values[mb_inds], mb_logprobs, mb_vpreds[:, :-1],
                masks[mb_inds], advantages[mb_inds], returns[mb_inds]
            )
            optimizer_ppo.zero_grad()
            loss.backward()
            optimizer_ppo.step()

### 3.5 Train RLHF (1 epoch demo)

In [None]:
num_epochs_rlhf = 1
for epoch in range(num_epochs_rlhf):
    for batch in train_loader_rlhf:
        q_tensors, r_tensors, qr_tensors, s_tensors = batch_generate_and_score(batch)
        input_batch = data_collator_pad([
            { 'input_ids': ids, 'attention_mask': torch.ones_like(ids) }
            for ids in qr_tensors
        ]).to(device)
        logprobs, rewards, values, masks = compute_rewards(input_batch, q_tensors, r_tensors, s_tensors)
        advantages, returns = compute_advantage(rewards, values, masks)
        mini_batch_train(input_batch, logprobs, values, masks, advantages, returns)
    print(f'[RLHF] epoch {epoch + 1} finished')

[RLHF] epoch 1 finished


### 3.6 Validate RLHF policy (avg reward)

In [None]:
val_gen_lengths = [random.choice(list(range(output_min_length, output_max_length))) for _ in range(len(tokenized_val_rlhf))]
def validate_rlhf():
    scores = []
    for b, batch in enumerate(val_loader_rlhf):
        query_tensors = batch['input_ids']
        query_attention_masks = batch['attention_mask']
        for i, query in enumerate(query_tensors):
            query = query.to(device)
            query_attention_mask = query_attention_masks[i].to(device)
            new_tokens = val_gen_lengths[b * len(query_tensors) + i]
            generation_kwargs['max_new_tokens'] = new_tokens
            qr = ppo_model.generate(input_ids=query.unsqueeze(0), attention_mask=query_attention_mask.unsqueeze(0), **generation_kwargs).squeeze(0)
            qr_score_ids = torch.cat([qr, torch.tensor([REWARD_TOKEN_ID]).to(device)])
            attention_mask = torch.ones_like(qr_score_ids, dtype=torch.long)
            score = reward_model(qr_score_ids.unsqueeze(0), attention_mask.unsqueeze(0)).squeeze(0)[-1]
            score = 2 * (score - 0.5)
            scores.append(score.item())
    print('avg score:', sum(scores) / max(1, len(scores)))

validate_rlhf()

avg score: 0.8178268348269598


### 3.7 Save PPO model

In [None]:
torch.save(ppo_model.state_dict(), 'ppo_model_epoch_1.pt')

* * *
## 4) Inference

In [None]:
# Load the trained PPO model
ppo_model_inference = ModelForCausalLMWithValueHead(model_name).to(device)
ppo_model_inference.load_state_dict(torch.load('ppo_model_epoch_1.pt', map_location=device))
ppo_model_inference.eval()

ModelForCausalLMWithValueHead(
  (llm): GPT2LMHeadModel(
    (transformer): GPT2Model(
      (wte): Embedding(50257, 768)
      (wpe): Embedding(1024, 768)
      (drop): Dropout(p=0.1, inplace=False)
      (h): ModuleList(
        (0-11): 12 x GPT2Block(
          (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (attn): GPT2Attention(
            (c_attn): Conv1D(nf=2304, nx=768)
            (c_proj): Conv1D(nf=768, nx=768)
            (attn_dropout): Dropout(p=0.1, inplace=False)
            (resid_dropout): Dropout(p=0.1, inplace=False)
          )
          (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (mlp): GPT2MLP(
            (c_fc): Conv1D(nf=3072, nx=768)
            (c_proj): Conv1D(nf=768, nx=3072)
            (act): NewGELUActivation()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
      )
      (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
    )
    (lm_head): Linear(in_features=76

Now you can test the model with a prompt.

In [None]:
prompt = "This movie was"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

# Generate a response
output_sequences = ppo_model_inference.generate(
    input_ids=input_ids,
    max_length=50,  # Adjust the max_length as needed
    num_return_sequences=1,
    **{k: v for k, v in generation_kwargs.items() if k != 'max_new_tokens'}
)

generated_text = tokenizer.decode(output_sequences[0], skip_special_tokens=True)
print("Generated text:")
print(generated_text)

Generated text:
This movie was a worth getting its foot in the door            .                           


Here are the responses from the Base LLM and the PPO fine-tuned model for comparison:

In [None]:
# Load the original base LLM
base_model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
base_model.eval()

prompt = "This movie was"
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(device)

# Generate a response from the Base LLM
base_output_sequences = base_model.generate(
    input_ids=input_ids,
    max_length=50,
    num_return_sequences=1,
    **{k: v for k, v in generation_kwargs.items() if k != 'max_new_tokens'}
)
base_generated_text = tokenizer.decode(base_output_sequences[0], skip_special_tokens=True)
print("Generated text from Base LLM:")
print(base_generated_text)

print("-" * 20) # Separator

# Generate a response from the PPO fine-tuned model
ppo_output_sequences = ppo_model_inference.generate(
    input_ids=input_ids,
    max_length=50,
    num_return_sequences=1,
    **{k: v for k, v in generation_kwargs.items() if k != 'max_new_tokens'}
)
ppo_generated_text = tokenizer.decode(ppo_output_sequences[0], skip_special_tokens=True)
print("Generated text from PPO Model:")
print(ppo_generated_text)

Generated text from Base LLM:
This movie was based off a conversation I had with an ironic Texas mom, who took a hiatus from motivational speaker-like talking about her awful system. After talking to my husband about future happiness, I realized he didn't believe my example, because "
--------------------
Generated text from PPO Model:
This movie was , as its detractors , delightful .                       .                 
