Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MXNET]Softmin, trunc op support added #5715

Merged
merged 1 commit into from Jun 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 7 additions & 0 deletions python/tvm/relay/frontend/mxnet.py
Expand Up @@ -846,6 +846,11 @@ def _mx_softsign(inputs, attrs):
return inputs[0] / (_expr.const(1.0) + _op.abs(inputs[0]))


def _mx_softmin(inputs, attrs):
axis = attrs.get_int("axis", -1)
return _op.nn.softmax(_op.negative(inputs[0]), axis)


def _mx_hard_sigmoid(inputs, attrs):
x = (_expr.const(0.2) * inputs[0]) + _expr.const(0.5)
return _op.clip(x, a_min=0.0, a_max=1.0)
Expand Down Expand Up @@ -1829,6 +1834,7 @@ def impl(inputs, input_types):
"floor",
"ceil",
"round",
"trunc",
"sign",
"sigmoid",
"negative",
Expand Down Expand Up @@ -1938,6 +1944,7 @@ def impl(inputs, input_types):
"log_softmax" : _softmax_op(_op.nn.log_softmax),
"Softmax" : _softmax_op(_op.nn.softmax),
"softsign" : _mx_softsign,
"softmin" : _mx_softmin,
"hard_sigmoid" : _mx_hard_sigmoid,
"reciprocal" : _mx_reciprocal,
# per op specialization
Expand Down
12 changes: 11 additions & 1 deletion tests/python/frontend/mxnet/test_forward.py
Expand Up @@ -372,8 +372,17 @@ def test_forward_elemwise_ops():
tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy())


def test_forward_softmin():
data = mx.sym.var('data')
mx_sym = mx.sym.softmin(data)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))

mx_sym = mx.sym.softmin(data, axis=2)
verify_mxnet_frontend_impl(mx_sym, (1, 3, 100, 100), (1, 3, 100, 100))


def test_forward_unary_ops():
for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal",
for op in ["abs", "sqrt", "ceil", "floor", "round", "reciprocal", "trunc",
"softsign", "hard_sigmoid",
"cos", "sin", "tan",
"cosh", "sinh", "tanh",
Expand Down Expand Up @@ -1191,6 +1200,7 @@ def verify(data_shape, kernel_size, max_displacement, stride1, stride2, pad_size
test_forward_rrelu()
test_forward_prelu()
test_forward_softrelu()
test_forward_softmin()
test_forward_fc_flatten()
test_forward_clip()
test_forward_split()
Expand Down