From c1b2a4bf634570b72e4305d618b1fdc92177ce10 Mon Sep 17 00:00:00 2001 From: Civitasv Date: Wed, 12 Jul 2023 19:04:51 +0800 Subject: [PATCH 1/3] [unity] [onnx frontend] [op] add support for trilu operator --- .../tvm/relax/frontend/onnx/onnx_frontend.py | 18 ++++++++++++++++++ tests/python/relax/test_frontend_onnx.py | 4 ++++ 2 files changed, 22 insertions(+) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index d653bb551113..5dd1cec2ad99 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -499,6 +499,23 @@ def _impl_v13(cls, bb, inputs, attr, params): return relax.op.sqrt(inputs[0]) +class Trilu(OnnxOpConverter): + """Given a 2-D matrix or batches of 2-D matrices, returns the upper or + lower triangular part of the tensor(s) + """ + + @classmethod + def _impl_v14(cls, bb, inputs, attr, params): + upper = attr.get("upper", True) + x = inputs[0] + k = inputs[1] if len(inputs) > 1 else 0 + + if upper: + return relax.op.triu(x, k) + else: + return relax.op.tril(x, k) + + class Relu(OnnxOpConverter): """Converts an onnx Relu node into an equivalent Relax expression.""" @@ -1712,6 +1729,7 @@ def _get_convert_map(): "Shape": Shape, "Tanh": Tanh, "Sqrt": Sqrt, + "Trilu": Trilu, "Relu": Relu, "Conv": Conv, "Pow": Pow, diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 4c4d2d5a955c..7e0644a50e8e 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -528,6 +528,10 @@ def test_relu(): verify_unary("Relu", [32, 32]) +def test_trilu(): + verify_unary("Trilu", [3, 5, 5], attrs={"upper": False}) + + def test_conv(): def _verify_conv(input_shape, weight_shape, output_shape): bias_shape = [output_shape[1]] From f36a68fc0e1d2465931b1553fcc2687b7c7a2b24 Mon Sep 17 00:00:00 2001 From: Civitasv Date: Thu, 13 Jul 2023 10:47:47 +0800 Subject: [PATCH 2/3] [unity] [onnx frontend] [op] fix ci and add test for triu. --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 6 +++--- tests/python/relax/test_frontend_onnx.py | 6 +++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5dd1cec2ad99..9ec340c0384b 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -501,7 +501,7 @@ def _impl_v13(cls, bb, inputs, attr, params): class Trilu(OnnxOpConverter): """Given a 2-D matrix or batches of 2-D matrices, returns the upper or - lower triangular part of the tensor(s) + lower triangular part of the tensor(s) """ @classmethod @@ -509,12 +509,12 @@ def _impl_v14(cls, bb, inputs, attr, params): upper = attr.get("upper", True) x = inputs[0] k = inputs[1] if len(inputs) > 1 else 0 - + if upper: return relax.op.triu(x, k) else: return relax.op.tril(x, k) - + class Relu(OnnxOpConverter): """Converts an onnx Relu node into an equivalent Relax expression.""" diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 7e0644a50e8e..468e504f7b75 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -528,8 +528,12 @@ def test_relu(): verify_unary("Relu", [32, 32]) -def test_trilu(): +def test_tril(): verify_unary("Trilu", [3, 5, 5], attrs={"upper": False}) + + +def test_triu(): + verify_unary("Trilu", [3, 5, 5], attrs={"upper": True}) def test_conv(): From b3af1dc9b2f0885a05863eeb5c032c4b6ee2b8c6 Mon Sep 17 00:00:00 2001 From: Civitasv <37768049+Civitasv@users.noreply.github.com> Date: Thu, 13 Jul 2023 12:20:21 +0800 Subject: [PATCH 3/3] [unity] [onnx frontend] [op] fix improper indents. --- tests/python/relax/test_frontend_onnx.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 468e504f7b75..c5c094e115fa 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -530,8 +530,8 @@ def test_relu(): def test_tril(): verify_unary("Trilu", [3, 5, 5], attrs={"upper": False}) - - + + def test_triu(): verify_unary("Trilu", [3, 5, 5], attrs={"upper": True})