diff --git a/torch/_inductor/codegen/wrapper.py b/torch/_inductor/codegen/wrapper.py index b3e1a19154c1a..d2f24b5274352 100644 --- a/torch/_inductor/codegen/wrapper.py +++ b/torch/_inductor/codegen/wrapper.py @@ -84,12 +84,15 @@ def convert_arg_type(python_type): def convert_return_type(python_type): - # TODO: only support Tensor as func return type for now # TODO: support alias - assert ( - python_type == "Tensor" - ), f"only support tensor output for cpp_wrapper, but receive type {python_type}" - return f"at::{python_type}" + python_to_cpp = { + "Tensor": "at::Tensor", + "List[Tensor]": "std::vector", + } + + cpp_type = python_to_cpp.get(python_type, None) + assert cpp_type is not None, f"NYI return type: {python_type}" + return cpp_type def get_cpp_op_schema(kernel): @@ -1742,11 +1745,6 @@ def fill_output_arg(arg, return_type): ) self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);") new_tensor_args.append(f"{arg}.get()") - elif isinstance(return_type, torch.ListType) and isinstance( - return_type.getElementType(), torch.TensorType - ): - # TODO: handle tensor list return type - raise NotImplementedError("NYI support for return type: List[Tensor]") elif isinstance(return_type, torch.SymIntType): raise NotImplementedError("NYI support for return type: SymInt") elif isinstance(return_type, torch.ListType) and isinstance( @@ -1754,18 +1752,24 @@ def fill_output_arg(arg, return_type): ): raise NotImplementedError("NYI support for return type: List[SymInt]") else: - raise AssertionError(f"Unsupport return type found: {return_type}") + raise AssertionError(f"Unsupported return type found: {return_type}") + + # TODO: Only support tensor(s) returns for now, SymInt is not implemented yet + for return_type in return_types: + if isinstance(return_type, (torch.TensorType)): + pass + elif isinstance(return_type, torch.OptionalType): + assert isinstance(return_type.getElementType(), torch.TensorType) + elif isinstance(return_type, torch.ListType): + assert isinstance(return_type.getElementType(), torch.TensorType) + else: + raise NotImplementedError( + f"return type {return_type} is not yet supported." + ) - for output_arg, return_type in zip(output_args, return_types): + for output_arg in output_args: if output_arg is not None: - if isinstance(return_type, torch.OptionalType): - fill_output_arg(output_arg, return_type.getElementType()) - elif isinstance(return_type, torch.TensorType): - fill_output_arg(output_arg, return_type) - else: - raise NotImplementedError( - "Only Tensor and OptionalTensor return type is supported." - ) + fill_output_arg(output_arg, torch.TensorType.get()) return new_tensor_args, new_int_args diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index 61d5f4b32f32d..523b7b1d02efe 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -3792,13 +3792,24 @@ def export_extern_kernel_node(self): named_arguments = serializer.serialize_inputs(self.op_overload, args, kwargs) # serialize_outputs - if isinstance(self.outputs, (list, tuple)): + if isinstance(self.outputs, tuple): + # For tuple returns, e.g "-> (Tensor, Tensor)" output_arguments = [ export_schema.Argument.create( as_tensor=export_schema.TensorArgument(name=output.get_name()) ) for output in self.outputs ] + elif isinstance(self.outputs, list): + # For list of tensor, e.g. "-> List[Tensor]" + output_arguments = [ + export_schema.Argument.create( + as_tensors=[ + export_schema.TensorArgument(name=output.get_name()) + for output in self.outputs + ] + ) + ] else: output_arguments = [ export_schema.Argument.create(