In [50]:
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
import numpy as np
from tqdm import tqdm


# setup base model
model_name = "Qwen/Qwen3-0.6B-Base"  # Replace with Qwen3 if available
tokenizer = AutoTokenizer.from_pretrained(model_name)
#base_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token


#setup lora config
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # typical for attention layers
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

class SFTModel(nn.Module):
    def __init__(self, base_model_name, lora_config):
        super().__init__()
        base_model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
        self.model = get_peft_model(base_model, lora_config)

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

sft_model = SFTModel(model_name, lora_config)
sft_model.model.print_trainable_parameters()
sft_model.eval()

#setup policy model
class PolicyModel(nn.Module):
    def __init__(self, lora_config=None):
        super().__init__()
        base_model = sft_model.model
        self.model = get_peft_model(base_model, lora_config)

    def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
        return self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

    def generate(self, *args, **kwargs):
        return self.model.generate(*args, **kwargs)

policy_model = PolicyModel(lora_config)
policy_model.model.print_trainable_parameters()
policy_model.eval()


# setup reward model
class RewardModel(nn.Module):
    def __init__(self, lora_config):
        super().__init__()
        base_model = sft_model.model
        self.model = get_peft_model(base_model, lora_config)
        hidden_size = self.model.config.hidden_size
        self.value_head = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden = outputs.hidden_states[-1]
        last_token_idx = attention_mask.sum(dim=1) - 1
        last_token_idx = last_token_idx.unsqueeze(1).unsqueeze(2).expand(-1, 1, last_hidden.size(-1))
        last_hidden_state = last_hidden.gather(1, last_token_idx).squeeze(1)
        reward = self.value_head(last_hidden_state).squeeze(-1)
        return reward
    
reward_model = RewardModel(lora_config)
reward_model.model.print_trainable_parameters()
reward_model.eval()

#setup value model
class ValueModel(nn.Module):
    def __init__(self, lora_config):
        super().__init__()
        base_model = sft_model.model
        self.model = get_peft_model(base_model, lora_config)
        hidden_size = self.model.config.hidden_size
        self.value_head = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden = outputs.hidden_states[-1]
        last_token_idx = attention_mask.sum(dim=1) - 1
        last_token_idx = last_token_idx.unsqueeze(1).unsqueeze(2).expand(-1, 1, last_hidden.size(-1))
        last_hidden_state = last_hidden.gather(1, last_token_idx).squeeze(1)
        reward = self.value_head(last_hidden_state).squeeze(-1)
        return reward
    
value_model = RewardModel(lora_config)
value_model.model.print_trainable_parameters()
value_model.eval()


#util
# Preprocessing function (outside the model)
def preprocess_function(example, tokenizer, max_length=256):
    return tokenizer(
        example["prompt"],
        truncation=True,
        max_length=max_length,
        padding="max_length"
    )

# Collate function (outside the model)
def collate_fn(batch):
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["labels"] for item in batch])
    return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}

trainable params: 4,587,520 || all params: 600,637,440 || trainable%: 0.7638
trainable params: 4,587,520 || all params: 600,637,440 || trainable%: 0.7638
trainable params: 4,587,520 || all params: 600,637,440 || trainable%: 0.7638
trainable params: 4,587,520 || all params: 600,637,440 || trainable%: 0.7638




In [51]:
#PPO
import torch

import torch.nn.functional as F

def generate_token_level_sa_and_logprobs_with_ref(
    policy_model, ref_model, tokenizer, prompts, device, max_new_tokens=10
):
    """
    For each prompt, generate a sequence and return:
      - (state, action) pairs at each token step
      - log-probability of each action under the policy model
      - log-probability of each action under the reference (SFT) model
    Returns:
        all_sa_pairs: list of dicts, each with keys:
            'state_text', 'action_token', 'action_token_id', 'logprob', 'ref_logprob', 'prompt_idx', 'step'
    """
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = policy_model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_new_tokens=max_new_tokens,
            do_sample=True,
            top_k=50,
            top_p=0.95,
            temperature=1.0,
            return_dict_in_generate=True,
            output_scores=True,
        )
    sequences = outputs.sequences  # (batch, prompt_len + max_new_tokens)
    scores = outputs.scores        # list of length max_new_tokens, each (batch, vocab_size)
    batch_size = sequences.shape[0]
    prompt_len = inputs["input_ids"].shape[1]

    generated_texts = []
    all_sa_pairs = []
    for i in range(batch_size):
        gen_tokens = []
        for t, score_t in enumerate(scores):
            # State: prompt + previously generated tokens (up to t)
            state_ids = sequences[i, :prompt_len + t]
            state_text = tokenizer.decode(state_ids, skip_special_tokens=True)
            # Action: next token
            action_id = sequences[i, prompt_len + t].item()
            action_token = tokenizer.decode([action_id])
            gen_tokens.append(action_token)
            # Logprob for this token (policy)
            log_probs = F.log_softmax(score_t[i], dim=-1)
            logprob = log_probs[action_id].item()
            # Reference logprob
            with torch.no_grad():
                ref_inputs = state_ids.unsqueeze(0)
                ref_outputs = ref_model(ref_inputs)
                ref_logits = ref_outputs.logits  # (1, seq_len, vocab_size)
            ref_next_token_logits = ref_logits[0, -1, :]
            ref_log_probs = F.log_softmax(ref_next_token_logits, dim=-1)
            ref_logprob = ref_log_probs[action_id].item()
            all_sa_pairs.append({
                "state_text": state_text,
                "action_token": action_token,
                "action_token_id": action_id,
                "logprob": logprob,
                "ref_logprob": ref_logprob,
                "prompt_idx": i,
                "step": t,
            })
        generated_texts.append("".join(gen_tokens))
    return all_sa_pairs, generated_texts

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_model = policy_model.to(device)

# Example batch of prompts
prompts = [
    "SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:",
]
sa_pairs, generated_texts = generate_token_level_sa_and_logprobs_with_ref(
    policy_model, sft_model, tokenizer, prompts, device, max_new_tokens=10
)

print("generated texts:")
for text in generated_texts:
    print(text)

print("sa pairs:")
for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} Step {pair['step']}:")
    print("State:", repr(pair['state_text']))
    print("Action token:", repr(pair['action_token']))
    print("Policy logprob:", pair['logprob'])
    print("Reference logprob:", pair['ref_logprob'])
    print("KL (per token):", pair['logprob'] - pair['ref_logprob'])
    print("-" * 30)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


generated texts:
 A woman, whose "other side" boyfriend is
sa pairs:
Prompt 0 Step 0:
State: 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:'
Action token: ' A'
Policy logprob: -4.208090782165527
Reference logprob: -4.679380416870117
KL (per token): 0.47128963470458984
------------------------------
Prompt 0 Step 1:
State: 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A'
Action token: ' woman'
Policy logprob: -3.809088945388794
Reference logprob: -4.732294082641602
KL (per token): 0.9232051372528076
------------------------------
Prompt 0 Step 2:
State: 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A woman'
Action token: ','
Policy logprob: -3.7089056968688965
Reference logprob: -4.155608177185059
KL (per token): 0.4467024803161621
------------------------------
Prompt 0 Step 3:
State: 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A woman,'
A

In [52]:
import torch

def compute_kl_per_token(logprobs_policy, logprobs_ref):
    """
    Compute per-token KL divergence for sampled actions.
    Returns:
        kl_per_token: torch.Tensor of shape (num_tokens,)
    """
    logprobs_policy = torch.tensor(logprobs_policy)
    logprobs_ref = torch.tensor(logprobs_ref)
    kl_per_token = logprobs_policy - logprobs_ref
    return kl_per_token

# Suppose you have these from previous steps:
# sa_pairs: list of dicts, each with 'logprob' (policy) and 'ref_logprob' (reference)
logprobs_policy = [pair['logprob'] for pair in sa_pairs]
logprobs_ref = [pair['ref_logprob'] for pair in sa_pairs]  # You'd add this key after SFT scoring

kl_per_token = compute_kl_per_token(logprobs_policy, logprobs_ref)

def add_kl_to_sa_pairs(sa_pairs):
    for pair in sa_pairs:
        logprob_policy = pair["logprob"]
        logprob_ref = pair["ref_logprob"]
        pair["kl_div"] = logprob_policy - logprob_ref
    return sa_pairs

sa_pairs = add_kl_to_sa_pairs(sa_pairs)

for pair in sa_pairs:
    print(f"Prompt {pair.get('prompt_idx', '?')} | Step {pair.get('step', '?')}")
    print(f"  State:         {repr(pair['state_text'])}")
    print(f"  Action token:  {repr(pair['action_token'])} (id: {pair['action_token_id']})")
    print(f"  Policy logprob:    {pair['logprob']:.4f}")
    print(f"  Reference logprob: {pair['ref_logprob']:.4f}")
    print(f"  KL divergence:     {pair['kl_div']:.4f}")
    print("-" * 50)


Prompt 0 | Step 0
  State:         'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:'
  Action token:  ' A' (id: 362)
  Policy logprob:    -4.2081
  Reference logprob: -4.6794
  KL divergence:     0.4713
--------------------------------------------------
Prompt 0 | Step 1
  State:         'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A'
  Action token:  ' woman' (id: 5220)
  Policy logprob:    -3.8091
  Reference logprob: -4.7323
  KL divergence:     0.9232
--------------------------------------------------
Prompt 0 | Step 2
  State:         'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A woman'
  Action token:  ',' (id: 11)
  Policy logprob:    -3.7089
  Reference logprob: -4.1556
  KL divergence:     0.4467
--------------------------------------------------
Prompt 0 | Step 3
  State:         'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A wom

In [53]:
def compute_rewards_for_sequences(reward_model, tokenizer, prompts, generated_texts, device):
    """
    For each prompt and generated text, compute the reward using the reward model.
    Returns:
        rewards: list of floats, one per sequence
    """
    # Concatenate prompt and generated text for each example
    full_texts = [p + gt for p, gt in zip(prompts, generated_texts)]
    inputs = tokenizer(full_texts, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        rewards = reward_model(inputs["input_ids"], inputs["attention_mask"])
    # Handle both batch and single input cases
    if isinstance(rewards, torch.Tensor):
        rewards = rewards.squeeze()
        if rewards.dim() == 0:
            rewards = [rewards.item()]
        else:
            rewards = rewards.cpu().tolist()
    elif isinstance(rewards, float):
        rewards = [rewards]
    return rewards

rewards = compute_rewards_for_sequences(reward_model, tokenizer, prompts, generated_texts, device)
for i, reward in enumerate(rewards):
    print(f"Prompt: {prompts[i]} \nGenerated text: {generated_texts[i]} \nReward = {reward:.4f}")

Prompt: SUBREDDIT: r/relationships
TITLE: Should I admit to snooping?
POST: ...
TL;DR: 
Generated text:  A woman, whose "other side" boyfriend is 
Reward = -0.6829


In [54]:
def adjust_rewards_with_kl(sa_pairs, rewards, kl_coef=0.1):
    """
    For each (s, a) pair, adjust the reward by subtracting kl_coef * kl_div.
    Broadcasts the sequence reward to each token in the sequence.
    Modifies sa_pairs in-place and also returns it.
    """
    for pair in sa_pairs:
        reward = rewards[pair['prompt_idx']]
        kl = pair['kl_div']
        pair['adjusted_reward'] = reward - kl_coef * kl
    return sa_pairs

kl_coef = 0.1
sa_pairs = adjust_rewards_with_kl(sa_pairs, rewards, kl_coef=kl_coef)

# Print to verify
for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  Reward:           {rewards[pair['prompt_idx']]:.4f}")
    print(f"  KL divergence:    {pair['kl_div']:.4f}")
    print(f"  Adjusted reward:  {pair['adjusted_reward']:.4f}")
    print("-" * 40)

Prompt 0 | Step 0
  Reward:           -0.6829
  KL divergence:    0.4713
  Adjusted reward:  -0.7301
----------------------------------------
Prompt 0 | Step 1
  Reward:           -0.6829
  KL divergence:    0.9232
  Adjusted reward:  -0.7753
----------------------------------------
Prompt 0 | Step 2
  Reward:           -0.6829
  KL divergence:    0.4467
  Adjusted reward:  -0.7276
----------------------------------------
Prompt 0 | Step 3
  Reward:           -0.6829
  KL divergence:    0.4820
  Adjusted reward:  -0.7311
----------------------------------------
Prompt 0 | Step 4
  Reward:           -0.6829
  KL divergence:    0.4030
  Adjusted reward:  -0.7232
----------------------------------------
Prompt 0 | Step 5
  Reward:           -0.6829
  KL divergence:    0.7603
  Adjusted reward:  -0.7590
----------------------------------------
Prompt 0 | Step 6
  Reward:           -0.6829
  KL divergence:    0.1405
  Adjusted reward:  -0.6970
----------------------------------------
Prompt

In [55]:
def add_value_to_sa_pairs(sa_pairs, value_model, tokenizer, device, batch_size=32):
    """
    For each sa_pair, compute the value (predicted reward) using the value model,
    and add it as 'value' to the sa_pair dict.
    """
    # Collect all state_texts
    state_texts = [pair['state_text'] for pair in sa_pairs]
    values = []

    # Process in batches for efficiency
    for i in range(0, len(state_texts), batch_size):
        batch_texts = state_texts[i:i+batch_size]
        inputs = tokenizer(batch_texts, return_tensors="pt", padding=True, truncation=True).to(device)
        with torch.no_grad():
            batch_values = value_model(inputs["input_ids"], inputs["attention_mask"])
        # batch_values: shape (batch,)
        batch_values = batch_values.squeeze().cpu().tolist()
        # Ensure batch_values is a list
        if isinstance(batch_values, float):
            batch_values = [batch_values]
        values.extend(batch_values)

    # Add to sa_pairs
    for pair, value in zip(sa_pairs, values):
        pair['value'] = value

    return sa_pairs

sa_pairs = add_value_to_sa_pairs(sa_pairs, value_model, tokenizer, device)

# Print to verify
for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  State:             {repr(pair['state_text'])}")
    print(f"  Action token:      {repr(pair['action_token'])} (id: {pair['action_token_id']})")
    print(f"  Policy logprob:    {pair['logprob']:.4f}")
    print(f"  Reference logprob: {pair['ref_logprob']:.4f}")
    print(f"  KL divergence:     {pair['kl_div']:.4f}")
    print(f"  Value prediction:  {pair['value']:.4f}")
    print(f"  KL-adjusted reward:{pair['adjusted_reward']:.4f}")
    print("-" * 60)

Prompt 0 | Step 0
  State:             'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:'
  Action token:      ' A' (id: 362)
  Policy logprob:    -4.2081
  Reference logprob: -4.6794
  KL divergence:     0.4713
  Value prediction:  1.1618
  KL-adjusted reward:-0.7301
------------------------------------------------------------
Prompt 0 | Step 1
  State:             'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A'
  Action token:      ' woman' (id: 5220)
  Policy logprob:    -3.8091
  Reference logprob: -4.7323
  KL divergence:     0.9232
  Value prediction:  1.0352
  KL-adjusted reward:-0.7753
------------------------------------------------------------
Prompt 0 | Step 2
  State:             'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A woman'
  Action token:      ',' (id: 11)
  Policy logprob:    -3.7089
  Reference logprob: -4.1556
  KL divergence:     0.4467
  Value predicti

In [56]:
from collections import defaultdict

def compute_td_errors(sa_pairs, gamma=0.99):
    # 1. Group by prompt_idx
    grouped = defaultdict(list)
    for pair in sa_pairs:
        grouped[pair['prompt_idx']].append(pair)

    # 2. For each prompt, sort by step and compute TD error
    for prompt_idx, pairs in grouped.items():
        pairs = sorted(pairs, key=lambda x: x['step'])
        for t, pair in enumerate(pairs):
            reward = pair['adjusted_reward']
            value = pair['value']
            # Next value: value of next step, or 0 if last
            if t < len(pairs) - 1:
                next_value = pairs[t + 1]['value']
            else:
                next_value = 0.0
            pair['td_error'] = reward + gamma * next_value - value
    return sa_pairs

sa_pairs = compute_td_errors(sa_pairs, gamma=0.99)

# Print to verify
for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  KL-adjusted reward: {pair['adjusted_reward']:.4f}")
    print(f"  Value:              {pair['value']:.4f}")
    print(f"  TD error:           {pair['td_error']:.4f}")
    print("-" * 60)

Prompt 0 | Step 0
  KL-adjusted reward: -0.7301
  Value:              1.1618
  TD error:           -0.8670
------------------------------------------------------------
Prompt 0 | Step 1
  KL-adjusted reward: -0.7753
  Value:              1.0352
  TD error:           -0.5560
------------------------------------------------------------
Prompt 0 | Step 2
  KL-adjusted reward: -0.7276
  Value:              1.2671
  TD error:           -0.4070
------------------------------------------------------------
Prompt 0 | Step 3
  KL-adjusted reward: -0.7311
  Value:              1.6038
  TD error:           -0.5590
------------------------------------------------------------
Prompt 0 | Step 4
  KL-adjusted reward: -0.7232
  Value:              1.7939
  TD error:           -2.5203
------------------------------------------------------------
Prompt 0 | Step 5
  KL-adjusted reward: -0.7590
  Value:              -0.0032
  TD error:           0.1668
-----------------------------------------------------

In [57]:
def compute_gae_advantages(sa_pairs, gamma=0.99, lam=0.95):
    """
    Compute GAE advantage for each (s, a) pair in sa_pairs.
    Adds 'advantage' to each sa_pair.
    Assumes sa_pairs are sorted by (prompt_idx, step).
    """
    from collections import defaultdict
    grouped = defaultdict(list)
    for pair in sa_pairs:
        grouped[pair['prompt_idx']].append(pair)

    for prompt_idx, pairs in grouped.items():
        # Sort by step
        pairs = sorted(pairs, key=lambda x: x['step'])
        num_steps = len(pairs)
        advantages = [0.0] * num_steps
        gae = 0.0
        # Go backwards through the sequence
        for t in reversed(range(num_steps)):
            td_error = pairs[t]['td_error']
            gae = td_error + gamma * lam * gae
            advantages[t] = gae
        # Assign to sa_pairs
        for t in range(num_steps):
            pairs[t]['advantage'] = advantages[t]
    return sa_pairs

sa_pairs = compute_gae_advantages(sa_pairs, gamma=0.99, lam=0.95)

for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  TD error:   {pair['td_error']:.4f}")
    print(f"  Advantage:  {pair['advantage']:.4f}")
    print("-" * 60)

Prompt 0 | Step 0
  TD error:   -0.8670
  Advantage:  -6.3835
------------------------------------------------------------
Prompt 0 | Step 1
  TD error:   -0.5560
  Advantage:  -5.8656
------------------------------------------------------------
Prompt 0 | Step 2
  TD error:   -0.4070
  Advantage:  -5.6454
------------------------------------------------------------
Prompt 0 | Step 3
  TD error:   -0.5590
  Advantage:  -5.5699
------------------------------------------------------------
Prompt 0 | Step 4
  TD error:   -2.5203
  Advantage:  -5.3279
------------------------------------------------------------
Prompt 0 | Step 5
  TD error:   0.1668
  Advantage:  -2.9852
------------------------------------------------------------
Prompt 0 | Step 6
  TD error:   -1.2223
  Advantage:  -3.3514
------------------------------------------------------------
Prompt 0 | Step 7
  TD error:   0.5263
  Advantage:  -2.2638
------------------------------------------------------------
Prompt 0 | Step 8


In [58]:
def add_reward_to_go_to_sa_pairs(sa_pairs):
    """
    For each sa_pair, add 'value_target' = advantage + value.
    """
    for pair in sa_pairs:
        pair['reward_to_go'] = pair['advantage'] + pair['value']
    return sa_pairs

sa_pairs = add_reward_to_go_to_sa_pairs(sa_pairs)

for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  Value:         {pair['value']:.4f}")
    print(f"  Advantage:     {pair['advantage']:.4f}")
    print(f"  Reward-to-go:  {pair['reward_to_go']:.4f}")
    print("-" * 40)

Prompt 0 | Step 0
  Value:         1.1618
  Advantage:     -6.3835
  Reward-to-go:  -5.2218
----------------------------------------
Prompt 0 | Step 1
  Value:         1.0352
  Advantage:     -5.8656
  Reward-to-go:  -4.8303
----------------------------------------
Prompt 0 | Step 2
  Value:         1.2671
  Advantage:     -5.6454
  Reward-to-go:  -4.3783
----------------------------------------
Prompt 0 | Step 3
  Value:         1.6038
  Advantage:     -5.5699
  Reward-to-go:  -3.9661
----------------------------------------
Prompt 0 | Step 4
  Value:         1.7939
  Advantage:     -5.3279
  Reward-to-go:  -3.5340
----------------------------------------
Prompt 0 | Step 5
  Value:         -0.0032
  Advantage:     -2.9852
  Reward-to-go:  -2.9884
----------------------------------------
Prompt 0 | Step 6
  Value:         0.9319
  Advantage:     -3.3514
  Reward-to-go:  -2.4195
----------------------------------------
Prompt 0 | Step 7
  Value:         0.4107
  Advantage:     -2.2638
 

In [59]:
for pair in sa_pairs:
    print(f"Prompt {pair['prompt_idx']} | Step {pair['step']}")
    print(f"  State:             {repr(pair['state_text'])}")
    print(f"  Action token:      {repr(pair['action_token'])} (id: {pair['action_token_id']})")
    print(f"  Policy logprob:    {pair['logprob']:.4f}")
    print(f"  Reference logprob: {pair['ref_logprob']:.4f}")
    print(f"  KL divergence:     {pair['kl_div']:.4f}")
    print(f"  Value:             {pair['value']:.4f}")
    print(f"  KL-adjusted reward:{pair['adjusted_reward']:.4f}")
    print(f"  TD error:          {pair['td_error']:.4f}")
    print(f"  Advantage:         {pair['advantage']:.4f}")
    print(f"  Reward-to-go:      {pair['reward_to_go']:.4f}")
    print("-" * 80)

Prompt 0 | Step 0
  State:             'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:'
  Action token:      ' A' (id: 362)
  Policy logprob:    -4.2081
  Reference logprob: -4.6794
  KL divergence:     0.4713
  Value:             1.1618
  KL-adjusted reward:-0.7301
  TD error:          -0.8670
  Advantage:         -6.3835
  Reward-to-go:      -5.2218
--------------------------------------------------------------------------------
Prompt 0 | Step 1
  State:             'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR: A'
  Action token:      ' woman' (id: 5220)
  Policy logprob:    -3.8091
  Reference logprob: -4.7323
  KL divergence:     0.9232
  Value:             1.0352
  KL-adjusted reward:-0.7753
  TD error:          -0.5560
  Advantage:         -5.8656
  Reward-to-go:      -4.8303
--------------------------------------------------------------------------------
Prompt 0 | Step 2
  State:             'SUBREDDIT: r/

In [60]:
import torch
import torch.nn as nn

def train_value_model_one_step(
    value_model, tokenizer, state_texts, reward_to_go_targets, device, optimizer, batch_size=32
):
    value_model.train()
    criterion = nn.MSELoss()
    total_loss = 0.0

    # Prepare data


    # Shuffle data for each epoch/step (optional)
    indices = torch.randperm(len(state_texts))
    state_texts = [state_texts[i] for i in indices]
    reward_to_go_targets = [reward_to_go_targets[i] for i in indices]

    for i in range(0, len(state_texts), batch_size):
        batch_states = state_texts[i:i+batch_size]
        batch_targets = torch.tensor(reward_to_go_targets[i:i+batch_size], dtype=torch.float32).to(device)

        inputs = tokenizer(batch_states, return_tensors="pt", padding=True, truncation=True).to(device)
        preds = value_model(inputs["input_ids"], inputs["attention_mask"]).squeeze(-1)

        loss = criterion(preds, batch_targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(batch_states)

    avg_loss = total_loss / len(state_texts)
    return avg_loss

# Setup optimizer if not already done
import torch.optim as optim
value_model = value_model.to(device)
value_optimizer = optim.AdamW(value_model.parameters(), lr=1e-4)


state_texts = [pair['state_text'] for pair in sa_pairs]
reward_to_go_targets = [pair['reward_to_go'] for pair in sa_pairs]
# Training step
avg_loss = train_value_model_one_step(
    value_model, tokenizer, state_texts, reward_to_go_targets, device, value_optimizer, batch_size=32
)
print(f"Value model training loss: {avg_loss:.4f}")

Value model training loss: 20.4032


In [78]:
state_texts = [pair['state_text'] for pair in sa_pairs]
action_token_ids = [pair['action_token_id'] for pair in sa_pairs]
old_logprobs = torch.tensor([pair['logprob'] for pair in sa_pairs], dtype=torch.float32)
advantages = torch.tensor([pair['advantage'] for pair in sa_pairs], dtype=torch.float32)

def get_new_logprobs(policy_model, tokenizer, state_texts, action_token_ids, device, batch_size=32):
    new_logprobs = []
    for i in range(0, len(state_texts), batch_size):
        batch_states = state_texts[i:i+batch_size]
        batch_action_ids = action_token_ids[i:i+batch_size]
        inputs = tokenizer(batch_states, return_tensors="pt", padding=True, truncation=True).to(device)
        outputs = policy_model(inputs["input_ids"], attention_mask=inputs["attention_mask"])
        logits = outputs.logits  # (batch, seq_len, vocab_size)
        # Get logprobs for the last token in each sequence
        log_probs = torch.nn.functional.log_softmax(logits[:, -1, :], dim=-1)  # (batch, vocab_size)
        batch_action_ids_tensor = torch.tensor(batch_action_ids, dtype=torch.long, device=device)
        batch_logprobs = log_probs[torch.arange(len(batch_states)), batch_action_ids_tensor]  # (batch,)
        new_logprobs.append(batch_logprobs)
    return torch.cat(new_logprobs, dim=0)  # (total_num_pairs,)

def ppo_clip_loss(new_logprobs, old_logprobs, advantages, clip_epsilon=0.2):
    ratio = torch.exp(new_logprobs - old_logprobs)
    unclipped = ratio * advantages
    clipped = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantages
    loss = -torch.mean(torch.min(unclipped, clipped))
    return loss 

def train_policy_model_one_step(
    policy_model, tokenizer, state_texts, action_token_ids, old_logprobs, advantages, device, optimizer, clip_epsilon=0.2, batch_size=32
):
    policy_model.train()
    # Prepare data
    state_texts = [pair['state_text'] for pair in sa_pairs]
    action_token_ids = [pair['action_token_id'] for pair in sa_pairs]
    old_logprobs = torch.tensor([pair['logprob'] for pair in sa_pairs], dtype=torch.float32).to(device)
    advantages = torch.tensor([pair['advantage'] for pair in sa_pairs], dtype=torch.float32).to(device)

    # Get new logprobs from current policy
    new_logprobs = get_new_logprobs(policy_model, tokenizer, state_texts, action_token_ids, device, batch_size)

    # Compute PPO-clip loss
    loss = ppo_clip_loss(new_logprobs.to(device), old_logprobs, advantages, clip_epsilon)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss.item()

import torch.optim as optim
policy_model = policy_model.to(device)
policy_optimizer = optim.AdamW(policy_model.parameters(), lr=1e-5)

loss = train_policy_model_one_step(
    policy_model, tokenizer, state_texts, action_token_ids, old_logprobs, advantages, device, policy_optimizer, clip_epsilon=0.2, batch_size=32
)
print(f"Policy model PPO loss: {loss:.4f}")

Policy model PPO loss: 3.5926


In [77]:
# 1. Start with a prompt
prompt = "SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:"
prompts = [prompt]

print("\n[START] Prompt:")
print(prompts)

# 2. Generate (s, a) pairs and logprobs from both policy and reference
sa_pairs, generated_texts = generate_token_level_sa_and_logprobs_with_ref(
    policy_model, sft_model, tokenizer, prompts, device, max_new_tokens=10
)
print("\n[AFTER GENERATION] sa_pairs[0]:")
print(sa_pairs[0])
print("\nGenerated text:", generated_texts[0])

# 3. Add KL divergence to sa_pairs (if not already present)
print("\n[BEFORE add_kl_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = add_kl_to_sa_pairs(sa_pairs)
print("[AFTER add_kl_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])

# 4. Compute reward for the generated sequence
print("\n[BEFORE compute_rewards_for_sequences] generated_texts:")
print(generated_texts)
rewards = compute_rewards_for_sequences(reward_model, tokenizer, prompts, generated_texts, device)
print("[AFTER compute_rewards_for_sequences] rewards:")
print(rewards)

# 5. Add KL-adjusted reward to sa_pairs
print("\n[BEFORE adjust_rewards_with_kl] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = adjust_rewards_with_kl(sa_pairs, rewards, kl_coef=0.1)
print("[AFTER adjust_rewards_with_kl] sa_pairs[0]:")
print(sa_pairs[0])

# 6. Add value predictions to sa_pairs
print("\n[BEFORE add_value_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = add_value_to_sa_pairs(sa_pairs, value_model, tokenizer, device, batch_size=32)
print("[AFTER add_value_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])

# 7. Compute TD errors
print("\n[BEFORE compute_td_errors] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = compute_td_errors(sa_pairs, gamma=0.99)
print("[AFTER compute_td_errors] sa_pairs[0]:")
print(sa_pairs[0])

# 8. Compute GAE advantages
print("\n[BEFORE compute_gae_advantages] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = compute_gae_advantages(sa_pairs, gamma=0.99, lam=0.95)
print("[AFTER compute_gae_advantages] sa_pairs[0]:")
print(sa_pairs[0])

# 9. Compute reward-to-go
print("\n[BEFORE add_reward_to_go_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])
sa_pairs = add_reward_to_go_to_sa_pairs(sa_pairs)
print("[AFTER add_reward_to_go_to_sa_pairs] sa_pairs[0]:")
print(sa_pairs[0])

# 10. Train value model
print("\n[BEFORE train_value_model_one_step] value model params (first layer):")
print(list(value_model.parameters())[0][0][:5])  # print a few params
state_texts = [pair['state_text'] for pair in sa_pairs]
reward_to_go_targets = [pair['reward_to_go'] for pair in sa_pairs]
value_optimizer = torch.optim.AdamW(value_model.parameters(), lr=1e-4)
avg_value_loss = train_value_model_one_step(
    value_model, tokenizer, state_texts, reward_to_go_targets, device, value_optimizer, batch_size=32
)
print("[AFTER train_value_model_one_step] value model params (first layer):")
print(list(value_model.parameters())[0][0][:5])
print("Value model training loss:", avg_value_loss)

# 11. Train policy model (PPO-clip)
print("\n[BEFORE train_policy_model_one_step] policy model params (first layer):")
print(list(policy_model.parameters())[0][0][:5])
state_texts = [pair['state_text'] for pair in sa_pairs]
action_token_ids = [pair['action_token_id'] for pair in sa_pairs]
old_logprobs = torch.tensor([pair['logprob'] for pair in sa_pairs], dtype=torch.float32).to(device)
advantages = torch.tensor([pair['advantage'] for pair in sa_pairs], dtype=torch.float32).to(device)
policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=1e-5)
avg_policy_loss = train_policy_model_one_step(
    policy_model, tokenizer, state_texts, action_token_ids, old_logprobs, advantages, device, policy_optimizer, clip_epsilon=0.2, batch_size=32
)
print("[AFTER train_policy_model_one_step] policy model params (first layer):")
print(list(policy_model.parameters())[0][0][:5])
print("Policy model PPO loss:", avg_policy_loss)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.



[START] Prompt:
['SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:']

[AFTER GENERATION] sa_pairs[0]:
{'state_text': 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:', 'action_token': ' I', 'action_token_id': 358, 'logprob': -1.483904242515564, 'ref_logprob': -1.9759289026260376, 'prompt_idx': 0, 'step': 0}

Generated text:  I think not! I'm going to have to

[BEFORE add_kl_to_sa_pairs] sa_pairs[0]:
{'state_text': 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:', 'action_token': ' I', 'action_token_id': 358, 'logprob': -1.483904242515564, 'ref_logprob': -1.9759289026260376, 'prompt_idx': 0, 'step': 0}
[AFTER add_kl_to_sa_pairs] sa_pairs[0]:
{'state_text': 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:', 'action_token': ' I', 'action_token_id': 358, 'logprob': -1.483904242515564, 'ref_logprob': -1.9759289026260376, 'prompt_idx': 0, 'step': 0, 'kl_div

In [79]:
import torch

# 1. Start with a prompt
prompt = "SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:"
prompts = [prompt]

print("="*80)
print("STEP 1: Input Prompt")
print("-"*80)
print(prompts[0])

# 2. Generate (s, a) pairs and logprobs from both policy and reference
print("\n" + "="*80)
print("STEP 2: Generate (state, action) pairs and logprobs")
sa_pairs, generated_texts = generate_token_level_sa_and_logprobs_with_ref(
    policy_model, sft_model, tokenizer, prompts, device, max_new_tokens=10
)
print(f"Generated summary: {generated_texts[0]}")
print(f"First (s, a) pair:\n{sa_pairs[0]}")

# 3. Add KL divergence to sa_pairs (if not already present)
print("\n" + "="*80)
print("STEP 3: Add KL divergence to each (s, a) pair")
sa_pairs = add_kl_to_sa_pairs(sa_pairs)
print(f"KL divergence for first pair: {sa_pairs[0]['kl_div']:.4f}")

# 4. Compute reward for the generated sequence
print("\n" + "="*80)
print("STEP 4: Compute reward for the generated sequence")
rewards = compute_rewards_for_sequences(reward_model, tokenizer, prompts, generated_texts, device)
print(f"Reward for generated summary: {rewards[0]:.4f}")

# 5. Add KL-adjusted reward to sa_pairs
print("\n" + "="*80)
print("STEP 5: Add KL-adjusted reward to each (s, a) pair")
sa_pairs = adjust_rewards_with_kl(sa_pairs, rewards, kl_coef=0.1)
print(f"KL-adjusted reward for first pair: {sa_pairs[0]['adjusted_reward']:.4f}")

# 6. Add value predictions to sa_pairs
print("\n" + "="*80)
print("STEP 6: Add value model predictions to each (s, a) pair")
sa_pairs = add_value_to_sa_pairs(sa_pairs, value_model, tokenizer, device, batch_size=32)
print(f"Value prediction for first pair: {sa_pairs[0]['value']:.4f}")

# 7. Compute TD errors
print("\n" + "="*80)
print("STEP 7: Compute TD error for each (s, a) pair")
sa_pairs = compute_td_errors(sa_pairs, gamma=0.99)
print(f"TD error for first pair: {sa_pairs[0]['td_error']:.4f}")

# 8. Compute GAE advantages
print("\n" + "="*80)
print("STEP 8: Compute GAE advantage for each (s, a) pair")
sa_pairs = compute_gae_advantages(sa_pairs, gamma=0.99, lam=0.95)
print(f"GAE advantage for first pair: {sa_pairs[0]['advantage']:.4f}")

# 9. Compute reward-to-go
print("\n" + "="*80)
print("STEP 9: Compute reward-to-go for each (s, a) pair")
sa_pairs = add_reward_to_go_to_sa_pairs(sa_pairs)
print(f"Reward-to-go for first pair: {sa_pairs[0]['reward_to_go']:.4f}")

# 10. Train value model
print("\n" + "="*80)
print("STEP 10: Value model update (regress value to reward-to-go)")
state_texts = [pair['state_text'] for pair in sa_pairs]
reward_to_go_targets = [pair['reward_to_go'] for pair in sa_pairs]
value_optimizer = torch.optim.AdamW(value_model.parameters(), lr=1e-4)
print("Value model params (first 5):", list(value_model.parameters())[0][0][:5])
avg_value_loss = train_value_model_one_step(
    value_model, tokenizer, state_texts, reward_to_go_targets, device, value_optimizer, batch_size=32
)
print("Value model params (first 5, after update):", list(value_model.parameters())[0][0][:5])
print(f"Value model training loss: {avg_value_loss:.4f}")

# 11. Train policy model (PPO-clip)
print("\n" + "="*80)
print("STEP 11: Policy model PPO update (maximize clipped advantage)")
state_texts = [pair['state_text'] for pair in sa_pairs]
action_token_ids = [pair['action_token_id'] for pair in sa_pairs]
old_logprobs = torch.tensor([pair['logprob'] for pair in sa_pairs], dtype=torch.float32).to(device)
advantages = torch.tensor([pair['advantage'] for pair in sa_pairs], dtype=torch.float32).to(device)
policy_optimizer = torch.optim.AdamW(policy_model.parameters(), lr=1e-5)
print("Policy model params (first 5):", list(policy_model.parameters())[0][0][:5])
avg_policy_loss = train_policy_model_one_step(
    policy_model, tokenizer, state_texts, action_token_ids, old_logprobs, advantages, device, policy_optimizer, clip_epsilon=0.2, batch_size=32
)
print("Policy model params (first 5, after update):", list(policy_model.parameters())[0][0][:5])
print(f"Policy model PPO loss: {avg_policy_loss:.4f}")

print("\n" + "="*80)
print("SUMMARY: PPO step complete!")
print(f"Generated summary: {generated_texts[0]}")
print(f"Final reward: {rewards[0]:.4f}")
print("="*80)

Setting `pad_token_id` to `eos_token_id`:151643 for open-end generation.


STEP 1: Input Prompt
--------------------------------------------------------------------------------
SUBREDDIT: r/relationships
TITLE: Should I admit to snooping?
POST: ...
TL;DR:

STEP 2: Generate (state, action) pairs and logprobs
Generated summary:  Should you or should it?
The topic of this
First (s, a) pair:
{'state_text': 'SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: ...\nTL;DR:', 'action_token': ' Should', 'action_token_id': 12260, 'logprob': -2.2864532470703125, 'ref_logprob': -2.791689872741699, 'prompt_idx': 0, 'step': 0}

STEP 3: Add KL divergence to each (s, a) pair
KL divergence for first pair: 0.5052

STEP 4: Compute reward for the generated sequence
Reward for generated summary: 0.0146

STEP 5: Add KL-adjusted reward to each (s, a) pair
KL-adjusted reward for first pair: -0.0359

STEP 6: Add value model predictions to each (s, a) pair
Value prediction for first pair: -0.9528

STEP 7: Compute TD error for each (s, a) pair
TD error for first pair:

In [None]:
policy_dataset = load_dataset("CarperAI/openai_summarize_tldr")

print(policy_dataset)

policy_train_data = policy_dataset["train"]

for i in range(3):
    print(policy_dataset["train"][i]["prompt"])
    print(policy_dataset["train"][i]["label"])
    print("-" * 40)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'label'],
        num_rows: 116722
    })
    test: Dataset({
        features: ['prompt', 'label'],
        num_rows: 6553
    })
    valid: Dataset({
        features: ['prompt', 'label'],
        num_rows: 6447
    })
})
SUBREDDIT: r/relationships
TITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting
POST: Not sure if this belongs here but it's worth a try. 

Backstory:
When I (f/22) went through my first real breakup 2 years ago because he needed space after a year of dating roand  it effected me more than I thought. It was a horrible time in my life due to living with my mother and finally having the chance to cut her out of my life. I can admit because of it was an emotional wreck and this guy was stable and didn't know how to deal with me. We ended by him avoiding for a month or so after going to a festival with my friends. When I think back I wish he just en

In [None]:

reward_dataset = load_dataset("CarperAI/openai_summarize_comparisons")

print(reward_dataset)

reward_train_data = reward_dataset["train"]

for i in range(3):
    print(reward_dataset["train"][i]["prompt"])
    print(reward_dataset["train"][i]["chosen"])
    print(reward_dataset["train"][i]["rejected"])
    print("-" * 40)

DatasetDict({
    train: Dataset({
        features: ['prompt', 'chosen', 'rejected'],
        num_rows: 92534
    })
    test: Dataset({
        features: ['prompt', 'chosen', 'rejected'],
        num_rows: 83629
    })
    valid1: Dataset({
        features: ['prompt', 'chosen', 'rejected'],
        num_rows: 33082
    })
    valid2: Dataset({
        features: ['prompt', 'chosen', 'rejected'],
        num_rows: 50715
    })
})
SUBREDDIT: r/relationships
TITLE: To admit or not to admit snooping...
POST: I [25M] have snooped in the past and copped up to it to my gf [25F] of 6 years.  We talked it through.  It had been a year or two since the last time.  That's an issue I'm working on.

Now she has a new close male work friend.  I won't go into details, but she hides things from me with him and does other things to make me a bit suspicious.  So...I snooped again, and this time, all texts from her new friend have been deleted and I saw a google search for "how to get over a guy" near so

In [None]:
def preprocess_function(example):
    # You can add a prompt if you want, e.g., "Summarize: "
    input_text = example["prompt"]
    target_text = example["label"]
    model_inputs = tokenizer(
        input_text, max_length=512, truncation=True, padding="max_length"
    )
    labels = tokenizer(
        target_text, max_length=64, truncation=True, padding="max_length"
    )["input_ids"]
    model_inputs["labels"] = labels
    return model_inputs

tokenized_dataset = policy_dataset["train"].map(preprocess_function, batched=False)

Map: 100%|██████████| 116722/116722 [01:55<00:00, 1013.85 examples/s]


In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model




# Prepare a small dataset for demonstration
from datasets import load_dataset
dataset = load_dataset("CarperAI/openai_summarize_tldr", split="train[:100]")  # Use a small subset for speed

def collate_fn(batch):
    # Concatenate prompt and label for each example
    input_texts = [x["prompt"] + "\nTL;DR: " + x["label"] for x in batch]
    # Tokenize the concatenated text
    encodings = tokenizer(
        input_texts,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=128,
    )
    # For SFT, labels are the same as input_ids
    labels = encodings["input_ids"].clone()
    encodings = {k: v.to(device) for k, v in encodings.items()}
    labels = labels.to(device)
    return encodings, labels

dataloader = DataLoader(dataset, batch_size=1, shuffle=True, collate_fn=collate_fn)

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

# Training loop
model.train()
for epoch in range(1):  # 1 epoch for demo
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        outputs = model(**inputs, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if batch_idx % 10 == 0:
            print(f"Epoch {epoch} Batch {batch_idx} Loss: {loss.item():.4f}")

NameError: name 'model' is not defined

In [None]:
model.eval()

prompt = "SUBREDDIT: r/relationships\nTITLE: Should I admit to snooping?\nPOST: I snooped on my girlfriend's phone and found some suspicious messages. Should I tell her about it or keep it to myself?\nTL;DR:"

inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
    outputs = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],
        max_new_tokens=10,  # Only generate 1 new token
        num_beams=1,
        do_sample=False,
    )
    summary = tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Prompt:\n", prompt)
print("\nGenerated TL;DR:\n", summary)

In [None]:
from peft import LoraConfig, get_peft_model

class RewardModel(nn.Module):
    def __init__(self, base_model_name, lora_config):
        super().__init__()
        base_model = AutoModelForCausalLM.from_pretrained(base_model_name, trust_remote_code=True)
        self.base_model = get_peft_model(base_model, lora_config)
        hidden_size = self.base_model.config.hidden_size
        self.value_head = nn.Linear(hidden_size, 1)

    def forward(self, input_ids, attention_mask):
        outputs = self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True,
            return_dict=True
        )
        last_hidden = outputs.hidden_states[-1]
        last_token_idx = attention_mask.sum(dim=1) - 1
        last_token_idx = last_token_idx.unsqueeze(1).unsqueeze(2).expand(-1, 1, last_hidden.size(-1))
        last_hidden_state = last_hidden.gather(1, last_token_idx).squeeze(1)
        reward = self.value_head(last_hidden_state).squeeze(-1)
        return reward

In [None]:
class RewardComparisonDataset(Dataset):
    def __init__(self, hf_dataset, tokenizer, max_length=512):
        self.data = hf_dataset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        prompt = item["prompt"]
        chosen = item["chosen"]
        rejected = item["rejected"]

        def encode(text):
            return self.tokenizer(
                text,
                truncation=True,
                max_length=self.max_length,
                padding="max_length",
                return_tensors="pt"
            )

        chosen_input = encode(prompt + chosen)
        rejected_input = encode(prompt + rejected)

        return {
            "input_ids_chosen": chosen_input["input_ids"].squeeze(),
            "attention_mask_chosen": chosen_input["attention_mask"].squeeze(),
            "input_ids_rejected": rejected_input["input_ids"].squeeze(),
            "attention_mask_rejected": rejected_input["attention_mask"].squeeze(),
        }

In [None]:
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

def pairwise_loss(chosen_reward, rejected_reward):
    return -torch.log(torch.sigmoid(chosen_reward - rejected_reward)).mean()

reward_model = RewardModel(model_name, lora_config)
reward_model.train()
optimizer = torch.optim.AdamW(reward_model.parameters(), lr=5e-6)

reward_dataset = RewardComparisonDataset(summarize_train_data, tokenizer)
reward_loader = DataLoader(reward_dataset, batch_size=4, shuffle=True)

for epoch in range(1):  # 1 epoch
    total_loss = 0
    for batch in tqdm(reward_loader):
        input_ids_chosen = batch["input_ids_chosen"]
        attention_mask_chosen = batch["attention_mask_chosen"]
        input_ids_rejected = batch["input_ids_rejected"]
        attention_mask_rejected = batch["attention_mask_rejected"]

        chosen_reward = reward_model(input_ids_chosen, attention_mask_chosen)
        rejected_reward = reward_model(input_ids_rejected, attention_mask_rejected)

        loss = pairwise_loss(chosen_reward, rejected_reward)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1} - Avg Loss: {total_loss / len(reward_loader)}")

In [None]:
from trl import PPOTrainer, PPOConfig
from transformers import AutoTokenizer

# 1. Prepare your policy and reward models (already done)
# 2. Prepare your tokenizer (already done)

# 3. PPO config
ppo_config = PPOConfig(
    batch_size=4,
    forward_batch_size=2,
    learning_rate=5e-6,
    log_with=None,  # or "wandb"
    mini_batch_size=4,
    optimize_cuda_cache=True,
)

# 4. Prepare your dataset (list of prompts)
prompts = [item["prompt"] for item in tldr_train_data]

# 5. PPOTrainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=summariser_model,
    ref_model=None,  # Optional: a reference model for KL penalty
    tokenizer=tokenizer,
    dataset=prompts,
    reward_model=reward_model,
)

# 6. Run PPO training
ppo_trainer.train()

In [None]:
## PPO

