Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set add_special_tokens=False to not add EOS unexpectedly #287

Merged
merged 12 commits into from
Feb 10, 2023

Conversation

cat-state
Copy link
Collaborator

This PR addresses #253 . Set add_special_tokens=False and instead add BOS manually, matching what is done in the trainers.

comparison report with ppo_sentiments: https://wandb.ai/carperai/trlx/reports/PromptPipeline-Add-BOS---VmlldzozNTA2MTM4

@cat-state cat-state marked this pull request as draft February 8, 2023 00:15
@cat-state
Copy link
Collaborator Author

I just realised this breaks the T5 example via

/home/a/trlx/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py:26 in reward_fn              │
│                                                                                                  │
│   23 if __name__ == "__main__":                                                                  │
│   24 │                                                                                           │
│   25def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):              │
│ ❱ 26 │   │   original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]           │
│   27 │   │   scores = [                                                                          │
│   28 │   │   │   meteor.compute(predictions=[output.strip()], references=[original])["meteor"    │
│   29 │   │   │   for (original, output) in zip(original_summaries, outputs)                      │
│                                                                                                  │
│ /home/a/trlx/examples/summarize_daily_cnn/t5_summarize_daily_cnn.py:26 in <listcomp>             │
│                                                                                                  │
│   23 if __name__ == "__main__":                                                                  │
│   24 │                                                                                           │
│   25def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):              │
│ ❱ 26 │   │   original_summaries = [prompt_label[prompt.strip()] for prompt in prompts]           │
│   27 │   │   scores = [                                                                          │
│   28 │   │   │   meteor.compute(predictions=[output.strip()], references=[original])["meteor"    │
│   29 │   │   │   for (original, output) in zip(original_summaries, outputs)                      │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
KeyError: 'Summarize: (CNN) -- This week marks one of the most-exciting non-major events of the golf season -- the Players Championship at the famed TPC Sawgrass. With 
a deep field and a great course, you won\'t want to miss any of the action. Before the tournament tees off, we had a chance to catch up with TPC Sawgrass PGA Head 
Professional Matt Borocz, who provided some inside insight on the home of the PGA Tour. PGA.com: Thanks for joining us. This week presents one of the most exciting on 

@cat-state cat-state marked this pull request as ready for review February 8, 2023 22:54
@jon-tow
Copy link
Collaborator

jon-tow commented Feb 9, 2023

Set add_special_tokens=False and instead add BOS manually, matching what is done in the trainers.

@cat-state wait where is this happening in the trainers? It looks like the <>Trainer.tokenize() methods that manually add BOS tokens never actually get called anymore, e.g. the base method

def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]):
if isinstance(texts[0], torch.LongTensor):
return texts
tokenized = self.tokenizer(
[self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts],
max_length=self.max_length,
truncation=True,
# NOTE: We manually add special tokens (bos) above so we set this False
# to avoid models that automatically add special tokens (e.g. OPT)
# adding them twice more.
add_special_tokens=False,
)
input_ids = list(map(torch.as_tensor, tokenized.input_ids))
return input_ids

Let me loop @reciprocated in here because I vaguely recall he brought this up before. Also, if these methods are dead we should probably remove them.

@maxreciprocate
Copy link
Collaborator

Neither AccelerateRLTrainer.tokenize, AccelerateILQLTrainer.tokenize nor NeMoILQLTrainer.tokenize don't seem to be used anywhere at the moment. Actually there are only three places were tokenization happens: in both orchestrators and in the prompt pipeline (apart from examples)

model_inputs = tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length)

tokens = tokenizer(phrase).input_ids[-ctx_length:]

outputs = self.trainer.tokenizer(str_outputs).input_ids

Only the last needs to have <eos> appended (for the reward model & to make indexing work) and I think none really need <bos>

@cat-state
Copy link
Collaborator Author

are only three places were tokenization happens: in both orchestrators and in the prompt pipeline (apart from examples)

Oh right! so this was just dead code 😅
It seems like we still add EOS and BOS for ILQL in

dialogue = [tokenizer.bos_token, dialogue]

With the 20b or J or gpt2 the tokenizers don't add BOS/EOS by default it seems, however the original issue was when using a tokenizer that does

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @cat-state! This looks good 👍 Just leaving some questions for small edits.

trlx/trainer/accelerate_ilql_trainer.py Show resolved Hide resolved
model_inputs = tokenizer(prompts, truncation=True, padding=False, max_length=max_prompt_length)
prompts = model_inputs["input_ids"]

# manually prepend bos token if not already present to match RL trainers tokenization
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment might be misleading since we no longer prepend BOS tokens from the trainers.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah, I'll change it. We still do add BOS and EOS for ILQL btw

Copy link
Collaborator

@maxreciprocate maxreciprocate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like in your report there are 5 different runs all with slightly different results, however all use gpt2-imdb for which this pr shouldn't change anything, do you think this is only due to added BOS's influence?

# default tokenizer behavior for PPO
if tokenizer.bos_token is not None:
prompts = [
tokenizer.bos_token + prompt if not prompt.startswith(tokenizer.bos_token) else prompt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have to force adding BOS here? In ILQL it is only added in case there is only a single string passed, to make sure loss starts from 0-index of output (action_ixs[0] = 0) by prepending a "prompt" as BOS, and not when real [prompt, output] is passed (unless truncation comes into play). Don't want to change behaviour here as I'm finilizing HH PR with already shaby results as they are, which may or may not take a hit due to this change 🙂

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, that makes sense that the bos is only added to the the prompt for a completion if no prompt is given for that completion.

@jon-tow
Copy link
Collaborator

jon-tow commented Feb 10, 2023

@reciprocated As an aside; could you clarify why you believe "none [of the tokenizations] really need <bos>"? Is this from empirical results you've noticed with RL? The "natural" thing to do with LMs is to include <bos> for inference (e.g. for PPO).

@cat-state
Copy link
Collaborator Author

cat-state commented Feb 10, 2023

Seems like in your report there are 5 different runs all with slightly different results, however all use gpt2-imdb for which this pr shouldn't change anything, do you think this is only due to added BOS's influence?

The main runs ran with the same seeds have the same results, I included 2 runs with each seed for both main and with bos to see how they change it

@cat-state cat-state changed the title Fix PromptPipeline tokenization Set add_special_tokens=False to not add EOS unexpectedly Feb 10, 2023
@maxreciprocate
Copy link
Collaborator

maxreciprocate commented Feb 10, 2023

could you clarify why you believe "none [of the tokenizations] really need <bos>"?

@jon-tow There I've meant that <bos> is really not needed algorithmically (unlike in the single output ILQL case), however I don't have any evidence nor counter-evidence that adding it is not advantageous in general. Yet there is such evidence for adding <eos>, something we actually don't do currently but I feel like ought to, from Anthropic's 2021 HHH paper:

We also found that appending a special ‘end-of-context’ token to each sequence to unambiguously delineate the end of passage sometimes improves performance, as discussed in section C.4.

section C.4.:

Here we outline a technical detail that improves the overall performance of preference models. We designate a special “end-of-context” token (EOC) which is included as the final token of each sample context. The preference model score is also predicted directly on top of this token. For our experiments we used the token, but in principle many other choices are possible.

We compare finetuning experiments with and without the EOC token. For experiments with, we consistently apply the same EOC token throughout both the PMP and fine-tuning stages; and for experiments without, we consistently do not apply the EOC token. From figure 31 we see that the EOC clearly improves performance.

We hypothesize that the improvement comes from two factors:
• Sometimes the sentiment behind a natural language statement can be altered or reversed significantly by the addition of one or two words, and so knowing where the context ends can be helpful for the preference model to predict a sensible score.
• Without an EOC token, the preference model must not only predict a score, but also try to anticipate where the context ends. As a result, the model is forced to predict a score at multiple tokens where the context may end, rather than at a single token where it definitely ends. This adds a level of ambiguity which may cause the model to under-perform.

@jon-tow
Copy link
Collaborator

jon-tow commented Feb 10, 2023

@reciprocated Yeah... I also haven't seen any experiments comparing results from <bos> vs no-<bos> for decoder models but folks make the argument you should because it gets you logprobs for the first token (however that may steer inference 🤷).
Thanks for linking the EOC stuff! (off-topic but I wonder if we did this for the HH RM training 🤔)

Copy link
Collaborator

@jon-tow jon-tow left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


@cat-state Just noting that this subtly changes results for T5 models dailymail/cnn summarization example (PPO): https://api.wandb.ai/links/jon-tow/p3kg4ejf
Reward seems to improve better with this fix since removing the eos token before passing ids to generate is (probably?) the right thing to do.

@maxreciprocate
Copy link
Collaborator

maxreciprocate commented Feb 10, 2023

@jon-tow Yeah, but the first logprob of interest in PPO is in the beginning of output, not of the prompt, so unless len(prompt) == 0 there is no reason to add <bos>. Regarding <eos> with HH RM, yes we did in https://github.com/Dahoas/reward-modeling/blob/deba81c2e9dab1514032800109325258caa470db/reward-modeling/rm_datasets.py#L93 @Dahoas, I also add it manually in:

def reward_fn(samples, prompts, outputs):
samples = [s + reward_tokenizer.eos_token for s in samples]

@cat-state
Copy link
Collaborator Author

cat-state commented Feb 10, 2023

improve better with this fix since removing the eos token before passing ids to generate is (probably?) the right thing to do.

Thanks for checking! I hope so, and its on the same random seed too. Seems to not affect GPT-2/neox tokenizer based models as those don't add special tokens by default.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants