Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap all fp8 extra states in LocalNonpersistentObject #9422

Closed
wants to merge 1 commit into from

Conversation

jbaczek
Copy link
Collaborator

@jbaczek jbaczek commented Jun 10, 2024

What does this PR do ?

This PR generalizes FP8 extra state wrapping for all tensors.

Collection: nlp

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you add or update any necessary documentation?
  • Does the PR affect components that are optional to install? (Ex: Numba, Pynini, Apex etc)
    • Reviewer: Does the PR have correct import guards for all optional libraries?

PR Type:

  • New Feature
  • Bugfix
  • Documentation

@github-actions github-actions bot added the NLP label Jun 10, 2024
@jbaczek jbaczek requested a review from timmoon10 June 10, 2024 09:25
Signed-off-by: Jan Baczek <jbaczek@nvidia.com>
@jbaczek jbaczek force-pushed the jbaczek/llm/fix_war_for_fp8_load branch from 7f822fa to 549c541 Compare June 10, 2024 09:46
Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Jun 28, 2024
@jbaczek jbaczek removed the stale label Jul 5, 2024
@@ -1850,7 +1850,7 @@ def sharded_state_dict(self, prefix: str = '') -> Dict[str, Any]:

# WAR: This is a temporary fix to skip loading FP8 parameters for Dot Product Attention
def skip_fp8_load(x):
if isinstance(x, ShardedObject) and 'fused_attention' in x.key and '_extra_state' in x.key:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm concerned that this PR basically makes the GPT models ignore the FP8 state for all layers in the checkpoint.

In the meantime I prepared a thorough solution, with an almost merged MCore branch and a corresponding NeMo branch.

@jbaczek could you check (in theory or in practice) if this would solve your problem?
The required flag to set would be model.dist_ckpt_load_sctrictness=log_all

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that non-strict loading would solve the problem. I see that this branch is already merged to mcore. When should we expect the sync of repositories, so I could use the public implementation?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I asked for a public sync, ideally should be available later today, but we don't have an official ETA

Copy link
Contributor

This PR is stale because it has been open for 14 days with no activity. Remove stale label or comment or update or this will be closed in 7 days.

@github-actions github-actions bot added the stale label Jul 23, 2024
Copy link
Contributor

This PR was closed because it has been inactive for 7 days since being marked as stale.

@github-actions github-actions bot closed this Jul 30, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants