-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise autocast usage error #93
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
For my own understanding, we expect this assert error to trigger if we perform multiple forward or backward passes within an FP8, or if we perform a partial forward or backward pass within an FP8 context (e.g. if we've frozen most of the model to finetune a specific section) (edit: partial forward or backward passes should run fine). It should run fine if the number of forward and backward passes don't match.
We don't want to run any portion of the model twice under the same autocast call when using amax reduction with FP8 training |
* catch incorrect usage of fp8_autocast Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * catch error on first time double execution Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
* catch incorrect usage of fp8_autocast Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * catch error on first time double execution Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Charlene Yang <charleney@nvidia.com>
* catch incorrect usage of fp8_autocast Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * catch error on first time double execution Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Charlene Yang <charleney@nvidia.com>
The following code give the error: Here is the Python code:
The code runs fine without the last line. Since the last line runs outside the Thanks for any insight. I am running with CUDA 11.8 for this library and CUDA 12.0 on the H100. |
Without this change, the following use case errors out with a not very helpful message,
This PR fixes this case and catches the error the first time
train()
is called in the above script. The "correct" usage is below: