Skip to content

Commit

Permalink
[AOTInductor] ProxyExecutor supports List[Tensor] return type (pytorc…
Browse files Browse the repository at this point in the history
…h#110182)

Summary:
Pull Request resolved: pytorch#110182

Support custom ops returns List[Tensor] type, like `"fn_with_list_output(Tensor[] tensors, int i) -> Tensor[]"`

As an example
`out5, out6 = torch.ops.fb.fn_with_list_output([out3, out4], 1)`

got compiled into

```
    AtenTensorHandle buf8_handle;  // output buffer
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&buf8_handle));
    RAIIAtenTensorHandle buf8(buf8_handle);
    AtenTensorHandle buf9_handle;  // output buffer
    AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&buf9_handle));
    RAIIAtenTensorHandle buf9(buf9_handle);
    AtenTensorHandle tensor_args_var_5[] = {buf5.get(), buf6.get(), buf8.get(), buf9.get()};
    int64_t int_args_var_6[] = {1};
    aoti_torch_proxy_executor_call_function(proxy_executor, 2, 1, int_args_var_6, 4, tensor_args_var_5);
```

Test Plan: Test

Differential Revision: D49694691

fbshipit-source-id: 4be9fe4c4786f7099710e8cbe4ce01cd5a3d70b8
  • Loading branch information
SherlockNoMad authored and facebook-github-bot committed Sep 28, 2023
1 parent eaf27cb commit b44b4cd
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 21 deletions.
44 changes: 24 additions & 20 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ def convert_arg_type(python_type):


def convert_return_type(python_type):
# TODO: only support Tensor as func return type for now
# TODO: support alias
assert (
python_type == "Tensor"
), f"only support tensor output for cpp_wrapper, but receive type {python_type}"
return f"at::{python_type}"
python_to_cpp = {
"Tensor": "at::Tensor",
"List[Tensor]": "std::vector<at::Tensor>",
}

cpp_type = python_to_cpp.get(python_type, None)
assert cpp_type is not None, f"NYI return type: {python_type}"
return cpp_type


def get_cpp_op_schema(kernel):
Expand Down Expand Up @@ -1742,30 +1745,31 @@ def fill_output_arg(arg, return_type):
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
new_tensor_args.append(f"{arg}.get()")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
):
# TODO: handle tensor list return type
raise NotImplementedError("NYI support for return type: List[Tensor]")
elif isinstance(return_type, torch.SymIntType):
raise NotImplementedError("NYI support for return type: SymInt")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.SymIntType
):
raise NotImplementedError("NYI support for return type: List[SymInt]")
else:
raise AssertionError(f"Unsupport return type found: {return_type}")
raise AssertionError(f"Unsupported return type found: {return_type}")

# TODO: Only support tensor(s) returns for now, SymInt is not implemented yet
for return_type in return_types:
if isinstance(return_type, (torch.TensorType)):
pass
elif isinstance(return_type, torch.OptionalType):
assert isinstance(return_type.getElementType(), torch.TensorType)
elif isinstance(return_type, torch.ListType):
assert isinstance(return_type.getElementType(), torch.TensorType)
else:
raise NotImplementedError(
f"return type {return_type} is not yet supported."
)

for output_arg, return_type in zip(output_args, return_types):
for output_arg in output_args:
if output_arg is not None:
if isinstance(return_type, torch.OptionalType):
fill_output_arg(output_arg, return_type.getElementType())
elif isinstance(return_type, torch.TensorType):
fill_output_arg(output_arg, return_type)
else:
raise NotImplementedError(
"Only Tensor and OptionalTensor return type is supported."
)
fill_output_arg(output_arg, torch.TensorType.get())

return new_tensor_args, new_int_args

Expand Down
13 changes: 12 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3792,13 +3792,24 @@ def export_extern_kernel_node(self):
named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs)

# serialize_outputs
if isinstance(self.outputs, (list, tuple)):
if isinstance(self.outputs, tuple):
# For tuple returns, e.g "-> (Tensor, Tensor)"
output_arguments = [
export_schema.Argument.create(
as_tensor=export_schema.TensorArgument(name=output.get_name())
)
for output in self.outputs
]
elif isinstance(self.outputs, list):
# For list of tensor, e.g. "-> List[Tensor]"
output_arguments = [
export_schema.Argument.create(
as_tensors=[
export_schema.TensorArgument(name=output.get_name())
for output in self.outputs
]
)
]
else:
output_arguments = [
export_schema.Argument.create(
Expand Down

0 comments on commit b44b4cd

Please sign in to comment.