Skip to content

[BUG] [Relay] PlanDevices Pass Failure when Two Operators with Different Target Devices Share the Same Input #15019

@lecoan

Description

@lecoan

While trying to utilize Apache TVM for heterogeneous execution, I've encountered a problem with the PlanDevices pass. Specifically, when two operators share the same input but are assigned to different target devices, PlanDevices pass fails.

As an example, consider the operation (a+b)-(b+c). If I intend to assign the first add operator to the CPU and the second one to the GPU, PlanDevices pass fails as it appears to have difficulty determining the appropriate device for b.

Given that it's common for multiple layers in a neural network to process the same input, this seems to be a bug in TVM that warrants attention.

Expected behavior

There are two behaviors I would expect from the PlanDevices pass in this scenario:

  1. Automatic addition of a device_copy: If CPU is the default device, the PlanDevices pass should append a device_copy between b and the last add operator.
  2. Input replication: The PlanDevices pass could replicate b on both CPU and GPU.

Actual behavior

Assuming the PlanDevices pass initially visits (a+b), it marks b for the CPU. However, when it visits (b+c), it throws an error as it attempts to place b on GPU.

Here is the error message:

TVMError: Function parameters and result VirtualDevices do not match those of call. Call:
free_var %b: Tensor[(5, 7), float32] ;
free_var %c: Tensor[(5, 7), float32] ;
%0 = add(%b, %c) ;
on_device(%0, virtual_device=VirtualDevice(device_type=2, virtual_device_id=0, target=Target(id=12cebdda0, kind='cuda', keys={'cuda', 'gpu'}, attrs={'max_num_threads': 1024, 'thread_warp_size': 32, 'arch': "sm_50"}, host=Target(id=12ceba810, kind='llvm', keys={'cpu'})))) 
with function virtual devices:
fn(?4828570296?VirtualDevice(device_type=2, virtual_device_id=0, target=Target(id=11fcaac00, kind='cuda', keys={'cuda', 'gpu'}, attrs={'max_num_threads': 1024, 'thread_warp_size': 32, 'arch': "sm_50"}, host=Target(id=11fcc96b0, kind='llvm', keys={'cpu'})))):?4828570408?
and implied call virtual devices:
fn(?4828403160?VirtualDevice(device_type=1, virtual_device_id=0, target=Target(id=11fcb2f80, kind='llvm', keys={'cpu'}, host=Target(id=11fcc96b0, kind='llvm', keys={'cpu'})))):?4828554904?

Environment

  • OS: Linux
  • TVM: Latest commit (4267fbf)

Steps to reproduce

Below is a minimal reproduction code which attempts to set devices for (a+b) - (b+c), where the first add operator is set to CPU and the other one is set to GPU:

import tvm
from tvm import relay

HOST_DEVICE = tvm.device("cpu")
HOST_TARGET = tvm.target.Target("llvm")

CPU_DEVICE = tvm.device("cpu")
CPU_TARGET = tvm.target.Target("llvm").with_host(HOST_TARGET)

GPU_DEVICE = tvm.device("cuda")
GPU_TARGET = tvm.target.Target("cuda").with_host(HOST_TARGET)
CPU = tvm.target.VirtualDevice(CPU_DEVICE, CPU_TARGET)  # device_type=1
GPU = tvm.target.VirtualDevice(GPU_DEVICE, GPU_TARGET)  # device_type=2

metatable = {"VirtualDevice": [CPU, GPU]}

mod = tvm.relay.parse(
        """
        #[version = "0.0.5"]
        def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
                    %c: Tensor[(5, 7), float32]) {
            %0 = add(%a, %b);
            %1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
            %2 = add(%b, %c);
            %3 = on_device(%2, virtual_device=meta[VirtualDevice][1]);
            subtract(%1, %3)
        }
        """,
        "from_string",
        None,
        metatable,
    )

DEFAULT = GPU
CTXT = tvm.transform.PassContext(config={"relay.fallback_device_type": DEFAULT.device_type_int})
TARGETS = [CPU_TARGET, GPU_TARGET]

config = tvm.target.make_compilation_config(CTXT, TARGETS)
mod = relay.transform.InferType()(mod)
mod = relay.transform.PlanDevices(config)(mod)
mod = relay.transform.InferType()(mod)

Triage

  • needs-triage

cc @shingjan

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions