In [1]:
!pip install trl
!pip install accelerate -U

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [9]:
import torch
import transformers
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, DataCollatorForLanguageModeling, BitsAndBytesConfig,AutoModelForSeq2SeqLM, AutoModelForSequenceClassification
from trl import RewardTrainer, SFTTrainer , AutoModelForSeq2SeqLMWithValueHead,PPOTrainer, PPOConfig
from trl.core import LengthSampler
from trl import create_reference_model
from datasets import Dataset, load_dataset
import numpy as np
import json
import random
import pandas as pd
from transformers import Trainer, TrainingArguments
from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead, create_reference_model

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()

In [4]:
dataset = load_dataset('CarperAI/openai_summarize_tldr')

In [5]:
dataset = dataset['train'].select(range(1000))

In [6]:
dataset

Dataset({
    features: ['prompt', 'label'],
    num_rows: 1000
})

In [7]:
dataset = dataset.rename_column("label", "response")

In [10]:
dataset

Dataset({
    features: ['prompt', 'response'],
    num_rows: 1000
})

In [11]:
### If we would have trained the 1st step(SFT), then that model will come here.

policy_model_id = "pszemraj/led-base-book-summary"
policy_model = AutoModelForSeq2SeqLM.from_pretrained(policy_model_id)
policy_model.to(device)
policy_tokenizer = AutoTokenizer.from_pretrained(policy_model_id)

In [12]:
ppo_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_model_id) 
ppo_model.to(device)

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): LEDForConditionalGeneration(
    (led): LEDModel(
      (shared): Embedding(50265, 768, padding_idx=1)
      (encoder): LEDEncoder(
        (embed_tokens): Embedding(50265, 768, padding_idx=1)
        (embed_positions): LEDLearnedPositionalEmbedding(16384, 768)
        (layers): ModuleList(
          (0-5): 6 x LEDEncoderLayer(
            (self_attn): LEDEncoderAttention(
              (longformer_self_attn): LEDEncoderSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (query_global): Linear(in_features=768, out_features=768, bias=True)
                (key_global): Linear(in_features=768, out_features=768, bias=True)
                (value_global): Linear(in_features=768, out_features=768, bias=True)
              )
  

In [13]:
# ref_model = AutoModelForSeq2SeqLMWithValueHead.from_pretrained(policy_model_id) 
# ref_model.to(device)

In [14]:
reward_model_directory = "reward_model"
reward_tokenizer_directory = "reward_model_tokenizer"

rm_model = AutoModelForSequenceClassification.from_pretrained(reward_model_directory).to(device)
rm_tokenizer = AutoTokenizer.from_pretrained(reward_tokenizer_directory)

In [15]:
split_ratio = 0.8  # 80% for training, 20% for evaluation
num_train_samples = int(split_ratio * len(dataset))
train_dataset = dataset.select(range(num_train_samples))
eval_dataset = dataset.select(range(num_train_samples, len(dataset)))


def tokenize_function(example):
    # Tokenize the prompt and store it as input_ids. Also return the response.
    return {
        "input_ids": policy_tokenizer(example["prompt"], return_tensors="pt", truncation=True, max_length=512)["input_ids"].squeeze(),
        "response": example["response"],
    }

train_dataset = train_dataset.map(tokenize_function, batched=False)
eval_dataset = eval_dataset.map(tokenize_function, batched=False)

In [16]:
train_dataset, eval_dataset

(Dataset({
     features: ['prompt', 'response', 'input_ids'],
     num_rows: 800
 }),
 Dataset({
     features: ['prompt', 'response', 'input_ids'],
     num_rows: 200
 }))

In [17]:
train_dataset[0]

{'prompt': "SUBREDDIT: r/relationships\nTITLE: I (f/22) have to figure out if I want to still know these girls or not and would hate to sound insulting\nPOST: Not sure if this belongs here but it's worth a try. \n\nBackstory:\nWhen I (f/22) went through my first real breakup 2 years ago because he needed space after a year of dating roand  it effected me more than I thought. It was a horrible time in my life due to living with my mother and finally having the chance to cut her out of my life. I can admit because of it was an emotional wreck and this guy was stable and didn't know how to deal with me. We ended by him avoiding for a month or so after going to a festival with my friends. When I think back I wish he just ended. So after he ended it added my depression I suffered but my friends helped me through it and I got rid of everything from him along with cutting contact. \n\nNow: Its been almost 3 years now and I've gotten better after counselling and mild anti depressants. My mothe

In [18]:
def collator(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

In [19]:
# Lets sample what the collator generates:
sample_data = [train_dataset[i] for i in range(3)]  # take first three examples
collated_data = collator(sample_data)
print(collated_data.keys())

dict_keys(['prompt', 'response', 'input_ids'])


In [20]:
learning_rate=1.41e-5
max_ppo_epochs=3
mini_batch_size=4
batch_size=16

DEFAULT_REJECTED_SUMMARY_TEXT = "This is a bad summary"

# Some initial values
output_min_length = 30
output_max_length = 150
output_length_sampler = LengthSampler(output_min_length, output_max_length)

generation_kwargs = {
    "temperature": 1.0,
    "min_length": 5,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True
}

max_ppo_steps = 100

In [21]:
config = PPOConfig(
    model_name=policy_model_id,
    learning_rate=learning_rate,
    ppo_epochs=max_ppo_epochs,
    mini_batch_size=mini_batch_size,
    batch_size=batch_size
)

In [22]:
ppo_trainer = PPOTrainer(config=config,
                         model=ppo_model,
                         #ref_model=ref_model,
                         tokenizer=policy_tokenizer,
                         dataset=train_dataset,
                         data_collator=collator)

In [23]:
import torch.nn.functional as F


def score_summaries(model, tokenizer, chosen_summary, rejected_summary):
    # Tokenize the inputs
    chosen_tokens = tokenizer(chosen_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)
    rejected_tokens = tokenizer(rejected_summary, return_tensors="pt", padding='max_length', truncation=True, max_length=512)

    chosen_tokens.to(device)
    rejected_tokens.to(device)

    # Get logits from the model
    with torch.no_grad():
        chosen_logits = model(**chosen_tokens).logits
        rejected_logits = model(**rejected_tokens).logits

    # Apply softmax to get probabilities
    chosen_probs = F.softmax(chosen_logits, dim=-1)
    rejected_probs = F.softmax(rejected_logits, dim=-1)

    # Assuming the positive class (indicating 'chosen' is good) is the second one
    chosen_score = chosen_probs[0][1].item()
    rejected_score = rejected_probs[0][1].item()

    # Extract logits for each summary
    chosen_logit = chosen_logits[0][1].item()
    rejected_logit = rejected_logits[0][1].item()

    return chosen_score, rejected_score, chosen_logit, rejected_logit

In [24]:
chosen_summary = "Water meter in another condo is not in our condo. What can we do legally to restore water to my condo complex?"
rejected_summary = "Go fix the problem."

chosen_score, rejected_score, chosen_logit, rejected_logit = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)

print(f"Chosen Score: {chosen_score:.4f}")
print(f"Rejected Score: {rejected_score:.4f}")

print(f"Chosen Logit: {chosen_logit:.4f}")
print(f"Rejected Logit: {rejected_logit:.4f}")

Chosen Score: 0.5161
Rejected Score: 0.4526
Chosen Logit: 0.3487
Rejected Logit: 0.0255


In [25]:
from tqdm import tqdm

for step, batch in tqdm(enumerate(ppo_trainer.dataloader)):
    if step >= max_ppo_steps: # Break when we reach max_steps.
        break

    prompt_tensors = batch["input_ids"]

    if isinstance(prompt_tensors, list) and all(isinstance(item, list) for item in prompt_tensors): # HACK!!! Check if original_prompt_tensors is a list of lists
        lengths = [len(seq) for seq in prompt_tensors] # Verify if sequences have fixed or variable length
        unique_lengths = set(lengths)

        if len(unique_lengths) > 1: # If sequences have variable lengths, pad them
            max_length = max(unique_lengths)
            original_prompt_tensors = [seq + [0] * (max_length - len(seq)) for seq in prompt_tensors]  # padding with zeros

        prompt_tensors = [torch.tensor(seq).to(device) for seq in prompt_tensors] # Convert original_prompt_tensors to individual tensors

    summary_tensors = []

    for prompt_tensor in prompt_tensors:
        prompt_tensor = torch.tensor(prompt_tensor).to(device)
        max_new_tokens = output_length_sampler()
        generation_kwargs["max_new_tokens"] = max_new_tokens
        summary = ppo_trainer.generate(prompt_tensor, **generation_kwargs)
        summary_tensors.append(summary.squeeze()[-max_new_tokens:])

    batch["response"] = [policy_tokenizer.decode(r.squeeze()) for r in summary_tensors]

    chosen_summaries = batch["response"]
    rejected_summaries = [DEFAULT_REJECTED_SUMMARY_TEXT] * len(batch["response"])

    reward_tensors = []

    for chosen_summary, rejected_summary in zip(chosen_summaries, rejected_summaries):
        chosen_score, _, _, _ = score_summaries(rm_model, rm_tokenizer, chosen_summary, rejected_summary)
        reward_tensors.append(torch.tensor(chosen_score))

    stats = ppo_trainer.step(prompt_tensors, summary_tensors, reward_tensors)
    ppo_trainer.log_stats(stats, batch, reward_tensors)

    print(f'objective/kl: {stats["objective/kl"]}') # Measures how different the policy's action distribution after the update is from the action distribution before the update. PPO tries to make these changes very small to avoid sudden changes.
    print(f'ppo/returns/mean: {stats["ppo/returns/mean"]}') # This is the average return achieved by the agent. Higher is better.
    print(f'ppo/policy/advantages_mean: {stats["ppo/policy/advantages_mean"]}') # Measures how much better an action is than the average action at a given state.
    print('-'.join('' for x in range(100)))

  prompt_tensor = torch.tensor(prompt_tensor).to(device)
Input ids are automatically padded from 493 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 431 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 195 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 221 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 274 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 450 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 504 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 293 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 289 to 1024 to be a multiple of `config.attention_window`: 1024
Input i

objective/kl: -9.676851732365321e-06
ppo/returns/mean: 0.5275627374649048
ppo/policy/advantages_mean: -0.002801092341542244
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 491 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 339 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 294 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 403 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 227 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 351 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 456 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 490 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 301 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 207 to 1024 to be a mult

objective/kl: -6.1851701736450195
ppo/returns/mean: 0.8105399012565613
ppo/policy/advantages_mean: 0.00224125268869102
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 342 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 355 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 429 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 284 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 236 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 332 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 381 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 283 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 336 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 179 to 1024 to be a mult

objective/kl: -27.92009162902832
ppo/returns/mean: 1.771461844444275
ppo/policy/advantages_mean: -0.004473959561437368
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 391 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 404 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 437 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 482 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 367 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 258 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 253 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 472 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 297 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 417 to 1024 to be a mult

objective/kl: 18.988086700439453
ppo/returns/mean: 0.05924704670906067
ppo/policy/advantages_mean: 0.017941389232873917
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 333 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 252 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 463 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 299 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 188 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 459 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 292 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 263 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 226 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 500 to 1024 to be a mult

objective/kl: 19.295307159423828
ppo/returns/mean: -0.39417099952697754
ppo/policy/advantages_mean: 0.007847161963582039
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 462 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 337 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 214 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 286 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 345 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 393 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 379 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 470 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 233 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 498 to 1024 to be a mult

objective/kl: 19.55228614807129
ppo/returns/mean: -0.5506460666656494
ppo/policy/advantages_mean: 0.026833228766918182
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 414 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 384 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 335 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 245 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 269 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 205 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 230 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 446 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 267 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 356 to 1024 to be a mult

objective/kl: 10.767576217651367
ppo/returns/mean: -0.5003163814544678
ppo/policy/advantages_mean: 0.013013686053454876
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 317 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 376 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 277 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 509 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 402 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 412 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 358 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 397 to 1024 to be a multiple of `config.attention_window`: 1024
8it [04:14, 31.57s/it]Input ids are automatically padded from 400 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 11.166448593139648
ppo/returns/mean: -0.6423916220664978
ppo/policy/advantages_mean: 0.016633620485663414
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 360 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 183 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 295 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 371 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 375 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 385 to 1024 to be a multiple of `config.attention_window`: 1024
9it [04:46, 31.96s/it]Input ids are automatically padded from 453 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 16.087535858154297
ppo/returns/mean: -1.0688539743423462
ppo/policy/advantages_mean: 0.012802040204405785
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 160 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 231 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 405 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 304 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 232 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 352 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 264 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 423 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 62 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 168 to 1024 to be a multi

objective/kl: 15.673182487487793
ppo/returns/mean: -1.175992727279663
ppo/policy/advantages_mean: 0.013759779743850231
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 422 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 407 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 409 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 280 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 305 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 234 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 445 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 235 to 1024 to be a multiple of `config.attention_window`: 1024
11it [05:55, 33.04s/it]Input ids are automatically padded from 408 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 20.181324005126953
ppo/returns/mean: -1.4192225933074951
ppo/policy/advantages_mean: 0.006254791282117367
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 191 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 420 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 350 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 467 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 486 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 460 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 479 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 193 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 421 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 331 to 1024 to be a mult

objective/kl: 9.493758201599121
ppo/returns/mean: -1.210973858833313
ppo/policy/advantages_mean: 0.015035843476653099
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 387 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 480 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 228 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 158 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 291 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 278 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 512 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 465 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 349 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 307 to 1024 to be a mult

objective/kl: 13.579954147338867
ppo/returns/mean: -1.3494033813476562
ppo/policy/advantages_mean: 0.009626551531255245
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 239 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 265 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 242 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 353 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 343 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 288 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 260 to 1024 to be a multiple of `config.attention_window`: 1024
14it [07:33, 33.08s/it]

objective/kl: 16.250015258789062
ppo/returns/mean: -1.399639368057251
ppo/policy/advantages_mean: 0.007492770440876484
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 508 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 354 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 475 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 357 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 377 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 298 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 348 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 196 to 1024 to be a multiple of `config.attention_window`: 1024
15it [08:05, 32.51s/it]

objective/kl: 11.326959609985352
ppo/returns/mean: -1.4371449947357178
ppo/policy/advantages_mean: 0.024734018370509148
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 255 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 273 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 487 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 425 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 406 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 505 to 1024 to be a multiple of `config.attention_window`: 1024
16it [08:35, 31.99s/it]

objective/kl: 11.720958709716797
ppo/returns/mean: -1.2684223651885986
ppo/policy/advantages_mean: 0.022238291800022125
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 312 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 378 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 394 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 416 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 271 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 366 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 419 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 275 to 1024 to be a multiple of `config.attention_window`: 1024
17it [09:07, 31.79s/it]

objective/kl: 13.602121353149414
ppo/returns/mean: -1.449101448059082
ppo/policy/advantages_mean: 0.006214344408363104
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 173 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 362 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 259 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 185 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 327 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 203 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 489 to 1024 to be a multiple of `config.attention_window`: 1024
18it [09:36, 31.04s/it]

objective/kl: 14.448728561401367
ppo/returns/mean: -1.398282766342163
ppo/policy/advantages_mean: 0.020223623141646385
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 318 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 497 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 184 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 469 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 340 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 270 to 1024 to be a multiple of `config.attention_window`: 1024
19it [10:04, 30.17s/it]

objective/kl: 15.194849014282227
ppo/returns/mean: -1.447683334350586
ppo/policy/advantages_mean: 0.003445131704211235
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 197 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 261 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 194 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 390 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 241 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 306 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 432 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 346 to 1024 to be a multiple of `config.attention_window`: 1024
20it [10:36, 30.64s/it]

objective/kl: 16.154577255249023
ppo/returns/mean: -1.2357490062713623
ppo/policy/advantages_mean: -0.03138148784637451
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 171 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 458 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 257 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 426 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 319 to 1024 to be a multiple of `config.attention_window`: 1024
21it [11:03, 29.56s/it]

objective/kl: 15.178276062011719
ppo/returns/mean: -1.3342101573944092
ppo/policy/advantages_mean: 0.0051803006790578365
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 325 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 296 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 310 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 169 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 249 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 359 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 247 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 330 to 1024 to be a multiple of `config.attention_window`: 1024
22it [11:29, 28.60s/it]

objective/kl: 19.931812286376953
ppo/returns/mean: -1.4793275594711304
ppo/policy/advantages_mean: -0.03212476521730423
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 251 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 373 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 215 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 186 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 444 to 1024 to be a multiple of `config.attention_window`: 1024
23it [11:55, 27.77s/it]Input ids are automatically padded from 189 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 18.41216278076172
ppo/returns/mean: -1.4293303489685059
ppo/policy/advantages_mean: -0.02179853245615959
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 302 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 395 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 389 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 503 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 386 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 300 to 1024 to be a multiple of `config.attention_window`: 1024
24it [12:21, 27.09s/it]

objective/kl: 18.313270568847656
ppo/returns/mean: -1.3590707778930664
ppo/policy/advantages_mean: 0.04553950950503349
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 181 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 501 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 452 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 334 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 428 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 311 to 1024 to be a multiple of `config.attention_window`: 1024
25it [12:45, 26.40s/it]

objective/kl: 19.455368041992188
ppo/returns/mean: -1.4842860698699951
ppo/policy/advantages_mean: 0.015880003571510315
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 398 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 502 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 372 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 507 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 192 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 321 to 1024 to be a multiple of `config.attention_window`: 1024
26it [13:11, 26.00s/it]

objective/kl: 17.37262725830078
ppo/returns/mean: -1.3334407806396484
ppo/policy/advantages_mean: 0.03635761886835098
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 476 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 209 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 279 to 1024 to be a multiple of `config.attention_window`: 1024
27it [13:34, 25.20s/it]

objective/kl: 16.29381561279297
ppo/returns/mean: -1.3798103332519531
ppo/policy/advantages_mean: 0.01042242906987667
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 370 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 329 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 237 to 1024 to be a multiple of `config.attention_window`: 1024
28it [13:57, 24.46s/it]Input ids are automatically padded from 172 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 16.86618423461914
ppo/returns/mean: -1.6342324018478394
ppo/policy/advantages_mean: 0.01997721567749977
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 481 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 488 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 238 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 223 to 1024 to be a multiple of `config.attention_window`: 1024
29it [14:20, 24.12s/it]

objective/kl: 21.899490356445312
ppo/returns/mean: -2.017723798751831
ppo/policy/advantages_mean: 0.030627649277448654
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 382 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 418 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 485 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 468 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 434 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 415 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 303 to 1024 to be a multiple of `config.attention_window`: 1024
30it [14:43, 23.79s/it]

objective/kl: 21.28044891357422
ppo/returns/mean: -1.9173072576522827
ppo/policy/advantages_mean: 0.003466537222266197
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 506 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 441 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 383 to 1024 to be a multiple of `config.attention_window`: 1024
31it [15:06, 23.52s/it]Input ids are automatically padded from 248 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 22.62061309814453
ppo/returns/mean: -1.8724308013916016
ppo/policy/advantages_mean: -0.016265610232949257
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 243 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 43 to 1024 to be a multiple of `config.attention_window`: 1024
32it [15:29, 23.48s/it]

objective/kl: 20.71583366394043
ppo/returns/mean: -1.9044256210327148
ppo/policy/advantages_mean: 0.042788416147232056
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 287 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 477 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 206 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 443 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 167 to 1024 to be a multiple of `config.attention_window`: 1024
33it [15:52, 23.21s/it]

objective/kl: 20.647018432617188
ppo/returns/mean: -2.075765371322632
ppo/policy/advantages_mean: 0.0029725029598921537
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 316 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 182 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 200 to 1024 to be a multiple of `config.attention_window`: 1024
34it [16:14, 22.91s/it]Input ids are automatically padded from 430 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 17.32471466064453
ppo/returns/mean: -1.8260258436203003
ppo/policy/advantages_mean: -0.007469045929610729
---------------------------------------------------------------------------------------------------


35it [16:36, 22.70s/it]Input ids are automatically padded from 202 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 20.99820327758789
ppo/returns/mean: -1.9017455577850342
ppo/policy/advantages_mean: 0.018114252015948296
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 457 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 187 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 217 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 510 to 1024 to be a multiple of `config.attention_window`: 1024
36it [16:59, 22.73s/it]

objective/kl: 24.015426635742188
ppo/returns/mean: -2.2660670280456543
ppo/policy/advantages_mean: 0.03504979610443115
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 471 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 222 to 1024 to be a multiple of `config.attention_window`: 1024
37it [17:21, 22.62s/it]

objective/kl: 20.361228942871094
ppo/returns/mean: -1.6828737258911133
ppo/policy/advantages_mean: 0.027435656636953354
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 157 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 413 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 313 to 1024 to be a multiple of `config.attention_window`: 1024
38it [17:43, 22.43s/it]Input ids are automatically padded from 473 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 20.096630096435547
ppo/returns/mean: -1.6212701797485352
ppo/policy/advantages_mean: 0.0009362204000353813
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 262 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 178 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 91 to 1024 to be a multiple of `config.attention_window`: 1024
39it [18:05, 22.17s/it]Input ids are automatically padded from 174 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 494 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 17.9088134765625
ppo/returns/mean: -1.7448195219039917
ppo/policy/advantages_mean: -0.0009860291611403227
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 410 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 320 to 1024 to be a multiple of `config.attention_window`: 1024
40it [18:27, 22.11s/it]Input ids are automatically padded from 216 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 19.28685188293457
ppo/returns/mean: -1.5581252574920654
ppo/policy/advantages_mean: 0.010299380868673325
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 220 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 439 to 1024 to be a multiple of `config.attention_window`: 1024
41it [18:49, 22.14s/it]

objective/kl: 17.75981903076172
ppo/returns/mean: -1.445212721824646
ppo/policy/advantages_mean: 0.06496807187795639
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 290 to 1024 to be a multiple of `config.attention_window`: 1024
42it [19:11, 22.01s/it]Input ids are automatically padded from 368 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 21.607620239257812
ppo/returns/mean: -2.1378445625305176
ppo/policy/advantages_mean: 0.008488010615110397
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 440 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 190 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 256 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 347 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 213 to 1024 to be a multiple of `config.attention_window`: 1024
43it [19:32, 21.87s/it]

objective/kl: 20.683544158935547
ppo/returns/mean: -1.8872864246368408
ppo/policy/advantages_mean: -0.021215610206127167
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 276 to 1024 to be a multiple of `config.attention_window`: 1024
44it [19:54, 21.74s/it]

objective/kl: 20.072036743164062
ppo/returns/mean: -1.848901629447937
ppo/policy/advantages_mean: -0.012770958244800568
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 436 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 254 to 1024 to be a multiple of `config.attention_window`: 1024
45it [20:17, 22.04s/it]

objective/kl: 21.51968002319336
ppo/returns/mean: -1.6432360410690308
ppo/policy/advantages_mean: -0.0025103092193603516
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 341 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 199 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 210 to 1024 to be a multiple of `config.attention_window`: 1024
46it [20:39, 22.02s/it]

objective/kl: 17.947484970092773
ppo/returns/mean: -1.4024991989135742
ppo/policy/advantages_mean: -0.011554657481610775
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 323 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 246 to 1024 to be a multiple of `config.attention_window`: 1024
47it [21:00, 21.97s/it]Input ids are automatically padded from 495 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 20.448421478271484
ppo/returns/mean: -1.4475255012512207
ppo/policy/advantages_mean: -0.07981672883033752
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 338 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 364 to 1024 to be a multiple of `config.attention_window`: 1024
Input ids are automatically padded from 328 to 1024 to be a multiple of `config.attention_window`: 1024
48it [21:22, 21.79s/it]Input ids are automatically padded from 204 to 1024 to be a multiple of `config.attention_window`: 1024


objective/kl: 20.476886749267578
ppo/returns/mean: -1.948337435722351
ppo/policy/advantages_mean: -0.007052720990031958
---------------------------------------------------------------------------------------------------


49it [21:43, 21.69s/it]

objective/kl: 19.28260612487793
ppo/returns/mean: -1.769110083580017
ppo/policy/advantages_mean: 0.024180859327316284
---------------------------------------------------------------------------------------------------


Input ids are automatically padded from 388 to 1024 to be a multiple of `config.attention_window`: 1024
50it [22:05, 26.51s/it]

objective/kl: 18.823970794677734
ppo/returns/mean: -1.6166677474975586
ppo/policy/advantages_mean: 0.027625955641269684
---------------------------------------------------------------------------------------------------





In [28]:
stats

{'objective/kl': 18.823970794677734,
 'objective/kl_dist': array([24.216343, 19.108904, 21.07835 , 16.415909, 17.75911 , 14.864124,
        26.851608, 22.932692, 16.316277, 13.655459, 23.952072, 15.781115,
        18.548454, 11.432112, 10.116042, 28.154984], dtype=float32),
 'objective/logprobs': array([[-1.25168972e-05, -1.44583166e+00, -8.41871321e-01,
         -8.78231823e-01, -2.08248310e-02, -8.10619895e-06,
         -2.79246407e+01, -1.77062149e+01, -1.73297615e+01,
         -1.81244164e+01, -1.67996159e+01, -1.94089375e+01,
         -1.90715580e+01, -1.70691032e+01, -1.64233971e+01,
         -1.63964252e+01, -1.93038960e+01, -1.86627960e+01,
         -1.83176861e+01, -1.76473846e+01, -1.81938953e+01,
         -1.69870052e+01, -1.80033016e+01, -1.83171062e+01,
         -1.73901234e+01, -1.84903088e+01],
        [-1.33513513e-05, -1.82678267e-01, -1.44932792e-01,
         -3.91438752e-02, -3.55180586e-04, -2.02771928e-03,
         -2.76918132e-02, -2.46851967e-04, -2.89469872e+01,

In [29]:
ppo_trainer.model

AutoModelForSeq2SeqLMWithValueHead(
  (pretrained_model): LEDForConditionalGeneration(
    (led): LEDModel(
      (shared): Embedding(50265, 768, padding_idx=1)
      (encoder): LEDEncoder(
        (embed_tokens): Embedding(50265, 768, padding_idx=1)
        (embed_positions): LEDLearnedPositionalEmbedding(16384, 768)
        (layers): ModuleList(
          (0-5): 6 x LEDEncoderLayer(
            (self_attn): LEDEncoderAttention(
              (longformer_self_attn): LEDEncoderSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (query_global): Linear(in_features=768, out_features=768, bias=True)
                (key_global): Linear(in_features=768, out_features=768, bias=True)
                (value_global): Linear(in_features=768, out_features=768, bias=True)
              )
  

In [40]:
ppo_model_path = "ppo_model_final_new"
ppo_model_tokenizer = "ppo_model_tokenizer_final_new"

# ppo_trainer.save_model("ppo_model_path")
# Save the model
ppo_model.save_pretrained(ppo_model_path)

# Save the tokenizer
policy_tokenizer.save_pretrained(ppo_model_tokenizer)

Non-default generation parameters: {'max_length': 1024, 'min_length': 8, 'early_stopping': True, 'num_beams': 4, 'repetition_penalty': 3.5, 'length_penalty': 0.8, 'no_repeat_ngram_size': 3}


('ppo_model_tokenizer_final_new/tokenizer_config.json',
 'ppo_model_tokenizer_final_new/special_tokens_map.json',
 'ppo_model_tokenizer_final_new/vocab.json',
 'ppo_model_tokenizer_final_new/merges.txt',
 'ppo_model_tokenizer_final_new/added_tokens.json',
 'ppo_model_tokenizer_final_new/tokenizer.json')

In [41]:
# ppo_saved_model_path = "ppo_model_20Epochs_new"
# tokenizer_path = "ppo_model_tokenizer_new"
ppo_model = AutoModelForSeq2SeqLM.from_pretrained(ppo_model_path)
policy_tokenizer = AutoTokenizer.from_pretrained(ppo_model_tokenizer)

Some weights of the model checkpoint at ppo_model_final_new were not used when initializing LEDForConditionalGeneration: ['v_head.summary.bias', 'v_head.summary.weight']
- This IS expected if you are initializing LEDForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LEDForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [75]:
# def generate_summary(prompt: str, model, tokenizer, generation_kwargs, output_length_sampler) -> str:
#     """
#     Generate a summary for a given prompt using a trained policy model.

#     Args:
#     - prompt (str): The input text for which a summary needs to be generated.
#     - model: The trained policy model.
#     - tokenizer: The tokenizer used for the policy model.
#     - generation_kwargs (dict): Arguments used for response generation.
#     - output_length_sampler (func): Function to sample the length of the output.

#     Returns:
#     - str: Generated summary.
#     """

#     # Tokenize the prompt
#     prompt_tensor = tokenizer.encode(prompt, return_tensors='pt').to(device)

#     # Ensure it's only one tensor and check its shape
#     assert prompt_tensor.dim() == 2, f"Unexpected tensor shape: {prompt_tensor.shape}"

#     # Set the generation arguments
#     max_new_tokens = output_length_sampler()
#     generation_kwargs["max_new_tokens"] = max_new_tokens

#     # Generate a summary
#     summary_tensor = model.generate(input_ids=prompt_tensor, **generation_kwargs)

#     # Decode and return the summary
#     summary = tokenizer.decode(summary_tensor[0], skip_special_tokens=True)
#     return summary

In [31]:
testing_sample = pd.read_csv('testing_samples.csv')
testing_sample

Unnamed: 0,prompt,label
0,SUBREDDIT: r/AskReddit\nTITLE: I need your hel...,American Family Insurance is screwing me with ...
1,SUBREDDIT: r/relationships\nTITLE: My boyfrien...,Boyfriend of 3 years started a business withou...
2,SUBREDDIT: r/AskReddit\nTITLE: Can someone hel...,Grandpa had a light bulb he could light up by ...
3,SUBREDDIT: r/travel\nTITLE: If I don't do this...,"I'm an American, bored with my career, wanting..."
4,SUBREDDIT: r/tifu\nTITLE: TIFU By Showing My H...,Made a bet with teacher to watch Vader vs Hitl...
5,SUBREDDIT: r/dating_advice\nTITLE: Should I [1...,Have a bit of a crush on a guy who I see every...
6,SUBREDDIT: r/relationship_advice\nTITLE: When ...,"If we both know we like each other, and have r..."
7,SUBREDDIT: r/relationships\nTITLE: I [18 M] ha...,Interested in a girl i sit with next to in cla...
8,SUBREDDIT: r/Advice\nTITLE: Freaking out about...,Freaking out about college being too much and ...
9,SUBREDDIT: r/personalfinance\nTITLE: 25 y/o lo...,forces out of home. I have $400 and $6000 debt...


In [34]:
pipe = pipeline("summarization", model=policy_model, tokenizer=policy_tokenizer) #max_length=350, num_return_sequences=1
Base_model_summary = []
for i in testing_sample['prompt']:
    output = pipe(i,temperature =  1.0, min_length = 5, top_k = 0.0, top_p = 1.0, do_sample =  True, max_length=150)
    Base_model_summary.append(output)

In [35]:
Base_model_summary 

[[{'summary_text': 'A Redditch is asking Reddit for help finding his insurance company, American Family Insurance. He explains that he was involved in a car accident where the other driver was at fault. The insurance company arranged for him to get a rental car and pay for auto body repair at the shop my dealer recommended. However, when he picked up the rental car, the insurance company paid for it, he had to sign for the coverage that the rental company (Enterprise) offers, which is $13 a day. This means that he will have to buy a new car every 2 weeks. Plus, there\'s no way this insurance company can "pay for rental insurance when their client is at fault." Translation: I don\'t have full liability insurance, so'}],
 [{'summary_text': 'A Reddit user has been asking a question about relationships: Why is it that his "part-time grad student" boyfriend of 3 years keeps making huge decisions without communicating with me [23 F] at all, is this normal? The redditor explains that he\'s ju

In [36]:
pipe = pipeline("summarization", model=ppo_model, tokenizer=policy_tokenizer) #max_length=350, num_return_sequences=1
PPO_model_summary = []
for i in testing_sample['prompt']:
    output = pipe(i,temperature =  1.0, min_length = 5, top_k = 0.0, top_p = 1.0, do_sample = True, max_length=150)
    PPO_model_summary.append(output)
PPO_model_summary

The model 'AutoModelForSeq2SeqLMWithValueHead' is not supported for summarization. Supported models are ['BartForConditionalGeneration', 'BigBirdPegasusForConditionalGeneration', 'BlenderbotForConditionalGeneration', 'BlenderbotSmallForConditionalGeneration', 'EncoderDecoderModel', 'FSMTForConditionalGeneration', 'GPTSanJapaneseForConditionalGeneration', 'LEDForConditionalGeneration', 'LongT5ForConditionalGeneration', 'M2M100ForConditionalGeneration', 'MarianMTModel', 'MBartForConditionalGeneration', 'MT5ForConditionalGeneration', 'MvpForConditionalGeneration', 'NllbMoeForConditionalGeneration', 'PegasusForConditionalGeneration', 'PegasusXForConditionalGeneration', 'PLBartForConditionalGeneration', 'ProphetNetForConditionalGeneration', 'SeamlessM4TForTextToText', 'SeamlessM4Tv2ForTextToText', 'SwitchTransformersForConditionalGeneration', 'T5ForConditionalGeneration', 'UMT5ForConditionalGeneration', 'XLMProphetNetForConditionalGeneration'].


[[{'summary_text': 'Does my car need repair?'}],
 [{'summary_text': 'Does this normal?'}],
 [{'summary_text': 'Can someone please help my dad?'}],
 [{'summary_text': "I'm looking for advice, help, reassurance, dissuasion or a little bit of each on a major life decision."}],
 [{'summary_text': 'Why would you need those movies?'}],
 [{'summary_text': "Should I, who went on a couple dead end dates with my friend's sister?"}],
 [{'summary_text': 'Does it alright to just ask here to be "more than friends"?'}],
 [{'summary_text': 'I am talking to a student in my class.'}],
 [{'summary_text': 'What do I do?'}],
 [{'summary_text': 'Does a career change make a smart idea?'}]]

In [38]:
testing_sample['base_model'] = Base_model_summary
testing_sample['PPO'] = PPO_model_summary

rlhf_result = testing_sample

rlhf_result.to_csv('rlhf_result_new.csv')

In [74]:
# Base_model_summary = []
# for i in testing_sample['prompt']:
#     prompt = "summarize: " + i
#     generated_summary = generate_summary(prompt, policy_model, policy_tokenizer, generation_kwargs, output_length_sampler)
#     Base_model_summary.append(generated_summary)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [76]:
# ppo_model_summary = []
# for i in testing_sample['prompt']:
#     prompt = "summarize: " + i
#     generated_summary = generate_summary(prompt, ppo_model, policy_tokenizer, generation_kwargs, output_length_sampler)
#     ppo_model_summary.append(generated_summary)

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [127]:
# pipe = pipeline("summarization", model=ppo_model, tokenizer=policy_tokenizer) #max_length=350, num_return_sequences=1
# PPO_model_summary = []
# for i in testing_sample['prompt']:
#     output = pipe(i,temperature =  1.0, min_length = 5, top_k = 0.0, top_p = 1.0, do_sample =  True, max_length=150)
#     PPO_model_summary.append(output)

In [128]:
# PPO_model_summary

[[{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh yeah.'}],
 [{'summary_text': 'Oh, yeah'}],
 [{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh yeah.'}],
 [{'summary_text': 'Oh, yeah.'}],
 [{'summary_text': 'Oh, yeah.'}]]