This code is based on the paper provided by Stiennon et al. that shows how to use PPO to train a LM using human feedback for a summarization task. 

# IMPORTS 

In [None]:
from datasets import load_dataset
import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW

# PARAMETERS

In [None]:
kl_beta = 0.5
batch_size = 10
epoch_nb = 5
clip_range = 0.2
return_sequence_nb = 4
beta = 0.5

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype  = torch.float16 if torch.cuda.is_available() else torch.float32

def set_seed(seed = 456):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

print("Device:", device, "| dtype:", dtype)

# LOADING SFT MODEL

In [None]:
model_name = "CarperAI/summarize_from_feedback_sft"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

policy_model = AutoModelForCausalLM.from_pretrained(model_name)
policy_model_ref = AutoModelForCausalLM.from_pretrained(model_name)

policy_model.to(device)
policy_model_ref.to(device)

In [None]:
critic_model_name = "gpt3"

class CriticModel(nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        self.base = AutoModelForCausalLM.from_pretrained(base_model_name)
        # Tête de valeur : transforme le hidden_state en un score scalaire
        self.v_head = nn.Linear(self.base.config.n_embd, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        # On utilise le dernier état caché
        last_hidden_state = outputs.hidden_states[-1]
        values = self.v_head(last_hidden_state).squeeze(-1)
        return values

# LOADING REDDIT DATASET

In [None]:
ds = load_dataset("openai/summarize_from_feedback", "comparisons")

def build_prompt(example):
    return f"Summarize the following text:\n{example['text']}\nSummary:"

In [None]:
def format_dpo(dataset):
    return dataset.map(
        lambda x: {
            "prompt": x["info"]["post"].strip(),
            "chosen": x["summaries"][x["choice"]]["text"].strip(),
            "rejected": x["summaries"][1 - x["choice"]]["text"].strip(),
        },
        remove_columns=dataset.column_names,
    )

ppo_ds = format_dpo(ds["train"])

# TRAINING REWARD MODEL

In [None]:
class RewardModel(nn.Module):
    def __init__(self, base_lm):
        super().__init__()
        self.lm = base_lm
        self.reward_head = nn.Linear(base_lm.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
        output_hidden_states=True
        )
        last_hidden = outputs.hidden_states[-1][:, -1]
        return self.reward_head(last_hidden).squeeze(-1)


reward_model = RewardModel(
    AutoModelForCausalLM.from_pretrained(model_name)
)

In [None]:
# Training the reward model
optimizer = AdamW(reward_model.parameters(), lr=1e-5)

for batch in ppo_ds.select(range(batch_size)):
    optimizer.zero_grad()

    inputs_w = tokenizer(batch['prompt'] + batch['chosen'], return_tensors='pt', truncation=True, padding=True)
    inputs_l = tokenizer(batch['prompt'] + batch['rejected'], return_tensors='pt', truncation=True, padding=True)


    r_w = reward_model(**inputs_w)
    r_l = reward_model(**inputs_l)

    loss = -torch.log(torch.sigmoid(r_w - r_l)).mean() #Based on paper page 6
    loss.backward()
    optimizer.step()

In [None]:
critic_model = CriticModel(critic_model_name)
critic_model.base.load_state_dict(reward_model.lm.state_dict())
critic_model.to(device)

# TRAINING PPO

In [None]:
def get_log_probs(logits, labels):
    """Calcule les log-probabilités des tokens générés."""
    log_probs = F.log_softmax(logits, dim=-1)
    # On décale pour aligner les logits avec les labels générés
    log_probs = log_probs[:, :-1, :]
    labels = labels[:, 1:]
    return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1)

In [None]:
def compute_gae(values, rewards, mask, gamma=1.0, lam=0.95):
    """Generalized Advantage Estimation conforme au papier."""
    advantages = torch.zeros_like(rewards)
    last_gae = 0
    for t in reversed(range(rewards.size(1))):
        # Pour la summarization, la récompense est souvent clairsemée (à la fin)
        next_value = values[:, t + 1] if t + 1 < rewards.size(1) else 0
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        advantages[:, t] = last_gae = delta + gamma * lam * last_gae
    returns = advantages + values
    return advantages, returns

In [None]:
def get_response_mask(full_seq, prompt_len):
    mask = torch.ones_like(full_seq)
    mask[:, :prompt_len] = 0
    mask[full_seq == tokenizer.pad_token_id] = 0
    return mask

In [None]:
reward_model.eval()


for batch in ppo_ds.select(range(batch)):
    prompt = build_prompt(batch)
    inputs = tokenizer(prompt, return_tensors='pt', padding=True).to(device)
    prompt_len = inputs['input_ids'].size(1)

    with torch.no_grad():
        summaries = policy_model.generate(
            inputs['input_ids'],
            max_new_tokens=64
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=return_sequence_nb
        )

        #Preparing for scoring
        masks = (summaries != tokenizer.pad_token_id).long()
        rm_scores = reward_model(summaries, attention_mask=masks).detach()
        best_idx = rm_scores.argmax()
        summary = summaries[best_idx].unsqueeze(0)  # (1, seq_len)

        mask = (summary != tokenizer.pad_token_id).long()

        response_mask = get_response_mask(summary, prompt_len)
        response_mask_shifted = response_mask[:, 1:]

        old_logits = policy_model(summary,attention_mask=mask,).logits
        ref_logits = policy_model_ref(summary,attention_mask=mask,).logits

        #Computing the log prob
        old_logprob = get_log_probs(old_logits, summary).detach()
        ref_logits = get_log_probs(ref_logits, summary)

        #Reward computation
        rm_score = reward_model(summary, attention_mask=mask).detach()

        #Simulating the human feedback policy
        kl_dist = old_logprob - ref_logits
        rewards = -kl_beta * kl_dist + rm_score
        last_idx = response_mask_shifted.sum(dim=1) - 1
        rewards[0, last_idx] += rm_score

        #Computing the advantage
        values = critic_model(summary, attention_mask=mask)[:,:-1]
        advantages, returns = compute_gae(values, rewards, mask[:,:-1])
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        advantages = advantages.detach()
        returns = returns.detach()

    #Actual PPO update
    for _ in range(epoch_nb):
        logits = policy_model(summary, attention_mask=mask).logits
        logprob = get_log_probs(logits, summary)
        value = critic_model(summary, attention_mask=mask)[:,:-1]

        #Clipping ratio
        ratio = torch.exp(logprob - old_logprob)

        #Policy loss
        clip1 = ratio * advantages
        clip2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages
        policy_loss = -(torch.min(clip1, clip2) * response_mask_shifted).sum() / response_mask_shifted.sum()
        #Now going back to the critic
        value_loss = (F.mse_loss(value, returns, reduction='none') * response_mask_shifted).sum() / response_mask_shifted.sum()

        loss = policy_loss + beta * value_loss

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f"Loss: {loss.item():.4f} | RM Score: {rm_score.mean().item():.4f} | KL: {kl_div.mean().item():.4f}")