Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forced EOS token in vllm generation? #238

Open
mgerstgrasser opened this issue Mar 8, 2024 · 6 comments
Open

Forced EOS token in vllm generation? #238

mgerstgrasser opened this issue Mar 8, 2024 · 6 comments

Comments

@mgerstgrasser
Copy link
Contributor

I see in RemoteExperienceMaker._generate_vllm(), line 375 that for generations that don't finish, i.e. don't output the EOS tokens within the max token limit, we manually set the last token to be the EOS token, even though that was not what the model generated.

Isn't this the wrong thing to do? E.g. if the model generated an unfinished sentence like "This is an unfinished" when it ran into the token limit, shouldn't we train on that, rather than "This is an "? My understanding of the PPO algorithm is also that it doesn't do well with off-policy experiences, which we technically have if we manually change to the EOS token. So I just wanted to check if there's a specific reason to do this?

It also looks to me that the huggingface model.generate() method, and by extension RemoteExperienceMaker._generate_local() and NaiveExperienceMaker do not do this.

@hijkzzz
Copy link
Collaborator

hijkzzz commented Mar 8, 2024

Because the reward model uses the EOS token to predict the reward value.
So we had to hack it.

@mgerstgrasser
Copy link
Contributor Author

Ahhhh, got it, that makes sense. I think that's probably broken with local generation then! I just verified that that doesn't have EOS if max_tokens is reached.

Also, would taking the RM output on the last non-masked token instead of EOS be a better way around this? Or alternatively, forcing the EOS token only when feeding the experience to the RM?

@hijkzzz
Copy link
Collaborator

hijkzzz commented Mar 8, 2024

I am not sure which approach to take at the moment, but our current implementation is heavily dependent on EOS tokens.

@mgerstgrasser
Copy link
Contributor Author

I am not sure which approach to take at the moment, but our current implementation is heavily dependent on EOS tokens.

You mean specifically for the RM? Or more broadly than that?

@mgerstgrasser
Copy link
Contributor Author

Oh, for local generation, does actor.process_sequences() do the same thing?

sequences.scatter_(dim=1, index=eos_indices, value=eos_token_id)

If so, then doing this in RemoteExperienceMaker seems unnecessary, since that also calls actor.process_sequences() later anyway, i.e. this is being done twice, I think?

@mgerstgrasser
Copy link
Contributor Author

@hijkzzz Could I ask a quick related question: In actor.process_sequences() I also see that attention_mask is set to False on all EOS tokens, except the final EOS token in each sequence. In the datasets used in the examples, it seems that there never is an EOS token in the input, but if there was (e.g. in a previous turn in a conversation), shouldn't the attention mask be True there?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants