Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,31 @@ def _impl_v1(cls, inputs, attr, params):
output = output[0]
return output

@classmethod
def _impl_v13(cls, inputs, attr, params):
splits = inputs[1]
splits_rank = None
if splits is not None:
splits_rank = len(infer_shape(splits))
if splits is not None and splits_rank > 0:
if isinstance(splits, _expr.Constant):
splits = splits.data.asnumpy()
indices = []
index = 0
for i in splits[:-1]:
index += i
indices.append(index)
else:
raise ValueError("Dynamic Split not yet supported")
# When splits isnt specified divide evenly over axis.
else:
indices = attr["tvm_custom"]["num_outputs"]
output = _op.split(inputs[0], indices, attr.get("axis", 0))
# If the output of split is a single value, unpack if from the TupleWrapper
if len(output) == 1:
output = output[0]
return output


class Slice(OnnxOpConverter):
"""Operator converter for Slice."""
Expand Down
18 changes: 15 additions & 3 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def get_tvm_output_with_vm(
if not isinstance(input_data, list):
input_data = [input_data]
_, shape_dict = get_input_data_shape_dict(graph_def, input_data)

mod, params = relay.frontend.from_onnx(
graph_def,
shape_dict,
Expand Down Expand Up @@ -167,7 +166,6 @@ def verify_with_ort_with_inputs(
model.opset_import[0].version = opset

ort_out = get_onnxruntime_output(model, inputs)

if use_vm:
tvm_out = get_tvm_output_with_vm(
model,
Expand Down Expand Up @@ -1954,7 +1952,9 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
inputs.append(
helper.make_tensor_value_info("split", TensorProto.INT64, list(np_split.shape))
)
indata = [indata, np_split]
# TODO(mbrookhart): Support dynamic split, edit this test case to remove split from
# the initializer and add it back to the input data
indata = [indata] # , np_split]
initializer.append(
helper.make_tensor("split", TensorProto.INT64, list(np_split.shape), np_split)
)
Expand Down Expand Up @@ -1989,6 +1989,8 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
opset=opset,
target=target,
dev=dev,
use_vm=True,
freeze_params=(opset >= 13),
)

# 1D
Expand All @@ -1997,13 +1999,23 @@ def verify_split(indata, outdatas, split, axis=0, pass_split=True, opset=11):
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [2, 2, 2], 0, False
)
verify_split([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0)
verify_split(
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [[1.0, 2.0], [3.0], [4.0, 5.0, 6.0]], [2, 1, 3], 0, opset=13
)
# 2D
verify_split(
[[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]],
[[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]],
[2, 2],
1,
)
verify_split(
[[1.0, 2.0, 3.0, 4.0], [7.0, 8.0, 9.0, 10.0]],
[[[1.0, 2.0], [7.0, 8.0]], [[3.0, 4.0], [9.0, 10.0]]],
[2, 2],
1,
opset=13,
)
# Split evenly (unstack)
verify_split([1, 2, 3], [[1], [2], [3]], False, 0, False)
# Split a single value to a single value
Expand Down