This code is based on the paper provided by Stiennon et al. that shows how to use PPO to train a LM using human feedback for a summarization task. 

# IMPORTS 

In [None]:
from datasets import load_dataset
import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW

# PARAMETERS

In [None]:
kl_beta = 0.5
batch_size = 10
epoch_nb = 5
clip_range = 0.2
return_sequence_nb = 4
beta = 0.5

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype  = torch.float16 if torch.cuda.is_available() else torch.float32

def set_seed(seed = 456):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

print("Device:", device, "| dtype:", dtype)

# LOADING SFT MODEL

In [None]:
model_name = "CarperAI/summarize_from_feedback_sft"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

policy_model = AutoModelForCausalLM.from_pretrained(model_name)
policy_model_ref = AutoModelForCausalLM.from_pretrained(model_name)

policy_model.to(device)
policy_model_ref.to(device)

In [None]:
critic_model_name = "gpt3"

class CriticModel(nn.Module):
    def __init__(self, base_model_name):
        super().__init__()
        self.base = AutoModelForCausalLM.from_pretrained(base_model_name)
        # Tête de valeur : transforme le hidden_state en un score scalaire
        self.v_head = nn.Linear(self.base.config.n_embd, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.base(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        # On utilise le dernier état caché
        last_hidden_state = outputs.hidden_states[-1]
        values = self.v_head(last_hidden_state).squeeze(-1)
        return values

# LOADING REDDIT DATASET

In [None]:
ds = load_dataset("openai/summarize_from_feedback", "comparisons")

def build_prompt(example):
    return f"Summarize the following text:\n{example['text']}\nSummary:"

In [None]:
def format_dpo(dataset):
    return dataset.map(
        lambda x: {
            "prompt": x["info"]["post"].strip(),
            "chosen": x["summaries"][x["choice"]]["text"].strip(),
            "rejected": x["summaries"][1 - x["choice"]]["text"].strip(),
        },
        remove_columns=dataset.column_names,
    )

ppo_ds = format_dpo(ds["train"])

# TRAINING REWARD MODEL

In [None]:
class RewardModel(nn.Module):
    def __init__(self, base_lm):
        super().__init__()
        self.lm = base_lm
        self.reward_head = nn.Linear(base_lm.config.hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.lm(
            input_ids=input_ids,
            attention_mask=attention_mask,
        output_hidden_states=True
        )
        last_hidden = outputs.hidden_states[-1][:, -1]
        return self.reward_head(last_hidden).squeeze(-1)


reward_model = RewardModel(
    AutoModelForCausalLM.from_pretrained(model_name)
)

In [None]:
# Training the reward model
reward_optimizer = AdamW(reward_model.parameters(), lr=1e-5)


for batch in ppo_ds.select(range(batch_size)):
    reward_optimizer.zero_grad()

    inputs_w = tokenizer(batch['prompt'] + batch['chosen'], return_tensors='pt', truncation=True, padding=True)
    inputs_l = tokenizer(batch['prompt'] + batch['rejected'], return_tensors='pt', truncation=True, padding=True)


    r_w = reward_model(**inputs_w)
    r_l = reward_model(**inputs_l)

    loss = -torch.log(torch.sigmoid(r_w - r_l)).mean() #Based on paper page 6
    loss.backward()
    reward_optimizer.step()

In [None]:
critic_model = CriticModel(critic_model_name)
critic_model.base.load_state_dict(reward_model.lm.state_dict())
critic_model.to(device)
critic_optimizer = torch.optim.AdamW(critic_model.parameters(), lr=1e-5)
policy_optimizer = torch.optim.AdamW(critic_model.parameters(), lr=1e-5)

# TRAINING PPO

In [None]:
def get_log_probs(logits, labels):
    """Calcule les log-probabilités des tokens générés."""
    log_probs = F.log_softmax(logits, dim=-1)
    # On décale pour aligner les logits avec les labels générés
    log_probs = log_probs[:, :-1, :]
    labels = labels[:, 1:]
    return torch.gather(log_probs, -1, labels.unsqueeze(-1)).squeeze(-1)

In [None]:
def compute_gae(values, rewards, mask, gamma=1.0, lam=0.95):
    """Generalized Advantage Estimation conforme au papier."""
    advantages = torch.zeros_like(rewards)
    last_gae = 0
    for t in reversed(range(rewards.size(1))):
        # Pour la summarization, la récompense est souvent clairsemée (à la fin)
        next_value = values[:, t + 1] if t + 1 < rewards.size(1) else 0
        delta = rewards[:, t] + gamma * next_value - values[:, t]
        advantages[:, t] = last_gae = delta + gamma * lam * last_gae
    returns = advantages + values
    return advantages, returns

In [None]:
def get_response_mask(full_seq, prompt_len):
    mask = torch.ones_like(full_seq)
    mask[:, :prompt_len] = 0
    mask[full_seq == tokenizer.pad_token_id] = 0
    return mask

# WIN RATE EVALUATION

In [None]:
judge_name = "Qwen/Qwen2.5-3B-Instruct"
judge_tokenizer = AutoTokenizer.from_pretrained(judge_name)
judge_model = AutoModelForCausalLM.from_pretrained(
    judge_name,
    device_map="auto",
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True,
).eval()

In [None]:
prompt_template = """
You are an evaluator.

    Given the post and two summaries, choose which summary is better.

    Reply with exactly ONE character:
    A or B

    Do not explain.
    Do not output anything else.

    Post:
    {post}

    Summary A:
    {summary_a}

    Summary B:
    {summary_b}
    """

In [None]:
import torch
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM

def evaluate_periodic_win_rate(
    policy_model: torch.nn.Module,
    ref_model: torch.nn.Module,
    judge_model: torch.nn.Module,
    judge_tokenizer: AutoTokenizer,
    dataset: list,
    prompt_template: str,
    device: str = "cuda",
    max_new_tokens: int = 64,
    temperature: float = 0.25,
    top_p: float = 0.9,
    subset_size: int = 100,
    verbose: bool = True,
) -> float:
    """
    Évalue le win rate du modèle de politique (DPO/PPO) par rapport au modèle de référence
    en utilisant un modèle juge pour comparer les réponses.

    Args:
        policy_model: Modèle de politique (DPO/PPO) à évaluer.
        ref_model: Modèle de référence (ex: SFT).
        judge_model: Modèle utilisé comme juge (ex: Qwen, GPT-4).
        judge_tokenizer: Tokenizer pour le modèle juge.
        dataset: Dataset contenant les prompts (liste de dictionnaires avec clé "prompt").
        prompt_template: Template pour formater les prompts du juge.
        device: Appareil utilisé ("cuda" ou "cpu").
        max_new_tokens: Nombre maximal de tokens générés.
        temperature: Température pour la génération.
        top_p: Paramètre top-p pour la génération.
        subset_size: Taille du sous-ensemble de données à évaluer.
        verbose: Si True, affiche les résultats.

    Returns:
        float: Win rate du modèle de politique (entre 0 et 1).
    """
    @torch.no_grad()
    def generate_summary(model, tokenizer, prompts, max_new_tokens, temperature, top_p, device):
        """Génère des résumés pour une liste de prompts."""
        enc = tokenizer(
            prompts,
            return_tensors="pt",
            truncation=True,
            max_length=2048,
            padding=True,
        ).to(device)

        out = model.generate(
            **enc,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )

        responses = []
        for i in range(len(prompts)):
            prompt_len = int(enc["attention_mask"][i].sum())
            gen_ids = out[i][prompt_len:]
            text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
            responses.append(text)

        return responses

    @torch.no_grad()
    def judge_pairwise(judge_model, judge_tokenizer, formatted_prompt, device):
        """Utilise le modèle juge pour comparer deux réponses."""
        inputs = judge_tokenizer(
            formatted_prompt,
            return_tensors="pt",
            truncation=True,
            max_length=2048,
        ).to(device)

        outputs = judge_model.generate(
            **inputs,
            max_new_tokens=2,
            do_sample=False,
            temperature=0.0,
            eos_token_id=judge_tokenizer.eos_token_id,
            pad_token_id=judge_tokenizer.eos_token_id,
        )

        gen = outputs[0][inputs["input_ids"].shape[1]:]
        text = judge_tokenizer.decode(gen, skip_special_tokens=True).strip().upper()

        if text.startswith("A"):
            return "A"
        elif text.startswith("B"):
            return "B"
        else:
            return None

    # Générer des résumés avec les deux modèles
    summaries_a, summaries_b, originals = [], [], []
    tokenizer = judge_tokenizer  # On suppose que le tokenizer est le même pour tous les modèles

    for ex in tqdm(dataset[:subset_size], desc="Génération des résumés", disable=not verbose):
        prompt = ex["prompt"].strip()
        if not prompt:
            continue

        ref_responses = generate_summary(
            ref_model, tokenizer, [prompt], max_new_tokens, temperature, top_p, device
        )
        dpo_responses = generate_summary(
            policy_model, tokenizer, [prompt], max_new_tokens, temperature, top_p, device
        )

        summaries_a.append(dpo_responses[0])
        summaries_b.append(ref_responses[0])
        originals.append(prompt)

    # Évaluer les win rates
    wins = 0
    losses = 0
    abstain = 0
    total = len(summaries_a)

    for sa, sb, txt in tqdm(zip(summaries_a, summaries_b, originals), desc="Évaluation des win rates", disable=not verbose):
        prompt = prompt_template.replace("{post}", txt).replace("{summary_a}", sa).replace("{summary_b}", sb)

        choice = judge_pairwise(judge_model, judge_tokenizer, prompt, device)

        if choice == "A":
            wins += 1
        elif choice == "B":
            losses += 1
        else:
            abstain += 1

    real_win_rate = wins / total if total > 0 else 0.0
    conditional_win_rate = wins / (wins + losses) if (wins + losses) > 0 else 0.0
    decision_rate = (wins + losses) / total if total > 0 else 0.0

    if verbose:
        print(f"\n=== ÉVALUATION PÉRIODIQUE ===")
        print(f"Win rate:               {real_win_rate:.3f}")
        print(f"Conditional win rate:   {conditional_win_rate:.3f}")
        print(f"Judge decision rate:    {decision_rate:.3f}")
        print(f"Abstentions:            {abstain}/{total}")

    return real_win_rate


# FINAL PIPELINE

In [None]:
reward_model.eval()


for batch in ppo_ds.select(range(batch)):
    prompt = build_prompt(batch)
    inputs = tokenizer(prompt, return_tensors='pt', padding=True).to(device)
    prompt_len = inputs['input_ids'].size(1)

    with torch.no_grad():
        summaries = policy_model.generate(
            inputs['input_ids'],
            max_new_tokens=64
            pad_token_id=tokenizer.eos_token_id,
            num_return_sequences=return_sequence_nb
        )

        #Preparing for scoring
        masks = (summaries != tokenizer.pad_token_id).long()
        rm_scores = reward_model(summaries, attention_mask=masks).detach()
        best_idx = rm_scores.argmax()
        summary = summaries[best_idx].unsqueeze(0)  # (1, seq_len)

        mask = (summary != tokenizer.pad_token_id).long()

        response_mask = get_response_mask(summary, prompt_len)
        response_mask_shifted = response_mask[:, 1:]

        old_logits = policy_model(summary,attention_mask=mask,).logits
        ref_logits = policy_model_ref(summary,attention_mask=mask,).logits

        #Computing the log prob
        old_logprob = get_log_probs(old_logits, summary).detach()
        ref_log_probs = get_log_probs(ref_logits, summary)

        #Reward computation
        rm_score = reward_model(summary, attention_mask=mask).detach()

        #Simulating the human feedback policy
        kl_dist = old_logprob - ref_log_probs
        rewards = -kl_beta * kl_dist + rm_score
        last_idx = response_mask_shifted.sum(dim=1) - 1

        #Computing the advantage
        values = critic_model(summary, attention_mask=mask)[:,:-1]
        advantages, returns = compute_gae(values, rewards, response_mask_shifted)
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        advantages = advantages.detach()
        returns = returns.detach()

    #Actual PPO update
    for _ in range(epoch_nb):
        logits = policy_model(summary, attention_mask=mask).logits
        logprob = get_log_probs(logits, summary)
        value = critic_model(summary, attention_mask=mask)[:,:-1]

        #Clipping ratio
        ratio = torch.exp(logprob - old_logprob)

        #Policy loss
        clip1 = ratio * advantages
        clip2 = torch.clamp(ratio, 1 - clip_range, 1 + clip_range) * advantages
        policy_loss = -(torch.min(clip1, clip2) * response_mask_shifted).sum() / response_mask_shifted.sum()
        
        value_loss = (F.mse_loss(value, returns, reduction='none') * response_mask_shifted).sum() / response_mask_shifted.sum()

        loss = policy_loss + beta * value_loss

        policy_optimizer.zero_grad()
        policy_loss.backward(retain_graph=True)
        policy_optimizer.step()

        critic_optimizer.zero_grad()
        value_loss.backward()
        critic_optimizer.step()

        print(f"Loss: {loss.item():.4f} | RM Score: {rm_score.mean().item():.4f} | KL: {kl_dist.mean().item():.4f}")

        if epoch % 5 == 0:  # Évaluer toutes les 5 époques
            win_rate = evaluate_periodic_win_rate(
                policy_model=policy_model,
                ref_model=policy_model_ref,
                judge_model=judge_model,  # Assurez-vous que judge_model est défini
                judge_tokenizer=judge_tokenizer,  # Assurez-vous que judge_tokenizer est défini
                dataset=ppo_ds,  # Utilisez le même dataset ou un dataset de validation
                prompt_template=prompt_template,
                device=device,
                subset_size=100,  # Évaluer sur un sous-ensemble de 100 exemples
                verbose=True,
            )

output_dir = "/saved_models"
os.makedirs(output_dir, exist_ok=True)

# Sauvegarder le modèle de politique
policy_model_path = os.path.join(output_dir, "policy_model.pt")
torch.save({
    'epoch': epoch_nb,
    'model_state_dict': policy_model.state_dict(),
    'optimizer_state_dict': policy_optimizer.state_dict(),
    'loss': loss.item(),
}, policy_model_path)

# Sauvegarder le modèle de critique
critic_model_path = os.path.join(output_dir, "critic_model.pt")
torch.save({
    'epoch': epoch_nb,
    'model_state_dict': critic_model.state_dict(),
    'optimizer_state_dict': critic_optimizer.state_dict(),
}, critic_model_path)

print(f"Modèles sauvegardés dans {output_dir}")