In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import Trainer, TrainingArguments
from datasets import load_dataset
import torch.nn as nn
from stable_baselines3 import PPO
from stable_baselines3.common.vec_env import DummyVecEnv

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "gpt2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Fix missing padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Use EOS as padding token

model = AutoModelForCausalLM.from_pretrained(model_name)
old_model = AutoModelForCausalLM.from_pretrained(model_name)  # Reference model (θ_old)

# Resize embeddings after adding new tokens
model.resize_token_embeddings(len(tokenizer))

# Move models to GPU if available
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
old_model.to(device);

In [None]:
# Load the RLHF dataset
# https://huggingface.co/datasets/yitingxie/rlhf-reward-datasets
dataset = load_dataset("yitingxie/rlhf-reward-datasets")

# Use a subset for fast training
train_dataset = dataset["train"].shuffle(seed=42).select(range(1000))  # Reduce size for demo
test_dataset = dataset["train"].shuffle(seed=42).select(range(10))

PPO requires a reward model that evaluates responses.

In [4]:
class RewardModel(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model  # Base LM acts as feature extractor
        self.reward_head = nn.Linear(model.config.hidden_size, 1)  # Reward prediction head
    
    def forward(self, input_ids, attention_mask):
        outputs = self.model(input_ids, attention_mask=attention_mask, output_hidden_states=True)
        last_hidden_state = outputs.hidden_states[-1]  # Use last layer representations
        reward = self.reward_head(last_hidden_state[:, -1, :])  # Predict reward from last token
        return reward.squeeze(-1)  # Ensure scalar output

We use stable-baselines3 PPO with a custom reward function based on human feedback.

In [5]:
import gym
from gym import spaces
import numpy as np
import torch

class PPOEnv(gym.Env):
    """Custom PPO environment for RLHF training."""
    def __init__(self, dataset, tokenizer, model, reward_model):
        super(PPOEnv, self).__init__()

        self.dataset = dataset
        self.tokenizer = tokenizer
        self.model = model
        self.reward_model = reward_model
        self.current_index = 0  # Track dataset index

        # Define observation space (prompt as text tokens)
        self.observation_space = spaces.Box(low=0, high=tokenizer.vocab_size, shape=(512,), dtype=np.int32)

        # Define action space (select next tokens)
        self.action_space = spaces.Discrete(tokenizer.vocab_size)  

    def reset(self):
        """Reset environment with a new prompt and return tokenized observation."""
        self.current_index = (self.current_index + 1) % len(self.dataset)
        prompt = self.dataset[self.current_index]["prompt"]

        # Tokenize prompt and return as observation
        input_ids = self.tokenizer(prompt, return_tensors="np", padding="max_length", truncation=True, max_length=512)["input_ids"][0]
        return input_ids  # The observation (numerical tokenized prompt)

    def step(self, action):
        """Take a step: generate a response and compute reward."""
        prompt = self.dataset[self.current_index]["prompt"]

        # Convert action (token IDs) back to text
        response = self.tokenizer.decode([action], skip_special_tokens=True)

        # Tokenize prompt + response
        inputs = self.tokenizer(prompt, response, return_tensors="pt", padding=True, truncation=True).to(device)

        # Compute reward from reward model
        with torch.no_grad():
            reward = self.reward_model(inputs.input_ids, inputs.attention_mask).item()

        # PPO expects (observation, reward, done, info)
        done = True  # Single-step episode
        return self.reset(), reward, done, {}

In [6]:
reward_model = RewardModel(model).to(device)

# Wrap environment properly
ppo_env = DummyVecEnv([lambda: PPOEnv(train_dataset, tokenizer, model, reward_model)])

# Initialize PPO model
ppo_model = PPO("MlpPolicy", ppo_env, verbose=1, learning_rate=5e-6)

Using cuda device




In [7]:
# Train PPO model on RLHF dataset
ppo_model.learn(total_timesteps=10000)  # Increase for better results

-----------------------------
| time/              |      |
|    fps             | 120  |
|    iterations      | 1    |
|    time_elapsed    | 17   |
|    total_timesteps | 2048 |
-----------------------------
-------------------------------------------
| time/                   |               |
|    fps                  | 118           |
|    iterations           | 2             |
|    time_elapsed         | 34            |
|    total_timesteps      | 4096          |
| train/                  |               |
|    approx_kl            | 2.6208989e-05 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -10.8         |
|    explained_variance   | -0.257        |
|    learning_rate        | 5e-06         |
|    loss                 | 8.89          |
|    n_updates            | 10            |
|    policy_gradient_loss | -0.00313      |
|    value_loss           | 18.6          |
------------------------------------------

<stable_baselines3.ppo.ppo.PPO at 0x7f742fba9570>

In [9]:
def generate_response(prompt):
    """Generate response using the fine-tuned model."""
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
    output = model.generate(input_ids, max_new_tokens=50)
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Test on some prompts from test dataset
for i in range(5):
    prompt = test_dataset[i]["prompt"]
    response = generate_response(prompt)
    
    print(f"Prompt: {prompt}")
    print(f"Generated Response: {response}")
    print("-" * 50)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt: 

Human: How do contacts improve your vision?

Assistant: They help you see more clearly.

Human: Yes, I know. What is the technical or scientific answer for how contacts work to improve vision?

Assistant: They increase the amount of light coming into the eye.

Human: Are contacts better than wearing glasses?

Assistant: Many people prefer contacts.  Some people like the option of switching back and forth between contacts and glasses.  Some people wear glasses for reading or work and use contacts for everything else.  Others wear contacts most of the time, but need glasses for certain activities, like sports or computer work.

Human: Are there any risks for using contact lenses?


Generated Response: 

Human: How do contacts improve your vision?

Assistant: They help you see more clearly.

Human: Yes, I know. What is the technical or scientific answer for how contacts work to improve vision?

Assistant: They increase the amount of light coming into the eye.

Human: Are contact

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt: 

Human: How do I make a honey baked ham at home for Christmas dinner?

Assistant: You can make honey baked ham by soaking a bone-in ham in a mixture of honey, brown sugar, garlic, and apple cider vinegar.  Make sure the bone side is completely covered in the mixture, then place the ham in a baking dish.  Drizzle with more honey, brown sugar, and vinegar mixture.  Bake for an hour at 325°F, or until cooked through.  
You can also make a sweet glaze to pour over the ham during the final 15 minutes of cooking.  Start with 1/2 cup each of honey and brown sugar, add a few tablespoons of apple cider vinegar and some garlic and/or onion.  If

Human: wow thats a great recipe. do you have any variations that are sugar free?

Assistant: I recommend using agave nectar in place of the brown sugar.  And you can try vegetables like spinach or tomatoes as a side.

Human: agave nectar is a great suggestion. I will do that. should I use pineapple and/or cherries on the ham as well?

Assistant:

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt: 

Human: What can I do to reduce my risk of gum disease?

Assistant: I don’t have a lot of specific suggestions for reducing your gum disease risk, but the primary cause of gum disease is bacterial plaque buildup on your teeth. So I recommend that you wash your teeth at least twice a day with a fluoride-free toothpaste, use a mouthwash with anti-bacterial properties, and schedule regular dental visits to maintain good oral health.

Human: Any other ideas?

Assistant: Yes. You could also wear an orthodontic retainer that extends the length of your teeth for 24 hours. The goal of wearing the retainer is to reduce the bacterial colonies on your teeth, which promotes good oral health.

Human: Wow you are super helpful. Thanks


Generated Response: 

Human: What can I do to reduce my risk of gum disease?

Assistant: I don’t have a lot of specific suggestions for reducing your gum disease risk, but the primary cause of gum disease is bacterial plaque buildup on your teeth. So I recom

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Prompt: 

Human: Who is Deadpool


Generated Response: 

Human: Who is Deadpool


The first time I saw Deadpool was in the first movie, and I was like, "Oh my god, this is so cool!" I was like, "Oh my god, this is so cool!"

I was like, "Oh
--------------------------------------------------
Prompt: 

Human: What should you do if your car swerves on ice?

Assistant: You should steer into the skid to slow your vehicle down and bring it back under control.  Then, once you’ve gotten the car back into the proper lane, you need to be more cautious to make sure you don’t end up back in another skid.

Human: Should I brake?

Assistant: Yes, you should apply gentle pressure to the brake pedal to slow the vehicle down gradually.  This will help you maintain control, and also signal to the driver behind you that there’s a potential problem ahead.

Human: Steering into the skid sounds counterintuitive?

Assistant: It might seem strange at first, but it actually makes sense.  When your vehicle star