diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3409d82606c1..6da6e8d3b76c 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2317,19 +2317,19 @@ class Softmax(OnnxOpConverter): @classmethod def _impl_v1(cls, inputs, attr, params): axis = attr.get("axis", 1) - ndim = len(infer_shape(inputs[0])) + in_shape = infer_shape(inputs[0]) + ndim = len(in_shape) if axis < 0: axis += ndim - # Older ONNX Softmax op does not properly support inputs of dimension > 2 - # But we can use our softmax when the axis is -1 - if axis == ndim - 1: - return _op.nn.softmax(inputs[0], axis=axis) - - axes = list(range(axis, ndim)) - x = inputs[0] - m = _op.max(x, axes, keepdims=True) - e = _op.exp(x - m) - return e / _op.sum(e, axes, keepdims=True) + if axis == 0: + reshape_shape = [-1] + else: + axis_val = [in_shape[i] for i in range(axis)] + reshape_shape = [np.prod(axis_val)] + [-1] + data_reshape = _op.reshape(inputs[0], newshape=reshape_shape) + out = _op.nn.softmax(data_reshape, axis=-1) + out = _op.reshape(out, newshape=in_shape) + return out @classmethod def _impl_v13(cls, inputs, attr, _): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index 01039e7443cb..d1e763bf0726 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -1599,6 +1599,10 @@ def verify_softmax(inshape, axis): verify_softmax((1, 10), None) verify_softmax((1, 10), 1) + verify_softmax((1, 2, 3, 10), 0) + verify_softmax((1, 2, 3, 10), 2) + verify_softmax((1, 2, 3, 4, 10), 3) + verify_softmax((1, 2, 3, 4, 10), 4) @tvm.testing.parametrize_targets