Fix shape of new quantized tensor in make_like#1515
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
There was a problem hiding this comment.
This is fine as a hack.
Rambling digression: I don't like the data kwarg. It's not generic and going forward there's no reason to expect this logic to be valid in the future (FP6 will probably require a randomly sized blob of bytes). Also, if you're providing data you should also probably provide scale_inv (and maybe column-wise data as well). I think it's fine if it's an optional kwarg in Float8Tensor and MXFP8Tensor, but we should stop exposing it in QuantizedTensor.
|
I agree with Tim, why not just add the shape argument to the usage of the make_like in float8tensor: diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py
index 49bf4facf..0063b286a 100644
--- a/transformer_engine/pytorch/tensor/float8_tensor.py
+++ b/transformer_engine/pytorch/tensor/float8_tensor.py
@@ -402,7 +402,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
- return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out]
+ return [Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) for split_tensor in func_out]
if func == aten.new_zeros.default:
tensor = args[0]
data = tensor._data
@@ -412,7 +412,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
- return Float8Tensor.make_like(tensor, data=func_out)
+ return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape)
if func == torch.ops.aten.as_strided.default:
tensor = args[0]
data = tensor._data
@@ -422,7 +422,7 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
[data] + list(args[1:]),
kwargs,
)
- return Float8Tensor.make_like(tensor, data=func_out)
+ return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape)
if func == torch.ops.aten.detach.default:
return cls.detach(args[0])
if func == torch.ops.aten.clone.default:I confirmed that it also resolves the given repro. |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci pytorch |
|
Added the shapes to |
* Fix quantized tensor shape Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * add shape to make_like; add test for chunk Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Fix typo from suggestion Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Description
In functions such as
split,chunketc. in which the shape of the input and output differs, the returned tensor is correct but with the incorrect shape which leads to bugs, e.g. in FSDP2 or checkpoint loading. A small repro:Type of change
Changes
Checklist: