This code is based on the paper provided by Stiennon et al. that shows how to use PPO to train a LM using human feedback for a summarization task. 

# IMPORTS 

In [None]:
from datasets import load_dataset
import os, math, random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModelForCausalLM
from torch.optim import AdamW


# LOADING SFT MODEL

In [None]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


policy_model = AutoModelForCausalLM.from_pretrained(model_name)
policy_model_ref = AutoModelForCausalLM.from_pretrained(model_name)

# LOADING REDDIT DATASET

In [None]:
ds = load_dataset("openai/summarize_from_feedback", "comparisons")

def build_prompt(example):
    return f"Summarize the following text:\n{example['text']}\nSummary:"

In [None]:
def format_dpo(dataset):
    return dataset.map(
        lambda x: {
            "prompt": x["info"]["post"].strip(),
            "chosen": x["summaries"][x["choice"]]["text"].strip(),
            "rejected": x["summaries"][1 - x["choice"]]["text"].strip(),
        },
        remove_columns=dataset.column_names,
    )

ppo_ds = format_dpo(ds["train"])

# TRAINING REWARD MODEL

In [None]:
class RewardModel(nn.Module):
    def __init__(self, base_lm):
        super().__init__()
        self.lm = base_lm
        self.reward_head = nn.Linear(base_lm.config.hidden_size, 1)


def forward(self, input_ids, attention_mask):
    outputs = self.lm(
        input_ids=input_ids,
        attention_mask=attention_mask,
    output_hidden_states=True
    )
    last_hidden = outputs.hidden_states[-1][:, -1]
    return self.reward_head(last_hidden).squeeze(-1)


reward_model = RewardModel(
    AutoModelForCausalLM.from_pretrained(model_name)
)

In [None]:
optimizer = AdamW(reward_model.parameters(), lr=1e-5)


for batch in ppo_ds.select(range(50)):
    optimizer.zero_grad()


    inputs_w = tokenizer(batch['prompt'] + batch['chosen'], return_tensors='pt', truncation=True, padding=True)
    inputs_l = tokenizer(batch['prompt'] + batch['rejected'], return_tensors='pt', truncation=True, padding=True)


    r_w = reward_model(**inputs_w)
    r_l = reward_model(**inputs_l)


    loss = -torch.log(torch.sigmoid(r_w - r_l)).mean()
    loss.backward()
    optimizer.step()

# TRAINING PPO

In [None]:
ppo_config = PPOConfig(
    model_name=model_name,
    learning_rate=1e-5,
    batch_size=4,
    mini_batch_size=2,
    target_kl=0.1
)


ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=policy_model_ref,
    tokenizer=tokenizer
)

In [None]:
reward_model.eval()


for batch in dataset.select(range(10)):
    prompt = build_prompt(batch)
    inputs = tokenizer(prompt, return_tensors='pt')


    response = ppo_trainer.generate(
    inputs['input_ids'],
    max_new_tokens=64
    )


    summary = tokenizer.decode(response[0], skip_special_tokens=True)


    with torch.no_grad():
    reward_inputs = tokenizer(prompt + summary, return_tensors='pt', truncation=True)
    reward = reward_model(**reward_inputs)


    ppo_trainer.step(
    inputs['input_ids'],
    response,
    reward
    )