Skip to content
Permalink
Browse files

Merge pull request #83 from Microsoft/EMI-Bugfixes

[Minor] Adding custom non-linearity option to FastRNN
  • Loading branch information...
adityakusupati committed Apr 15, 2019
2 parents 67f24ae + ee408fd commit b6cbd5cf9ab0819196618cabba975fcb050f605b
Showing with 9 additions and 2 deletions.
  1. +9 −2 tf/edgeml/graph/rnn.py
@@ -12,6 +12,9 @@
def gen_non_linearity(A, non_linearity):
'''
Returns required activation for a tensor based on the inputs
non_linearity is either a callable or a value in
['tanh', 'sigmoid', 'relu', 'quantTanh', 'quantSigm']
'''
if non_linearity == "tanh":
return math_ops.tanh(A)
@@ -25,7 +28,12 @@ def gen_non_linearity(A, non_linearity):
A = (A + 1.0) / 2.0
return gen_math_ops.maximum(gen_math_ops.minimum(A, 1.0), 0.0)
else:
return math_ops.tanh(A)
# non_linearity is a user specified function
if not callable(non_linearity):
raise ValueError("non_linearity is either a callable or a value " +
+ "['tanh', 'sigmoid', 'relu', 'quantTanh', " +
"'quantSigm'")
return non_linearity(A)


class FastGRNNCell(RNNCell):
@@ -181,7 +189,6 @@ def call(self, inputs, state):
"B_h", [1, self._hidden_size], initializer=bias_update_init)
c = gen_non_linearity(
pre_comp + self.bias_update, self._update_non_linearity)

new_h = z * state + (math_ops.sigmoid(self.zeta) * (1.0 - z) +
math_ops.sigmoid(self.nu)) * c
return new_h, new_h

0 comments on commit b6cbd5c

Please sign in to comment.
You can’t perform that action at this time.