In [1]:

# In Colab, run this cell to install everything
# Run this cell (Colab or notebook shell)
!pip uninstall -y trl transformers peft safetensors accelerate sentencepiece || true

!pip install -q transformers 'datasets<3.0.0' accelerate 'trl==0.8.5' sentencepiece
# the exact trl version may change over time. If you hit API errors, bump/downgrade trl accordingly.


Found existing installation: trl 0.8.5
Uninstalling trl-0.8.5:
  Successfully uninstalled trl-0.8.5
Found existing installation: transformers 4.57.3
Uninstalling transformers-4.57.3:
  Successfully uninstalled transformers-4.57.3
[0mFound existing installation: safetensors 0.7.0
Uninstalling safetensors-0.7.0:
  Successfully uninstalled safetensors-0.7.0
Found existing installation: accelerate 1.12.0
Uninstalling accelerate-1.12.0:
  Successfully uninstalled accelerate-1.12.0
Found existing installation: sentencepiece 0.2.1
Uninstalling sentencepiece-0.2.1:
  Successfully uninstalled sentencepiece-0.2.1


In [2]:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, AutoModelForSequenceClassification
from trl import AutoModelForCausalLMWithValueHead
from datasets import Dataset
from trl import PPOTrainer, PPOConfig

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


Device: cuda


In [3]:
# You can change 'gpt2' to a smaller or larger model depending on Colab GPU availability
policy_model_name = "gpt2"          # student policy to be fine-tuned via PPO
ref_model_name = "gpt2"             # reference model (frozen) used for KL penalty


In [4]:
tokenizer = AutoTokenizer.from_pretrained(policy_model_name)
# GPT2 does not have a pad token by default. Add one to avoid generation/padding issues.
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [151]:
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(policy_model_name).to(device)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name).to(device)

# Resize token embeddings if we added new tokens
# policy_model.resize_token_embeddings(len(tokenizer))
# ref_model.resize_token_embeddings(len(tokenizer))


# Resize embeddings on the underlying model
policy_model.pretrained_model.resize_token_embeddings(len(tokenizer))
ref_model.pretrained_model.resize_token_embeddings(len(tokenizer))


Embedding(50258, 768)

In [152]:
# Put reference model in eval and freeze parameters
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

In [153]:
#  reward model: use a sentiment classifier as a proxy reward model
# We will reward outputs that are classified as 'positive' with higher score.
# Replace this with your trained reward model when available.

reward_pipeline = pipeline(
    "sentiment-analysis",
    model="distilbert-base-uncased-finetuned-sst-2-english",
    tokenizer="distilbert-base-uncased-finetuned-sst-2-english",
    device=0 if device == "cuda" else -1
)

Device set to use cuda:0


In [154]:
from collections import Counter
import re

TEMPLATE_PATTERNS = [
    r"Rated\s*\d+\s*out of\s*5",
    r"\bAnonymous\b",
    r"Make sure you're using",
    r"Note:",
]

def contains_template(response):
    for pat in TEMPLATE_PATTERNS:
        if re.search(pat, response, flags=re.IGNORECASE):
            return True
    return False

def ngram_repetition_ratio(text, n=3):
    tokens = text.split()
    if len(tokens) < n + 1:
        return 0.0
    ngrams = [" ".join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
    counts = Counter(ngrams)
    repeats = sum(c-1 for c in counts.values() if c > 1)
    return repeats / max(len(ngrams), 1)

def reward_fn(prompts, responses, tokenizer=None):
    outs = reward_pipeline(responses, truncation=True)
    rewards = []
    for i, o in enumerate(outs):
        # sentiment
        if o["label"].upper().startswith("POS"):
            sentiment_score = o["score"]
        else:
            sentiment_score = 1.0 - o["score"]

        resp = responses[i]
        # prompt copy ratio
        prompt_words = set(prompts[i].split())
        resp_words = resp.split()
        if len(resp_words) == 0:
            prompt_copy_ratio = 1.0
        else:
            common = sum(1 for w in resp_words if w in prompt_words)
            prompt_copy_ratio = common / len(resp_words)
        prompt_copy_penalty = prompt_copy_ratio * 1.5

        # ngram repetition penalty
        rep_ratio = ngram_repetition_ratio(resp, n=3)
        rep_penalty = rep_ratio * 3.0   # stronger penalty than before

        # template penalty
        template_pen = 1.5 if contains_template(resp) else 0.0

        # length bonus using tokenizer tokens
        if tokenizer is not None:
            token_len = len(tokenizer(resp)["input_ids"])
        else:
            token_len = len(resp_words)
        length_bonus = 0.25 if 20 <= token_len <= 60 else -0.25

        # combine, center sentiment to [-0.5,0.5]
        base = (sentiment_score - 0.5) * 1.2
        reward = base + length_bonus - rep_penalty - prompt_copy_penalty - template_pen

        # clip for stability
        reward = max(-2.0, min(2.0, reward))
        rewards.append(float(reward))
    return rewards


In [155]:
prompts = [
    "Write a short positive product review for a phone that has great battery life.",
    "Write a short product review for a phone with disappointing camera quality.",
    "Give a friendly recommendation for a cafe you like in town.",
    "Explain why someone should try this new restaurant that serves healthy bowls.",
    "Write a short ad copy praising the speed and reliability of a new SSD drive.",
    "Write a short positive review for a laptop with excellent performance.",
    "Write a short review for a Bluetooth speaker with weak bass.",
    "Give a friendly recommendation for a local bakery known for fresh pastries.",
    "Explain why someone should visit a cozy tea house you enjoy.",
    "Write a short ad copy praising the durability of a new backpack.",
    "Describe why a travel pillow would make long flights more comfortable.",
    "Write a short positive review of a wireless mouse with smooth tracking.",
    "Write a short review for a pair of headphones with average noise isolation.",
    "Give a quick recommendation for a bookstore with a relaxing atmosphere.",
    "Explain why someone should try a juice bar known for fresh ingredients.",
    "Write a short positive product review for a smartwatch with helpful health features.",
    "Write a short review for a fitness tracker with limited battery life.",
    "Recommend a quiet study cafe with comfortable seating.",
    "Explain why a family would enjoy a friendly neighborhood diner.",
    "Write an ad copy highlighting the clarity of a new 4K monitor.",
    "Describe why a reusable water bottle is a great daily companion.",
    "Write a short review about a tablet with a sharp display.",
    "Write a short review about a tablet with slow charging speed.",
    "Give a friendly recommendation for a cozy pizza place.",
    "Explain why someone might enjoy a peaceful lakeside park.",
    "Write an ad copy for a pen that writes smoothly.",
    "Write a positive review for a keyboard with soft key travel.",
    "Write a review for a monitor with poor viewing angles.",
    "Give a friendly recommendation for a local craft shop.",
    "Explain why someone should try a smoothie place with fresh fruit.",
    "Write a positive review for a vacuum cleaner with strong suction.",
    "Write a review for a hair dryer that gets too loud.",
    "Recommend a quiet lounge for reading and relaxing.",
    "Explain why someone should try a sandwich shop known for fresh bread.",
    "Write an ad copy praising a portable charger with fast output.",
    "Write a positive review for a pair of stylish sunglasses.",
    "Write a review for a wallet that feels too bulky.",
    "Give a friendly recommendation for a music store with vintage instruments.",
    "Explain why someone should visit a small family owned ice cream shop.",
    "Write an ad copy for running shoes designed for comfort.",
    "Write a positive review for a backpack with many pockets.",
    "Write a review for a coffee maker that takes too long to brew.",
    "Give a friendly recommendation for a peaceful walking trail.",
    "Explain why someone should try a new salad bar with organic options.",
    "Write an ad copy for a camera tripod that is lightweight.",
    "Write a positive review for a portable fan that works quietly.",
    "Write a review for a desk lamp that is too dim.",
    "Give a friendly recommendation for a neighborhood flower shop.",
    "Explain why someone should try a bakery with fresh morning pastries.",
    "Write a short positive review for noise cancelling headphones.",
    "Write a short review for a budget laptop with slow performance.",
    "Recommend a family friendly hotel with good breakfast options.",
    "Explain why someone should try a new vegan restaurant downtown.",
    "Write a short ad for a portable Bluetooth speaker with clear sound.",
    "Write a short positive review for a dishwasher that cleans well.",
    "Write a short review for a blender that struggles with ice.",
    "Recommend a quiet co working space for freelancers.",
    "Explain why a city park is perfect for weekend picnics.",
    "Write an ad for an ergonomic office chair that reduces back pain.",
    "Write a positive review for a streaming service with great originals.",
    "Write a review for an app that crashes frequently.",
    "Recommend a weekend getaway spot for couples on a budget.",
    "Explain why a neighborhood farmer market is worth visiting.",
    "Write an ad for a lightweight luggage set for frequent travelers.",
    "Write a positive review for a smart thermostat that saves energy.",
    "Write a short review for an air purifier with loud fan noise.",
    "Recommend a scenic hiking trail for beginners.",
    "Explain why someone should attend a local pottery workshop.",
    "Write an ad praising the battery life of a new wireless earbuds model.",
    "Write a positive review for a cookbook with easy weeknight recipes.",
    "Write a review for a slow responding GPS navigation app.",
    "Recommend a quiet library with private study rooms.",
    "Explain why a community garden is great for families.",
    "Write an ad for a stain resistant sofa fabric.",
    "Write a positive review for a yoga mat that offers good grip.",
    "Write a review for a smartwatch with unreliable step counting.",
    "Recommend a cozy bed and breakfast near the coast.",
    "Explain why someone should try a cooking class focused on pasta.",
    "Write an ad for a compact espresso machine for small kitchens.",
    "Write a positive review for a photo printing service with fast delivery.",
    "Write a review for a slow charging wireless phone charger.",
    "Recommend an independent film theater with curated screenings.",
    "Explain why a local charity shop is a great place to find bargains.",
    "Write an ad for a kitchen knife set with precision edges.",
    "Write a positive review for a rain jacket that keeps you dry.",
    "Write a review for a dryer that shrinks delicate clothes.",
    "Recommend a kid friendly museum with interactive exhibits.",
    "Explain why a neighborhood walking tour is a fun activity.",
    "Write an ad for a high capacity external hard drive.",
    "Write a positive review for a language learning app with fun lessons.",
    "Write a review for a photo editing software that is confusing to use.",
    "Recommend a dog friendly cafe with outdoor seating.",
    "Explain why a local music festival is worth the trip.",
    "Write an ad for a versatile multi tool for camping trips.",
    "Write a positive review for an electric toothbrush with long battery life.",
    "Write a review for a fashion brand with inconsistent sizing.",
    "Recommend a rooftop bar with great city views.",
    "Explain why someone should try a guided city bike tour.",
    "Write an ad for a quick charging power bank.",
    "Write a positive review for a mattress that improved your sleep.",
    "Write a review for a slow responding smart home hub.",
    "Recommend a botanical garden for a relaxing afternoon.",
    "Explain why a weekend farmers market supports local producers.",
    "Write an ad for a premium tea sampler set.",
    "Write a positive review for a compact projector for movie nights.",
    "Write a review for sneakers that wear out quickly.",
    "Recommend a craft cocktail bar with unique ingredients.",
    "Explain why a pottery studio membership is a good hobby choice.",
    "Write an ad for a noise reducing window insert.",
    "Write a positive review for an online course that taught practical skills.",
    "Write a review for a bike with frequent gear slipping.",
    "Recommend a family friendly amusement park with gentle rides.",
    "Explain why a local book club is a great way to meet people.",
    "Write an ad for a reliable electric scooter for short commutes.",
    "Write a positive review for a subscription box with curated snacks.",
    "Write a review for a slow responding customer support service.",
    "Recommend a small neighborhood bakery with artisan breads.",
    "Explain why a weekend pottery retreat is refreshing.",
    "Write an ad for a compact multi port USB charger.",
    "Write a positive review for a garden hose that does not kink.",
    "Write a review for a smartwatch with short battery life.",
    "Recommend a historic walking route with informative plaques.",
    "Explain why a local cooking co op is a good community resource.",
    "Write an ad for a durable picnic blanket for outdoor events.",
    "Write a positive review for a noise isolating microphone for podcasts.",
    "Write a review for a slow loading news app.",
    "Recommend a casual brunch spot with excellent pancakes.",
    "Explain why someone should volunteer at an animal shelter.",
    "Write an ad for a compact air fryer for busy kitchens.",
    "Write a positive review for a compact camera with sharp images.",
    "Write a review for a jacket that loses color after washing.",
    "Recommend a family friendly lake for swimming and boating.",
    "Explain why a photography walk is great for learning composition.",
    "Write an ad for a travel pillow that supports your neck.",
    "Write a positive review for a cloud storage service with simple sharing.",
    "Write a review for a slow performing video editor.",
    "Recommend a scenic lookout point for sunrise photos.",
    "Explain why joining a local sports league improves fitness and social life.",
    "Write an ad for a soft throw blanket perfect for movie nights.",
    "Write a positive review for an intuitive budgeting app.",
    "Write a review for headphones with poor microphone quality."
]



In [156]:

batch_size = 8
num_steps = 100

ppo_config = PPOConfig(
    learning_rate=1.41e-5,
    batch_size=batch_size,            # number of samples per PPO update
    mini_batch_size=1,       # inner mini-batch
    cliprange=0.2,           # PPO clipping range for policy ratio
    cliprange_value=0.2,     # clipping for value loss if used
    vf_coef=0.5,             # value loss coefficient
    # fixed KL penalty coefficient (use this since adaptive keys are not present)
    target_kl=0.1,
    whiten_rewards=False,
)



In [157]:
# Create trainer
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
)



In [158]:
!set CUDA_LAUNCH_BLOCKING=1

In [159]:

# Generation and PPO training loop
gen_kwargs = {
    "max_new_tokens": 100,
    "do_sample": True,
    "top_k": 20,
    "top_p": 0.95,
    "temperature": 0.7,
    "pad_token_id": tokenizer.pad_token_id,
}

In [160]:


# for step in range(num_steps):
#     prompt = prompts[step % len(prompts)]

#     query_tensors = tokenizer(prompt, return_tensors="pt").to(device)

#     generated_ids = policy_model.generate(**query_tensors, **gen_kwargs)
#     gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

#     response = gen_text[len(prompt):].strip() if gen_text.startswith(prompt) else gen_text


#     if len(response.strip()) == 0:
#         response = "N/A"

#     # Compute reward (scalar)
#     reward = reward_fn([response])[0]


#     query_ids = query_tensors["input_ids"].squeeze(0)
#     response_ids = tokenizer(response, return_tensors="pt", add_special_tokens=True).input_ids.squeeze(0).to(device)

#     stats = ppo_trainer.step(
#         [query_ids],
#         [response_ids],
#         [torch.tensor(reward, device=device)]
#     )



#     if (step + 1) % 5 == 0:
#         print(f"Step {step+1}/{num_steps} | Reward={reward:.4f} | Response='{response[:120]}'")


In [161]:


batch_queries, batch_responses, batch_rewards = [], [], []

for step in range(num_steps):
    prompt = prompts[step % len(prompts)]

    # Encode prompt
    query_tensors = tokenizer(prompt, return_tensors="pt").to(device)

    # Generate response
    generated_ids = policy_model.generate(**query_tensors, **gen_kwargs)
    gen_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Extract response
    response = gen_text[len(prompt):].strip() if gen_text.startswith(prompt) else gen_text
    if len(response.strip()) == 0:
        response = "N/A"

    # Compute reward
    reward = reward_fn([prompt],[response])[0]  # single float

    # Prepare tensors
    query_ids = query_tensors["input_ids"].squeeze(0)
    response_ids = tokenizer(response, return_tensors="pt", add_special_tokens=True).input_ids.squeeze(0).to(device)

    # Accumulate batch
    batch_queries.append(query_ids)
    batch_responses.append(response_ids)
    batch_rewards.append(torch.tensor(reward, device=device))

    # When batch is full, perform PPO step
    if len(batch_queries) == batch_size:
        stats = ppo_trainer.step(batch_queries, batch_responses, batch_rewards)

        # Reset batch
        batch_queries, batch_responses, batch_rewards = [], [], []

        print(f"Step {step+1}/{num_steps} | Last Reward={reward:.4f} | Last Response='{response[:120]}'")


Step 8/100 | Last Reward=-0.0179 | Last Response='The bakery is located in the beautiful town of Nisqually, British Columbia, on the eastern coast of Canada. The bakery i'




Step 16/100 | Last Reward=-0.8783 | Last Response='We'll use the same quality of products and services.

Our goal is to create a great product. We will provide you with a '
Step 24/100 | Last Reward=-0.0294 | Last Response='We're a family owned and operated by a single mom and a dad.

A great restaurant with a great food menu.

We have a lot '
Step 32/100 | Last Reward=-1.5470 | Last Response='It's a great thing to have. But if you've never had a dryer before, it's not a great thing to have. But, I've got a few '
Step 40/100 | Last Reward=-2.0000 | Last Response='An ad copy of a book or magazine.

An ad copy of a book or magazine for sale.

An ad copy of a book or magazine for sale'
Step 48/100 | Last Reward=-0.9187 | Last Response='Please send us an email with your name and email address.

Thank you for your time and interest.

Best,

Michael
We are '
Step 56/100 | Last Reward=-0.7981 | Last Response='Reviewed By Date Rating

Reviewed By Date Rating 1 of 1 reviews

Reviewed By Date Ra

In [162]:

test_prompt = "Write a positive review for an electric toothbrush with long battery life."

# Encode prompt
inputs = tokenizer(test_prompt, return_tensors="pt").to(device)

# Generate from policy model (trained)
policy_output_ids = policy_model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.pad_token_id)
policy_response = tokenizer.decode(policy_output_ids[0], skip_special_tokens=True)

# Generate from reference model (original)
ref_output_ids = ref_model.generate(**inputs, max_new_tokens=50, pad_token_id=tokenizer.pad_token_id)
ref_response = tokenizer.decode(ref_output_ids[0], skip_special_tokens=True)

print("Prompt:", test_prompt)
print(("===" * 50 ))
print("\nPolicy Model Response:\n", policy_response)
print(("===" * 50 ))
print("\nReference Model Response:\n", ref_response)
print(("===" * 50 ))


Prompt: Write a positive review for an electric toothbrush with long battery life.

Policy Model Response:
 Write a positive review for an electric toothbrush with long battery life.

The new, long battery is a long battery battery that is designed to be a long battery for the most efficient and easy to use.

The long battery is a long battery that is designed to be a long battery for the most efficient

Reference Model Response:
 Write a positive review for an electric toothbrush with long battery life.

The best way to get a good quality toothbrush is to buy a good quality toothbrush.

The best way to get a good quality toothbrush is to buy a good quality toothbrush.

The best way to get a
