Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8918,7 +8918,7 @@ def forward(self, v1: torch.Tensor):
model = Model()
x = torch.rand(10, 3, 0)

self.common(model, (x,))
self.common(model, (x,), exact_stride=True)

def test_randint(self):
@torch.compile(fullgraph=True)
Expand Down Expand Up @@ -8973,9 +8973,21 @@ def bin(index, max_size):
@config.patch(fallback_random=True)
def test_like_rands(self):
def fn(x):
return torch.rand_like(x), torch.randn_like(x)
return torch.rand_like(x), torch.randn_like(x), torch.randint_like(x, 1, 11)

self.common(fn, [torch.zeros([20, 20])])
self.common(fn, [torch.zeros([20, 20])], exact_stride=True)

@config.patch(fallback_random=True)
@xfail_if_mps # 100% are not close
def test_like_rands_sliced(self):
def fn(x):
return (
torch.randn_like(x),
torch.randn_like(x),
torch.randint_like(x, 1, 11),
)

self.common(fn, (torch.zeros([3, 4])[:, ::2].permute(1, 0),), exact_stride=True)

@config.patch(check_stack_no_cycles_TESTING_ONLY=True)
def test_check_stack_no_cycles(self):
Expand Down Expand Up @@ -9008,6 +9020,8 @@ def fn(x):
a0 = fn(x).clone()
a1 = fn(x).clone()
self.assertFalse(torch.allclose(a0, a1))
self.assertEqual(a0.shape, a1.shape)
self.assertEqual(a0.stride(), a1.stride())

@requires_gpu()
@skip_if_triton_cpu("Flaky on Triton CPU")
Expand All @@ -9025,6 +9039,8 @@ def fn(x, device):
a1 = test_like_rands_on_different_device(GPU_TYPE, "cpu")
self.assertTrue(a0.device.type == GPU_TYPE)
self.assertTrue(a1.device.type == "cpu")
self.assertEqual(a0.shape, a1.shape)
self.assertEqual(a0.stride(), a1.stride())

def test_max_pool2d_with_indices_backward(self):
def fn(a, b, c):
Expand Down
1 change: 1 addition & 0 deletions test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def run(*ex, **kwargs):
"test_bucketize_int_dynamic_shapes": TestFailure(("cpu",)),
"test_searchsorted_dynamic_shapes": TestFailure(("cpu",)),
"test_like_rands_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_like_rands_sliced_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace3_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_linspace4_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
Expand Down
154 changes: 79 additions & 75 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,49 +535,17 @@ def view_copy_dtype(
return self.to(dtype).clone()


def get_like_layout(
tensor: torch.Tensor,
memory_format: Optional[torch.memory_format] = None,
) -> torch.memory_format:
# TODO: _to_copy tensor to stride permutation
if memory_format is torch.preserve_format or memory_format is None:
return utils.suggest_memory_format(tensor)
else:
return memory_format


@register_decomposition(aten.rand_like)
def rand_like(
def _get_shape_permutation_like(
self: torch.Tensor,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
memory_format: Optional[torch.memory_format] = None,
**kwargs: Any,
) -> torch.Tensor:
return torch.rand(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
).to(memory_format=get_like_layout(self, memory_format))
) -> tuple[utils.ShapeType, utils.StrideType]:
physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(self)
shape = [self.shape[l] for l in physical_layout]

permutation = [0] * len(shape)
for p, l in enumerate(physical_layout):
permutation[l] = p

@register_decomposition(aten.randn_like)
def randn_like(
self: torch.Tensor,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
memory_format: Optional[torch.memory_format] = None,
**kwargs: Any,
) -> torch.Tensor:
return torch.randn(
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
).to(memory_format=get_like_layout(self, memory_format))
return (shape, permutation)


@register_decomposition(aten.full_like)
Expand All @@ -592,55 +560,91 @@ def full_like(
requires_grad: bool = False,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor:
return torch.full(
[*self.size()],
fill_value,
dtype=dtype or self.dtype,
layout=layout or self.layout,
device=device or self.device,
requires_grad=requires_grad,
).to(memory_format=get_like_layout(self, memory_format))
dtype = self.dtype if dtype is None else dtype
layout = self.layout if layout is None else layout
device = self.device if device is None else device

if memory_format != torch.preserve_format:
result = torch.full(
self.shape,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
return result.to(memory_format=memory_format)

else:
assert layout == torch.strided
shape, permutation = _get_shape_permutation_like(self)
result = torch.full(
shape,
fill_value,
dtype=dtype,
layout=layout,
device=device,
pin_memory=pin_memory,
requires_grad=requires_grad,
)
if permutation == list(range(len(permutation))):
return result
return result.permute(permutation).clone()


@register_decomposition(aten.randint_like.default)
def randint_like(
def _rand_like(
rand_fn: Callable[..., torch.Tensor],
self: torch.Tensor,
high: int,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
memory_format: Optional[torch.memory_format] = None,
memory_format: torch.memory_format = torch.preserve_format,
**kwargs: Any,
) -> torch.Tensor:
return aten.randint.low(
0,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
dtype = self.dtype if dtype is None else dtype
device = self.device if device is None else device

if memory_format != torch.preserve_format:
return rand_fn(
self.shape,
dtype=dtype,
device=device,
**kwargs,
).to(memory_format=memory_format)

shape, permutation = _get_shape_permutation_like(self)
result = rand_fn(
shape,
dtype=dtype,
device=device,
**kwargs,
).to(memory_format=get_like_layout(self, memory_format))
)
if permutation == list(range(len(permutation))):
return result
return result.permute(permutation).clone()


@register_decomposition(aten.rand_like)
def rand_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
return _rand_like(torch.rand, self, **kwargs)


@register_decomposition(aten.randn_like)
def randn_like(self: torch.Tensor, **kwargs: Any) -> torch.Tensor:
return _rand_like(torch.randn, self, **kwargs)


@register_decomposition(aten.randint_like.default)
def randint_like(self: torch.Tensor, high: int, **kwargs: Any) -> torch.Tensor:
return _rand_like(functools.partial(aten.randint.low, 0, high), self, **kwargs)


@register_decomposition(aten.randint_like.low_dtype)
def randint_like_low(
self: torch.Tensor,
low: int,
high: int,
*,
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
memory_format: Optional[torch.memory_format] = None,
**kwargs: Any,
self: torch.Tensor, low: int, high: int, **kwargs: Any
) -> torch.Tensor:
return aten.randint.low(
low,
high,
[*self.size()],
dtype=dtype or self.dtype,
device=device or self.device,
**kwargs,
).to(memory_format=get_like_layout(self, memory_format))
return _rand_like(functools.partial(aten.randint.low, low, high), self, **kwargs)


@register_decomposition(aten.randint.default)
Expand Down