Skip to content

[Bug] relax onnx frontend can not support dynamic slice #17535

@irasin

Description

@irasin

Thanks for participating in the TVM community! We use https://discuss.tvm.ai for any general usage questions and discussions. The issue tracker is used for actionable items such as feature proposals discussion, roadmaps, and bug tracking. You are always welcomed to post on the forum first 😸

Issues that are inactive for a period of time may get closed. We adopt this policy so that we won't lose track of actionable issues that may fall at the bottom of the pile. Feel free to reopen a new one if you feel there is an additional problem that needs attention when an old one gets closed.

Expected behavior

I have a onnx model which only contains a single slice op as below
image

And I test it with onnxruntime using the data from https://onnx.ai/onnx/operators/onnx__Slice.html
image
And the answer is right.

And I think relax can run this model too, but it failed.

When I build it with relax onnx frontend, it reported

    raise ValueError("Only constant Slice parameters are currently supported.")
ValueError: Only constant Slice parameters are currently supported.

which is reported by onnx_frontend,py

class Slice(OnnxOpConverter):
    """Converts an onnx Splice node into an equivalent Relax expression."""

    @classmethod
    def _impl_v13(cls, bb, inputs, attr, params):
        # TODO (jwfromm) currently only supports constant parameters.
        data = inputs[0]
        starts = get_constant(inputs[1], params)
        ends = get_constant(inputs[2], params)
        axes = get_constant(inputs[3], params)
        steps = get_constant(inputs[4], params)
        if not all(
            [
                (
                    isinstance(param, (relax.Constant, relax.ShapeExpr, relax.PrimValue))
                    or param is None
                )
                for param in [starts, ends, axes, steps]
            ]
        ):
            raise ValueError("Only constant Slice parameters are currently supported.")

However, I find that a relax op called dynamic_strided_slice exists in https://github.com/apache/tvm/blob/main/python/tvm/relax/op/index.py#L102

I wonder why we can not use it in onnx frontend?
And if we can use it, how should we modify the onnx slice op converter.

Thanks a lot.

Environment

I used the latest main branch of tvm and installed from source on my ubuntu server.

Steps to reproduce

Here is a minmal python scripts to build onnx model and test on relax

from typing import Dict, List, Literal, Optional

import onnx
from onnx import ModelProto, TensorProto, helper, mapping
import onnxruntime
import numpy as np


import tvm
import tvm.testing
from tvm import relax
from tvm.relax.frontend.onnx import from_onnx
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script import ir as I


bg = np.random.MT19937(0)
rg = np.random.Generator(bg)


def generate_random_value(shape, elem_type) -> np.ndarray:

    # Extract datatype for the input.
    if elem_type:
        dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
    else:
        dtype = "float32"

    # Generate random inputs for each input.
    if dtype == "bool":
        # random_value = np.random.choice(a=[False, True], size=shape)
        random_value = rg.choice(a=[False, True], size=shape)
    elif dtype.startswith("int"):
        # Keep non-zero values
        random_value = rg.integers(low=-63, high=63, size=shape).astype(dtype)
        random_value[random_value <= 0] -= 1
    else:
        random_value = rg.standard_normal(size=shape).astype(dtype)

    return random_value


def generate_random_inputs(
    model: ModelProto, inputs: Optional[Dict[str, np.ndarray]] = None
) -> Dict[str, np.ndarray]:
    input_values = {}
    # Iterate through model inputs and extract their shape.
    for i in model.graph.input:
        if inputs is not None and i.name in inputs and inputs[i.name] is not None:
            input_values[i.name] = inputs[i.name]
            continue
        shape = []
        for dim in i.type.tensor_type.shape.dim:
            shape.append(dim.dim_value)

        input_values[i.name] = generate_random_value(shape, i.type.tensor_type.elem_type)

    return input_values



def check_correctness(
    model: ModelProto,
    inputs: Optional[Dict[str, np.ndarray]] = None,
    ir_version: int = 8,
    opset: int = 14,
    rtol: float = 1e-7,
    atol: float = 1e-5,
) -> None:
    """Run an onnx model in both onnxruntime and TVM through our importer
       confirm that the results match. Otherwise, an exception will be raised.

    Parameters
    ----------
    model: ModelProto
        The input onnx model that should be tested.
    inputs: Optional[Dict[str, np.ndarray]]
        An optional dictionary containing values for each input in the onnx model.
    ir_version: int
        Which version of the onnx IR to use.
    opset: int
        The opset version to use for the onnx importer.
    atol: float
        Set the tolerance of correctness checking. Some ops may be show more
        arithmetic variance than others.
    """
    # Configure model format.
    if ir_version is not None:
        model.ir_version = ir_version
    if opset is not None:
        model.opset_import[0].version = opset

    # If inputs are not provided, extract them from the onnx graph and produce random
    # values that we'll use for testing.
    inputs = generate_random_inputs(model, inputs)

    # Run the model through onnx to get the expected result.
    ort_session = onnxruntime.InferenceSession(
        model.SerializeToString(), providers=["CPUExecutionProvider"]
    )
    ort_output = ort_session.run([], inputs)

    # Convert the onnx model into relax through the onnx importer.
    tvm_model = from_onnx(model, opset=opset, keep_params_in_input=True)
    # Convert operators for inference mode.
    tvm_model = relax.transform.DecomposeOpsForInference()(tvm_model)
    # Legalize any relax ops into tensorir.
    tvm_model = relax.transform.LegalizeOps()(tvm_model)

    # Separate model from parameters.
    tvm_model, params = relax.frontend.detach_params(tvm_model)
    # Compile the relax graph into a VM then run.

    target = "cuda"
    dev = tvm.device(target, 0)

    with tvm.target.Target(target):
        tvm_model = tvm.tir.transform.DefaultGPUSchedule()(tvm_model) 

    with tvm.transform.PassContext(opt_level=3):
        ex = relax.build(tvm_model, target=target)
        vm = relax.VirtualMachine(ex, dev)


    # Prepare inputs.
    input_list = [
        inputs[key.name_hint] for key in tvm_model["main"].params if key.name_hint in inputs
    ]
    if params:
        input_list += params["main"]

    # Run model and check outputs.
    vm.set_input("main", *input_list)
    vm.invoke_stateful("main")
    tvm_output = vm.get_outputs("main")
    # Wrap as a list if there is only one output.
    if len(ort_output) == 1:
        # Do not check the output number for TVM
        # As for sequence output, the TVM output is a Tuple
        # while the ONNX output number is one, which is a list
        tvm_output = [tvm_output]

    def _check_output(tvm_out, ort_out):
        if isinstance(tvm_out, tuple) and isinstance(ort_out, (tvm.runtime.ShapeTuple, list)):
            assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
            for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
                _check_output(tvm_out_i, ort_out_i)
        elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, np.ndarray):
            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol)
        elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and isinstance(ort_out, np.ndarray):
            shape_out = tvm.nd.array([int(i) for i in tvm_out])
            tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol)
        elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, np.ndarray):
            tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol)
        else:
            raise ValueError(f"Unsupported types: {type(tvm_out)}, {type(ort_out)}")

    # Check that number of outputs match.
    assert len(tvm_output) == len(ort_output), "Unequal number of outputs"
    for tvm_out, ort_out in zip(tvm_output, ort_output):
        # TODO Allow configurable tolerance.
        if ort_out is not None:
            _check_output(tvm_out, ort_out)


if __name__ == "__main__":


    data_shape = ["A", "B"]
    output_shape = ["C", "D"]
    starts_shape = [2]
    ends_shape = [2]
    axes_shape = [2]
    steps_shape = [2]

    slice_node = helper.make_node("Slice", inputs=["data", "starts", "ends", "axes", "steps"], outputs=["output"])

    graph = helper.make_graph(
        [slice_node],
        "slice_test",
        inputs=[
            helper.make_tensor_value_info("data", TensorProto.FLOAT, data_shape),
            helper.make_tensor_value_info("starts", TensorProto.INT64, starts_shape),
            helper.make_tensor_value_info("ends", TensorProto.INT64, ends_shape),
            helper.make_tensor_value_info("axes", TensorProto.INT64, axes_shape),
            helper.make_tensor_value_info("steps", TensorProto.INT64, steps_shape),
        ],
        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, output_shape)],

    )

    model = helper.make_model(graph, producer_name="slice_test")
    model.opset_import[0].version = 11
    onnx.checker.check_model(model)
    # onnx.save(model, "slice.onnx")

    # data from https://onnx.ai/onnx/operators/onnx__Slice.html
    inputs = {
        "data": np.array([
            [1, 2, 3, 4],
            [5, 6, 7, 8]]
        ).astype("float32"),
        "starts": np.array([1, 0], "int64"),
        "ends": np.array([2, 3], "int64"),
        "axes": np.array([0, 1], "int64"),
        "steps": np.array([1, 2], "int64"),   
    }
    expected_output = np.array([[5, 7]]).astype("float32")


    import onnxruntime as ort
    session = ort.InferenceSession(model.SerializeToString(), providers=["CPUExecutionProvider"])
    ort_output = session.run([], inputs)
    print(ort_output[0])
    print(expected_output)


    check_correctness(model, inputs)

Triage

Please refer to the list of label tags here to find the relevant tags and add them below in a bullet format (example below).

  • needs-triage

Metadata

Metadata

Assignees

No one assigned

    Labels

    needs-triagePRs or issues that need to be investigated by maintainers to find the right assignees to address ittype: bug

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions