Skip to content

Commit

Permalink
[PT2 Inference] Enable ProxyExecutor with Runtime (pytorch#109748)
Browse files Browse the repository at this point in the history
Summary:

Switch ProxyExecutor to use AtenTensorHandle.

bypass-github-pytorch-ci-checks
OSS CI has a irrelevant failure.

Test Plan: E2E Test

Reviewed By: yifuwang

Differential Revision: D49471659
  • Loading branch information
SherlockNoMad authored and facebook-github-bot committed Sep 27, 2023
1 parent 6138750 commit b62a2a9
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 15 deletions.
7 changes: 6 additions & 1 deletion torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,12 @@ def is_sym_bool_arg(self, arg) -> bool:

def serialize_input(self, arg) -> Argument:
import torch._inductor.ir as inductor_ir
inductor_tensor_buffers = (inductor_ir.InputBuffer, inductor_ir.ComputedBuffer, inductor_ir.ConcatKernel)
inductor_tensor_buffers = (
inductor_ir.InputBuffer,
inductor_ir.ComputedBuffer,
inductor_ir.ConcatKernel,
inductor_ir.ExternKernelOut,
)

if isinstance(arg, torch.fx.Node):
if arg.op == "get_attr":
Expand Down
40 changes: 31 additions & 9 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
cpp_kernel_overload_name="",
op_overload=None,
raw_args=None,
outputs=None,
):
self.writeline(f"{name} = {kernel}({', '.join(codegen_args)})")

Expand Down Expand Up @@ -836,7 +837,7 @@ def make_buffer_free(self, buffer):
return f"del {buffer.get_name()}"

def codegen_exact_buffer_reuse(self, old_name: str, new_name: str, del_line: str):
return f"{self.declare}{new_name} = {old_name}{del_line} {self.comment} reuse"
return f"{self.declare}{new_name} = {old_name}{del_line}{self.ending} {self.comment} reuse"

def make_buffer_reuse(self, old, new):
assert old.get_dtype() == new.get_dtype()
Expand Down Expand Up @@ -1664,10 +1665,16 @@ def fill_args(arg, arg_type):
torch.Type,
torch.DeviceObjType,
)
inductor_tensor_buffers = (
ir.InputBuffer,
ir.ComputedBuffer,
ir.ConcatKernel,
ir.ExternKernelOut,
)

if isinstance(arg_type, torch.TensorType):
assert isinstance(arg, (ir.InputBuffer, ir.ComputedBuffer))
new_tensor_args.append(f"&{arg.name}")
assert isinstance(arg, inductor_tensor_buffers)
new_tensor_args.append(f"{arg.name}.get()")
elif isinstance(arg_type, (torch.IntType, torch.SymIntType)):
# int or SymInt
assert isinstance(arg, int)
Expand All @@ -1683,14 +1690,16 @@ def fill_args(arg, arg_type):

# List[Tensor]
if isinstance(arg_type.getElementType(), torch.TensorType):
new_tensor_args.extend([f"&{a.name}" for a in arg])
new_tensor_args.extend([f"{a.name}.get()" for a in arg])
# List[Optional[Tensor]]
elif isinstance(
arg_type.getElementType(), torch.OptionalType
) and isinstance(
arg_type.getElementType().getElementType(), torch.TensorType
):
new_tensor_args.extend([f"&{a.name}" for a in arg if a is not None])
new_tensor_args.extend(
[f"{a.name}.get()" for a in arg if a is not None]
)
# List [int] or List[SymInt]
elif isinstance(
arg_type.getElementType(), (torch.IntType, torch.SymIntType)
Expand Down Expand Up @@ -1723,8 +1732,12 @@ def fill_args(arg, arg_type):

def fill_output_arg(arg, return_type):
if isinstance(return_type, torch.TensorType):
self.writeline(f"at::Tensor {arg}; // output buffer")
new_tensor_args.append(f"&{output_arg}")
self.writeline(f"AtenTensorHandle {arg}_handle; // output buffer")
self.writeline(
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_new_uninitialized_tensor(&{arg}_handle));"
)
self.writeline(f"RAIIAtenTensorHandle {arg}({arg}_handle);")
new_tensor_args.append(f"{arg}.get()")
elif isinstance(return_type, torch.ListType) and isinstance(
return_type.getElementType(), torch.TensorType
):
Expand Down Expand Up @@ -1763,16 +1776,19 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed(
cpp_kernel_overload_name="",
op_overload=None,
raw_args=None,
outputs=None,
):
if config.is_fbcode():
assert op_overload is not None
assert raw_args is not None
assert outputs is not None

return self.generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
name,
cpp_kernel_key,
op_overload,
raw_args,
outputs,
)
else:
return self.generate_extern_kernel_alloc_and_find_schema_if_needed_oss(
Expand Down Expand Up @@ -1813,8 +1829,12 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(
cpp_kernel_key,
op_overload,
raw_args, # contains both args and flatten kwargs
outputs,
):
output_args = [name]
if isinstance(outputs, (list, tuple)):
output_args = [output.get_name() for output in outputs]
else:
output_args = [outputs.get_name()]

(
tensor_call_args,
Expand All @@ -1825,7 +1845,9 @@ def generate_extern_kernel_alloc_and_find_schema_if_needed_fbcode(

tensor_args_var = f"tensor_args_var_{next(self.kernel_callsite_id)}"
tensor_call_args_str = ", ".join(tensor_call_args)
self.writeline(f"void* {tensor_args_var}[] = {{{tensor_call_args_str}}};")
self.writeline(
f"AtenTensorHandle {tensor_args_var}[] = {{{tensor_call_args_str}}};"
)

int_args_var = f"int_args_var_{next(self.kernel_callsite_id)}"
int_call_args_str = ", ".join(int_call_args)
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3829,6 +3829,7 @@ def codegen(self, wrapper):
self.cpp_kernel_overlad_name,
self.op_overload,
exported_args,
self.outputs,
)
else:
super().codegen(wrapper)
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/inductor/aoti_torch/c/shim.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
AtenTensorHandle* ret8 // returns new reference
);

// This function will create a new uninitialized tensor object
// and its pointer is returned through *ret.
AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret);

AOTI_TORCH_EXPORT AOTITorchError
aoti_torch_tensor_copy_(AtenTensorHandle src, AtenTensorHandle dst);

Expand Down Expand Up @@ -214,7 +219,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_proxy_executor_call_function(
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
void** flatten_tensor_args);
AtenTensorHandle* flatten_tensor_args);

#ifdef __cplusplus
} // extern "C"
Expand Down
6 changes: 3 additions & 3 deletions torch/csrc/inductor/aoti_torch/proxy_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

#include <ATen/core/ivalue.h>
#include <c10/macros/Export.h>
#include <string>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

namespace torch {
namespace aot_inductor {

class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
class ProxyExecutor {
public:
ProxyExecutor() {}
virtual ~ProxyExecutor() {}
Expand All @@ -17,7 +17,7 @@ class TORCH_API ProxyExecutor : public torch::CustomClassHolder {
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
void** flatten_tensor_args) = 0;
AtenTensorHandle* flatten_tensor_args) = 0;
};

} // namespace aot_inductor
Expand Down
9 changes: 8 additions & 1 deletion torch/csrc/inductor/aoti_torch/shim_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ AOTITorchError aoti_torch__scaled_dot_product_flash_attention(
});
}

AOTITorchError aoti_torch_new_uninitialized_tensor(AtenTensorHandle* ret) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
at::Tensor* out_tensor = new at::Tensor();
*ret = tensor_pointer_to_tensor_handle(out_tensor);
});
}

// TODO: implement a more efficient version instead of calling into aten
AOTITorchError aoti_torch_tensor_copy_(
AtenTensorHandle src,
Expand Down Expand Up @@ -301,7 +308,7 @@ AOTITorchError aoti_torch_proxy_executor_call_function(
int num_ints,
int64_t* flatten_int_args,
int num_tensors,
void** flatten_tensor_args) {
AtenTensorHandle* flatten_tensor_args) {
AOTI_TORCH_CONVERT_EXCEPTION_TO_ERROR_CODE({
ProxyExecutor* executor = reinterpret_cast<ProxyExecutor*>(proxy_executor);
executor->call_function(
Expand Down

0 comments on commit b62a2a9

Please sign in to comment.