# 🚀 PPO Fine-Tuning with a Custom Reward Function
This notebook shows how to fine-tune a language model using PPO from Hugging Face's `trl`, with a simple reward function.

In [None]:
# PPO Fine-Tuning of a Language Model with a Custom Reward Function

# 🛠️ Setup
!pip install -q transformers datasets trl accelerate bitsandbytes

from transformers import AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from datasets import load_dataset
import torch

# ⚙️ Load model & tokenizer
model_name = "tiiuae/falcon-1b"  # small model for quick test
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name, load_in_8bit=True, device_map="auto")

# 📄 Load small dataset
dataset = load_dataset("Abirate/english_quotes", split="train[:100]")

# 📈 Define a custom reward function
def simple_reward_fn(query, response):
    # Reward longer responses that contain "life"
    reward = 1.0 if "life" in response.lower() else -1.0
    reward += len(response) / 100.0  # small bonus for length
    return reward

# ⚙️ PPO Config
config = PPOConfig(
    model_name=model_name,
    learning_rate=1.41e-5,
    log_with=None,
    mini_batch_size=1,
    batch_size=2,
)

ppo_trainer = PPOTrainer(config, model, tokenizer)

# 🔁 Training loop
text_prompts = [f"What is the purpose of life?" for _ in range(100)]

for epoch, prompt in enumerate(text_prompts[:5]):  # Keep it short for test
    query_tensors = tokenizer(prompt, return_tensors="pt", padding=True).input_ids.to(model.device)

    response_tensors = model.generate(query_tensors, max_new_tokens=32, pad_token_id=tokenizer.pad_token_id)
    response_text = tokenizer.batch_decode(response_tensors[:, query_tensors.shape[-1]:], skip_special_tokens=True)

    rewards = [simple_reward_fn(prompt, response_text[0])]
    print(f"Epoch {epoch+1}: {response_text[0]} | Reward: {rewards[0]}")

    ppo_trainer.step([prompt], response_text, rewards)