-
Notifications
You must be signed in to change notification settings - Fork 23
Added keep_fp8_weight_transpose_cache checks while updating transpose in fwd pass #298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi @ipanfilo, @wangye805, @wenchenvincent |
wangye805
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add a section of documentation for keep_fp8_weight_transpose_cache, either in the source codes before it's definition or in our README
…for transpose cache, Modified docstring
34cc7fb to
38956fd
Compare
6a0e621 to
5a35e71
Compare
|
Several meaningful test failures on CI, all stemming from if keep_fp8_weight_transpose_cache:
> assert not transpose_is_empty_or_none, "Expected _transpose to be a valid, non-empty tensor when transpose cache is enabled."
E AssertionError: Expected _transpose to be a valid, non-empty tensor when transpose cache is enabled. |
| # This file was modified for portability to AMDGPU | ||
| # Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. | ||
| # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| # |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keep it
I might be OK to assert that _transpose is not valid when keep flag is False but there might be other reasons for _transpose not to be valid when the flag is True (default behaviour) |
| fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=device) | ||
|
|
||
| # FC2 GEMM | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove this to reduce unnecessary changes
… in fwd pass (#298) * Added keep_fp8_weight_transpose_cache checks while updating transpose * Added unittest for the fix * Added comment for the unit test * Fixed comment * Reverted test for single iteration, added assert statements to check for transpose cache, Modified docstring * Fixed test_numerics spacing * Added HIP Guards * Addressed PR Comments, and moved assertion statements under fp8 check * Reverting assertion to fix the dev ticket * Removed spacing --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com>
* Ensure weight transpose is valid for FP8 training (#1596) (#276) * Update usage of weightmat before saving for backward * Added keep_fp8_weight_transpose_cache checks while updating transpose in fwd pass (#298) * Added keep_fp8_weight_transpose_cache checks while updating transpose * Added unittest for the fix * Added comment for the unit test * Fixed comment * Reverted test for single iteration, added assert statements to check for transpose cache, Modified docstring * Fixed test_numerics spacing * Added HIP Guards * Addressed PR Comments, and moved assertion statements under fp8 check * Reverting assertion to fix the dev ticket * Removed spacing --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com> * Bug fix for get_fp8_metas * Added keep_fp8_transpose_cache fix for base.py * added _fp8_metas check for None * Added comment --------- Co-authored-by: Sudharshan Govindan <sugovind@amd.com>
Description
While keep_fp8_weight_transpose_cache is False, the expectation is to not create any transpose in the forward pass and compute transpose in the backward pass. ad76b62#diff-ba97b0d1ae75d17a678bc38b4fa69ffec1e0ea007657a28d65565ee2cff35b95
The above commit introduced check to ensure transpose is created when input requires grad is True. But we don't want to create transpose when keep_fp8_weight_transpose_cache is False.
Without this check, it leads to creating transpose while transpose data ptr isn't initialized.
RuntimeError: /workspace/TransformerEngine/transformer_engine/common/transpose/transpose.hip:206 in function transpose: Assertion failed: output.data.dptr != nullptr. Output is not allocated.
Fixes # (13552)
https://ontrack-internal.amd.com/browse/SWDEV-553639
Type of change
Changes
Added keep_fp8_weight_transpose_cache checks while updating transpose
Checklist: