Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] [Tracking Issue] Heterogeneous execution for Relax #15101

Closed
5 tasks done
yongwww opened this issue Jun 14, 2023 · 9 comments
Closed
5 tasks done

[Unity] [Tracking Issue] Heterogeneous execution for Relax #15101

yongwww opened this issue Jun 14, 2023 · 9 comments
Assignees
Labels
branch: unity needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs

Comments

@yongwww
Copy link
Member

yongwww commented Jun 14, 2023

This issue is to track progress for Relax Heterogenous support proposed in here.

  • P1. Add the VDevice, update TVMScript parser and printer to support it.
  • P2. Implement hint_on_device, to_vdevice, and builtin to_device.
  • P3. Update StructInfo Deduction and VM Codegen.
  • P4. Add UpdateVDevice pass.
  • P5. Develop RealizeVDevice pass.

cc @quic-sanirudh

@yongwww yongwww added type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Jun 14, 2023
@yongwww yongwww self-assigned this Jun 14, 2023
csullivan pushed a commit that referenced this issue Sep 5, 2023
This pr adds RealizeVDevice pass as mentioned in #15101

* [Unity] RealizeVDevice pass

* Replace hint_on_device with to_vdevice in pass

* Add sugar to to_vdevice
@csullivan
Copy link
Contributor

Thank you for bringing support for heterogeneous graphs @yongwww! Shall we close this as completed?

@qelk123
Copy link

qelk123 commented Sep 20, 2023

Thank you for your contirbution @yongwww! I am wondering are there any e2e test cases I could follow? It seems I can't specify more than one target in the relax.build API. How should I compile when nodes in my graph are heterogeneous to each other? (CPU node and CUDA node for example)
In relay.vm.compile I could offer a target dict like: target={"cpu": tvm.target.Target("llvm"), "cuda": tvm.target.Target("cuda")} to support multi-device compilation. What should I do when I using relax vm?

@yongwww
Copy link
Member Author

yongwww commented Sep 22, 2023

@csullivan will close this tracking issue in the coming days, I have a pr to be up to add e2e test cases.

@yongwww
Copy link
Member Author

yongwww commented Sep 22, 2023

@qelk123 great question!!! currently relax.build doesn't support multiple targets. I have a pr to enable this e2e test case, the pr is supposed to be up in the coming days (hopefully by end of this week). What we plan to do in Relax is:

  • use the target information defined in vdevice list of global_infos of ir_module, then we don't need to feed a list/dict of targets to relax.build.
  • modify tvm.build to lower the TIR based on the target attribute of the prim_func.
  • update relax virtual machine to accept multiple devices.

@qelk123
Copy link

qelk123 commented Sep 23, 2023

@yongwww Thank you very much! I am currently working on adding end-to-end multi-device support for my Relax heterogeneous graph based on this issue. At the moment, it works on some simple graphs. As you mentioned, these three aspects are also the main aspects I am focusing on during development. However, I have encountered some constraints during my experiments. One such constraint is that when I am not using the R.dataflow scope, the vdevice information does not appear to be properly propagated. Is the current design intended to support heterogeneous execution only within the dataflow scope?

@yongwww
Copy link
Member Author

yongwww commented Sep 26, 2023

@qelk123 Thank you for sharing your progress and challenges on the end-to-end multi-device support. It's promising to hear that it's operational for simpler graphs. The VDevice propagation is intended to function not just for the dataflow block, but also for the binding block. If the vdevice isn't propagating correctly for the binding block, that's a bug that needs fixing. I'll attempt to reproduce the problem and address it. If possible, could you share the specific test case you're working with? Additionally, don't hesitate to report the limitations you've encountered, welcome to submit the patch you have to tvm unity! I'm eager to explore possible enhancements.

@yongwww
Copy link
Member Author

yongwww commented Sep 26, 2023

The e2e multi-device test cases were added in #15823. Will close this issue once it lands

@liquanfeng
Copy link
Contributor

liquanfeng commented Oct 25, 2023

A very useful job! And I try it on tests/python/relax/test_codegen_cudnn.py::test_conv2d_offload according to tests/python/relax/test_vm_multi_device.py::test_multi_device as shown below

import numpy as np

import tvm
import tvm.testing
import tvm.topi.testing
from tvm import relax
from tvm.relax.backend.contrib.cudnn import partition_for_cudnn
from tvm.script import relax as R, ir as I

from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder

data_shape, weight_shape, dtype = (
    (16, 32, 32, 16),
    (32, 3, 3, 16),
    "float32",
)

input_np = np.random.randn(*data_shape).astype(dtype)
weight_np = np.random.randn(*weight_shape).astype(dtype)

oc = weight_shape[0]
bias_np = np.random.randn(1, 1, 1, oc).astype(dtype)
args = (input_np, weight_np, bias_np)

with IRBuilder() as builder:
    with relax_builder.function():
        R.func_name("main")
        data = R.arg("data", R.Tensor(data_shape, dtype))
        weight = R.arg("weight", R.Tensor(weight_shape, dtype))
        bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype))

        with R.dataflow() as frame:
            output = R.emit(
                R.nn.conv2d(
                    data,
                    weight,
                    out_dtype=dtype,
                    padding=(1, 1),
                    data_layout="NHWC",
                    kernel_layout="OHWI",
                )
            )
            output = R.emit(output + bias)

            output = R.emit(relax.op.to_vdevice(output, I.vdevice("llvm")))
            output = R.emit(R.multiply(output, R.const(2, "float32")))
            R.output(output)

        R.func_ret_value(frame.output_vars[0])

func = builder.get()
mod = tvm.IRModule(
    {"main": func},
    global_infos={
        "vdevice": [
            I.vdevice("cuda", 0),
            I.vdevice("llvm"),
        ]
    },
)

mod = partition_for_cudnn(mod)
mod = relax.transform.RunCodegen()(mod)

devs = [tvm.device("cuda", 0), tvm.device("llvm")]
mod = relax.transform.RealizeVDevice()(mod)
mod = relax.transform.LegalizeOps()(mod)
mod = tvm.tir.transform.DefaultGPUSchedule()(mod)

with tvm.transform.PassContext(config={"relax.backend.use_cuda_graph": False}):
    ex = relax.build(mod)
vm = relax.VirtualMachine(ex, devs)
f = vm["main"]
inputs = [tvm.nd.array(inp, tvm.device("cuda", 0)) for inp in input_np]

print(f(*inputs).numpy())

but raise following error

Traceback (most recent call last):
  File "/workspace/yongwww/tvm/tests/byoc.py", line 77, in <module>
    ex = relax.build(mod)
  File "/workspace/yongwww/tvm/python/tvm/relax/vm_build.py", line 334, in build
    new_mod = lowering_passes(mod)
  File "/workspace/yongwww/tvm/python/tvm/ir/transform.py", line 238, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/workspace/yongwww/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 239, in __call__
    raise_last_ffi_error()
  File "/workspace/yongwww/tvm/python/tvm/_ffi/base.py", line 476, in raise_last_ffi_error
    raise py_err
tvm._ffi.base.TVMError: Traceback (most recent call last):
  24: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::TVMRetValue)
  23: tvm::transform::Pass::operator()(tvm::IRModule) const
  22: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  21: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  20: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  19: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  18: _ZN3tvm7runtime13PackedFuncObj
  17: tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1}>(tvm::relax::transform::CallTIRRewrite()::{lambda(tvm::IRModule, tvm::transform::PassContext)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}::operator()(tvm::runtime::TVMArgs const, tvm::runtime::TVMRetValue) const
  16: tvm::relax::CallTIRMutator::Run()
  15: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  14: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  13: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::FunctionNode const*)
  12: tvm::relax::ExprMutator::VisitWithNewScope(tvm::RelayExpr const&, tvm::runtime::Optional<tvm::runtime::Array<tvm::relax::Var, void> >)
  11: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  10: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  9: tvm::relax::ExprMutator::VisitExpr_(tvm::relax::SeqExprNode const*)
  8: tvm::relax::ExprMutator::VisitBindingBlock(tvm::relax::BindingBlock const&)
  7: tvm::relax::ExprMutator::VisitBindingBlock_(tvm::relax::BindingBlockNode const*)
  6: tvm::relax::ExprMutator::VisitBinding(tvm::relax::Binding const&)
  5: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*)
  4: tvm::relax::ExprMutator::VisitBinding_(tvm::relax::VarBindingNode const*, tvm::relax::CallNode const*)
  3: tvm::relax::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relax11ExprFunctorIFNS_9RelayExprERKS2_EE10InitVTableEvENUlRKNS_7r
  1: tvm::relax::CallTIRMutator::VisitExpr_(tvm::relax::CallNode const*)
  0: tvm::relax::GetDeviceIndex(tvm::IRModule const&, tvm::VDevice const&)
  File "/workspace/yongwww/tvm/src/relax/transform/utils.h", line 384
TVMError: The vdevice is not in the ir_module.

Is there any problem with byoc or I miss something?

@yongwww
Copy link
Member Author

yongwww commented Oct 31, 2023

@liquanfeng thanks for reporting this! As the error shown, the vdevice is not defined in the global_info of the IRModule. Reason is a new vdevice was created with I.vdevice() in R.emit(relax.op.to_vdevice(output, I.vdevice("llvm"))), whereas another new vdevices are created in global_infos={"vdevice": [ I.vdevice("cuda", 0),I.vdevice("llvm"),.... In order to reuse the same vdevice object, you can define a vdevice list before the builder and reuse it. Updating your sample as below should be good to move forward.


vdevice = [I.vdevice("cuda", 0), I.vdevice("llvm")]
with IRBuilder() as builder:
    with relax_builder.function():
        R.func_name("main")
        data = R.arg("data", R.Tensor(data_shape, dtype))
        weight = R.arg("weight", R.Tensor(weight_shape, dtype))
        bias = R.arg("bias", R.Tensor((1, 1, 1, weight_shape[0]), dtype))

        with R.dataflow() as frame:
            output = R.emit(
                R.nn.conv2d(
                    data,
                    weight,
                    out_dtype=dtype,
                    padding=(1, 1),
                    data_layout="NHWC",
                    kernel_layout="OHWI",
                )
            )
            output = R.emit(output + bias)

            output = R.emit(relax.op.to_vdevice(output, vdevice[1]))
            output = R.emit(R.multiply(output, R.const(2, "float32")))
            R.output(output)

        R.func_ret_value(frame.output_vars[0])

func = builder.get()
mod = tvm.IRModule(
    {"main": func},
    global_infos={"vdevice": vdevice},
)

@yongwww yongwww closed this as completed Dec 20, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
branch: unity needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it type:rfc-tracking RFC progress tracking. Ref: https://github.com/apache/tvm-rfcs
Projects
None yet
Development

No branches or pull requests

4 participants