From 22a83becb3e280a16e972fb773dabf5c657f6c6d Mon Sep 17 00:00:00 2001 From: blackkker <823036806@qq.com> Date: Mon, 25 Apr 2022 10:51:10 +0000 Subject: [PATCH 1/2] update Softmax with uniform operator --- python/tvm/relay/frontend/onnx.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3409d82606c1..6da6e8d3b76c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2317,19 +2317,19 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) - ndim = len(infer_shape(inputs[0])) + in_shape = infer_shape(inputs[0]) + ndim = len(in_shape) if axis < 0: axis += ndim - # Older ONNX Softmax op does not properly support inputs of dimension > 2 - # But we can use our softmax when the axis is -1 - if axis == ndim - 1: - return _op.nn.softmax(inputs[0], axis=axis) - - axes = list(range(axis, ndim)) - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - return e / _op.sum(e, axes, keepdims=True) + if axis == 0: + reshape_shape = [-1] + else: + axis_val = [in_shape[i] for i in range(axis)] + reshape_shape = [np.prod(axis_val)] + [-1] + data_reshape = _op.reshape(inputs[0], newshape=reshape_shape) + out = _op.nn.softmax(data_reshape, axis=-1) + out = _op.reshape(out, newshape=in_shape) + return out @classmethod def _impl_v13(cls, inputs, attr, _): From 0be6dbf554845f750845560fe4ac3dc00b2b9fca Mon Sep 17 00:00:00 2001 From: blackkker <823036806@qq.com> Date: Mon, 25 Apr 2022 10:51:34 +0000 Subject: [PATCH 2/2] add testcases for softmax --- tests/python/frontend/onnx/test_forward.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 01039e7443cb..d1e763bf0726 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1599,6 +1599,10 @@ def verify_softmax(inshape, axis): verify_softmax((1, 10), None) verify_softmax((1, 10), 1) + verify_softmax((1, 2, 3, 10), 0) + verify_softmax((1, 2, 3, 10), 2) + verify_softmax((1, 2, 3, 4, 10), 3) + verify_softmax((1, 2, 3, 4, 10), 4) @tvm.testing.parametrize_targets