Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derive RewardModel from PreTrainedModel #2158

Merged
merged 4 commits into from Mar 22, 2023
Merged

Conversation

andreaskoepf
Copy link
Collaborator

First simple way to derive our RewardModel from PreTrainedModel so simplify loading .. still requires full download of base model which could probably be simplified in the future.

Copy link
Collaborator

@sanagno sanagno left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@dvruette
Copy link
Collaborator

dvruette commented Mar 21, 2023

Looks good! Does this work with AutoModel? If so, how exactly? I would expect it to work with AutoModelForSequenceClassification.from_pretrained.

Also, why not inherit from GPTNeoXModel? Could save us the trouble of loading the base model. We could basically just copy GPTNeoXForCausalLM and add pooling between hidden states and head.

Copy link
Collaborator

@theblackcat102 theblackcat102 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@andreaskoepf
Copy link
Collaborator Author

Looks good! Does this work with AutoModel? If so, how exactly? I would expect it to work with AutoModelForSequenceClassification.from_pretrained.

This PR does currently not support loading from AutoModel but requires explicit use of the RewardModel class. If you have time you could look into what is necessary to register for AutoModel loading.

Also, why not inherit from GPTNeoXModel? Could save us the trouble of loading the base model. We could basically just copy GPTNeoXForCausalLM and add pooling between hidden states and head.

It isn't derived from a single architecture to allow different types of base models.

@andreaskoepf andreaskoepf merged commit 99b42bc into main Mar 22, 2023
1 check passed
@andreaskoepf andreaskoepf deleted the rm_pre_trained_model branch March 22, 2023 00:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants