[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW#1844
[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW#1844phu0ngng merged 6 commits intoNVIDIA:mainfrom
Conversation
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
Overall looks good, some small comments. I like the TensorUsage concept
There was a problem hiding this comment.
Should we keep the check for whether we're training or not? So it'd be something like this
casted_ln_out.get_tensor(TensorUsage.LHS_TRANS) if quantizer_set.x.is_2x2x() else None
because on Hopper this layout wouldn't exist when doing 1x for inference, right?
There was a problem hiding this comment.
No, we don't need to do that here.
Whether it's training or not should be handled in the get_tensor() method, and currently, we don't support it yet.
There was a problem hiding this comment.
Ah I see, that will make it easier if it's automatic. So the idea is based on whether it's x or kernel, we know which usage (LHS/RHS) will be used in the forward and the backward, so we can automatically know to remove it when doing inference (not currently, but we have all the required info to support it in future)?
|
Adding @huanghua1994 to review the changes in |
|
/te-ci JAX L0 |
|
/te-ci JAX L0 |
b411402 to
0644636
Compare
|
/te-ci JAX L0 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L0 |
1 similar comment
|
/te-ci JAX L0 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
|
/te-ci JAX L0 |
|
/te-ci JAX L0 |
|
/te-ci JAX L0 |
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci JAX L0 |
* TensorUsage + FP8 GEMM with all layouts handling on BW Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Description
TensorUsage + FP8 GEMM with all layouts handling on BW.
Verified that no NT enforcements by JAX, i.e. no additional transpose.
Type of change
Checklist: