-
Notifications
You must be signed in to change notification settings - Fork 327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use version of NDArray split that always returns a list. #454
Conversation
- Use version of ndarray split that always returns a list for uniform handling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thats a very subtle bug!
sockeye/inference.py
Outdated
@@ -158,7 +158,7 @@ def _get_encoder_module(self) -> Tuple[mx.mod.BucketingModule, int]: | |||
|
|||
def sym_gen(source_seq_len: int): | |||
source = mx.sym.Variable(C.SOURCE_NAME) | |||
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0] | |||
source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't need this change here, as in the symbolic API split seems to return always an 'indexable' Symbol/SliceChannel. We do these source-factor related splits also in other places of the code and it works there just fine. Also, the util function isn't typed for symbols and I am surprised data.split
(aka using the fluent method) works for symbols).
mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[0].eval(x=mx.nd.ones((2,2,2)))
This throws an error:
mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[1].eval(x=mx.nd.ones((2,2,2)))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
:return: List of NDArrays resulting from the split. | ||
""" | ||
ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis) | ||
if num_outputs == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to avoid the split altogether when num_outputs is 1? If squeeze_axis==True, one only would need a reshape/squeeze which is essentially a no-op.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another good point. I think it's a toss-up between staying as close to the original as possible versus micro-optimizing the call. Since we're now using this in just one place, called once per batch, I would lean toward keeping it this way for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, probably not worth the additional complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again, great fix!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
indeed a great catch :)
This fixes an issue where inference was silently breaking when using vocabulary restriction, batch decoding, and a single source factor. The source was the behavior of
mxnet.ndarray.split
that returns a list when num_outputs is greater than 1, but the individual NDArray that would be element 0 when num_outputs is 1. This was leading the code to pull element 0 of the NDarray instead of the NDArray itself. This commit adds a wrapper forsplit
that always returns a list for consistent behavior.Pull Request Checklist
until you can check this box.
pytest
)pytest test/system
)./style-check.sh
)sockeye/__init__.py
. Major version bump if this is a backwards incompatible change.By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 license.