In [1]:
import torch

def query_model(input_ids, model, batch_size, response_length = 32, attention_mask=None):    
    #query = query.to(model.dtype)
    response_tensors = []
    tensor_shape = input_ids.shape[0]
    for i in range(int(tensor_shape/batch_size)):
        with torch.no_grad():
            
            if attention_mask is not None:
                mask = attention_mask[i*batch_size:(i+1)*batch_size].to(model.device)
            else:
                mask = None
                
            ids = input_ids[i*batch_size:(i+1)*batch_size]
            
            generation_output = model.generate(input_ids=ids,
                                               attention_mask=mask,
                                               max_length=input_ids.shape[1]+response_length, 
                                               do_sample=True)
            
        for tensor in generation_output:
            response_tensors.append(tensor)
            
    output_ids = torch.stack(response_tensors)[:, tensor_shape:].to('cpu')
    
    if attention_mask is not None:
        output_mask = torch.ones_like(output_ids)
        output_mask[response_tensors == model.config.pad_token_id] = 0
        return output_ids, output_mask
    return output_ids
    
            
#response, mask = query_model(input_ids=query_tensor['input_ids'],
#                       #attention_mask=query_tensor['attention_mask'],
#                       model=gpt2_model, 
#                       batch_size=1, 
#                       response_length = 32)

In [2]:
import os
import torch
from tqdm import tqdm
import math
#from ppo import PPO
from datasets import load_dataset
from utils import logprobs_from_logits
from torch.utils.data import DataLoader
from transformers import AutoModelForSequenceClassification, GPT2LMHeadModel, GPT2Tokenizer
from gpt2withvaluehead import GPT2HeadWithValueModel, respond_to_batch
from ppo2 import PPO


tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

sentiment_model = AutoModelForSequenceClassification.from_pretrained(
    "../models/distilgpt2_for_generation_for_scoring",
    pad_token_id = tokenizer.eos_token_id
).to('cuda')

gpt2_model_ref = GPT2HeadWithValueModel.from_pretrained(
    "distilgpt2",
    pad_token_id = tokenizer.eos_token_id
).to('cuda')

gpt2_model = GPT2HeadWithValueModel.from_pretrained(
   "distilgpt2",
    pad_token_id = tokenizer.eos_token_id,
    use_cache=True,
    ).to('cuda')

#value_model = GPT2HeadWithValueModel.from_pretrained(
#    "distilgpt2"
#).to('cuda')

Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['v_head.summary.weight', 'transformer.h.0.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'v_head.summary.bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.1.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of GPT2HeadWithValueModel were not initialized from the model checkpoint at distilgpt2 and are newly initialized: ['v_head.summary.weight', 'transformer.h.0.attn.masked_bias', 'transformer.h.2.attn.masked_bias', 'v_head.summary.bias', 'transformer.h.5.attn.masked_bias', 'transformer.h.3.attn.masked_bias', 'transformer.h.4.attn.masked_bias', 'transformer.h.1.attn.masked_bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and infer

In [3]:
#value_model = torch.nn.DataParallel(value_model, device_ids=[5,6,7])
#value_model.to(f'cuda:{value_model.device_ids[0]}')

datasets = load_dataset("json", field='data', data_files={
    "train": "../data/tldr-filtered-test.json",
})

# prep dataset
def tokenize_function(examples):
    text = [f'SUBREDDIT: r/{subreddit}\nTITLE: {title}\nPOST: {post}\nTL;DR:' for subreddit, title, post in zip(
        examples['subreddit'], 
        examples['title'], 
        examples['content'],)]
    output = tokenizer(text, max_length=32, truncation=True, padding=True)
    #output["total_length"] = output.pop("length")
    #output["summary_length"] = tokenizer(examples['summary'], return_length = True)['length']
    return output

tokenized_datasets = datasets.map(
    tokenize_function,
    batched=True,
    num_proc=8,
    remove_columns = datasets["train"].column_names
)

ppo_steps = 1000000
batch_size=64 # Should be 64
per_device_batch_size=8
response_len = 32

def collate_wrapper(batch):
    return tokenizer.pad(batch, return_tensors='pt')

loader = DataLoader(tokenized_datasets['train'], batch_size=batch_size, pin_memory=False, collate_fn=collate_wrapper, shuffle=True)

Using custom data configuration default
Reusing dataset json (/home/kip/.cache/huggingface/datasets/json/default-913b4a67787bf3b8/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514)
Loading cached processed dataset at /home/kip/.cache/huggingface/datasets/json/default-913b4a67787bf3b8/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-aed1fca3ca3985d6.arrow
Loading cached processed dataset at /home/kip/.cache/huggingface/datasets/json/default-913b4a67787bf3b8/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-75636440fa93c637.arrow
Loading cached processed dataset at /home/kip/.cache/huggingface/datasets/json/default-913b4a67787bf3b8/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-2ef78ea2351df0dc.arrow
Loading cached processed dataset at /home/kip/.cache/huggingface/datasets/json/default-913b4a67787bf3b8/0.0.0/70d89ed4db1394f028c651589fcab6d6b28dddcabbe39d3b21b4d41f9a708514/cache-141a3

In [4]:
config = {
    "lm_name": "lvwerra/gpt2-imdb",
    "ref_lm_name": "lvwerra/gpt2-imdb",
    "cls_model_name": "lvwerra/bert-imdb",
    "tk_name": "gpt2",
    "steps": 51200,
    "batch_size": 64,
    "forward_batch_size": 1,
    "ppo_epochs": 4,   
    "txt_in_len": 5,
    "txt_out_len": 20,
    "lr": 1.41e-5,
    "init_kl_coef":0.2,
    "target": 6,
    "horizon":10000,
    "gamma":1,
    "lam":0.95,
    "cliprange": .2,
    "cliprange_value":.2,
    "vf_coef":.1, 
    "seed": 1,
}
    
#ppo_trainer = PPOTrainer(gpt2_model, gpt2_model_ref, **config)
import wandb
wandb.init(project='transformer_ppo')

ppo_trainer = PPO(model=gpt2_model, ref_model=gpt2_model_ref, batch_size=2, wandb = wandb)

[34m[1mwandb[0m: Currently logged in as: [33mkdog[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.33 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [5]:
fbs = 32
#r = []

for idx, l in enumerate(tqdm(loader)):
    query_tensors = l['input_ids'].to('cuda')

    response_tensors = []
    for i in range(int(query_tensors.shape[0]/fbs)):
        response  = respond_to_batch(gpt2_model, query_tensors[i*fbs:(i+1)*fbs],
                                     txt_len=32)
        response_tensors.append(response)
    response_tensors = torch.cat(response_tensors)
    torch.cuda.empty_cache()

    scores_tensors = []
    inputs = torch.cat((query_tensors, response_tensors), axis=1)
    for i in range(int(query_tensors.shape[0]/fbs)):
        response = sentiment_model(inputs[i*fbs:(i+1)*fbs])['logits'].detach()
        scores_tensors.append(response)
    scores_tensors = torch.squeeze(torch.cat(scores_tensors))
    
    stats = ppo_trainer.step(query_tensors, response_tensors, scores_tensors)
    

    data = [[tokenizer.decode(query), 
             tokenizer.decode(response)] for query, response in zip(query_tensors[:8], response_tensors[:8])]
    wandb.log({
        #"model_scores": wandb.histogram(scores_tensors.cpu()),
        "text_examples": wandb.Table(data=data, columns=["Prompt", "Summary"])
    })
    
    if idx == 80:
        break

 13%|█▎        | 80/632 [12:39<1:27:23,  9.50s/it]
