Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FP8 & Activation checkpointing do not play well together #415

Open
PiotrDabkowski opened this issue Sep 3, 2023 · 6 comments
Open

FP8 & Activation checkpointing do not play well together #415

PiotrDabkowski opened this issue Sep 3, 2023 · 6 comments

Comments

@PiotrDabkowski
Copy link

PiotrDabkowski commented Sep 3, 2023

Activation checkpointing recomputes the activations and hence it will need to re-execute parts of forward pass.

This re-execution should not affect history and be allowed.
Currently this error is being thrown: #93.

Would it be possible to cover this case such that Activation checkpointing works seamlessly?

@jramapuram
Copy link

+1

Here is the error:

  File "/miniconda/lib/python3.10/site-packages/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/miniconda/lib/python3.10/site-packages/torch/autograd/__init__.py", line 204, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/miniconda/lib/python3.10/site-packages/torch/utils/checkpoint.py", line 804, in unpack_hook
    raise AssertionError(
AssertionError: if early stop is enabled, we don't expect to reach here

@PiotrDabkowski
Copy link
Author

PiotrDabkowski commented Sep 3, 2023

Without checkpointing it is hard to use the full compute capability of H100, because of not enough VRAM.

For clarity i am getting the following error when checkpointing TransformerLayer with fp8:

...
raceback (most recent call last):
  File "/home/user/llm/llmte/te_model.py", line 129, in <module>
    x.mean().backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 491, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 204, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 274, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/checkpoint.py", line 261, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/transformer.py", line 499, in forward
    self_attention_outputs = self.self_attention(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/attention.py", line 1243, in forward
    layernorm_qkv_outputs = self.layernorm_qkv(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1505, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1514, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/_dynamo/eval_frame.py", line 294, in _fn
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/layernorm_linear.py", line 842, in forward
    with self.prepare_forward(inp, is_first_microbatch) as inp:
  File "/usr/lib/python3.10/contextlib.py", line 135, in __enter__
    return next(self.gen)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/module/base.py", line 612, in prepare_forward
    add_amax_to_global_buffer(self.fp8_meta, forward=True)
  File "/usr/local/lib/python3.10/dist-packages/transformer_engine/pytorch/fp8.py", line 135, in add_amax_to_global_buffer
    assert fp8_meta[buffer_position_key] == len(_global_fp8_buffer[buffer_key]) - 1, \
AssertionError: Same module is being invoked more than once inside an `fp8_autocast` region when using FP8 with amax reduction. This behavior is currently unsupported. For more details and correct usage, please see https://github.com/NVIDIA/TransformerEngine/pull/93.

So I gess to fix this we just need to detect if we are within this checkpoint recomputation path, and if so, do not change the history at all.

@ksivaman
Copy link
Member

ksivaman commented Sep 6, 2023

How are you currently doing activation checkpointing? Are you using an underlying toolkit such as NeMo?

@jramapuram
Copy link

@ksivaman : I'm using checkpoint_wrapper from torch.distributed.algorithms._checkpoint.checkpoint_wrapper coupled with FSDP.

@ksivaman
Copy link
Member

For this purpose, you can use the checkpoint function that is provided in TransformerEngine. You can find the documentation here. Here we handle the additional items required for FP8 execution with activation recompute.

@jramapuram
Copy link

@ksivaman @denera : is there an example of using transformer_engine.pytorch.checkpoint? Is it possible to add this to the FSDP example in TE?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants