Skip to content

Support FP8 primary weight in FSDP training#1630

Merged
ksivaman merged 3 commits intoNVIDIA:mainfrom
shjwudp:fp8_primary_weight_support_for_fsdp
Apr 7, 2025
Merged

Support FP8 primary weight in FSDP training#1630
ksivaman merged 3 commits intoNVIDIA:mainfrom
shjwudp:fp8_primary_weight_support_for_fsdp

Conversation

@shjwudp
Copy link
Contributor

@shjwudp shjwudp commented Apr 1, 2025

Description

This MR modifies the cast_master_weights_to_fp8 function in the FP8 primary weight application, allowing us to use FP8 primary weight in FSDP training.

In FSDP training, the model weight may be incomplete, and model_weight._data may be DTensor(FSDP2) or resized for parameter sharding. We cannot obtain the actual model weight shard address through the slice reading method like model_weight._data.view(-1)[start_offset:end_offset]. This MR extends the cast_master_weights_to_fp8 function to accept the direct input of shard model weight, so that the special use of FSDP can be implemented.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@shjwudp shjwudp force-pushed the fp8_primary_weight_support_for_fsdp branch from 422f432 to bc4b9e9 Compare April 2, 2025 11:25
@shjwudp shjwudp closed this Apr 2, 2025
@shjwudp shjwudp reopened this Apr 2, 2025
Signed-off-by: jianbinc <shjwudp@gmail.com>
@shjwudp shjwudp force-pushed the fp8_primary_weight_support_for_fsdp branch from bc4b9e9 to d14c1f0 Compare April 2, 2025 11:32
@shjwudp shjwudp changed the title Support FP8 primary weight with FSDP Support FP8 primary weight in FSDP training Apr 2, 2025
@ksivaman
Copy link
Member

ksivaman commented Apr 3, 2025

/te-ci pytorch L0 L1

@ksivaman ksivaman merged commit c84d170 into NVIDIA:main Apr 7, 2025
2 checks passed
wdykas pushed a commit to wdykas/TransformerEngine that referenced this pull request Apr 14, 2025
Support fp8 primary weight in fsdp training

Signed-off-by: jianbinc <shjwudp@gmail.com>
Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Peter Dykas <wdykas@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants