Skip to content

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Dec 22, 2023

This PR adds the following features (high-level):

  • make_graphed_callables API similar to the PyTorch API with some additional arguments for FP8 usage. Support for fp8 weight caching via existing is_first_microbatchargument is also retained.
  • Restructuring and amax reduction logic with a simpler design and handling of various parallelisms with minimal book-keeping compared to the previous approach.
  • Forward and backward amaxes are reduced within the scope of current iteration, solving numerous bugs w.r.t. checkpointing and removing the need to save global buffers.
  • Support for nested/multiple FP8 autocast contexts with different recipes and distributed groups.
  • Amax reductions are module independent and happen at at autocast level. This also resolves numerous bugs and allows for support for MoE/LoRA like models.
  • Redesign of transposes for Float8Tensor that makes the transposes persistent for graph capture. Also fixes use cases for the vanilla optimizers (non fp8-distopt).
  • The scaling inverses for weight tensors are no longer frozen when caching weights across microbatches.

@ksivaman ksivaman marked this pull request as draft December 22, 2023 14:08
@timmoon10 timmoon10 self-requested a review March 11, 2024 22:29
Comment on lines 543 to 549
def _reset_caches(self) -> None:
"""Reset cached values

Should be called after any in-place operation.

"""
self._transpose = None
Copy link
Collaborator

@timmoon10 timmoon10 Mar 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the automatic cache clearing makes using the transpose cache a much more manual and dangerous process. Consider something like:

matmul(x, w.transpose(0, 1))
w -= learning_rate * w.grad
matmul(x, w.transpose(0, 1))

Previously we could just set update_cache="lazy". Now there needs to be manual logic to figure out the microbatch step, or else it will provide the stale values.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this example, caching is not used, so a fresh transpose will be computed each time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If caching is used, it is reasonable to expect the user to know when to reuse a cached value and when to force recompute. This is consistent with our design of is_first_microbatch argument to the forward for module APIs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: we use 2 args cache and update_cache to support this logic.

Copy link
Collaborator

@timmoon10 timmoon10 Mar 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we're overfitting to the Linear weight use-case. For example, in #707 I want to pass Float8Tensors between ops as inputs or dgrads:

class DbiasCastTranspose:
    def backward(self, dy):
        db = dy.sum(dim=0)
        dx = cast_transpose(dy)  # Creates Float8Tensor with transpose cache
        return dx, db

class FP8Linear:  # Part of FP8 attention
    def backward(self, dy):
        if not isinstance(dy, Float8Tensor):
           dy = Float8Tensor.to_float8(dy)
        dx = Float8Tensor(...)  # No transpose cache
        fp8_gemm(w.transpose()._data, dy.transpose()._data, out=dx._data)
        dw = fp8_gemm(x, dy)
        return dx, dw

FP8Linear has no idea where its input came from. Maybe it's from DbiasCastTranspose (Float8Tensor with cached transpose), FP8Linear (Float8Tensor without cached transpose), or a non-FP8 op. Our current approach with lazy transpose caching gives us a lot of flexibility and I think we should abandon it only when really necessary.

I suppose this is not precisely relevant since it doesn't involve in-place operations, but a more general statement about the design of Float8Tensor.

@ksivaman ksivaman marked this pull request as ready for review March 23, 2024 05:42
@ksivaman ksivaman changed the title [WIP ] PyTorch FP8 cuda graphs [PyTorch] cuda graph support Mar 23, 2024
@ksivaman
Copy link
Member Author

/te-ci

@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman
Copy link
Member Author

/te-ci pytorch

@timmoon10
Copy link
Collaborator

timmoon10 commented Mar 27, 2024

#735 has some improvements to the Float8Tensor transpose function, which should reduce the divergence with #707. If there are no issues, we should merge that branch into this PR.

@ksivaman ksivaman marked this pull request as draft March 27, 2024 23:59
@ptrendx ptrendx added the 1.6.0 label Apr 2, 2024
@ksivaman
Copy link
Member Author

ksivaman commented Apr 9, 2024

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Charlene Yang <charleney@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci pytorch

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci pytorch

@ksivaman ksivaman merged commit 73f8d90 into NVIDIA:main Apr 12, 2024
@ksivaman ksivaman mentioned this pull request Apr 30, 2024
4 tasks
pggPL pushed a commit to pggPL/TransformerEngine that referenced this pull request May 23, 2024
* FP8 cuda graphs

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Charlene Yang <charleney@nvidia.com>

* Fix numerics

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* exclude torch compile from numerics tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* More numerics fixes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix tests

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* rm fusion from unfused path

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: Charlene Yang <charleney@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants