Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use version of NDArray split that always returns a list. #454

Merged
merged 4 commits into from
Jun 25, 2018
Merged

Conversation

mjdenkowski
Copy link
Contributor

@mjdenkowski mjdenkowski commented Jun 22, 2018

This fixes an issue where inference was silently breaking when using vocabulary restriction, batch decoding, and a single source factor. The source was the behavior of mxnet.ndarray.split that returns a list when num_outputs is greater than 1, but the individual NDArray that would be element 0 when num_outputs is 1. This was leading the code to pull element 0 of the NDarray instead of the NDArray itself. This commit adds a wrapper for split that always returns a list for consistent behavior.

Pull Request Checklist

  • Changes are complete (if posting work-in-progress code, prefix your pull request title with '[WIP]'
    until you can check this box.
  • Unit tests pass (pytest)
  • System tests pass (pytest test/system)
  • Passed code style checking (./style-check.sh)
  • You have considered writing a test
  • Updated major/minor version in sockeye/__init__.py. Major version bump if this is a backwards incompatible change.
  • Updated CHANGELOG.md

By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.

Copy link
Contributor

@fhieber fhieber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thats a very subtle bug!

@@ -158,7 +158,7 @@ def _get_encoder_module(self) -> Tuple[mx.mod.BucketingModule, int]:

def sym_gen(source_seq_len: int):
source = mx.sym.Variable(C.SOURCE_NAME)
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this change here, as in the symbolic API split seems to return always an 'indexable' Symbol/SliceChannel. We do these source-factor related splits also in other places of the code and it works there just fine. Also, the util function isn't typed for symbols and I am surprised data.split (aka using the fluent method) works for symbols).

mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[0].eval(x=mx.nd.ones((2,2,2)))

This throws an error:

mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[1].eval(x=mx.nd.ones((2,2,2)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

:return: List of NDArrays resulting from the split.
"""
ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
if num_outputs == 1:
Copy link
Contributor

@fhieber fhieber Jun 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to avoid the split altogether when num_outputs is 1? If squeeze_axis==True, one only would need a reshape/squeeze which is essentially a no-op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good point. I think it's a toss-up between staying as close to the original as possible versus micro-optimizing the call. Since we're now using this in just one place, called once per batch, I would lean toward keeping it this way for clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, probably not worth the additional complexity.

@fhieber fhieber added the bug label Jun 23, 2018
Copy link
Contributor

@fhieber fhieber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, great fix!

Copy link
Contributor

@tdomhan tdomhan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

indeed a great catch :)

@tdomhan tdomhan merged commit c59361d into master Jun 25, 2018
@tdomhan tdomhan deleted the split-fix branch June 25, 2018 09:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants