Skip to content

Commit

Permalink
[Relax][UnitTest] Validate IRModule with multiple targets
Browse files Browse the repository at this point in the history
This commit adds a unit test to verify that a single `IRModule` can
contain functions that will be used on multiple distinct targets.
Previously, this test case caused errors when running the
`LegalizeOps` and `ApplyDefaultSchedule` transforms.
  • Loading branch information
Lunderberg committed May 15, 2024
1 parent f044eef commit cf021a1
Showing 1 changed file with 59 additions and 0 deletions.
59 changes: 59 additions & 0 deletions tests/python/relax/test_vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -1246,5 +1246,64 @@ def test_set_input_get_failure_rpc(exec_mode):
run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode)


@tvm.testing.requires_gpu
def test_relax_module_with_multiple_targets(exec_mode):
"""Relax functions may contain kernels for multiple targets
In this example, the module contains one function to execute on
LLVM, and one function to execute on CUDA.
"""

@I.ir_module
class Module:
I.module_global_infos({"vdevice": [I.vdevice("llvm")]})

@R.function
def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")):
C = R.add(A, B)
return C

@R.function
def func_llvm(
A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm")
):
C = R.add(A, B)
return C

seq = tvm.ir.transform.Sequential(
[
tvm.relax.transform.LegalizeOps(),
tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()),
],
name="LegalizeAndSchedule",
)
with tvm.target.Target("cuda"):
built = tvm.relax.build(seq(Module))

np_A = np.random.random([32, 32]).astype("float32")
np_B = np.random.random([32, 32]).astype("float32")

dev_llvm = tvm.device("llvm")
vm_llvm = tvm.relax.VirtualMachine(built, device=dev_llvm)
llvm_output = vm_llvm["func_llvm"](
tvm.nd.array(np_A, dev_llvm),
tvm.nd.array(np_B, dev_llvm),
)

dev_cuda = tvm.device("cuda")
vm_cuda = tvm.relax.VirtualMachine(built, device=dev_cuda)

cuda_output = vm_cuda["func_cuda"](
tvm.nd.array(np_A, dev_cuda),
tvm.nd.array(np_B, dev_cuda),
)

np_C = np_A + np_B

tvm.testing.assert_allclose(llvm_output.numpy(), np_C)
tvm.testing.assert_allclose(cuda_output.numpy(), np_C)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit cf021a1

Please sign in to comment.