Skip to content

Commit

Permalink
[AOTI][refactor] Add aoti_torch_item as a util function (pytorch#126352)
Browse files Browse the repository at this point in the history
Summary: The logic has been repeated several times in the code, so it's worth to write a common util function.

Pull Request resolved: pytorch#126352
Approved by: https://github.com/chenyang78
ghstack dependencies: pytorch#126181, pytorch#126182, pytorch#126183
  • Loading branch information
desertfire authored and ZelboK committed May 19, 2024
1 parent d27e21d commit 272b119
Showing 1 changed file with 24 additions and 18 deletions.
42 changes: 24 additions & 18 deletions torch/_inductor/codegen/cpp_wrapper_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,10 +552,8 @@ def write_wrapper_decl(self):
), "Fails to get the dtype of the sympy.Expr"
cpp_dtype = DTYPE_TO_CPP[dtype]
if config.abi_compatible:
self.prefix.writeline(f"{cpp_dtype} {input_key};")
dtype_str = str(dtype).split(".")[-1]
self.prefix.writeline(
f"aoti_torch_item_{dtype_str}(inputs[{idx}], &{input_key});"
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
else:
self.prefix.writeline(
Expand Down Expand Up @@ -890,6 +888,19 @@ def codegen_scalar_to_tensor(self, output: str):
)
return name

def codegen_tensor_item(
self, dtype: torch.dtype, tensor: str, scalar: str, indented_buffer=None
):
assert (
config.abi_compatible
), "codegen_tensor_item is only used for the ABI-compatible mode"
dtype_str = str(dtype).split(".")[-1]
writer = indented_buffer or self
writer.writeline(f"{DTYPE_TO_CPP[dtype]} {scalar};")
writer.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_{dtype_str}({tensor}, &{scalar}));"
)

@cache_on_self
def get_output_refs(self):
return [
Expand Down Expand Up @@ -1376,10 +1387,9 @@ def codegen_shape_tuple(self, shape: Tuple[Expr, ...]) -> str:
def codegen_dynamic_scalar(self, node):
(data,) = (t.codegen_reference() for t in node.inputs)
if config.abi_compatible:
dtype = node.inputs[0].get_dtype()
dtype_str = str(dtype).split(".")[-1]
self.writeline(f"{DTYPE_TO_CPP[dtype]} {node.sym}_raw;")
self.writeline(f"aoti_torch_item_{dtype_str}({data}, &{node.sym}_raw);")
self.codegen_tensor_item(
node.inputs[0].get_dtype(), data, f"{node.sym}_raw"
)
else:
convert_type = DTYPE_TO_ATEN[node.inputs[0].get_dtype()].replace(
"at::k", "to"
Expand Down Expand Up @@ -1763,12 +1773,13 @@ def codegen_conditional(self, conditional):
outer_outputs.append(out.get_name())

if not isinstance(conditional.predicate, ir.ShapeAsConstantBuffer):
predicate = f"{conditional.predicate.get_name()}_scalar"
self.writeline(f"bool {predicate};")
# in ABI-compatible mode, we need to use the ABI shim function
# to extract a C++ bool from the unrelying scalar bool Tensor
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({conditional.predicate.codegen_reference()}, &{predicate}));"
predicate = f"{conditional.predicate.get_name()}_scalar"
self.codegen_tensor_item(
torch.bool,
conditional.predicate.codegen_reference(),
predicate,
)
else:
# the predicate is not a Tensor: SymBool or Python bool
Expand Down Expand Up @@ -1847,12 +1858,7 @@ def codegen_while_loop(self, while_loop):

if config.abi_compatible:
cond_result = f"{cond_result_name}_scalar"
self.writeline(f"bool {cond_result};")
# in ABI-compatible mode, we need to use the ABI shim function
# to extract a C++ bool from the unrelying scalar bool Tensor
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_item_bool({cond_result_name}, &{cond_result}));"
)
self.codegen_tensor_item(torch.bool, cond_result_name, cond_result)
else:
cond_result = f"{cond_result_name}.item<bool>()"
self.writeline(f"if (!{cond_result}) break;")
Expand Down

0 comments on commit 272b119

Please sign in to comment.