Skip to content

Fix shape of new quantized tensor in make_like#1515

Merged
ksivaman merged 4 commits into
NVIDIA:mainfrom
ksivaman:fix_quantized_tensor_shape
Feb 28, 2025
Merged

Fix shape of new quantized tensor in make_like#1515
ksivaman merged 4 commits into
NVIDIA:mainfrom
ksivaman:fix_quantized_tensor_shape

Conversation

@ksivaman
Copy link
Copy Markdown
Member

Description

In functions such as split, chunk etc. in which the shape of the input and output differs, the returned tensor is correct but with the incorrect shape which leads to bugs, e.g. in FSDP2 or checkpoint loading. A small repro:

import torch
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
import transformer_engine_torch as tex

t = torch.randn(4, 4)
quantizer = Float8Quantizer(
    scale=torch.full([1], 1.0, dtype=torch.float32, device="cuda"),
    amax=torch.empty([1], dtype=torch.float32, device="cuda"),
    fp8_dtype=tex.DType.kFloat8E4M3,
)
x = quantizer(t.cuda())
a, b = x.chunk(2, dim=0)
print(x.shape, a.shape, b.shape)
print(x, a, b)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • If data

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman added the bug Something isn't working label Feb 27, 2025
@ksivaman ksivaman requested a review from timmoon10 February 27, 2025 08:52
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch

Copy link
Copy Markdown
Member

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as a hack.

Rambling digression: I don't like the data kwarg. It's not generic and going forward there's no reason to expect this logic to be valid in the future (FP6 will probably require a randomly sized blob of bytes). Also, if you're providing data you should also probably provide scale_inv (and maybe column-wise data as well). I think it's fine if it's an optional kwarg in Float8Tensor and MXFP8Tensor, but we should stop exposing it in QuantizedTensor.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented Feb 27, 2025

I agree with Tim, why not just add the shape argument to the usage of the make_like in float8tensor:

diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py
index 49bf4facf..0063b286a 100644
--- a/transformer_engine/pytorch/tensor/float8_tensor.py
+++ b/transformer_engine/pytorch/tensor/float8_tensor.py
@@ -402,7 +402,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
                 [data] + list(args[1:]),
                 kwargs,
             )
-            return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
+            return [Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) for split_tensor in func_out]
         if func == aten.new_zeros.default:
             tensor = args[0]
             data = tensor._data
@@ -412,7 +412,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
                 [data] + list(args[1:]),
                 kwargs,
             )
-            return Float8Tensor.make_like(tensor, data=func_out)
+            return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape)
         if func == torch.ops.aten.as_strided.default:
             tensor = args[0]
             data = tensor._data
@@ -422,7 +422,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
                 [data] + list(args[1:]),
                 kwargs,
             )
-            return Float8Tensor.make_like(tensor, data=func_out)
+            return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape)
         if func == torch.ops.aten.detach.default:
             return cls.detach(args[0])
         if func == torch.ops.aten.clone.default:

I confirmed that it also resolves the given repro.
Also, could you add that repro to the sanity tests?

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Copy Markdown
Member Author

/te-ci pytorch

@ksivaman
Copy link
Copy Markdown
Member Author

Added the shapes to make_like calls in float8_tensor, but I'm also keeping the original change to quantized_tensor since that is logically the correct shape to use in case the tensor's and data's shape do differ. We can remove the data kwarg completely from quantized_tensor which will require implementation of make_like in the Float8Tensor and MXFP8Tensor, so leaving that for a different PR.

@ksivaman ksivaman merged commit 9588109 into NVIDIA:main Feb 28, 2025
ptrendx pushed a commit that referenced this pull request Feb 28, 2025
* Fix quantized tensor shape

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* add shape to make_like; add test for chunk

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix typo from suggestion

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants