Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NNVM][TENSORFLOW] LSTM operator and PTB word prediction frontend #1389

Merged
merged 10 commits into from
Jul 25, 2018

Conversation

joyalbin
Copy link
Contributor

@joyalbin joyalbin commented Jul 6, 2018

This PR contain:

  • LSTM operator implementation and layer based operator parsing in Tensorflow
  • Single LSTMBlockCell layer model support
  • PTB LSTMBlockCell based model frontend script
  • Frontend testcases

@joyalbin
Copy link
Contributor Author

joyalbin commented Jul 7, 2018

@Huyuwei @masahi @srkreddy1238 Please help to review this PR

Copy link
Contributor

@srkreddy1238 srkreddy1238 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some initial review.

_, out_shapes = graph_util.infer_shape(g, **shape_dict)
return out_shapes

def _stridedSlice():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now we have strided_slice operator from nnvm. We should use it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@srkreddy1238 The frontend is using NNVM stridedslice only. But tensorflow have additional mask attributes for stridedslice. This is handled here.

'Taxis', '_class'])(new_input, attr)
return _impl

def _infer_out_shapes(inputs, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have already _input_shapes attribute. Any challenge using it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_infer_out_shapes is used for finding intermediate node shapes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

@@ -544,3 +711,6 @@ def test_forward_mobilenet():
test_forward_inception_v1()
test_forward_mobilenet()
test_forward_variable()
test_forward_lstm()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggest to add a real time end to end test case for LSTM.

@@ -0,0 +1,242 @@
"""
Tutorial for Tensorflow RNN Models
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

underline up to the text length pls.

sample_data_file = 'simple-examples.tgz'
sample_url = sample_repo+sample_data_file

ptb_repo = 'https://github.com/joyalbin/dmlc_store/raw/master/trained-models/tf/ptb/pb/'
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use dmlc/web-data to store these model.


###############################################################################
# Input words
# ---------------------------------------------
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

under line up to the text.


o = sigmoid(cs * wco + o)
co = tanh(cs)
h = co .* o
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"""Recurrent network layer handlers.

Unlike normal operators, recurrent network operators have layer concept.
Same operators will be called multiple times (based on number of layers)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does layer refer to steps? or multiple cells (layers) in stacked rnn? The doc here is not very clear.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Huyuwei here, layers refers to multiple cells in RNN stack. I have rephrased the comments

@Huyuwei
Copy link
Contributor

Huyuwei commented Jul 8, 2018

@joyalbin Seems current implementation only supports step=1 and users need to do unrolling manually in the case of multiple steps. What about adding a wrapper to support multiple steps?

@merrymercy may have some suggestions.

@joyalbin
Copy link
Contributor Author

@Huyuwei , @srkreddy1238 I have reworked on all the review comments and updated the PR

@Huyuwei Current implementation support step=1 at a time. This implementation is similar to tensorflow. Here the model input is one 'word', LSTM input is calculated from the 'word's token. so couldnt able to remove the unrolling.
Can we merge this PR with step=1 and multiple step handling in the next PR? Could you please put more light on avoiding this unrolling to handle multiple steps?

@joyalbin
Copy link
Contributor Author

@Huyuwei , @srkreddy1238 can you please help to review this PR further?


Unlike normal operators, stacked rnn have cells and layer concepts.
Each Layer represent a cell in RNN stack. Cells in the same RNN stack
sequentialy process input data.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The description here is still a little confusing. Where is the recurrent part?
Can remove it or give a link to some RNN tutorial.

initializer = tf.random_uniform_initializer(config.init_scale,
config.init_scale)
with tf.variable_scope("Model", reuse=None, initializer=initializer):
mtest = nnvm.testing.tf.PTBModel(is_training=False, config=config)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtest is not used? Then testing.tf.PTBModel can be removed

#TVM graph module creation
params, m = _get_tvm_graph_module(graph_def)


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove one blank line

@Huyuwei
Copy link
Contributor

Huyuwei commented Jul 16, 2018

@joyalbin Sorry for the late response. Have added some comments.

The lstm part looks good to me. fill, gather, stridedslice operators need review from @srkreddy1238

Copy link
Contributor

@nishi-t nishi-t left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

trivial typo fix

Dict of operator attributes

params : dict
List of pretrained weights and bias
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant whitespace after of

word_to_id : dict
English word to integer id mapping
id_to_word : dict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove this blank line?

Dict of operator attributes

params : dict
List of pretrained weights and bias
Copy link
Contributor

@nishi-t nishi-t Jul 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you remove redundant whitespace after 'of'.

'Taxis', '_class'])(new_input, attr)
return _impl

def _infer_out_shapes(inputs, params):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

new_axis_mask = int(attr.get('new_axis_mask', 0))
shrink_axis_mask = int(attr.get('shrink_axis_mask', 0))

#Constant values used forming output shape.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for cleaning the above logic.
Would suggest to clean the below part as well a bit.

My understanding of all these masks as below. Please correct me if I am wrong.

begin, end: Basically to ignore a value from the inputs and use 0 for begin, max for end.
elipsismask: Only one bit will be set which is used to expand begin, end and strides to size of input dimensions by filling max range in between.
newaxis mask: add new axis in the result based on the mask bits.
shringaxis_mask: shrink the axis from result based on the mask bits.

I see this logic can be split into below logically separable blocks.
1: Handle begin and end masks on begin, end
2: Expand begin and end based on elipsis_mask
3: Apply strided_slice operation
4: Apply reshape operation with newaxis and shrink axis

Also suggest to use python API for converting bitmask into lists of integers.

_test_stridedslice((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, -3, 0], [2, -2, 3], [1, 1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 1, 0], [4, 4, 3], [2, 1, 1], 'float32')
_test_stridedslice((3, 4, 3), [1, 0], [4, 3], [2, 1],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add another testcase with begin, end, ellipsis_mask and new/shrink axis mask with multiple axis.

@joyalbin
Copy link
Contributor Author

joyalbin commented Jul 24, 2018

@srkreddy1238 yes, your understanding is correct. But handling all the 5 masks based on its priority make things bit lengthy.
I have modified the code based on your comments.
Mask logic is changed as you suggested and testcases are added.

@nishi-t @Huyuwei I have handled all the review comments. Please help to approve the changes

return ptb_raw_data(data_path, file_name)

def get_workload_ptb():
""" Import mobilenet workload from frozen protobuf
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not mobilenet

state = session.run(state_input_name)
fetches = [['Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:1',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell:6',
'Model/RNN/RNN/multi_rnn_cell/cell_0/lstm_cell/LSTMBlockCell_1:1',
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment here? What is fetches? What is in the content of LSTMBlockCell:6?

return int(np.searchsorted(t, 0.5 * s))

def do_tf_sample(session, data, in_states, num_samples):
"""Sampled from the model"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data is not used in the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Huyuwei data is used at line 192, could you please correct me if I didn't get the real issue here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joyalbin sorry, my wrong.

Copy link
Contributor

@srkreddy1238 srkreddy1238 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joyalbin thanks, Strided_slice, Gather, Fill operators and GraphProto class modification are good to go.

@Huyuwei can confirm on the LSTM part.

@tqchen tqchen merged commit 9176753 into apache:master Jul 25, 2018
@tqchen
Copy link
Member

tqchen commented Jul 25, 2018

Thanks @joyalbin @Huyuwei @nishi-t @srkreddy1238 , this is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants