Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

rnn_cell little bug fixed #11003

Merged
merged 1 commit into from
Jun 30, 2018
Merged

rnn_cell little bug fixed #11003

merged 1 commit into from
Jun 30, 2018

Conversation

chinakook
Copy link
Contributor

@chinakook chinakook commented May 20, 2018

  1. When input is list of "NC" layout(i.e. split by time). The batch_axis should be 0.
  2. Add HybidSequentialRNNCell, which can be nested in HybridBlock but SequentialRNNCell cannot.

Description

(Brief description on what this PR is about)

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

@chinakook chinakook requested a review from szha as a code owner May 20, 2018 04:00
@@ -630,7 +630,7 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
return next_h, [next_h]


class SequentialRNNCell(RecurrentCell):
class SequentialRNNCell(HybridRecurrentCell):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might break existing code where the member cells are not HybridRecurrentCells. Maybe have a separate class like HybridSequentialRNNCell?

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests?

super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridSequentialRNNCell should have hybrid_forward implemented similarly to SequentialRNNCell.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. The hybrid_forward method should have the forward logic similar to what SequentialRNNCell has in its __call__

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong do you remember why SequentialRNNCell has hybrid_forward?

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
def hybrid_forward(self, F, *args, **kwargs):
super(HybridSequentialRNNCell, self).hybrid_forward(args, kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this class is an abstract base class and should not implement hybrid_forward.

super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the confusion. The hybrid_forward method should have the forward logic similar to what SequentialRNNCell has in its __call__

super(HybridSequentialRNNCell, self).__init__(prefix=prefix, params=params)

def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@piiswrong do you remember why SequentialRNNCell has hybrid_forward?

@@ -79,9 +79,10 @@ def _format_sequence(length, inputs, layout, merge, in_layout=None):
assert length is None or len(inputs) == length
if isinstance(inputs[0], symbol.Symbol):
F = symbol
# TODO: batch_size cannot got here
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

symbol should be able to infer this just fine

return _cells_begin_state(self._children.values(), **kwargs)

def __call__(self, inputs, states):
raise NotImplementedError("HybridSequentialRNN cannot be stepped. Please use unroll")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should allow forward unless bidirectional cell is registered, similar to sequential rnn cell.
a hybrid cell means that the graph inside the cell (i.e. for a single step) can be hybridized.

@szha szha requested a review from piiswrong May 21, 2018 18:43
@szha
Copy link
Member

szha commented Jun 18, 2018

Ping @chinakook

@szha
Copy link
Member

szha commented Jun 26, 2018

@chinakook pinging again. It would be great if we could get this in. Would you address the review comments? Thanks.

@chinakook
Copy link
Contributor Author

@szha Is that all?

next_states = []
p = 0
for cell in self._children.values():
assert not isinstance(cell, BidirectionalCell)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the assertion outside to fail fast. assert all(not isinstance(cell) for cell in ...)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, I'd suggest including a similar error message to what SequentialRNNCell step has.

Copy link
Member

@szha szha left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinakook otherwise looks good to me. Thanks.

@szha
Copy link
Member

szha commented Jun 29, 2018

@chinakook would you mind doing a rebase? we have a couple of flaky tests that are already fixed in the upstream.

@chinakook
Copy link
Contributor Author

chinakook commented Jun 29, 2018

Rebasing like this? I've not very familiar with rebasing. Should I close this and open another PR?

@szha
Copy link
Member

szha commented Jun 29, 2018

suppose you have a git remote called "upstream" (which you can get by doing git remote add upstream https://github.com/apache/incubator-mxnet), you can do git pull upstream master --rebase and then do a force push. You don't need to reopen the PR

@chinakook chinakook reopened this Jun 29, 2018
@@ -171,6 +171,54 @@ def test_stack():
assert outs == [(10, 100), (10, 100), (10, 100)]


def test_hybridstack():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a unit test for the NC layout bug you fixed?

def hybrid_forward(self, F, x):
return self.rnncell.unroll(3, x, layout="NTC", merge_outputs=True)

x = mx.nd.random.uniform(shape=(10, 3, 100))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chinakook could you split this into a list to verify your fix?

@szha szha merged commit 92900f0 into apache:master Jun 30, 2018
@szha
Copy link
Member

szha commented Jun 30, 2018

@chinakook thanks for the contribution. @eric-haibin-lin I will add the test.

szha added a commit to szha/mxnet that referenced this pull request Jun 30, 2018
@szha szha mentioned this pull request Jun 30, 2018
4 tasks
szha added a commit that referenced this pull request Jul 1, 2018
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
XinYao1994 pushed a commit to XinYao1994/incubator-mxnet that referenced this pull request Aug 29, 2018
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants