In [15]:
import torch

from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from sklearn.model_selection import train_test_split

from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead

In [16]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [17]:
def cut(sample):
    sample["query"] = " ".join(sample["text"].split()[: 5])
    return sample

dataset = load_dataset("yelp_review_full", split="test")
dataset = dataset.filter(lambda x: len(x["text"]) > 80, batched=False)
dataset = dataset.filter(lambda x: len(x["text"]) < 120, batched=False)
dataset = dataset.filter(lambda x: x["label"] < 4, batched=False)
dataset = dataset.filter(lambda x: x["label"] > 0, batched=False)
dataset = dataset.map(cut, batched=False)
train_ds, test_ds = train_test_split(dataset,
                 test_size=0.99,
                 random_state=2345)
del train_ds['label']
del train_ds['text']

In [18]:
# 1. load a pretrained model
model = AutoModelForCausalLMWithValueHead.from_pretrained("Zohar/distilgpt2-finetuned-restaurant-reviews")
model_ref = AutoModelForCausalLMWithValueHead.from_pretrained("Zohar/distilgpt2-finetuned-restaurant-reviews")
tokenizer = AutoTokenizer.from_pretrained("Zohar/distilgpt2-finetuned-restaurant-reviews")
tokenizer.pad_token = tokenizer.eos_token

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
    "max_new_tokens": 25,
}

In [19]:
reward_tokenizer = AutoTokenizer.from_pretrained("finiteautomata/bertweet-base-sentiment-analysis", model_max_length=256)
reward_model = AutoModelForSequenceClassification.from_pretrained("finiteautomata/bertweet-base-sentiment-analysis")

sentiment_pipe = pipeline("sentiment-analysis", model=reward_model, device=device, tokenizer=reward_tokenizer)
sent_kwargs = {"top_k": None, "function_to_apply": "softmax", "batch_size": 1}

emoji is not installed, thus not converting emoticons or emojis into text. Install emoji: pip3 install emoji==0.6.0


In [20]:
# 2. initialize trainer
ppo_config = {"batch_size": 1}
config = PPOConfig(**ppo_config)
ppo_trainer = PPOTrainer(config, model, model_ref, tokenizer)

for query_txt in train_ds['query']:
    # 3. encode a query
    print("\n----------------------------")
    print(f"Query: {query_txt}")
    query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device)

    # 4. generate model response
    response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=True, **generation_kwargs)
    response_txt = tokenizer.decode(response_tensor[0], skip_special_tokens=True)
    print(f'Response: {response_txt}')

    # 5. define a reward for response
    # (this could be any reward such as human feedback or output from another model)
    pipe_outputs = sentiment_pipe(response_txt, **sent_kwargs)
    reward = [torch.tensor(next(val for val in pipe_outputs if val["label"] == "POS")['score'], device=model.pretrained_model.device)]
    print(f'Reward: {reward[0].item()}')

    # 6. train model with ppo
    train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward)

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.



----------------------------
Query: The outside and the inside
Response: The outside and the inside are wonderful. You can tell by the color of building with zig zags of chalk together making you feel like you are
Reward: 0.9553338885307312

----------------------------
Query: Two Words: Frozen Cosmos\n\nYou need
Response: Two Words: Frozen Cosmos\n\nYou need to study Greek 'Holy Bread House'

If you're a Diet, do yourself a favor and go to this Vietnamese
Reward: 0.0942615270614624

----------------------------
Query: Food was good. Service was
Response: Food was good. Service was ok. Duck salad was a little dry with low carb and not too spicy!
The BBQ Sandwich was the perfect shave for
Reward: 0.983796238899231

----------------------------
Query: Nice salad bar. Reminds me
Response: Nice salad bar. Reminds me of how I was in central Columbus.  .   Overall was good but cruise.    Husband was
Reward: 0.9850949048995972

----------------------------
Query: We got the spa suite,
Respon