In [1]:
import pandas as pd
import torch
import torch.nn.functional as F
from datasets import Dataset, load_dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, pipeline
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
from trl.core import LengthSampler

from WARPTrainer import WARPTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
raw_datasets = load_dataset("imdb", split="train")
raw_datasets = raw_datasets['text'][:24960]
raw_datasets = Dataset.from_dict({"query": raw_datasets})

In [3]:
config = PPOConfig(
    model_name="lvwerra/gpt2-imdb",
    learning_rate=1.41e-5,
    remove_unused_columns=False,
    batch_size = 128,
    kl_penalty = "mse"
   
)

In [4]:
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

In [5]:
def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["query"])
    return sample

raw_datasets = raw_datasets.map(tokenize, batched=False)

Map:   0%|          | 0/24960 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 24960/24960 [00:10<00:00, 2450.60 examples/s]


In [6]:
def preprocess_function(examples):
        new_examples = {
            "input_ids": [],
            "query": [],
        }

       
        
        new_examples["input_ids"] = examples["input_ids"][:15]
        new_examples["query"] = tokenizer.decode(examples["input_ids"][:15])
            
                
        return new_examples

datasets = raw_datasets.map(preprocess_function, batched=False)

Map: 100%|██████████| 24960/24960 [00:02<00:00, 9256.11 examples/s]


In [7]:
datasets["query"][:5]

['I rented I AM CURIOUS-YELLOW from my video store',
 '"I Am Curious: Yellow" is a risible and pretentious ste',
 'If only to avoid making this type of film in the future. This film',
 "This film was probably inspired by Godard's Masculin, fé",
 'Oh, brother...after hearing about this ridiculous film for umpteen years']

In [8]:
datasets = datasets.filter(
    lambda x: len(x["input_ids"]) == 15
    )

Filter: 100%|██████████| 24960/24960 [00:00<00:00, 187671.67 examples/s]


In [9]:
datasets.set_format("pytorch")
dataloader = DataLoader(datasets, batch_size= 128, shuffle=True) 

In [10]:
trainer = WARPTrainer(
    model=model,
    ref_model=ref_model,
    config=config,
    dataset=datasets,
    tokenizer=tokenizer,
)

In [11]:
#Параметры для генерации ответов моделью
generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

In [12]:
#Загрузка предварительно обученой модели вознаграждения и создание pipeline
r_model = AutoModelForSequenceClassification.from_pretrained("reward_model", config=AutoConfig.from_pretrained('reward_model/config.json'))
bert_tokenizer = AutoTokenizer.from_pretrained("distilbert-base-cased")
sentiment_pipe = pipeline("sentiment-analysis", model=r_model, tokenizer=bert_tokenizer, device="cuda")
sent_kwargs = {"return_all_scores": True, "function_to_apply": "none", "batch_size": 16}

In [13]:
#Основной цикл дообучения на батчах данных


output_min_length = 4
output_max_length = 16
output_length_sampler = LengthSampler(output_min_length, output_max_length)


generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,
}

all_stats = []
for epoch, batch in tqdm(enumerate(dataloader)):
    query_tensors = batch["input_ids"]
    
    
    #### Get response from gpt2
    response_tensors = []
    query_tensors_l = []
    for query in query_tensors:
        query_tensors_l.append(query)
        gen_len = output_length_sampler()
        generation_kwargs["max_new_tokens"] = gen_len
        response = trainer.generate(query.to("cuda"), **generation_kwargs)
        response_tensors.append(response.squeeze()[-gen_len:])
    batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
    

    #### Compute sentiment score
    texts = [q + r for q, r in zip(batch["query"], batch["response"])]
    pipe_outputs = sentiment_pipe(texts, **sent_kwargs)
    
    rewards = [torch.tensor(output[0]["score"]) for output in pipe_outputs]
    

    #### Run step
    try:
        stats = trainer.step(query_tensors_l, response_tensors, rewards)
        all_stats.append(stats)
    except Exception as e:
        print(e)
    

0it [00:00, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
  attn_output = torch.nn.functional.scaled_dot_product_attention(
10it [03:18, 20.50s/it]You seem to be using the pipelines sequentially on GPU. In order to maximize efficiency please use a dataset
195it [1:36:32, 29.71s/it]

Batch size (128) does not match number of examples - but got 126 for: queries





In [14]:
gen_kwargs = {"min_length": -1, "top_k": 0.0, "top_p": 1.0, "do_sample": True, "pad_token_id": tokenizer.eos_token_id}

In [15]:
bs = 16
data = dict()
datasets.set_format("pandas")
df_batch = datasets[:].sample(bs)
data["query"] = df_batch["query"].tolist()
query_tensors = df_batch["input_ids"].tolist()

response_tensors_ref, response_tensors = [], []

#### get response from gpt2 and gpt2_ref
for i in range(bs):
    gen_len = output_length_sampler()
    output = ref_model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors_ref.append(output)
    output = model.generate(
        torch.tensor(query_tensors[i]).unsqueeze(dim=0).to("cuda"), max_new_tokens=gen_len, **gen_kwargs
    ).squeeze()[-gen_len:]
    response_tensors.append(output)

#### decode responses
data["response (before)"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]
data["response (after)"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]

#### sentiment analysis of query/response pairs before/after
texts = [q + r for q, r in zip(data["query"], data["response (before)"])]

data["rewards (before)"] = [output[0]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

texts = [q + r for q, r in zip(data["query"], data["response (after)"])]
data["rewards (after)"] = [output[0]["score"] for output in sentiment_pipe(texts, **sent_kwargs)]

# store results in a dataframe
df_results = pd.DataFrame(data)
df_results

Unnamed: 0,query,response (before),response (after),rewards (before),rewards (after)
0,this show is one of the worst shows of ALL TIM...,ity throughout the film,OK OK OK OK,6.70796,6.590818
1,Sure this movie is not historically accurate b...,pp's OK OK,OK OK OK OK,-6.718466,-6.616941
2,Mild Spoilers<br /><br />In the near future,that Spike Lee does OK OK OK OK,OK OK OK OK OK OK OK OK,1.133858,0.825365
3,I've seen some very terrible horror movies in ...,was OK OK OK,OK OK OK OK,-0.448612,-0.953155
4,For anyone who has only seen Disney Production...,and the Beast' OK OK OK OK OK,OK OK OK OK OK OK OK OK OK,-2.353873,-2.755658
5,This is the touching story of two families in ...,of them are OK OK OK OK OK OK OK OK,OK OK OK OK OK OK OK OK OK OK OK,-6.118678,-6.125435
6,The spoilers in this review are offered as a p...,person watching the movie had done OK OK OK O...,OK OK OK OK OK OK OK OK OK OK OK OK OK OK,2.192267,0.971499
7,So this made for TV film scores only a 7.6 on ...,", OK OK OK OK OK OK OK OK OK OK OK OK",OK OK OK OK OK OK OK OK OK OK OK OK OK,0.895894,0.821184
8,This movie is very cool. If you're a fan of Ts...,ark having Luke Skywalker canon about to work ...,OK OK OK OK OK OK OK OK OK OK OK OK,-5.120286,-5.487232
9,Frankly I don't understand why this movie has ...,budget OK OK OK OK OK OK OK OK,OK OK OK OK OK OK OK OK OK,1.117311,0.125291


In [16]:
print("mean:")
display(df_results[["rewards (before)", "rewards (after)"]].mean())
print()
print("median:")
display(df_results[["rewards (before)", "rewards (after)"]].median())

mean:


rewards (before)   -1.094106
rewards (after)    -1.260838
dtype: float64


median:


rewards (before)   -1.054031
rewards (after)    -0.891460
dtype: float64