In [1]:
import collections
import os
import itertools
import time

import datasets
import rich
import torch
from tqdm import tqdm
import transformers
import trlx
from trlx.data.configs import TRLConfig



In [2]:
ds_train = datasets.load_dataset("gigaword", split="train")
# ds_eval  = datasets.load_dataset("gigaword", split="validation")

Found cached dataset gigaword (/home/mila/g/gagnonju/.cache/huggingface/datasets/gigaword/default/1.2.0/ea83a8b819190acac5f2dae011fad51dccf269a0604ec5dd24795b64efb424b6)


In [3]:
class LMDataset(torch.utils.data.Dataset):
    def __init__(self, ds):
        self.ds = ds
    
    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        return self.ds[idx]["document"]

ds_train_obj = torch.utils.data.Subset(LMDataset(ds_train), range(10000))
# ds_eval_obj  = torch.utils.data.Subset(LMDataset(ds_eval),  range(1000))
reward_tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")

stuff = collections.Counter()
for entry in ds_train_obj:
    input_ids = reward_tokenizer(entry)["input_ids"]
    stuff.update([len(input_ids)])
    
rich.print(max(stuff.keys()))

In [4]:
reward_model_model_name = "gpt2"
reward_tokenizer = transformers.AutoTokenizer       .from_pretrained(reward_model_model_name)
reward_tokenizer.pad_token = reward_tokenizer.eos_token
reward_model     = transformers.AutoModelForCausalLM.from_pretrained(reward_model_model_name).cuda()

for param in reward_model.parameters():
    param.requires_grad = False

In [5]:
def rl_distillation(samples):
    with torch.no_grad():
        tokenized = reward_tokenizer(
            samples, 
            padding        = True, 
            truncation     = True,
            return_tensors = "pt", 
        )
        tokenized = {k: v.cuda() for k, v in tokenized.items()}
        
        # Good
        logp = reward_model(**tokenized).logits.log_softmax(-1)
        
        # Likely good
        logp = logp.gather(
            dim=-1, 
            index=tokenized["input_ids"].unsqueeze(-1)
        ).squeeze(-1)
        
        assert logp.shape == (len(samples), logp.shape[1]), logp.shape
        per_batch_logp = logp.sum(-1)
        
        assert per_batch_logp.shape == (len(samples),), per_batch_logp.shape

        return per_batch_logp
    
CONFIG_PATH = "/home/mila/g/gagnonju/Marg-Li-CoT/our_scratchpad/configs/ppo_config.yml"
model = trlx.train(
    "distilgpt2", 
    config=TRLConfig.load_yaml(CONFIG_PATH),
    prompts      = ds_train_obj,
    eval_prompts = ds_train_obj,
    reward_fn    = rl_distillation,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: Currently logged in as: [33mjulesgm[0m. Use [1m`wandb login --relogin`[0m to force relogin


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Output()

  tbar = tqdm(


In [6]:
# model_name = "distilgpt2"
# model     = transformers.AutoModelWithLMHead.from_pretrained(model_name)
# tokenizer = transformers.AutoTokenizer      .from_pretrained(model_name)

In [None]:
# transformers.generation_logits_process.NoRepeatNGramLogitsProcessor
# transformers.generation_logits_process.RepetitionPenaltyLogitsProcessor

In [None]:
prompts = {
    "cats are ": "and very pretty",
    "dogs are ": "and very cute",
}

prompt_end     = "<|cls|>"
scratchpad_end = "<|cls|>"

for prompt in prompts:
    assert prompt_end not in prompt, prompt

reward_model_model_name = "gpt2"
reward_tokenizer        = transformers.AutoTokenizer           .from_pretrained(reward_model_model_name)
reward_model            = transformers.AutoModelModelWithLMHead.from_pretrained(reward_model_model_name).cuda()



def scratchpad_reward_fn(samples):
    # The idea is to:
    # 1. Extract the associated answers & tokenize the answers
    # 2. Create a mask for the answers
    # 3. Tokenize the samples
    # 4. Concate the samples & answers
    # 5. Run the reward model on the concatenated samples & answers
    # 6. Extract the logp for the answers
    # 7. Return the logp for the answers

    reward_model_inputs  = []
    reward_model_answers = []

    for sample in samples:
        # 1.a Extract the associated answers 
        splitted = sample.split(prompt_end, 1)
        assert len(splitted) == 2, len(splitted)
        question, scratchpad = splitted
        answer = prompts[question]
        reward_model_inputs.append(sample + scratchpad_end)
    
    # 1.b Tokenize the answers
    tokenized_answers = reward_tokenizer(reward_model_answers)
    tokenized_answers_masks = [[1] * len(x) for x in tokenized_answers["input_ids"]]
    tokenizer_samples = reward_tokenizer(reward_model_inputs)
    full_seq = [sample + answer for sample, answer in zip(tokenizer_samples["input_ids"], tokenized_answers["input_ids"])]
    full_seq = reward_tokenizer.pad(full_seq, return_tensors="pt", padding=True)
    full_seq = {k: v.cuda() for k, v in full_seq.items()}

    # MASKS VALUE INPUTS
    tokenized_answers_masks = [[1] * len(x) for x in tokenizer_samples["input_ids"]]
    full_seq_input_masks = [sample_mask + answer_mask for sample_mask, answer_mask in zip(tokenizer_samples["attention_mask"], tokenized_answers_masks)]
    full_seq_input_masks = torch.nn.utils.rnn.pad_sequence(full_seq_input_masks, batch_first=True, padding_value=0)

    reward_model_outputs = reward_model(full_seq["input_ids"], attention_mask=full_seq_input_masks).logsoftmax()

    # MASKS VALUE LOGITS
    tokenized_answers_masks = [[0] * len(x) for x in tokenizer_samples["input_ids"]]
    full_seq_output_masks = [sample_mask + answer_mask for sample_mask, answer_mask in zip(tokenizer_samples["attention_mask"], tokenized_answers_masks)]
    full_seq_output_masks = torch.nn.utils.rnn.pad_sequence(full_seq_output_masks, batch_first=True, padding_value=0)

    # 6. Extract the logp for the entry
    logp = reward_model_outputs.gather(dim=-1, index=full_seq["input_ids"]).squeeze(-1)

    # 7. Return the logp for the answer
    logp *= full_seq_output_masks
    logp = logp.sum(-1)
    assert logp.shape == (len(samples),), logp.shape

    output = logp.mean()
    assert output.shape == (), output.shape
    return output
