Skip to content

Commit

Permalink
[TIR][RPC] Allow RPC calls to compiled PrimFuncs with no arguments (#…
Browse files Browse the repository at this point in the history
…17098)

The `PackedFunc` interface has arguments `int num_args` and `TVMValue*
args`, which contain the number of arguments and a pointer to the
array of arguments.  Prior to this commit, when implementing the
`PackedFunc` interface for TIR `PrimFunc`s, the `MakePackedAPI` pass
would always assert that the `args` pointer was not null.  However,
the `args` pointer is allowed to be null if `num_args` is zero.  For
example, this occurs when calling an RPC function with no arguments.

This commit updates the `MakePackedAPI` transform to only assert that
`args` is non-null when `num_args` is greater than zero.
  • Loading branch information
Lunderberg committed Jun 18, 2024
1 parent 675a023 commit f6fe2aa
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 7 deletions.
10 changes: 6 additions & 4 deletions src/tir/transforms/make_packed_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -296,10 +296,12 @@ PrimFunc MakePackedAPI(PrimFunc func) {
return error_message.str();
}()));

seq_init.push_back(
MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL"));
seq_init.push_back(
MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL"));
if (num_args > 0) {
seq_init.push_back(
MakeAssertNotNull(v_packed_args, name_hint + ": TVMValue* arg pointer was NULL"));
seq_init.push_back(
MakeAssertNotNull(buf_packed_arg_type_ids->data, name_hint + ": int* type_codes was NULL"));
}

seq_init.emplace_back(DeclBuffer(buf_packed_arg_type_ids, nop));

Expand Down
55 changes: 52 additions & 3 deletions tests/python/runtime/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,27 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import tvm.testing

import multiprocessing
import os
import stat
import sys
import tempfile
import time

import pytest
import numpy as np

import tvm
import tvm.testing

from tvm import te
from tvm import rpc
from tvm.relay.backend import Runtime
from tvm.contrib import utils, cc
from tvm.rpc.tracker import Tracker
from tvm.rpc.proxy import Proxy
from tvm.script import ir as I, tir as T


if __name__ == "__main__":
Expand Down Expand Up @@ -685,3 +690,47 @@ def test_rpc_session_timeout_error(with_proxy):
if with_proxy:
proxy.terminate()
tracker.terminate()


@pytest.mark.parametrize("call_with_unused_argument", [True, False])
def test_compiled_function_with_zero_arguments(call_with_unused_argument):
"""RPC functions do not require an argument
This is a regression test. When no arguments are provided, RPC
provides NULL as the `TVMValue* args` argument to a PackedFunc.
However, previous implementations of `MakePackedAPI`
unconditionally asserted that the `args` pointer was non-null.
This assertion is now generated only when the function accepts
a non-zero number of arguments.
"""

@I.ir_module
class Module:
@T.prim_func
def func_without_arg() -> T.int64:
return T.int64(42)

@T.prim_func
def func_with_arg(unused: T.int64) -> T.int64:
return T.int64(42)

built = tvm.build(Module, target="llvm")

server = tvm.rpc.Server(key="x1")
client = tvm.rpc.connect("127.0.0.1", server.port, key="x1")

libname = "libbuilt.so"
with tempfile.TemporaryDirectory(prefix="tvm_rpc_testing_") as temp_dir:
local_path = os.path.join(temp_dir, libname)
built.export_library(local_path)
client.upload(local_path)

remote_mod = client.load_module(libname)

if call_with_unused_argument:
res = remote_mod["func_with_arg"](0)
else:
res = remote_mod["func_without_arg"]()

assert res == 42
41 changes: 41 additions & 0 deletions tests/python/tir-transform/test_tir_transform_make_packed_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,5 +353,46 @@ def func(A: T.Buffer([16, 16], "int32"), B: T.Buffer([16, 16], "int32")):
built(A, B)


def test_zero_arg_function():
"""Only check non-null args when num_args>0"""

@I.ir_module
class Before:
@T.prim_func
def func_without_arg() -> T.int64:
T.func_attr({"target": T.target("llvm", host="llvm")})
return T.int64(42)

@I.ir_module
class Expected:
@T.prim_func
def func_without_arg(
args: T.handle,
arg_type_ids: T.handle("int32"),
num_args: T.int32,
out_ret_value: T.handle("void"),
out_ret_tcode: T.handle("int32"),
resource_handle: T.handle,
) -> T.int32:
T.func_attr(
{
"calling_conv": 1,
"target": T.target("llvm"),
}
)
assert num_args == 0, "func_without_arg: num_args should be 0"
arg_type_ids_1 = T.decl_buffer((0,), "int32", data=arg_type_ids)
with T.attr(0, "compute_scope", "func_without_arg_compute_"):
out_ret_value_1 = T.Buffer((1,), "int64", data=out_ret_value, strides=(1,))
out_ret_value_1[0] = T.Cast("int64", T.int64(42))
out_ret_tcode_1 = T.Buffer((1,), "int32", data=out_ret_tcode, strides=(1,))
out_ret_tcode_1[0] = 0
return 0
return 0

After = tvm.tir.transform.MakePackedAPI()(Before)
tvm.ir.assert_structural_equal(Expected, After)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit f6fe2aa

Please sign in to comment.