diff --git a/aten/src/ATen/native/cuda/Blas.cpp b/aten/src/ATen/native/cuda/Blas.cpp index 541e8b8c773f..9a36d8b29176 100644 --- a/aten/src/ATen/native/cuda/Blas.cpp +++ b/aten/src/ATen/native/cuda/Blas.cpp @@ -1048,9 +1048,11 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, TORCH_CHECK(!out_dtype || *out_dtype == out.scalar_type(), "out_dtype must match output matrix type"); TORCH_CHECK(isFloat8Type(mat1.scalar_type()), "Expected mat1 to be Float8 matrix got ", mat1.scalar_type()); TORCH_CHECK(isFloat8Type(mat2.scalar_type()), "Expected mat2 to be Float8 matrix got ", mat2.scalar_type()); + #ifndef USE_ROCM // Type restrictions imposed by CuBLASLt as of CUDA-12.1 TORCH_CHECK(mat1.scalar_type() != ScalarType::Float8_e5m2 || mat2.scalar_type() != ScalarType::Float8_e5m2, "Multiplication of two Float8_e5m2 matrices is not supported"); + #endif if (bias) { TORCH_CHECK(out.scalar_type() != kFloat, "Bias is not supported when out_dtype is set to Float32"); TORCH_CHECK(bias->scalar_type() == ScalarType::BFloat16 || bias->scalar_type() == ScalarType::Half, diff --git a/functorch/experimental/__init__.py b/functorch/experimental/__init__.py index 3941f6d96e1f..ec414d8c135b 100644 --- a/functorch/experimental/__init__.py +++ b/functorch/experimental/__init__.py @@ -1,5 +1,5 @@ # PyTorch forward-mode is not mature yet -from functorch import functionalize +from torch._functorch.deprecated import functionalize from torch._functorch.apis import chunk_vmap from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ from torch._functorch.eager_transforms import hessian, jacfwd, jvp diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index 550c0633e3f2..6d51bbc5c1c3 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -26,7 +26,10 @@ ) from torch.distributed.tensor import DTensor, init_device_mesh, Shard from torch.distributed.tensor.debug import CommDebugMode -from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, + TEST_CUDA, +) from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( check_sharded_parity, @@ -40,7 +43,9 @@ ) from torch.testing._internal.common_utils import ( get_cycles_per_ms, + NAVI_ARCH, run_tests, + skipIfRocmArch, wrapSwapTensorsTest, ) from torch.testing._internal.distributed._tensor.common_dtensor import ( @@ -93,6 +98,7 @@ def world_size(self) -> int: return 4 @unittest.skipIf(not TEST_CUDA, "no cuda") + @skipIfRocmArch(NAVI_ARCH) # Supported in future releaes def test_param_registration_after_forward(self): """Tests the parameter registration after forward.""" device = torch.device("cuda", 0) @@ -199,6 +205,7 @@ def world_size(self) -> int: @unittest.skipIf(not TEST_CUDA, "no cuda") @wrapSwapTensorsTest(True) + @skipIfRocmArch(NAVI_ARCH) # Supported in future releaes def test_to_float64_after_init(self): """Tests that the user can cast the module to float64 after init.""" # NOTE: Test fp64 instead of a lower precision dtype like bf16 for @@ -317,6 +324,9 @@ def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]: @skip_if_lt_x_gpu(2) @test_compiled_fsdp(compile_compute_on_module=Transformer) + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA" + ) def test_train_parity_multi_group(self): """ Tests train parity against DDP when using multiple parameter groups for diff --git a/test/distributed/_composable/test_replicate_with_compiler.py b/test/distributed/_composable/test_replicate_with_compiler.py index 8690bef6cc26..5e6d7e3331b9 100644 --- a/test/distributed/_composable/test_replicate_with_compiler.py +++ b/test/distributed/_composable/test_replicate_with_compiler.py @@ -330,9 +330,9 @@ def test_bucketing_coalesced_op(self): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) fc = FileCheck() for i in range(3): - fc.check("cpp_fused_").check( - "torch.ops._c10d_functional.all_reduce_coalesced_.default(" - ) + fc.check("cpp_fused_") + for i in range(3): + fc.check("torch.ops._c10d_functional.all_reduce_coalesced_.default(") for i in range(3): fc.check("torch.ops._c10d_functional.wait_tensor.default") @@ -343,9 +343,9 @@ def test_bucketing_coalesced_op(self): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) fc = FileCheck() for i in range(3): - fc.check("cpp_fused_").check( - "torch.ops._c10d_functional.all_reduce_coalesced_.default(" - ) + fc.check("cpp_fused_") + for i in range(3): + fc.check("torch.ops._c10d_functional.all_reduce_coalesced_.default(") for i in range(3): fc.check("torch.ops._c10d_functional.wait_tensor.default") @@ -372,9 +372,9 @@ def test_bucketing_concat_op(self): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) fc = FileCheck() for i in range(3): - fc.check("aten.flatten.using_ints(").check("cpp_fused_").check( - "torch.ops._c10d_functional.all_reduce_.default(" - ) + fc.check("aten.flatten.using_ints(").check("cpp_fused_") + for i in range(3): + fc.check("torch.ops._c10d_functional.all_reduce_.default(") for i in range(3): fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.run(code) @@ -384,9 +384,9 @@ def test_bucketing_concat_op(self): self.assertEqual(counters["inductor"]["ddp_buckets"], 3) fc = FileCheck() for i in range(3): - fc.check("aten.flatten.using_ints(").check("cpp_fused_").check( - "torch.ops._c10d_functional.all_reduce_.default(" - ) + fc.check("aten.flatten.using_ints(").check("cpp_fused_") + for i in range(3): + fc.check("torch.ops._c10d_functional.all_reduce_.default(") for i in range(3): fc.check("torch.ops._c10d_functional.wait_tensor.default") fc.run(code) diff --git a/test/distributed/_tools/test_sac_ilp.py b/test/distributed/_tools/test_sac_ilp.py index 62c6afe76829..abcbec6f59c7 100644 --- a/test/distributed/_tools/test_sac_ilp.py +++ b/test/distributed/_tools/test_sac_ilp.py @@ -18,7 +18,7 @@ get_optimal_checkpointing_policy_per_module, sac_milp, ) -from torch.testing._internal.common_cuda import TEST_CUDA +from torch.testing._internal.common_cuda import TEST_CUDA, PLATFORM_SUPPORTS_FLASH_ATTENTION from torch.testing._internal.common_utils import ( run_tests, skipIfTorchDynamo, @@ -181,7 +181,7 @@ def test_sac_ilp_case1(self): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/115653") @unittest.skipIf(not TEST_CUDA, "CUDA not available") - @skipIfRocmArch(NAVI_ARCH) + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA") def test_sac_ilp_case2(self): """ This is a case where the memory budget is not binding, meaning that no diff --git a/test/distributed/elastic/test_control_plane.py b/test/distributed/elastic/test_control_plane.py index ede4e352b045..7c72e222c6d8 100644 --- a/test/distributed/elastic/test_control_plane.py +++ b/test/distributed/elastic/test_control_plane.py @@ -16,7 +16,12 @@ TORCH_WORKER_SERVER_SOCKET, worker_main, ) -from torch.testing._internal.common_utils import requires_cuda, run_tests, TestCase +from torch.testing._internal.common_utils import ( + requires_cuda, + run_tests, + skipIfRocm, + TestCase, +) class UnixHTTPConnection(HTTPConnection): @@ -152,6 +157,7 @@ def test_dump_nccl_trace_pickle_with_json(self) -> None: ) self.assertEqual(resp.status, 200) + @skipIfRocm # skipped upstream too def test_tcp(self) -> None: import requests diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index a95a35c95c4c..6705d65c976c 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -35,8 +35,11 @@ TransformerWithSharedParams, ) from torch.testing._internal.common_utils import ( + instantiate_parametrized_tests, + NAVI_ARCH, parametrize, run_tests, + skipIfRocmArch, TEST_HPU, TEST_WITH_DEV_DBG_ASAN, ) @@ -160,6 +163,7 @@ def test_nested_always_wrap_model( @skip_if_lt_x_gpu(2) @parametrize(params, configs, subtest_name) + @skipIfRocmArch(NAVI_ARCH) # Supported in future releases def test_transformer( self, cpu_offload: CPUOffload, diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py index 9398f7901da4..15668b7982ff 100644 --- a/test/distributed/fsdp/test_fsdp_hybrid_shard.py +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -6,6 +6,7 @@ from enum import auto, Enum from functools import partial from typing import List, Optional, Tuple +import unittest import torch import torch.distributed as dist @@ -31,6 +32,9 @@ FSDPTest, TransformerWithSharedParams, ) +from torch.testing._internal.common_cuda import ( + PLATFORM_SUPPORTS_FLASH_ATTENTION, +) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, run_tests, @@ -227,6 +231,7 @@ def test_invalid_pg_specification_raises(self): # resharded after forward. @skip_if_lt_x_gpu(2) + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Does not support flash attention") def test_fsdp_hybrid_shard_basic_setup(self): """ Tests basic functionality of HYBRID_SHARD and _HYBRID_SHARD_ZERO2: diff --git a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py index 0797eb9e0f0a..af660cd76b2d 100644 --- a/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py +++ b/test/distributed/fsdp/test_fsdp_sharded_grad_scaler.py @@ -19,6 +19,7 @@ from torch.distributed.fsdp.wrap import ModuleWrapPolicy from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer from torch.nn.parallel.distributed import DistributedDataParallel as DDP +from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_MEM_EFF_ATTENTION from torch.testing._internal.common_distributed import skip_if_lt_x_gpu from torch.testing._internal.common_fsdp import ( DEVICEInitMode, @@ -236,6 +237,9 @@ def _build_model_and_optim( return model, optim, ref_model, ref_optim @skip_if_lt_x_gpu(2) + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Platform does not support fused SDPA" + ) def test_sharded_grad_scaler_found_inf(self): self.run_subtests( { diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 67edb211b9f1..175bdebef0c0 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -923,6 +923,8 @@ def closure_sharded(input_tensor=input_tensor): torch.testing.assert_close( loss_ddp, loss_sharded_optim, + atol=1.6e-3, + rtol=3e-6, msg="Losses differ between local optimizer and ZeRO", ) self._check_same_model_params( diff --git a/test/distributed/tensor/parallel/test_tp_examples.py b/test/distributed/tensor/parallel/test_tp_examples.py index 81bc52278a73..eab1e53fdfb7 100644 --- a/test/distributed/tensor/parallel/test_tp_examples.py +++ b/test/distributed/tensor/parallel/test_tp_examples.py @@ -43,6 +43,7 @@ Transformer, with_comms, ) +from unittest import skipIf c10d_functional = torch.ops.c10d_functional @@ -414,6 +415,7 @@ def test_transformer_training(self, is_seq_parallel, dtype: torch.dtype): + f"{str(dtype).split('.')[-1]}_" + f"thaw_{'__'.join(sorted({n.rpartition('.')[0].replace('.', '_') for n in thaw})) if thaw else 'all'}", ) + @skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Does not support fused scaled dot product attention") def test_transformer_req_grad(self, thaw_params, is_seq_parallel, dtype, exp_cnts): # Sample a subset of `requires_grad` patterns diff --git a/test/distributed/test_compute_comm_reordering.py b/test/distributed/test_compute_comm_reordering.py index a2780b55c203..2481fc5f8bf6 100644 --- a/test/distributed/test_compute_comm_reordering.py +++ b/test/distributed/test_compute_comm_reordering.py @@ -122,8 +122,8 @@ def func(a): # above the 2nd matmul. ( FileCheck() - .check("torch.ops._c10d_functional.all_reduce_.default") .check("extern_kernels.mm") + .check("torch.ops._c10d_functional.all_reduce_.default") .check("torch.ops._c10d_functional.wait_tensor.default") .check("extern_kernels.mm") .run(code) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 92a2fd6ee2cf..cd34a91746ee 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -659,10 +659,10 @@ def func(inp, *, tag, ranks, group_size): FileCheck() .check("buf0 = empty_strided") .check(".run(arg0_1, buf0") - .check("torch.ops._c10d_functional.all_reduce_.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("buf5 = empty_strided") .check(".run(buf5, 16") + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("return (buf0, buf5") .run(code) ) @@ -697,10 +697,10 @@ def func(inp, *, tag, ranks, group_size): .check("buf0 = empty_strided") .check("buf5 = empty_strided") .check(".run(arg0_1, buf0, buf5, 16") - .check("torch.ops._c10d_functional.all_reduce_.default(buf0") - .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("buf6 = empty_strided") .check(".run(buf6, 16") + .check("torch.ops._c10d_functional.all_reduce_.default(buf0") + .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("return (buf0, buf5, buf6") .run(code) ) @@ -1153,9 +1153,8 @@ def func(inp, *, tag, ranks, group_size): ) .check("buf2 = buf1[0]") .check("buf3 = buf1[1]") - .check("torch.ops._c10d_functional.wait_tensor.default(buf2") - .check("buf7 = buf0; del buf0 # reuse") .check(".run(buf7, 16") + .check("torch.ops._c10d_functional.wait_tensor.default(buf2") .check("torch.ops._c10d_functional.wait_tensor.default(buf3") .check("return (buf2, buf6, buf7, buf3") .run(code) @@ -1199,9 +1198,8 @@ def func(inp, *, tag, ranks, group_size): ) .check("buf2 = buf1[0]") .check("buf3 = buf1[1]") - .check("torch.ops._c10d_functional.wait_tensor.default(buf2") - .check("buf7 = buf0; del buf0 # reuse") .check(".run(buf7, 16") + .check("torch.ops._c10d_functional.wait_tensor.default(buf2") .check("torch.ops._c10d_functional.wait_tensor.default(buf3") .check("return (buf2, buf6, buf7, buf3") .run(code) diff --git a/test/dynamo/test_graph_deduplication.py b/test/dynamo/test_graph_deduplication.py index 544dea240219..42118fbc769b 100644 --- a/test/dynamo/test_graph_deduplication.py +++ b/test/dynamo/test_graph_deduplication.py @@ -57,18 +57,15 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): subgraph_0 = self.subgraph_0 l_x_ = L_x_ l_y_ = L_y_ - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_y_, l_x_)); invoke_subgraph = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); invoke_subgraph = None o1: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(o1, l_x_)); o1 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, o1)); o1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None @@ -78,13 +75,13 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): return (mul_1,) class subgraph_0(torch.nn.Module): - def forward(self, subgraph_input_l_y_, subgraph_input_l_x_): - y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None - + def forward(self, subgraph_input_l_x_, subgraph_input_l_y_): x0: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None - sum_2: "f32[]" = y0.sum(); y0 = None + y0: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None + sum_1: "f32[]" = x0.sum(); x0 = None + sum_2: "f32[]" = y0.sum(); y0 = None z: "f32[]" = sum_1 + sum_2; sum_1 = sum_2 = None return (z,) """, @@ -98,12 +95,10 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2) repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph0_1 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, '___forward_subgraph_0', (primals_1, sin)); repeated_subgraph0_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None repeated_subgraph0_2 = self.repeated_subgraph0 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, \ -'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph0_2 = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_2, '___forward_subgraph_0', (primals_1, primals_2)); repeated_subgraph0_2 = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None mul: "f32[]" = torch.ops.aten.mul.Tensor(getitem_2, getitem_2) @@ -112,12 +107,12 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): return (mul_1, primals_1, primals_2, sin, getitem_1, getitem_2) class repeated_subgraph0(torch.nn.Module): - def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"): - add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None - add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None + def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None - add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None + add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_2,) """, ) @@ -267,27 +262,22 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): y0: "f32[10, 20]" = torch.sin(l_y_) - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \ -'subgraph_1', (y0, x0)); invoke_subgraph_3 = None - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (l_y_, l_x_)) + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); invoke_subgraph_3 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)) getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None o1: "f32[]" = torch.sin(getitem); getitem = None - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, \ -'subgraph_0', (y0, l_x_)) + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, y0)) getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, \ -'subgraph_1', (y0, x0)); subgraph_1 = y0 = x0 = None + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(subgraph_1, 'subgraph_1', (x0, y0)); subgraph_1 = x0 = y0 = None getitem_4: "f32[10, 10]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', \ -(l_y_, l_x_)); subgraph_0 = l_y_ = l_x_ = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(subgraph_0, 'subgraph_0', (l_x_, l_y_)); subgraph_0 = l_x_ = l_y_ = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None @@ -297,25 +287,24 @@ def forward(self, L_x_: "f32[10, 10]", L_y_: "f32[10, 20]"): return (add_13,) class subgraph_1(torch.nn.Module): - def forward(self, subgraph_input_y0, subgraph_input_x0): + def forward(self, subgraph_input_x0, subgraph_input_y0): + a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None + b0: "f32[10, 20]" = subgraph_input_y0 + 3; subgraph_input_y0 = None cos_1: "f32[10, 20]" = b0.cos(); b0 = None sum_1: "f32[]" = cos_1.sum(); cos_1 = None - - a0: "f32[10, 10]" = subgraph_input_x0 + 2; subgraph_input_x0 = None - c: "f32[10, 10]" = a0 * sum_1; a0 = sum_1 = None return (c,) class subgraph_0(torch.nn.Module): - def forward(self, subgraph_input_l_y_, subgraph_input_l_x_): - y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None - + def forward(self, subgraph_input_l_x_, subgraph_input_l_y_): x1: "f32[10, 10]" = subgraph_input_l_x_ + 1; subgraph_input_l_x_ = None - sum_3: "f32[]" = y1.sum(); y1 = None + y1: "f32[10, 20]" = subgraph_input_l_y_ + 2; subgraph_input_l_y_ = None + sum_2: "f32[]" = x1.sum(); x1 = None + sum_3: "f32[]" = y1.sum(); y1 = None z: "f32[]" = sum_2 + sum_3; sum_2 = sum_3 = None return (z,) """, @@ -330,23 +319,19 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): sin: "f32[10, 20]" = torch.ops.aten.sin.default(primals_2) repeated_subgraph1 = self.repeated_subgraph1 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, \ -'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1, '___forward_subgraph_0', (primals_1, primals_2)); repeated_subgraph1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None sin_1: "f32[]" = torch.ops.aten.sin.default(getitem_1) repeated_subgraph1_1 = self.repeated_subgraph1 - invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_1, \ -'___forward_subgraph_0', (sin, primals_1)); repeated_subgraph1_1 = None + invoke_subgraph_2 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_1, '___forward_subgraph_0', (primals_1, sin)); repeated_subgraph1_1 = None getitem_2: "f32[]" = invoke_subgraph_2[0]; invoke_subgraph_2 = None repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'___forward_subgraph_1', (sin, cos)); repeated_subgraph0_1 = None + invoke_subgraph_3 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, '___forward_subgraph_1', (cos, sin)); repeated_subgraph0_1 = None getitem_3: "f32[10, 10]" = invoke_subgraph_3[0]; invoke_subgraph_3 = None repeated_subgraph1_2 = self.repeated_subgraph1 - invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_2, \ -'___forward_subgraph_0', (primals_2, primals_1)); repeated_subgraph1_2 = None + invoke_subgraph_4 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph1_2, '___forward_subgraph_0', (primals_1, primals_2)); repeated_subgraph1_2 = None getitem_4: "f32[]" = invoke_subgraph_4[0]; invoke_subgraph_4 = None mul: "f32[]" = torch.ops.aten.mul.Tensor(sin_1, getitem_2); sin_1 = None @@ -355,21 +340,21 @@ def forward(self, primals_1: "f32[10, 10]", primals_2: "f32[10, 20]"): return (add, primals_1, primals_2, cos, sin, getitem_1, getitem_2, getitem_3) class repeated_subgraph1(torch.nn.Module): - def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"): - add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None - add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 1); arg1_1 = None + def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg0_1, 1); arg0_1 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None - add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None + add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add_2,) class repeated_subgraph0(torch.nn.Module): - def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"): - add: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg0_1, 3); arg0_1 = None - cos: "f32[10, 20]" = torch.ops.aten.cos.default(add); add = None + def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): + add: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg0_1, 2); arg0_1 = None + add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, 3); arg1_1 = None + cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None - add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(arg1_1, 2); arg1_1 = None - mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add_1, sum_1); add_1 = sum_1 = None + mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = sum_1 = None return (mul,) """, ) @@ -482,8 +467,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_3: "f32[10, 20]" = torch.ops.aten.add.Tensor(arg1_1, add_1); add_1 = None repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (add_3, add_2)); repeated_subgraph0 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (add_2, add_3)); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None clone: "f32[10, 10]" = torch.ops.aten.clone.default(add_2); add_2 = None @@ -498,8 +482,7 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): add_7: "f32[10, 20]" = torch.ops.aten.add.Tensor(clone_1, add_5); clone_1 = add_5 = None repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (add_7, add_6)); repeated_subgraph0_1 = add_7 = add_6 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (add_6, add_7)); repeated_subgraph0_1 = add_6 = add_7 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None add_8: "f32[]" = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None @@ -508,10 +491,10 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): return (add_8,) class repeated_subgraph0(torch.nn.Module): - def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"): + def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1); arg0_1 = None sum_2: "f32[]" = torch.ops.aten.sum.default(arg1_1); arg1_1 = None - add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None + add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) @@ -558,12 +541,10 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): view_3: "f32[10, 10]" = torch.ops.aten.view.default(view_2, [10, 10]); view_2 = None repeated_subgraph0 = self.repeated_subgraph0 - invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, \ -'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0 = None + invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0 = None getitem: "f32[]" = invoke_subgraph[0]; invoke_subgraph = None repeated_subgraph0_1 = self.repeated_subgraph0 - invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, \ -'subgraph_0', (arg1_1, arg0_1)); repeated_subgraph0_1 = arg1_1 = arg0_1 = None + invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, 'subgraph_0', (arg0_1, arg1_1)); repeated_subgraph0_1 = arg0_1 = arg1_1 = None getitem_1: "f32[]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None add: "f32[10, 10]" = torch.ops.aten.add.Tensor(view_1, view_3); view_1 = view_3 = None @@ -574,12 +555,12 @@ def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): return (add_2,) class repeated_subgraph0(torch.nn.Module): - def forward(self, arg0_1: "f32[10, 20]", arg1_1: "f32[10, 10]"): - mul: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None - mul_1: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None + def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"): + mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None + mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None - add: "f32[]" = torch.ops.aten.add.Tensor(sum_2, sum_1); sum_2 = sum_1 = None + add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None return (add,) """, ) diff --git a/test/dynamo/test_graph_region_tracker.py b/test/dynamo/test_graph_region_tracker.py index 04962bc4f8f8..1b003d040f50 100644 --- a/test/dynamo/test_graph_region_tracker.py +++ b/test/dynamo/test_graph_region_tracker.py @@ -70,8 +70,7 @@ def fn(x, y): torch.rand(10, 10), torch.ones(10, 20), ), - """[[['y0', 'x0', 'sum_2', 'sum_1', 'z'], \ -['y0_1', 'x0_1', 'sum_4', 'sum_3', 'z_1'], ['y0_2', 'x0_2', 'sum_6', 'sum_5', 'z_2']]]""", + """[[['x0', 'y0', 'sum_1', 'sum_2', 'z'], ['x0_1', 'y0_1', 'sum_3', 'sum_4', 'z_1'], ['x0_2', 'y0_2', 'sum_5', 'sum_6', 'z_2']]]""", ) def test_get_regions_multiple_region_groups(self): @@ -104,8 +103,7 @@ def fn(x, y): torch.rand(10, 10), torch.ones(10, 20), ), - """[[['y1', 'x1', 'sum_3', 'sum_2', 'z'], ['y1_1', 'x1_1', 'sum_5', 'sum_4', 'z_1'], \ -['y1_2', 'x1_2', 'sum_8', 'sum_7', 'z_2']], [['b', 'cos_1', 'sum_1', 'a', 'c'], ['b_1', 'cos_2', 'sum_6', 'a_1', 'c_1']]]""", + """[[['x1', 'y1', 'sum_2', 'sum_3', 'z'], ['x1_1', 'y1_1', 'sum_4', 'sum_5', 'z_1'], ['x1_2', 'y1_2', 'sum_7', 'sum_8', 'z_2']], [['a', 'b', 'cos_1', 'sum_1', 'c'], ['a_1', 'b_1', 'cos_2', 'sum_6', 'c_1']]]""", ) def test_no_single_node_regions(self): @@ -177,8 +175,7 @@ def fn(x, y): torch.rand(10, 10), torch.ones(10, 20), ), - """[[['y1', 'sum_1', 'x1', 'o0'], ['y1_1', 'sum_2', 'x1_1', 'o2'], \ -['y1_2', 'sum_3', 'x1_2', 'o4'], ['y1_3', 'sum_4', 'x1_3', 'o5']]]""", + """[[['x1', 'y1', 'sum_1', 'o0'], ['x1_1', 'y1_1', 'sum_2', 'o2'], ['x1_2', 'y1_2', 'sum_3', 'o4'], ['x1_3', 'y1_3', 'sum_4', 'o5']]]""", ) def test_nested_args(self): diff --git a/test/inductor/test_cooperative_reductions.py b/test/inductor/test_cooperative_reductions.py index 26e90136ed24..900f41b140ec 100644 --- a/test/inductor/test_cooperative_reductions.py +++ b/test/inductor/test_cooperative_reductions.py @@ -12,6 +12,7 @@ from torch._inductor.codegen.triton import FixedTritonConfig, TritonKernel from torch._inductor.test_case import TestCase from torch._inductor.utils import run_and_get_code +from torch.testing import assert_close from torch.testing._internal.common_cuda import IS_SM89 from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, @@ -33,19 +34,99 @@ def setUp(self): torch._inductor.metrics.generated_kernel_count = 0 torch._dynamo.reset() - def run_and_check(self, fn, args, *, expect_kernel_count=1): - args_cpu = [tensor.cpu().to(torch.float32) for tensor in args] - expected = fn(*args_cpu).to(torch.float16) - fn = torch.compile(fn, fullgraph=True) - result, (source_code,) = run_and_get_code(fn, *args) - self.assertEqual(result, expected) - self.assertIn("@triton_heuristics.cooperative_reduction", source_code) + def run_and_check(self, fn, args, dtype=None, *, expect_kernel_count=1): + # Define fixed tolerances + RTOL = 1e-5 + ATOL = 1e-6 + + # calculate reference value in higher precision when input dtype is float16 + ref_dtype = dtype + if dtype == torch.float16: + ref_dtype = torch.float64 + + # Cast to the determined reference dtype + args_ref = [tensor.to(ref_dtype) for tensor in args] + + # Calculate expected output + raw_expected = fn(*args_ref) + + if isinstance(raw_expected, (tuple, list)): + # If it's a tuple or list, apply .to(dtype) to each tensor within it + # Also, handle cases where dtype might not be provided (e.g., for bool reductions) + if dtype is not None: + expected = type(raw_expected)( + [ + t.to(dtype) if isinstance(t, torch.Tensor) else t + for t in raw_expected + ] + ) + else: + expected = type(raw_expected)( + [ + t.to(torch.float64) if isinstance(t, torch.Tensor) else t + for t in raw_expected + ] + ) + else: + # If it's a single tensor + if dtype is not None: + expected = raw_expected.to(dtype) + else: + expected = raw_expected.to(torch.float64) + + fn_compiled = torch.compile(fn, fullgraph=True) + result, (source_code,) = run_and_get_code(fn_compiled, *args) + + # For comparison, ensure result is also a tuple/list if expected is + if isinstance(expected, (tuple, list)): + if isinstance(result, torch.Tensor): + result = (result,) + elif not isinstance(result, type(expected)): + result = type(expected)(result) + + if dtype is not None: + result = type(result)( + [t.to(dtype) if isinstance(t, torch.Tensor) else t for t in result] + ) + else: + result = type(result)( + [ + t.to(torch.float64) if isinstance(t, torch.Tensor) else t + for t in result + ] + ) + else: + if dtype is not None and isinstance(result, torch.Tensor): + result = result.to(dtype) + elif isinstance(result, torch.Tensor): + result = result.to(torch.float64) + + # Apply assert_close with fixed tolerances for tensor comparisons + if isinstance(result, torch.Tensor) and isinstance(expected, torch.Tensor): + assert_close(result, expected, rtol=RTOL, atol=ATOL) + elif isinstance(result, (tuple, list)) and isinstance(expected, (tuple, list)): + # Iterate through elements for comparison + for r_item, e_item in zip(result, expected): + if isinstance(r_item, torch.Tensor) and isinstance( + e_item, torch.Tensor + ): + assert_close(r_item, e_item, rtol=RTOL, atol=ATOL) + else: + # Fallback to assertEqual for non-tensor elements (e.g., bool, int) + self.assertEqual(r_item, e_item) + else: + # Fallback to assertEqual for other types not handled by assert_close + self.assertEqual(result, expected) + + if "@triton_heuristics.fixed_config" in source_code: + self.assertIn("cooperative_reduction_grid", source_code) + else: + self.assertIn("@triton_heuristics.cooperative_reduction", source_code) if "async_compile.multi_kernel" not in source_code: self.assertEqual( torch._inductor.metrics.generated_kernel_count, expect_kernel_count ) return source_code - @parametrize( "name", [ diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index b779623cfa45..604ce6ad8bc3 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -34,6 +34,7 @@ IS_FBCODE, skipIfRocm, TEST_WITH_ASAN, + xfailIfPy312Plus, ) @@ -1568,6 +1569,7 @@ def get_input() -> torch.Tensor: self.assertEqual(result, a + b) self.assertIn("znumel", code) + @xfailIfPy312Plus # https://github.com/pytorch/pytorch/issues/142032 def test_repeated_masked_load(self): target_size = (8, 2) mem_eff_temporal_upsampling_interp_chunks = 2 diff --git a/test/inductor/test_kernel_benchmark.py b/test/inductor/test_kernel_benchmark.py index 187fe4bfdd2d..63f01d40cd6e 100644 --- a/test/inductor/test_kernel_benchmark.py +++ b/test/inductor/test_kernel_benchmark.py @@ -15,6 +15,7 @@ from torch.testing import FileCheck from torch.testing._internal.common_cuda import xfailIfSM89 from torch.testing._internal.common_device_type import expectedFailureXPU +from torch.testing._internal.common_utils import skipIfRocm from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU diff --git a/test/test_license.py b/test/test_license.py index e7eb0459d3aa..516cb78f1202 100644 --- a/test/test_license.py +++ b/test/test_license.py @@ -45,7 +45,11 @@ def test_distinfo_license(self): 'Found too many "torch-*dist-info" directories ' f'in "{site_packages}, expected only one' ) - with open(os.path.join(os.path.join(distinfo[0], "LICENSE"))) as fid: + # setuptools renamed *dist-info/LICENSE to *dist-info/licenses/LICENSE sicne 77.0 + license_file = os.path.join(distinfo[0], "licenses", "LICENSE") + if not os.path.exists(license_file): + license_file = os.path.join(distinfo[0], "LICENSE") + with open(license_file) as fid: txt = fid.read() self.assertTrue(starting_txt in txt) diff --git a/test/test_linalg.py b/test/test_linalg.py index 81c58b623b23..13618e4e6bda 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -19,6 +19,7 @@ TEST_WITH_ROCM, IS_FBCODE, IS_REMOTE_GPU, iter_indices, make_fullrank_matrices_with_distinct_singular_values, freeze_rng_state, IS_ARM64, IS_SANDCASTLE, TEST_OPT_EINSUM, parametrize, skipIfTorchDynamo, + skipIfRocmArch, NAVI4_ARCH, setBlasBackendsToDefaultFinally, setLinalgBackendsToDefaultFinally, serialTest) from torch.testing._internal.common_device_type import \ (instantiate_device_type_tests, dtypes, has_cusolver, has_hipsolver, @@ -6440,6 +6441,7 @@ def test_baddbmm_input_dtypes_compatibility(self, device, dtype): @unittest.skipIf(IS_FBCODE and IS_REMOTE_GPU, "cublas runtime error") @onlyCUDA + @skipIfRocmArch(NAVI4_ARCH) def test_matmul_45724(self, device): # https://github.com/pytorch/pytorch/issues/45724 a = torch.rand(65537, 22, 64, device=device, dtype=torch.half) diff --git a/test/test_matmul_cuda.py b/test/test_matmul_cuda.py index 9da63650f394..1c031561b1b6 100644 --- a/test/test_matmul_cuda.py +++ b/test/test_matmul_cuda.py @@ -1,5 +1,6 @@ # Owner(s): ["module: linear algebra"] +import contextlib import unittest from itertools import product from functools import partial @@ -351,14 +352,15 @@ def _test_tautological_mm(self, device: str = "cuda", @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) def test_float8_basics(self, device) -> None: self._test_tautological_mm(device, e4m3_type, e4m3_type, size=16) - # hipblaslt does not yet support mixed e4m3_type input - if torch.version.hip is None: - self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) - self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) # According to https://docs.nvidia.com/cuda/cublas/#id99 8F_E5M2 MM is unsupported - with self.assertRaises(RuntimeError): + # supported on ROCm but fails on CUDA + ctx = self.assertRaises(RuntimeError) if torch.version.hip is None else contextlib.nullcontext() + with ctx: self._test_tautological_mm(device, e5m2_type, e5m2_type) + self._test_tautological_mm(device, e4m3_type, e5m2_type, size=32) + self._test_tautological_mm(device, e5m2_type, e4m3_type, size=48) + self._test_tautological_mm(device, size=64, out_dtype=torch.float16) self._test_tautological_mm(device, size=96, out_dtype=torch.float32) self._test_tautological_mm(device, size=80, out_dtype=torch.bfloat16)