-
Notifications
You must be signed in to change notification settings - Fork 617
[PyTorch] cuda graph support #575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
46b509a to
bd7fd0a
Compare
5d5e52c to
8cb93ff
Compare
f4c8b6f to
374867a
Compare
| def _reset_caches(self) -> None: | ||
| """Reset cached values | ||
|
|
||
| Should be called after any in-place operation. | ||
|
|
||
| """ | ||
| self._transpose = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing the automatic cache clearing makes using the transpose cache a much more manual and dangerous process. Consider something like:
matmul(x, w.transpose(0, 1))
w -= learning_rate * w.grad
matmul(x, w.transpose(0, 1))Previously we could just set update_cache="lazy". Now there needs to be manual logic to figure out the microbatch step, or else it will provide the stale values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this example, caching is not used, so a fresh transpose will be computed each time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If caching is used, it is reasonable to expect the user to know when to reuse a cached value and when to force recompute. This is consistent with our design of is_first_microbatch argument to the forward for module APIs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note: we use 2 args cache and update_cache to support this logic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we're overfitting to the Linear weight use-case. For example, in #707 I want to pass Float8Tensors between ops as inputs or dgrads:
class DbiasCastTranspose:
def backward(self, dy):
db = dy.sum(dim=0)
dx = cast_transpose(dy) # Creates Float8Tensor with transpose cache
return dx, db
class FP8Linear: # Part of FP8 attention
def backward(self, dy):
if not isinstance(dy, Float8Tensor):
dy = Float8Tensor.to_float8(dy)
dx = Float8Tensor(...) # No transpose cache
fp8_gemm(w.transpose()._data, dy.transpose()._data, out=dx._data)
dw = fp8_gemm(x, dy)
return dx, dw
FP8Linear has no idea where its input came from. Maybe it's from DbiasCastTranspose (Float8Tensor with cached transpose), FP8Linear (Float8Tensor without cached transpose), or a non-FP8 op. Our current approach with lazy transpose caching gives us a lot of flexibility and I think we should abandon it only when really necessary.
I suppose this is not precisely relevant since it doesn't involve in-place operations, but a more general statement about the design of Float8Tensor.
transformer_engine/common/include/transformer_engine/cast_transpose_noop.h
Show resolved
Hide resolved
d0aa61c to
bb5b4d6
Compare
|
/te-ci |
|
/te-ci pytorch |
|
/te-ci pytorch |
eff5d27 to
32e070c
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com>
db6a812 to
31dc133
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
9944150 to
3c50a17
Compare
|
/te-ci pytorch |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
* FP8 cuda graphs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com> * Fix numerics Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * exclude torch compile from numerics tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * More numerics fixes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix tests Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix CI Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * rm fusion from unfused path Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by: Vasudevan Rengasamy <vrengasamy@nvidia.com> Co-authored-by: Charlene Yang <charleney@nvidia.com> Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
This PR adds the following features (high-level):
make_graphed_callablesAPI similar to the PyTorch API with some additional arguments for FP8 usage. Support for fp8 weight caching via existingis_first_microbatchargument is also retained.Float8Tensorthat makes the transposes persistent for graph capture. Also fixes use cases for the vanilla optimizers (non fp8-distopt).