In [None]:
import os
os.makedirs("/content/drive/MyDrive/rlhf_checkpoints", exist_ok=True)
train_ppo_rlhf(
    policy=policy,
    ref_lm=ref_lm,
    tokenizer=tokenizer,
    rl_train_loader=rl_train_loader,
    rl_valid_loader=rl_valid_loader,
    epochs=4,
    lr=1e-6,
    ppo_epochs=1,
    save_dir="/content/drive/MyDrive/rlhf_checkpoints",
)


===== Epoch 1 / 4 =====


Epoch 1/4: 100%|██████████| 7500/7500 [1:08:12<00:00,  1.83it/s, R=0.585, KL=2.509, Loss=0.486, KLc=0.0896]



[Train] Epoch summary 1/4
 avg_reward = 0.6686
 avg_KL     = 1.0118
 avg_loss   = 0.3471
 KL_coef    = 0.0896
[Valid] avg reward = 0.6703
✔ Model saved to /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch1.pt

===== Epoch 2 / 4 =====


Epoch 2/4: 100%|██████████| 7500/7500 [1:08:19<00:00,  1.83it/s, R=0.651, KL=0.825, Loss=0.310, KLc=0.1263]



[Train] Epoch summary 2/4
 avg_reward = 0.6696
 avg_KL     = 1.0041
 avg_loss   = 0.3652
 KL_coef    = 0.1263
[Valid] avg reward = 0.6759
✔ Model saved to /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch2.pt

===== Epoch 3 / 4 =====


Epoch 3/4: 100%|██████████| 7500/7500 [1:07:29<00:00,  1.85it/s, R=0.533, KL=0.748, Loss=0.342, KLc=0.1468]



[Train] Epoch summary 3/4
 avg_reward = 0.6712
 avg_KL     = 1.0029
 avg_loss   = 0.4180
 KL_coef    = 0.1468
[Valid] avg reward = 0.6743
✔ Model saved to /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch3.pt

===== Epoch 4 / 4 =====


Epoch 4/4: 100%|██████████| 7500/7500 [1:07:34<00:00,  1.85it/s, R=0.707, KL=0.533, Loss=0.308, KLc=0.1655]



[Train] Epoch summary 4/4
 avg_reward = 0.6732
 avg_KL     = 1.0027
 avg_loss   = 0.3957
 KL_coef    = 0.1655
[Valid] avg reward = 0.6801
✔ Model saved to /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch4.pt


In [None]:
import torch
from transformers import AutoTokenizer
import os

device = "cuda" if torch.cuda.is_available() else "cpu"

CKPT_DIR = "/content/drive/MyDrive/rlhf_checkpoints"

tokenizer = AutoTokenizer.from_pretrained("distilgpt2", padding_side="right")
tokenizer.pad_token = tokenizer.eos_token

def load_policy(epoch, dir):
    ckpt_path = os.path.join(dir, f"ppo_policy_epoch{epoch}.pt")
    print(f"Loading: {ckpt_path}")
    policy = PolicyValueModel("distilgpt2").to(device)
    policy.load_state_dict(torch.load(ckpt_path, map_location=device))
    policy.eval()
    return policy

@torch.no_grad()
def generate(model, prompt, max_new_tokens=40, sample=True):
    enc = tokenizer(prompt, return_tensors="pt").to(device)
    out = model.lm.generate(
        **enc,
        max_new_tokens=max_new_tokens,
        do_sample=sample,
        top_k=50,
        top_p=0.95,
    )
    return tokenizer.decode(out[0], skip_special_tokens=True)

In [None]:
policy_e1 = load_policy(1, CKPT_DIR)
policy_e2 = load_policy(2, CKPT_DIR)
policy_e3 = load_policy(3, CKPT_DIR)
policy_e4 = load_policy(4, CKPT_DIR)

prompt = "The movie was"
print("=== Epoch 1 ===")
print(generate(policy_e1, prompt))

print("\n=== Epoch 2 ===")
print(generate(policy_e2, prompt))

print("\n=== Epoch 3 ===")
print(generate(policy_e3, prompt))

print("\n=== Epoch 4 ===")
print(generate(policy_e4, prompt))

Loading: /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch1.pt
Loading: /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch2.pt
Loading: /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch3.pt
Loading: /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch4.pt
=== Epoch 1 ===
The movie was written during that time when the music was so prevalent, it seemed that some viewers were going to have a great moment watching it, however, I just don't think there was a lot of entertainment to

=== Epoch 2 ===
The movie was shot in black. I must say that this is not one of the most important movie I've seen. I'm not sure if the movie is a great musical; it will entertain you as well as

=== Epoch 3 ===
The movie was supposed to be a comedy, it was not. I actually liked the script and was pleasantly surprised. My friend was great too. Robert Downey Jr. did an amazing job portraying it and I'm

=== Epoch 4 ===
The movie was based on a great TV movie called THE FILMMAKER , and sta

In [None]:
ppo_policy = load_policy(4,CKPT_DIR)

ppo_metrics = eval_metrics(
    ppo_policy.lm,
    sft_test_loader,
    max_batches=200,
    use_bertscore=True
)

print("\n=== PPO Epoch 4 (RLHF) ===")
for k, v in ppo_metrics.items():
    print(f"{k:12s}: {v}")

Loading: /content/drive/MyDrive/rlhf_checkpoints/ppo_policy_epoch4.pt


Eval: 100%|██████████| 200/200 [08:45<00:00,  2.63s/it, pairs=1600]



=== PPO Epoch 4 (RLHF) ===
bleu        : 0.002936413330527698
rouge1      : 0.19061815299177629
rouge2      : 0.013334473249201348
rougel      : 0.12112305243322945
meteor      : 0.09206513724604497
avg_len     : 41.223125
avg_rep     : 0.1617872499011467
distinct1   : 0.16394014282032232
distinct2   : 0.590192209083705
samples     : 1600
bertscore_f1: 0.8261020183563232
