Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save Virtual Prompt Weights Only #4237

Merged
merged 20 commits into from
May 26, 2022
Merged

Conversation

vadam5
Copy link
Contributor

@vadam5 vadam5 commented May 23, 2022

Removes all GPT/Frozen model configs and weights from Prompt Learning Model's .nemo file. After training has ended, the prompt learning model now only saves prompt table parameters. During training, intermediate checkpoint files may also have prompt encoder parameters.

Collection: BigNLP

  • Adds custom state_dict and load_state_dict methods for prompt learning model class
  • Updates prompt learning inference example script
  • Updates prompt learning documentation

Usage

The final .nemo checkpoint file contains

  1. model_config.yaml
  2. model_weights.ckpt

where model_weights.ckpt only has prompt table params

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
  • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

If you haven't finished some of the above items you can still open "Draft" PR.

Who can review?

Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.

Additional Information

  • Related to # (issue)

Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Virginia Adams <vadams@nvidia.com>
@titu1994 titu1994 requested a review from ericharper May 24, 2022 02:43
@titu1994
Copy link
Collaborator

And any modification of core needs two core devs to approve it. FYI @ericharper

Signed-off-by: Virginia Adams <vadams@nvidia.com>
@vadam5 vadam5 requested a review from titu1994 May 24, 2022 17:47
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ericharper can you work with Virginia to properly use the base class methods. This currently basically just replaces save to with custom impl which will just drift apart more and more from the base implementation

nemo/collections/nlp/parts/nlp_overrides.py Outdated Show resolved Hide resolved
Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Virginia Adams <vadams@nvidia.com>
@lgtm-com
Copy link

lgtm-com bot commented May 25, 2022

This pull request introduces 3 alerts when merging 89e374d into ff9bc79 - view on LGTM.com

new alerts:

  • 3 for Unused import

vadam5 and others added 3 commits May 25, 2022 02:22
Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Virginia Adams <vadams@nvidia.com>
@vadam5 vadam5 changed the title Add GPT Artifact Preservation Method for Prompt Learning Save Virtual Prompt Weights Only May 25, 2022
@lgtm-com
Copy link

lgtm-com bot commented May 25, 2022

This pull request introduces 3 alerts when merging 9b3401d into ff9bc79 - view on LGTM.com

new alerts:

  • 3 for Unused import

Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Virginia Adams <vadams@nvidia.com>
@vadam5 vadam5 requested a review from titu1994 May 26, 2022 19:45
titu1994
titu1994 previously approved these changes May 26, 2022
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks fine, I'm just worried about the double extraction cost. At least it deleted the intermediate file so storage cost should be manageable.

docs/source/nlp/prompt_learning.rst Show resolved Hide resolved
@vadam5 vadam5 requested a review from titu1994 May 26, 2022 20:54
Copy link
Collaborator

@titu1994 titu1994 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. @MaximumEntropy for final review before merge.

@vadam5 vadam5 merged commit 84265ac into main May 26, 2022
@vadam5 vadam5 deleted the prompt_learning_preserve_gpt_artifacts branch May 26, 2022 22:23
yaoyu-33 pushed a commit that referenced this pull request May 31, 2022
* Added gpt artifact preservation method

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Removed redundent line of code

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Moved preserve artifact method to NLPSaveStoreConnector

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Saving only prompt table weights in final nemo file

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Put NLP overrides back the way it was

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Put NLP overrides back the way it was

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Added doc strings for new methods

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Python style fix

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Added loading state dict backward compatibility

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Upddated prompt learning inference to reset frozen model path

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Python formatting fix

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Update prompt_learning.rst

* Update prompt_learning.rst

* Update prompt_learning.rst

* changed model_file to gpt_model_file, updated CI tests

Signed-off-by: Virginia Adams <vadams@nvidia.com>
Signed-off-by: Yu Yao <yuya@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants