-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend #1389
Conversation
@Huyuwei @masahi @srkreddy1238 Please help to review this PR |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial review.
_, out_shapes = graph_util.infer_shape(g, **shape_dict) | ||
return out_shapes | ||
|
||
def _stridedSlice(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now we have strided_slice operator from nnvm. We should use it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@srkreddy1238 The frontend is using NNVM stridedslice only. But tensorflow have additional mask attributes for stridedslice. This is handled here.
'Taxis', '_class'])(new_input, attr) | ||
return _impl | ||
|
||
def _infer_out_shapes(inputs, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have already _input_shapes attribute. Any challenge using it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
_infer_out_shapes is used for finding intermediate node shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
@@ -544,3 +711,6 @@ def test_forward_mobilenet(): | |||
test_forward_inception_v1() | |||
test_forward_mobilenet() | |||
test_forward_variable() | |||
test_forward_lstm() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggest to add a real time end to end test case for LSTM.
@@ -0,0 +1,242 @@ | |||
""" | |||
Tutorial for Tensorflow RNN Models |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
underline up to the text length pls.
sample_data_file = 'simple-examples.tgz' | ||
sample_url = sample_repo+sample_data_file | ||
|
||
ptb_repo = 'https://github.com/joyalbin/dmlc_store/raw/master/trained-models/tf/ptb/pb/' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use dmlc/web-data to store these model.
|
||
############################################################################### | ||
# Input words | ||
# --------------------------------------------- |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
under line up to the text.
|
||
o = sigmoid(cs * wco + o) | ||
co = tanh(cs) | ||
h = co .* o |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can remove these math equations and give a link to https://github.com/tensorflow/tensorflow/blob/r1.8/tensorflow/contrib/rnn/python/ops/lstm_ops.py#L41-L114
"""Recurrent network layer handlers. | ||
|
||
Unlike normal operators, recurrent network operators have layer concept. | ||
Same operators will be called multiple times (based on number of layers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does layer refer to steps? or multiple cells (layers) in stacked rnn? The doc here is not very clear.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Huyuwei here, layers refers to multiple cells in RNN stack. I have rephrased the comments
@joyalbin Seems current implementation only supports step=1 and users need to do unrolling manually in the case of multiple steps. What about adding a wrapper to support multiple steps? @merrymercy may have some suggestions. |
@Huyuwei , @srkreddy1238 I have reworked on all the review comments and updated the PR @Huyuwei Current implementation support step=1 at a time. This implementation is similar to tensorflow. Here the model input is one 'word', LSTM input is calculated from the 'word's token. so couldnt able to remove the unrolling. |
@Huyuwei , @srkreddy1238 can you please help to review this PR further? |
|
||
Unlike normal operators, stacked rnn have cells and layer concepts. | ||
Each Layer represent a cell in RNN stack. Cells in the same RNN stack | ||
sequentialy process input data. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The description here is still a little confusing. Where is the recurrent part?
Can remove it or give a link to some RNN tutorial.
initializer = tf.random_uniform_initializer(config.init_scale, | ||
config.init_scale) | ||
with tf.variable_scope("Model", reuse=None, initializer=initializer): | ||
mtest = nnvm.testing.tf.PTBModel(is_training=False, config=config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mtest is not used? Then testing.tf.PTBModel can be removed
#TVM graph module creation | ||
params, m = _get_tvm_graph_module(graph_def) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove one blank line
@joyalbin Sorry for the late response. Have added some comments. The lstm part looks good to me. fill, gather, stridedslice operators need review from @srkreddy1238 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trivial typo fix
Dict of operator attributes | ||
|
||
params : dict | ||
List of pretrained weights and bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant whitespace after of
nnvm/python/nnvm/testing/tf.py
Outdated
word_to_id : dict | ||
English word to integer id mapping | ||
id_to_word : dict | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove this blank line?
Dict of operator attributes | ||
|
||
params : dict | ||
List of pretrained weights and bias |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you remove redundant whitespace after 'of'.
'Taxis', '_class'])(new_input, attr) | ||
return _impl | ||
|
||
def _infer_out_shapes(inputs, params): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
new_axis_mask = int(attr.get('new_axis_mask', 0)) | ||
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0)) | ||
|
||
#Constant values used forming output shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for cleaning the above logic.
Would suggest to clean the below part as well a bit.
My understanding of all these masks as below. Please correct me if I am wrong.
begin, end: Basically to ignore a value from the inputs and use 0 for begin, max for end.
elipsismask: Only one bit will be set which is used to expand begin, end and strides to size of input dimensions by filling max range in between.
newaxis mask: add new axis in the result based on the mask bits.
shringaxis_mask: shrink the axis from result based on the mask bits.
I see this logic can be split into below logically separable blocks.
1: Handle begin and end masks on begin, end
2: Expand begin and end based on elipsis_mask
3: Apply strided_slice operation
4: Apply reshape operation with newaxis and shrink axis
Also suggest to use python API for converting bitmask into lists of integers.
_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32') | ||
_test_stridedslice((3, 4, 3), [1, -3, 0], [2, -2, 3], [1, 1, 1], 'float32') | ||
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32') | ||
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add another testcase with begin, end, ellipsis_mask and new/shrink axis mask with multiple axis.
@srkreddy1238 yes, your understanding is correct. But handling all the 5 masks based on its priority make things bit lengthy. @nishi-t @Huyuwei I have handled all the review comments. Please help to approve the changes |
nnvm/python/nnvm/testing/tf.py
Outdated
return ptb_raw_data(data_path, file_name) | ||
|
||
def get_workload_ptb(): | ||
""" Import mobilenet workload from frozen protobuf |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not mobilenet
nnvm/python/nnvm/testing/tf.py
Outdated
state = session.run(state_input_name) | ||
fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1', | ||
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6', | ||
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a comment here? What is fetches? What is in the content of LSTMBlockCell:6?
return int(np.searchsorted(t, 0.5 * s)) | ||
|
||
def do_tf_sample(session, data, in_states, num_samples): | ||
"""Sampled from the model""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
data
is not used in the function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Huyuwei data
is used at line 192, could you please correct me if I didn't get the real issue here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@joyalbin sorry, my wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @joyalbin @Huyuwei @nishi-t @srkreddy1238 , this is merged |
This PR contain: