From 5fd2bc7763cc996d4c1c59a17fe8449895266811 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 27 Feb 2025 08:23:17 +0000 Subject: [PATCH 1/3] Fix quantized tensor shape Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/tensor/quantized_tensor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index ef21412ca7..b540cd91a1 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -433,7 +433,8 @@ def make_like( data. """ - shape = shape if shape is not None else tensor.shape + if shape is None: + shape = data.shape if data is not None else tensor.shape dtype = dtype if dtype is not None else tensor.dtype kwargs = tensor.get_metadata() if data is not None: From 7034ede4991c32c109b64f951c0fddea4d031e52 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 28 Feb 2025 07:11:08 +0000 Subject: [PATCH 2/3] add shape to make_like; add test for chunk Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8tensor.py | 30 +++++++++++++++++++ .../pytorch/tensor/float8_tensor.py | 9 ++++-- 2 files changed, 36 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_float8tensor.py b/tests/pytorch/test_float8tensor.py index 56b01f1dbc..9d01527ac5 100644 --- a/tests/pytorch/test_float8tensor.py +++ b/tests/pytorch/test_float8tensor.py @@ -161,6 +161,36 @@ def test_basic_ops( with pytest.raises(AssertionError): torch.testing.assert_close(x_fp8 + y_fp8, x_ref - y_fp8, **tols) + @pytest.mark.parametrize("dims", [2, [4, 4], [8, 5, 3, 3]]) + def test_chunk_op( + self, + dims: DimsType, + fp8_dtype: tex.DType = tex.DType.kFloat8E4M3, + scale: float = 3.5, + dtype: torch.dtype = torch.float32, + ) -> None: + """Test for ops for which shape of inputs and outputs differ.""" + + # Initialize random data + dims = _to_list(dims) + x_ref = torch.randn(dims, dtype=dtype, device="cpu") + x_fp8 = to_float8(x_ref, fp8_dtype=fp8_dtype, scale=1.0) + + # Get chunks. + chunk1, chunk2 = x_fp8.chunk(2, dim=0) + + # Test chunks. + torch.testing.assert_close(x_fp8[0 : dims[0] // 2,], chunk1, atol=0, rtol=0) + torch.testing.assert_close(x_fp8[dims[0] // 2 :,], chunk2, atol=0, rtol=0) + + # Check shapes. + assert ( + chunk1.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk1" + assert ( + chunk2.shape == torch.Size([x_fp8.shape[0] // 2]) + x_fp8.shape[1:] + ), "Wrong shape for chunk2" + def test_inplace_ops( self, dims: DimsType = 23, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 49bf4facfa..7ceaf731c3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -402,7 +402,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return [Float8Tensor.make_like(tensor, data=split_tensor) for split_tensor in func_out] + return [ + Float8Tensor.make_like(tensor, data=split_tensor, shape=split_tensor.shape) + for split_tensor in func_out + ] if func == aten.new_zeros.default: tensor = args[0] data = tensor._data @@ -412,7 +415,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape) if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data @@ -422,7 +425,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out) + return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape) if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: From 605d33e483a8fd3b571f73542e45e58103263165 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 28 Feb 2025 07:21:10 +0000 Subject: [PATCH 3/3] Fix typo from suggestion Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/tensor/float8_tensor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 7ceaf731c3..c9e65bd93a 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -415,7 +415,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.as_strided.default: tensor = args[0] data = tensor._data @@ -425,7 +425,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): [data] + list(args[1:]), kwargs, ) - return Float8Tensor.make_like(tensor, data=func_out, shape=data.func_out.shape) + return Float8Tensor.make_like(tensor, data=func_out, shape=func_out.shape) if func == torch.ops.aten.detach.default: return cls.detach(args[0]) if func == torch.ops.aten.clone.default: