Skip to content

Commit

Permalink
Revert "[AOTI][refactor][3/n] Declare python_kernel_name and cpp_kern…
Browse files Browse the repository at this point in the history
…el_name in ExternKernel (pytorch#115831)"

This reverts commit 287a865.

Reverted pytorch#115831 on behalf of https://github.com/desertfire due to rocm CI failure ([comment](pytorch#115831 (comment)))
  • Loading branch information
pytorchmergebot authored and ZhiweiYan-96 committed Dec 22, 2023
1 parent d000cf2 commit 0942718
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 49 deletions.
4 changes: 0 additions & 4 deletions test/inductor/test_cpp_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,6 @@ class BaseTest(NamedTuple):

for item in [
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_bernoulli1"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),
Expand Down Expand Up @@ -231,7 +230,6 @@ class BaseTest(NamedTuple):
BaseTest("test_custom_op"),
BaseTest("test_dtype_sympy_expr"),
BaseTest("test_embedding_bag"), # test default FallbackKernel
BaseTest("test_index_put1"),
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_int_div", "", test_cpu_repro.CPUReproTests()),
Expand Down Expand Up @@ -378,7 +376,6 @@ class BaseTest(NamedTuple):
for item in [
BaseTest("test_as_strided"), # buffer reuse
BaseTest("test_batch_norm_2d_2"),
BaseTest("test_bernoulli1"),
BaseTest("test_bitwise"), # int32
BaseTest("test_bmm1"),
BaseTest("test_bmm2"),
Expand All @@ -387,7 +384,6 @@ class BaseTest(NamedTuple):
BaseTest("test_conv_backward"),
BaseTest("test_custom_op"),
BaseTest("test_embedding_bag"), # test default FallbackKernel
BaseTest("test_index_put1"),
BaseTest("test_index_put_deterministic_fallback"),
BaseTest("test_adding_tensor_offsets"),
BaseTest("test_index_tensor"),
Expand Down
12 changes: 5 additions & 7 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def generate_fallback_kernel(self, fallback_kernel, args):
def generate_extern_kernel_alloc(self, extern_kernel, args):
output_name = extern_kernel.get_name()
origin_node = extern_kernel.get_origin_node()
kernel_name = extern_kernel.get_kernel_name()
kernel_name = extern_kernel.codegen_kernel_name()
ending = self.ending
if config.memory_planning and "view_as_complex" in kernel_name:
# view operation fallbacks cause issues since inductor
Expand Down Expand Up @@ -571,7 +571,7 @@ def generate_user_defined_triton_kernel(self, kernel_name, grid, configs, args):
)

def generate_scatter_fallback(
self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
self, output, inputs, kernel, fn, src_is_tensor, reduce, kwargs
):
line = f"{kernel}({','.join(map(str, inputs))}"
if kernel == "aten.scatter_":
Expand Down Expand Up @@ -1423,8 +1423,6 @@ def write_header(self):
#include <torch/csrc/inductor/aoti_torch/tensor_converter.h>
#include <torch/csrc/inductor/inductor_ops.h>
#include <torch/types.h>
#include <ATen/ops/bernoulli_native.h>
#define reinterpret_tensor torch::inductor::_reinterpret_tensor
#define alloc_from_pool torch::inductor::_alloc_from_pool
"""
Expand Down Expand Up @@ -2022,7 +2020,7 @@ def generate_c_shim_extern_kernel_alloc(self, extern_kernel, args):
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_arg = f"&{output_handle_name}"
self.generate_c_shim_extern_kernel_call(
extern_kernel.get_kernel_name(), args + [output_arg]
extern_kernel.codegen_kernel_name(), args + [output_arg]
)
self.writeline(f"RAIIAtenTensorHandle {name}({output_handle_name});")

Expand Down Expand Up @@ -2112,14 +2110,14 @@ def generate_user_defined_triton_kernel(self, kernel_name, grid, configs, args):
)

def generate_scatter_fallback(
self, output, inputs, kernel, python_kernel_name, src_is_tensor, reduce, kwargs
self, output, inputs, kernel, fn, src_is_tensor, reduce, kwargs
):
# TODO: support other overload for cpp wrapper and remove the below assertions
if V.graph.aot_mode and config.aot_inductor.abi_compatible:
# call the ABI shim function instead of the ATen one
kernel = kernel.replace("at::", "aoti_torch_")
line = f"{kernel}({output}, {','.join(map(str, inputs))}"
if python_kernel_name == "aten.scatter_":
if fn == "aten.scatter_":
if src_is_tensor:
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
Expand Down
76 changes: 38 additions & 38 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3511,8 +3511,6 @@ class ExternKernel(InputsKernel):
constant_args: Tuple[Any, ...] = ()
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None
python_kernel_name: Optional[str] = None
cpp_kernel_name: Optional[str] = None
ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
default_factory=list
)
Expand All @@ -3530,9 +3528,6 @@ def codegen_comment(self, wrapper):
def codegen(self, wrapper):
raise NotImplementedError()

def get_kernel_name(self):
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name

@staticmethod
def copy_input(x):
pw = Pointwise.create(
Expand Down Expand Up @@ -3898,14 +3893,16 @@ def __str__(self):

@dataclasses.dataclass
class ExternKernelOut(ExternKernel):
output_view: Optional[ReinterpretView] = None

def codegen(self, wrapper):
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
wrapper.generate_extern_kernel_out(
self.output_view,
self.codegen_reference(),
args,
self.get_kernel_name(),
self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name,
)

def __init__(
Expand Down Expand Up @@ -3949,6 +3946,9 @@ def __init__(self, count: int, device: torch.device):


class ExternKernelAlloc(ExternKernel):
def codegen_kernel_name(self):
return self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name

def codegen(self, wrapper):
self.codegen_comment(wrapper)
args = [*self.codegen_args(), *self.codegen_kwargs()]
Expand Down Expand Up @@ -4113,7 +4113,7 @@ class InplaceBernoulliFallback(ExternKernel):
def codegen(self, wrapper):
(x,) = (t.codegen_reference() for t in self.inputs)
wrapper.writeline(
f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
f"{self.python_kernel_name}({x}, {', '.join(map(repr, self.constant_args))})"
)

def should_allocate(self):
Expand All @@ -4134,7 +4134,6 @@ def __init__(self, x, *constant_args):
)
self.name = V.graph.register_buffer(self)
self.python_kernel_name = "aten.bernoulli_"
self.cpp_kernel_name = "at::native::bernoulli_"
mark_node_as_mutating(self, x)


Expand All @@ -4145,9 +4144,7 @@ class AccumulateGrad(ExternKernel):

def codegen(self, wrapper):
(variable, new_grad) = (t.codegen_reference() for t in self.inputs)
wrapper.writeline(
f"{self.get_kernel_name()}({variable}, {new_grad}){wrapper.ending}"
)
wrapper.writeline(f"{self.python_kernel_name}({variable}, {new_grad})")

def should_allocate(self):
return False
Expand All @@ -4166,7 +4163,6 @@ def __init__(self, variable, new_grad):
)
self.name = V.graph.register_buffer(self)
self.python_kernel_name = "inductor_ops.accumulate_grad_"
self.cpp_kernel_name = "torch::inductor::accumulate_grad_"
mark_node_as_mutating(self, variable)


Expand All @@ -4184,6 +4180,7 @@ def codegen(self, wrapper):
get_operator_enum = {"add": "sum", "multiply": "prod"}
if reduce in get_operator_enum:
reduce = get_operator_enum[reduce]
self.cpp_kernel_name = self.get_cpp_kernel(self.fn, reduce)

if self.src_is_tensor:
(x, index, src) = (t.codegen_reference() for t in self.inputs)
Expand All @@ -4193,8 +4190,8 @@ def codegen(self, wrapper):
wrapper.generate_scatter_fallback(
x,
[x, self.constant_args[0], index, src],
self.get_kernel_name(),
self.python_kernel_name,
self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name,
self.fn,
self.src_is_tensor,
reduce,
self.codegen_kwargs(),
Expand All @@ -4203,9 +4200,8 @@ def codegen(self, wrapper):
def should_allocate(self):
return False

def get_cpp_kernel(self):
reduce = self.kwargs["reduce"]
if self.python_kernel_name == "aten.scatter_":
def get_cpp_kernel(self, fn, reduce):
if fn == "aten.scatter_":
if self.src_is_tensor:
kernel = (
"at::scatter_out" if reduce is None else "at::scatter_reduce_out"
Expand All @@ -4230,7 +4226,7 @@ def get_unbacked_symbol_defs(self):

def __init__(
self,
python_kernel_name,
fn,
x,
dim: int,
index,
Expand All @@ -4239,8 +4235,10 @@ def __init__(
reduce: Optional[str] = None,
include_self: bool = True,
):
assert python_kernel_name in {"aten.scatter_", "aten.scatter_reduce_"}
assert fn in {"aten.scatter_", "aten.scatter_reduce_"}
self.src_is_tensor = isinstance(src, TensorBox)
self.python_kernel_name = fn
self.fn = fn

constant_args: Tuple[Any, ...]
if self.src_is_tensor:
Expand All @@ -4257,9 +4255,6 @@ def __init__(
constant_args,
{"reduce": reduce, "include_self": include_self},
)

self.python_kernel_name = python_kernel_name
self.cpp_kernel_name = self.get_cpp_kernel()
self.ordered_kwargs_for_cpp_kernel = ["reduce", "include_self"]
self.name = V.graph.register_buffer(self)
mark_node_as_mutating(self, x)
Expand All @@ -4284,7 +4279,9 @@ def codegen(self, wrapper):
args = [x, indices_str, values, *self.codegen_const_args()]
wrapper.writeline(
wrapper.wrap_kernel_call(
self.get_kernel_name(),
self.cpp_kernel_name
if V.graph.cpp_wrapper
else self.python_kernel_name,
args,
)
)
Expand Down Expand Up @@ -4458,7 +4455,9 @@ def is_not_write(arg):

self.cpp_kernel_name = kernel._schema.name
self.cpp_kernel_overload_name = kernel._schema.overload_name
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]
self.cpp_kernel_key = (
f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"
)

self.cpp_op_schema = get_cpp_op_schema(kernel)
self.ordered_kwargs_for_cpp_kernel = [
Expand Down Expand Up @@ -4496,8 +4495,7 @@ def sdpa_ver_fn():
return self.cpp_kernel_name

kernel_to_ver = {"at::_scaled_dot_product_flash_attention": sdpa_ver_fn}
ver_fn = kernel_to_ver.get(self.cpp_kernel_name, None) # type: ignore[arg-type]
if ver_fn is not None:
if (ver_fn := kernel_to_ver.get(self.cpp_kernel_name, None)) is not None:
return ver_fn()
return self.cpp_kernel_name

Expand Down Expand Up @@ -4722,7 +4720,7 @@ def codegen(self, wrapper):

wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
args,
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5142,7 +5140,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.cpp_kernel_name if V.graph.cpp_wrapper else self.python_kernel_name,
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5217,7 +5215,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5307,7 +5305,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5393,7 +5391,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5447,7 +5445,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5513,7 +5511,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5584,7 +5582,7 @@ def __init__(
def codegen(self, wrapper):
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
self.codegen_args(),
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -5837,7 +5835,7 @@ def codegen(self, wrapper):
)
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
codegen_args,
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -6021,7 +6019,7 @@ def codegen(self, wrapper):
)
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
conv_args,
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -6195,7 +6193,7 @@ def codegen(self, wrapper):
)
wrapper.generate_extern_kernel_alloc_and_find_schema_if_needed(
self.get_name(),
self.get_kernel_name(),
self.codegen_kernel_name(),
codegen_args,
self.cpp_op_schema,
self.cpp_kernel_key,
Expand Down Expand Up @@ -7227,7 +7225,9 @@ def set_cpp_kernel(self, kernel):

self.cpp_kernel_name = kernel._schema.name
self.cpp_kernel_overload_name = kernel._schema.overload_name
self.cpp_kernel_key = f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}" # type: ignore[union-attr]
self.cpp_kernel_key = (
f"{self.cpp_kernel_name.replace('::', '_')}_{self.cpp_kernel_overload_name}"
)

self.cpp_op_schema = get_cpp_op_schema(kernel)
self.ordered_kwargs_for_cpp_kernel = [
Expand Down

0 comments on commit 0942718

Please sign in to comment.