Skip to content

Commit

Permalink
[NNVM][TENSORFLOW] Sigmoid op support #1367 (#1369)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 authored and tqchen committed Jul 4, 2018
1 parent 035696f commit 9ccfaad
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
1 change: 1 addition & 0 deletions nnvm/python/nnvm/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,7 @@ def _impl(inputs, attr, params):
'Relu6' : _relu6(),
'DepthwiseConv2dNative' : _depthwise_conv(),
'Shape' : _shape(),
'Sigmoid' : AttrCvt('sigmoid'),
}


Expand Down
38 changes: 38 additions & 0 deletions nnvm/tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.core.framework import graph_pb2

import nnvm.testing.tf
Expand Down Expand Up @@ -329,6 +330,42 @@ def _test_forward_concat_v2():

_test_concat_v2([t1, t2], 1)

#######################################################################
# Sigmoid
# -------

def _test_sigmoid(data):
""" One iteration of sigmoid """

with tf.Graph().as_default():
in_data = constant_op.constant(data, shape=data.shape, dtype=data.dtype)

# pylint: disable=unused-variable
sigmoid_out = math_ops.sigmoid(in_data)
# pylint: enable=unused-variable

with tf.Session() as sess:
graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
['Sigmoid'],
)

tf_output = run_tf_graph(sess, data,
'Const:0', 'Sigmoid:0')
tvm_output = run_tvm_graph(graph_def,
data,
"Const", tf_output.shape, data.dtype)

np.testing.assert_allclose(tf_output, tvm_output, atol=1e-5, rtol=1e-5)

sess.close()

def test_forward_sigmoid():
""" Sigmoid """

_test_sigmoid(np.random.uniform(size=(3, 4, 4, 3)).astype('float32'))

#######################################################################
# Multi Input to graph
# --------------------
Expand Down Expand Up @@ -437,6 +474,7 @@ def test_forward_mobilenet():
test_forward_pooling()
test_forward_reshape()
test_forward_squeeze()
test_forward_sigmoid()
if tf.__version__ == '1.4.1':
_test_forward_concat_v2()
test_forward_multi_input()
Expand Down

0 comments on commit 9ccfaad

Please sign in to comment.