From cf021a1dfc435590d9d5d8e0075746da01200940 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 30 Apr 2024 13:11:24 -0500 Subject: [PATCH] [Relax][UnitTest] Validate IRModule with multiple targets This commit adds a unit test to verify that a single `IRModule` can contain functions that will be used on multiple distinct targets. Previously, this test case caused errors when running the `LegalizeOps` and `ApplyDefaultSchedule` transforms. --- tests/python/relax/test_vm_build.py | 59 +++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/tests/python/relax/test_vm_build.py b/tests/python/relax/test_vm_build.py index 180535231d98..ab40e181a35a 100644 --- a/tests/python/relax/test_vm_build.py +++ b/tests/python/relax/test_vm_build.py @@ -1246,5 +1246,64 @@ def test_set_input_get_failure_rpc(exec_mode): run_on_rpc(TestVMSetInput, set_input_attempt_get, exec_mode) +@tvm.testing.requires_gpu +def test_relax_module_with_multiple_targets(exec_mode): + """Relax functions may contain kernels for multiple targets + + In this example, the module contains one function to execute on + LLVM, and one function to execute on CUDA. + + """ + + @I.ir_module + class Module: + I.module_global_infos({"vdevice": [I.vdevice("llvm")]}) + + @R.function + def func_cuda(A: R.Tensor([32, 32], "float32"), B: R.Tensor([32, 32], "float32")): + C = R.add(A, B) + return C + + @R.function + def func_llvm( + A: R.Tensor([32, 32], "float32", "llvm"), B: R.Tensor([32, 32], "float32", "llvm") + ): + C = R.add(A, B) + return C + + seq = tvm.ir.transform.Sequential( + [ + tvm.relax.transform.LegalizeOps(), + tvm.dlight.ApplyDefaultSchedule(tvm.dlight.gpu.Fallback()), + ], + name="LegalizeAndSchedule", + ) + with tvm.target.Target("cuda"): + built = tvm.relax.build(seq(Module)) + + np_A = np.random.random([32, 32]).astype("float32") + np_B = np.random.random([32, 32]).astype("float32") + + dev_llvm = tvm.device("llvm") + vm_llvm = tvm.relax.VirtualMachine(built, device=dev_llvm) + llvm_output = vm_llvm["func_llvm"]( + tvm.nd.array(np_A, dev_llvm), + tvm.nd.array(np_B, dev_llvm), + ) + + dev_cuda = tvm.device("cuda") + vm_cuda = tvm.relax.VirtualMachine(built, device=dev_cuda) + + cuda_output = vm_cuda["func_cuda"]( + tvm.nd.array(np_A, dev_cuda), + tvm.nd.array(np_B, dev_cuda), + ) + + np_C = np_A + np_B + + tvm.testing.assert_allclose(llvm_output.numpy(), np_C) + tvm.testing.assert_allclose(cuda_output.numpy(), np_C) + + if __name__ == "__main__": tvm.testing.main()