In [1]:
!pip install trl
import os
import torch
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from transformers import pipeline, AutoTokenizer
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
os.environ["WANDB_DISABLED"] = "true"

Collecting trl
  Downloading trl-0.9.6-py3-none-any.whl (245 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/245.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m245.8/245.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m245.8/245.8 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
Collecting accelerate (from trl)
  Downloading accelerate-0.32.1-py3-none-any.whl (314 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m314.1/314.1 kB[0m [31m31.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting datasets (from trl)
  Downloading datasets-2.20.0-py3-none-any.whl (547 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m547.8/547.8 kB[0m [31m34.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tyro>=0.5.11 (from trl)
  Downloading tyro-0.8.5-py3-none-any.whl (103 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [2]:
class LengthSampler:
    def __init__(self, min_value, max_value):
        self.values = list(range(min_value, max_value))
    def __call__(self):
        return np.random.choice(self.values)

input_size = LengthSampler(2, 8)
output_size = LengthSampler(4, 16)

In [3]:
generative_model = "lvwerra/gpt2-imdb"
active_model = AutoModelForCausalLMWithValueHead.from_pretrained(generative_model)
reference_model = AutoModelForCausalLMWithValueHead.from_pretrained(generative_model)

config.json:   0%|          | 0.00/577 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

In [4]:
tokenizer = AutoTokenizer.from_pretrained(generative_model)
tokenizer.pad_token = tokenizer.eos_token
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["review"])[:input_size()]
    sample["query"] = tokenizer.decode(sample["input_ids"])
    return sample
dataset = load_dataset("imdb", split="train")
dataset = dataset.select(range(min(len(dataset), 500)))
dataset = dataset.rename_columns({"text": "review"})
dataset = dataset.filter(lambda x: len(x["review"]) > 200, batched=False)
dataset = dataset.map(tokenize, batched=False)
dataset.set_format(type="torch")

tokenizer_config.json:   0%|          | 0.00/17.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Filter:   0%|          | 0/500 [00:00<?, ? examples/s]

Map:   0%|          | 0/495 [00:00<?, ? examples/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


In [5]:
config = PPOConfig(model_name=generative_model, learning_rate=1.41e-5)
ppo_trainer = PPOTrainer(config, active_model, reference_model, tokenizer, dataset=dataset, data_collator=lambda data: {key: [d[key] for d in data] for key in data[0]})
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"

In [6]:
sentiment_model = pipeline("sentiment-analysis", model="lvwerra/distilbert-imdb", device=device)

config.json:   0%|          | 0.00/735 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/333 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

In [7]:
sentiment_kwargs = {"top_k": None, "function_to_apply": "none", "batch_size": 16}
generation_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]
    response_tensors = []
    for query in query_tensors:
        generation_kwargs["max_new_tokens"] = output_size()
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-output_size():])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    sentiment_outputs = sentiment_model(texts, **sentiment_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in sentiment_outputs]
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

3it [00:51, 17.25s/it]


In [8]:
query = "This movie is"
query_tensor = tokenizer.encode(query, return_tensors="pt").squeeze()
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}
reference_output = reference_model.generate(query_tensor.clone().detach().unsqueeze(dim=0).to(device), max_new_tokens=output_size(), **gen_kwargs).squeeze()[-output_size():]
active_output = active_model.generate(query_tensor.clone().detach().unsqueeze(dim=0).to(device), max_new_tokens=output_size(), **gen_kwargs).squeeze()[-output_size():]
response_before = tokenizer.decode(reference_output)
response_after = tokenizer.decode(active_output)
reward_before = [active_output[1]["score"] for active_output in sentiment_model([query + response_before], **sentiment_kwargs)][0]
reward_after = [active_output[1]["score"] for active_output in sentiment_model([query + response_after], **sentiment_kwargs)][0]
print("Query:", query)
print("Response before:", response_before)
print("Response after:", response_after)
print("Reward before:", reward_before)
print("Reward after:", reward_after)

Query: This movie is
Response before: This movie is a solid episode in the
Response after:  is one for Tom and Jerry. So don't let the
Reward before: -2.296550989151001
Reward after: -0.30996468663215637
