Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use version of NDArray split that always returns a list. #454

Merged
merged 4 commits into from
Jun 25, 2018
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.


## [1.18.27]
### Fixed
- Fix silent failing of NDArray splits during inference by using a version that always returns a list. This was causing incorrect behavior when using lexicon restriction and batch inference with a single source factor.

## [1.18.26]
### Added
- ROUGE score evaluation. It can be used as the stopping criterion for tasks such as summarization.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.26'
__version__ = '1.18.27'
6 changes: 3 additions & 3 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def _get_encoder_module(self) -> Tuple[mx.mod.BucketingModule, int]:

def sym_gen(source_seq_len: int):
source = mx.sym.Variable(C.SOURCE_NAME)
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't need this change here, as in the symbolic API split seems to return always an 'indexable' Symbol/SliceChannel. We do these source-factor related splits also in other places of the code and it works there just fine. Also, the util function isn't typed for symbols and I am surprised data.split (aka using the fluent method) works for symbols).

mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[0].eval(x=mx.nd.ones((2,2,2)))

This throws an error:

mx.sym.split(mx.sym.Variable('x'), num_outputs=1, axis=2)[1].eval(x=mx.nd.ones((2,2,2)))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

source_length = utils.compute_lengths(source_words)

# source embedding
Expand Down Expand Up @@ -1334,7 +1334,7 @@ def _beam_search(self,
"""
Translates multiple sentences using beam search.

:param source: Source ids. Shape: (batch_size, bucket_key).
:param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
:param source_length: Max source length.
:param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
that must appear in each output.
Expand Down Expand Up @@ -1383,9 +1383,9 @@ def _beam_search(self,
pad_dist = self.pad_dist
vocab_slice_ids = None # type: mx.nd.NDArray
if self.restrict_lexicon:
source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
# TODO: See note in method about migrating to pure MXNet when set operations are supported.
# We currently convert source to NumPy and target ids back to NDArray.
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
vocab_slice_ids = self.restrict_lexicon.get_trg_ids(source_words.astype("int32").asnumpy())
if any(raw_constraint_list):
# Add the constraint IDs to the list of permissibled IDs, and then project them into the reduced space
Expand Down
25 changes: 25 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,28 @@ def uncast_conditionally(data: mx.sym.Symbol, dtype: str) -> mx.sym.Symbol:
if dtype != C.DTYPE_FP32:
return mx.sym.cast(data=data, dtype=C.DTYPE_FP32)
return data


def split(data: mx.nd.NDArray,
num_outputs: int,
axis: int = 1,
squeeze_axis: bool = False) -> List[mx.nd.NDArray]:
"""
Version of mxnet.ndarray.split that always returns a list. The original
implementation only returns a list if num_outputs > 1:
https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split

Splits an array along a particular axis into multiple sub-arrays.

:param data: The input.
:param num_outputs: Number of splits. Note that this should evenly divide
the length of the axis.
:param axis: Axis along which to split.
:param squeeze_axis: If true, Removes the axis with length 1 from the shapes
of the output arrays.
:return: List of NDArrays resulting from the split.
"""
ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
if num_outputs == 1:
Copy link
Contributor

@fhieber fhieber Jun 23, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to avoid the split altogether when num_outputs is 1? If squeeze_axis==True, one only would need a reshape/squeeze which is essentially a no-op.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another good point. I think it's a toss-up between staying as close to the original as possible versus micro-optimizing the call. Since we're now using this in just one place, called once per batch, I would lean toward keeping it this way for clarity.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, probably not worth the additional complexity.

return [ndarray_or_list]
return ndarray_or_list
10 changes: 9 additions & 1 deletion test/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,4 +338,12 @@ def test_metric_value_is_better(new, old, metric, result):
assert utils.metric_value_is_better(new, old, metric) == result



@pytest.mark.parametrize("num_factors", [1, 2, 3])
def test_split(num_factors):
batch_size = 4
bucket_key = 10
# Simulates splitting factored input
data = mx.nd.random.normal(shape=(batch_size, bucket_key, num_factors))
result = utils.split(data, num_outputs=num_factors, axis=2, squeeze_axis=True)
assert isinstance(result, list)
assert result[0].shape == (batch_size, bucket_key)