Skip to content

Commit

Permalink
[microNPU][ETHOSU] Fix ConcatRewriter args processing
Browse files Browse the repository at this point in the history
In ConcatRewriter the case was not considered when the concatenation argument is TupleGetItem.
  • Loading branch information
Aleksei-grovety committed Oct 30, 2023
1 parent d9f3029 commit 7de4d9a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/tvm/relay/backend/contrib/ethosu/legalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,7 +1184,7 @@ def callback(
# Find the tensors that are inputs to the concat and the scales and zero points
concat_args = list()
for arg in post.args:
if isinstance(arg, tvm.relay.expr.Call):
if isinstance(arg, (tvm.relay.expr.Call, tvm.relay.expr.TupleGetItem)):
concat_args.append(arg)

axis = post.op.body.attrs.axis
Expand Down
16 changes: 16 additions & 0 deletions tests/python/contrib/test_ethosu/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,22 @@ def concat_func(*inputs):
infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False)


def test_tflite_unstack_concat():
np.random.seed(0)
shapes = [(2, 4, 16)]
axis = 1
accel_type = "ethos-u55-256"

@tf.function
def concat_func(input):
inputs = tf.unstack(input)
inputs.reverse()
op = tf.concat(inputs, axis)
return op

infra.compare_tvm_with_tflite(concat_func, shapes, accel_type, enable_cascader=False)


def test_tflite_concat_with_reused_args():
np.random.seed(0)
shapes = [(1, 1, 24, 1), (1, 1, 24, 1), (1, 1, 10, 1), (1, 1, 68, 1)]
Expand Down

0 comments on commit 7de4d9a

Please sign in to comment.