-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FP8 & Activation checkpointing do not play well together #415
Comments
+1 Here is the error: File "/miniconda/lib/python3.10/site-packages/torch/_tensor.py", line 491, in backward
torch.autograd.backward(
File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 204, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/miniconda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 804, in unpack_hook
raise AssertionError(
AssertionError: if early stop is enabled, we don't expect to reach here |
Without checkpointing it is hard to use the full compute capability of H100, because of not enough VRAM. For clarity i am getting the following error when checkpointing TransformerLayer with fp8:
So I gess to fix this we just need to detect if we are within this checkpoint recomputation path, and if so, do not change the history at all. |
How are you currently doing activation checkpointing? Are you using an underlying toolkit such as NeMo? |
@ksivaman : I'm using |
For this purpose, you can use the checkpoint function that is provided in TransformerEngine. You can find the documentation here. Here we handle the additional items required for FP8 execution with activation recompute. |
@ksivaman @denera : is there an example of using |
Activation checkpointing recomputes the activations and hence it will need to re-execute parts of forward pass.
This re-execution should not affect history and be allowed.
Currently this error is being thrown: #93.
Would it be possible to cover this case such that Activation checkpointing works seamlessly?
The text was updated successfully, but these errors were encountered: