From edcb9508dcdf567db2d48dac29b6be5628b24c2e Mon Sep 17 00:00:00 2001 From: Sam Larsen Date: Fri, 1 Aug 2025 12:54:53 -0700 Subject: [PATCH] Fix rand_like decomposition to preserve strides (#159294) Summary: Like https://github.com/pytorch/pytorch/pull/158898, the rand_like variants are not preserving strides. Followed the pattern established in https://github.com/pytorch/pytorch/pull/158898. Test Plan: New unit test (fails before this PR; but fixed after) Differential Revision: [D79472604](https://our.internmc.facebook.com/intern/diff/D79472604) Pull Request resolved: https://github.com/pytorch/pytorch/pull/159294 Approved by: https://github.com/eellison --- test/inductor/test_torchinductor.py | 22 ++- ...st_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/decomposition.py | 154 +++++++++--------- 3 files changed, 99 insertions(+), 78 deletions(-) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 25aeed19a84c..34942411fff0 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8918,7 +8918,7 @@ def forward(self, v1: torch.Tensor): model = Model() x = torch.rand(10, 3, 0) - self.common(model, (x,)) + self.common(model, (x,), exact_stride=True) def test_randint(self): @torch.compile(fullgraph=True) @@ -8973,9 +8973,21 @@ def bin(index, max_size): @config.patch(fallback_random=True) def test_like_rands(self): def fn(x): - return torch.rand_like(x), torch.randn_like(x) + return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11) - self.common(fn, [torch.zeros([20, 20])]) + self.common(fn, [torch.zeros([20, 20])], exact_stride=True) + + @config.patch(fallback_random=True) + @xfail_if_mps # 100% are not close + def test_like_rands_sliced(self): + def fn(x): + return ( + torch.randn_like(x), + torch.randn_like(x), + torch.randint_like(x, 1, 11), + ) + + self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True) @config.patch(check_stack_no_cycles_TESTING_ONLY=True) def test_check_stack_no_cycles(self): @@ -9008,6 +9020,8 @@ def fn(x): a0 = fn(x).clone() a1 = fn(x).clone() self.assertFalse(torch.allclose(a0, a1)) + self.assertEqual(a0.shape, a1.shape) + self.assertEqual(a0.stride(), a1.stride()) @requires_gpu() @skip_if_triton_cpu("Flaky on Triton CPU") @@ -9025,6 +9039,8 @@ def fn(x, device): a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu") self.assertTrue(a0.device.type == GPU_TYPE) self.assertTrue(a1.device.type == "cpu") + self.assertEqual(a0.shape, a1.shape) + self.assertEqual(a0.stride(), a1.stride()) def test_max_pool2d_with_indices_backward(self): def fn(a, b, c): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 29d74152bf4e..8529c0379ef5 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -173,6 +173,7 @@ def run(*ex, **kwargs): "test_bucketize_int_dynamic_shapes": TestFailure(("cpu",)), "test_searchsorted_dynamic_shapes": TestFailure(("cpu",)), "test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 2dd8a47feb4a..415fd0a6098b 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -535,49 +535,17 @@ def view_copy_dtype( return self.to(dtype).clone() -def get_like_layout( - tensor: torch.Tensor, - memory_format: Optional[torch.memory_format] = None, -) -> torch.memory_format: - # TODO: _to_copy tensor to stride permutation - if memory_format is torch.preserve_format or memory_format is None: - return utils.suggest_memory_format(tensor) - else: - return memory_format - - -@register_decomposition(aten.rand_like) -def rand_like( +def _get_shape_permutation_like( self: torch.Tensor, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - memory_format: Optional[torch.memory_format] = None, - **kwargs: Any, -) -> torch.Tensor: - return torch.rand( - [*self.size()], - dtype=dtype or self.dtype, - device=device or self.device, - **kwargs, - ).to(memory_format=get_like_layout(self, memory_format)) +) -> tuple[utils.ShapeType, utils.StrideType]: + physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self) + shape = [self.shape[l] for l in physical_layout] + permutation = [0] * len(shape) + for p, l in enumerate(physical_layout): + permutation[l] = p -@register_decomposition(aten.randn_like) -def randn_like( - self: torch.Tensor, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - memory_format: Optional[torch.memory_format] = None, - **kwargs: Any, -) -> torch.Tensor: - return torch.randn( - [*self.size()], - dtype=dtype or self.dtype, - device=device or self.device, - **kwargs, - ).to(memory_format=get_like_layout(self, memory_format)) + return (shape, permutation) @register_decomposition(aten.full_like) @@ -592,55 +560,91 @@ def full_like( requires_grad: bool = False, memory_format: torch.memory_format = torch.preserve_format, ) -> torch.Tensor: - return torch.full( - [*self.size()], - fill_value, - dtype=dtype or self.dtype, - layout=layout or self.layout, - device=device or self.device, - requires_grad=requires_grad, - ).to(memory_format=get_like_layout(self, memory_format)) + dtype = self.dtype if dtype is None else dtype + layout = self.layout if layout is None else layout + device = self.device if device is None else device + + if memory_format != torch.preserve_format: + result = torch.full( + self.shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + return result.to(memory_format=memory_format) + + else: + assert layout == torch.strided + shape, permutation = _get_shape_permutation_like(self) + result = torch.full( + shape, + fill_value, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + if permutation == list(range(len(permutation))): + return result + return result.permute(permutation).clone() -@register_decomposition(aten.randint_like.default) -def randint_like( +def _rand_like( + rand_fn: Callable[..., torch.Tensor], self: torch.Tensor, - high: int, *, dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, - memory_format: Optional[torch.memory_format] = None, + memory_format: torch.memory_format = torch.preserve_format, **kwargs: Any, ) -> torch.Tensor: - return aten.randint.low( - 0, - high, - [*self.size()], - dtype=dtype or self.dtype, - device=device or self.device, + dtype = self.dtype if dtype is None else dtype + device = self.device if device is None else device + + if memory_format != torch.preserve_format: + return rand_fn( + self.shape, + dtype=dtype, + device=device, + **kwargs, + ).to(memory_format=memory_format) + + shape, permutation = _get_shape_permutation_like(self) + result = rand_fn( + shape, + dtype=dtype, + device=device, **kwargs, - ).to(memory_format=get_like_layout(self, memory_format)) + ) + if permutation == list(range(len(permutation))): + return result + return result.permute(permutation).clone() + + +@register_decomposition(aten.rand_like) +def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: + return _rand_like(torch.rand, self, **kwargs) + + +@register_decomposition(aten.randn_like) +def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor: + return _rand_like(torch.randn, self, **kwargs) + + +@register_decomposition(aten.randint_like.default) +def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor: + return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs) @register_decomposition(aten.randint_like.low_dtype) def randint_like_low( - self: torch.Tensor, - low: int, - high: int, - *, - dtype: Optional[torch.dtype] = None, - device: Optional[torch.device] = None, - memory_format: Optional[torch.memory_format] = None, - **kwargs: Any, + self: torch.Tensor, low: int, high: int, **kwargs: Any ) -> torch.Tensor: - return aten.randint.low( - low, - high, - [*self.size()], - dtype=dtype or self.dtype, - device=device or self.device, - **kwargs, - ).to(memory_format=get_like_layout(self, memory_format)) + return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs) @register_decomposition(aten.randint.default)