In [1]:
import torch
from tqdm import tqdm
import pandas as pd

tqdm.pandas()

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead
from trl.core import LengthSampler

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = PPOConfig(
    model_name="gpt2",
    learning_rate=1.41e-5,
    log_with="wandb",
)

sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

In [3]:
import wandb

wandb.init()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjungliana[0m ([33mpiksle[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [4]:
def build_dataset(config, dataset_name="yelp_review_full", input_min_text_length=3, input_max_text_length=6):
    """
    Build dataset for training. This builds the dataset from `load_dataset`, one should
    customize this function to train the model on its own dataset.

    Args:
        dataset_name (`str`):
            The name of the dataset to be loaded.

    Returns:
        dataloader (`torch.utils.data.DataLoader`):
            The dataloader for the dataset.
    """

    # Define a custom function to convert ratings to True or False
    def convert_labels(example):
        example["label"] = 1 if example["label"] > 2 else 0
        return example
    
    tokenizer = AutoTokenizer.from_pretrained(config.model_name)
    tokenizer.pad_token = tokenizer.eos_token

    ds = load_dataset(dataset_name, split="test")
    ds = ds.rename_columns({"text": "review"})
    ds = ds.filter(lambda x: x["label"] != 2, batched=False)
    ds = ds.filter(lambda x: len(x["review"]) > 60, batched=False)
    ds = ds.filter(lambda x: len(x["review"]) < 100, batched=False)
    ds = ds.map(convert_labels)

    input_size = LengthSampler(input_min_text_length, input_max_text_length)

    def tokenize(sample):
        sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()]
        sample["query"] = tokenizer.decode(sample["input_ids"])
        return sample

    ds = ds.map(tokenize, batched=False)
    ds.set_format(type="torch")
    return ds

In [5]:
dataset = build_dataset(config)


def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

Map: 100%|██████████| 3366/3366 [00:01<00:00, 2635.38 examples/s]


In [6]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

In [7]:
ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)

In [8]:
device = ppo_trainer.accelerator.device
if ppo_trainer.accelerator.num_processes == 1:
    device = 0 if torch.cuda.is_available() else "cpu"  # to avoid a `pipeline` bug

tokenizer2 = AutoTokenizer.from_pretrained("distilbert-base-uncased", model_max_length=256)
model2 = AutoModelForSequenceClassification.from_pretrained("../models/yelpBERT")

sentiment_pipe = pipeline("sentiment-analysis", model=model2, device=device, tokenizer=tokenizer2)

In [9]:
text = "peaceful restaurant"
sentiment_pipe(text, **sent_kwargs)



[[{'label': 'NEGATIVE', 'score': -4.097566604614258},
  {'label': 'POSITIVE', 'score': 4.290657997131348}]]

In [10]:
text = "this swimming pool is always crowded"
sentiment_pipe(text, **sent_kwargs)

[[{'label': 'NEGATIVE', 'score': 2.777513265609741},
  {'label': 'POSITIVE', 'score': -2.28507399559021}]]

In [11]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}

In [12]:
output_min_length = 10
output_max_length = 24
output_length_sampler = LengthSampler(output_min_length, output_max_length)


generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}


for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    query_tensors = batch["input_ids"]

    #### Get response from gpt2
    response_tensors = []
    for query in query_tensors:
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = ppo_trainer.generate(query, **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]

    #### Run PPO step
    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
    ppo_trainer.log_stats(stats, batch, rewards)

0it [00:00, ?it/s]You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
9it [32:41, 217.93s/it]


KeyboardInterrupt: 

In [13]:
#### get a batch from the dataset
bs = 16
game_data = dict()
dataset.set_format("pandas")
df_batch = dataset[:].sample(bs)
game_data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
game_data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
game_data["response (after)"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(game_data["query"], game_data["response (before)"])]
game_data["rewards (before)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(game_data["query"], game_data["response (after)"])]
game_data["rewards (after)"] = [output[1]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(game_data)
df_results



Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,Glad to see so,many different fan favorites and priceless fe...,many people thanking you for the support in t...,4.288124,3.506455
1,The staff worked,"on the opposition environmental campaign, whi...",together with the Grizzlies and the mountain ...,-4.986053,3.997762
2,I visit often.,Instant rib photographing is a wonderfully ef...,ÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂÂ,4.225183,3.745905
3,Best place ever,"except in-game.Ghoul Opens fierce, battleground",for the Hill.\n\nwww.punt.com,4.306678,4.24065
4,Love it!!!,It's my first beer. I am trying to discover m...,I will go here and to the team. I will be a b...,4.35002,4.312842
5,One of my,favorite parts of writing this article is tha...,favorite books by an author of near and multi...,-3.03216,4.266473
6,Some of the,following tags are known to be seen in NSA ar...,"man's family dog, Sharwaj, five brothers and ...",-4.117185,3.552669
7,My favourite boutique in,tonight's admirers before me is Make Yourself...,"the East London area, with its contemporary a...",4.274474,4.275887
8,Excellent Thai food and service,"\n\nLunch: cool and smooth, with plenty of thi...",at a price you will have to meet to be a Stre...,4.297171,4.234684
9,Pristine,"= tmethynology, chalice = thatzanyone of pearl",- An incredible early model Apple detailed on...,3.945855,4.321208
