From d27e21dd3e697aeda4ed60dab99e9af295b2775f Mon Sep 17 00:00:00 2001 From: Bin Bao Date: Wed, 15 May 2024 07:05:09 -0700 Subject: [PATCH] [AOTI] Support InplaceBernoulliFallback in the ABI-compatible codegen (#126183) Summary: Update the torchgen rule for inplace ops like bernoulli_, and update InplaceBernoulliFallback to codegen in the ABI-compatible mode. Fixes https://github.com/pytorch/pytorch/issues/121809 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126183 Approved by: https://github.com/angelayi ghstack dependencies: #126181, #126182 --- test/inductor/test_cpu_cpp_wrapper.py | 1 - test/inductor/test_cuda_cpp_wrapper.py | 1 - torch/_inductor/ir.py | 25 ++++++++++++------- torch/_inductor/lowering.py | 7 +++++- .../aoti_torch/generated/c_shim_cpu.h | 6 +++-- .../aoti_torch/generated/c_shim_cuda.h | 6 +++-- torchgen/aoti/fallback_ops.py | 2 ++ torchgen/gen_aoti_c_shim.py | 12 ++++----- 8 files changed, 38 insertions(+), 22 deletions(-) diff --git a/test/inductor/test_cpu_cpp_wrapper.py b/test/inductor/test_cpu_cpp_wrapper.py index b8fdbc49bd387..66b92eedc97c0 100644 --- a/test/inductor/test_cpu_cpp_wrapper.py +++ b/test/inductor/test_cpu_cpp_wrapper.py @@ -71,7 +71,6 @@ class DynamicShapesCppWrapperCpuTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cpu", # cpp fallback op naming issue "test_conv2d_binary_inplace_fusion_failed_cpu", "test_conv2d_binary_inplace_fusion_pass_cpu", "test_dynamic_qlinear_cpu", diff --git a/test/inductor/test_cuda_cpp_wrapper.py b/test/inductor/test_cuda_cpp_wrapper.py index 5bbe588d3a84e..5cb8af9db165a 100644 --- a/test/inductor/test_cuda_cpp_wrapper.py +++ b/test/inductor/test_cuda_cpp_wrapper.py @@ -97,7 +97,6 @@ class DynamicShapesCudaWrapperCudaTests(InductorTestCase): if config.abi_compatible: xfail_list = [ - "test_bernoulli1_cuda", # cpp fallback op naming issue "test_profiler_mark_wrapper_call_cuda", "test_scaled_dot_product_attention_cuda_dynamic_shapes", ] diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index cbf990ea0b777..7b7e8e567a0be 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -4784,9 +4784,17 @@ class InplaceBernoulliFallback(ExternKernel): def codegen(self, wrapper): (x,) = (t.codegen_reference() for t in self.inputs) - wrapper.writeline( - f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" - ) + + if V.graph.cpp_wrapper and config.abi_compatible: + # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here, + # which needs to be explicitly generated for cpp wrapper + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}" + ) + else: + wrapper.writeline( + f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}" + ) def should_allocate(self): return False @@ -4797,20 +4805,19 @@ def get_mutation_names(self): def get_unbacked_symbol_defs(self) -> Set[sympy.Symbol]: return set() - def __init__(self, x, *constant_args): + def __init__(self, op_overload, x, *constant_args): super().__init__( None, NoneLayout(x.get_device()), # type: ignore[arg-type] self.unwrap_storage([x]), constant_args, + op_overload=op_overload, ) self.name = V.graph.register_buffer(self) self.python_kernel_name = "aten.bernoulli_" - self.cpp_kernel_name = ( - "aoti_torch_bernoulli_" - if config.abi_compatible - else "at::native::bernoulli_" - ) + if not config.abi_compatible: + # TODO: this should be simplified once we switch to ABI-compatible only + self.cpp_kernel_name = "at::native::bernoulli_" mark_node_as_mutating(self, x) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 389ff16e39025..77d7b6c046dee 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1788,7 +1788,12 @@ def bernoulli_(x, *args): "cpu" ), "this should be handled in decomps unless config.fallback_random or the device is CPU" x.realize() - ir.InplaceBernoulliFallback(x, *args) + op_overload = ( + aten.bernoulli_.float + if len(args) == 0 or isinstance(args[0], float) + else aten.bernoulli_.Tensor + ) + ir.InplaceBernoulliFallback(op_overload, x, *args) return x diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h index bbd8cbc9d31a0..2c7f05dd84cd9 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cpu.h @@ -47,6 +47,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d(AtenTensorHandle self AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -105,8 +107,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_randperm(int64_t n, int32_t* dty AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cpu_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h index 2905aa810d3cb..1dceac240e40a 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_cuda.h @@ -55,6 +55,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d(AtenTensorHandle sel AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d(AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_avg_pool3d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* kernel_size, int64_t kernel_size_len_, const int64_t* stride, int64_t stride_len_, const int64_t* padding, int64_t padding_len_, int32_t ceil_mode, int32_t count_include_pad, int64_t* divisor_override, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__Tensor(AtenTensorHandle self, AtenTensorHandle p, AtenGeneratorHandle* generator); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bernoulli__float(AtenTensorHandle self, double p, AtenGeneratorHandle* generator); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bmm_out(AtenTensorHandle out, AtenTensorHandle self, AtenTensorHandle mat2); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_bucketize_Tensor(AtenTensorHandle self, AtenTensorHandle boundaries, int32_t out_int32, int32_t right, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_cat(const AtenTensorHandle* tensors, int64_t tensors_len_, int64_t dim, AtenTensorHandle* ret0); @@ -112,8 +114,8 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_randperm(int64_t n, int32_t* dt AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_repeat_interleave_Tensor(AtenTensorHandle repeats, int64_t* output_size, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad1d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_replication_pad2d_backward(AtenTensorHandle grad_output, AtenTensorHandle self, const int64_t* padding, int64_t padding_len_, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format, AtenTensorHandle* ret0); -AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* memory_format); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_resize_as_(AtenTensorHandle self, AtenTensorHandle the_template, int32_t* memory_format); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_src_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_value_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_cuda_scatter_reduce_two_out(AtenTensorHandle out, AtenTensorHandle self, int64_t dim, AtenTensorHandle index, AtenTensorHandle src, const char* reduce, int32_t include_self); diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index f77527a156beb..4a300c3cc3010 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -25,6 +25,8 @@ "aten.avg_pool2d.default", "aten.avg_pool3d_backward.default", "aten.avg_pool3d.default", + "aten.bernoulli_.float", + "aten.bernoulli_.Tensor", "aten.bmm.out", "aten.bucketize.Tensor", "aten.cat.default", diff --git a/torchgen/gen_aoti_c_shim.py b/torchgen/gen_aoti_c_shim.py index 5bc29e514a27e..f123bc879cd34 100644 --- a/torchgen/gen_aoti_c_shim.py +++ b/torchgen/gen_aoti_c_shim.py @@ -249,18 +249,18 @@ def gen_declaration_and_definition( return declaration_definition_cache[(func_name, device, backend_call)] if schema.is_out_fn(): - # out_variant has out arguments in the front, and it's ok to ignore return value + # out_variant has out arguments in the front, and it's ok to ignore return values # because C shim functions only return AOTITorchError - # Somehow at::native out-variant functions have out arguments in the back args, callsite_exprs = gen_arguments( - [*schema.arguments.flat_non_out, *schema.arguments.out] - if "at::native" in backend_call - else [*schema.arguments.out, *schema.arguments.flat_non_out], + [*schema.arguments.out, *schema.arguments.flat_non_out] ) ret_assignments: List[str] = [] else: args, callsite_exprs = gen_arguments(schema.arguments.flat_all) - ret_declarations, ret_assignments = gen_returns(schema) + # ignore return values for inplace ops + ret_declarations, ret_assignments = ( + ([], []) if schema.name.name.inplace else gen_returns(schema) + ) args.extend(ret_declarations) declaration = f"AOTITorchError aoti_torch_{device}_{func_name}({', '.join(args)})"