Skip to content

Commit

Permalink
Use version of NDArray split that always returns a list. (#454)
Browse files Browse the repository at this point in the history
* Fix source factor splitting for single factor.

- Use version of ndarray split that always returns a list for uniform
handling.

* Unit test for factor splitting.

* Update version, changelog.

* Keep original split call for sym_gen.
  • Loading branch information
mjdenkowski authored and tdomhan committed Jun 25, 2018
1 parent a4f8698 commit c59361d
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 4 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ Note that Sockeye has checks in place to not translate with an old model that wa
Each version section may have have subsections for: _Added_, _Changed_, _Removed_, _Deprecated_, and _Fixed_.


## [1.18.27]
### Fixed
- Fix silent failing of NDArray splits during inference by using a version that always returns a list. This was causing incorrect behavior when using lexicon restriction and batch inference with a single source factor.

## [1.18.26]
### Added
- ROUGE score evaluation. It can be used as the stopping criterion for tasks such as summarization.
Expand Down
2 changes: 1 addition & 1 deletion sockeye/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

__version__ = '1.18.26'
__version__ = '1.18.27'
4 changes: 2 additions & 2 deletions sockeye/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,7 +1334,7 @@ def _beam_search(self,
"""
Translates multiple sentences using beam search.
:param source: Source ids. Shape: (batch_size, bucket_key).
:param source: Source ids. Shape: (batch_size, bucket_key, num_factors).
:param source_length: Max source length.
:param raw_constraint_list: A list of optional lists containing phrases (as lists of target word IDs)
that must appear in each output.
Expand Down Expand Up @@ -1383,9 +1383,9 @@ def _beam_search(self,
pad_dist = self.pad_dist
vocab_slice_ids = None # type: mx.nd.NDArray
if self.restrict_lexicon:
source_words = utils.split(source, num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
# TODO: See note in method about migrating to pure MXNet when set operations are supported.
# We currently convert source to NumPy and target ids back to NDArray.
source_words = source.split(num_outputs=self.num_source_factors, axis=2, squeeze_axis=True)[0]
vocab_slice_ids = self.restrict_lexicon.get_trg_ids(source_words.astype("int32").asnumpy())
if any(raw_constraint_list):
# Add the constraint IDs to the list of permissibled IDs, and then project them into the reduced space
Expand Down
25 changes: 25 additions & 0 deletions sockeye/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,3 +860,28 @@ def uncast_conditionally(data: mx.sym.Symbol, dtype: str) -> mx.sym.Symbol:
if dtype != C.DTYPE_FP32:
return mx.sym.cast(data=data, dtype=C.DTYPE_FP32)
return data


def split(data: mx.nd.NDArray,
num_outputs: int,
axis: int = 1,
squeeze_axis: bool = False) -> List[mx.nd.NDArray]:
"""
Version of mxnet.ndarray.split that always returns a list. The original
implementation only returns a list if num_outputs > 1:
https://mxnet.incubator.apache.org/api/python/ndarray/ndarray.html#mxnet.ndarray.split
Splits an array along a particular axis into multiple sub-arrays.
:param data: The input.
:param num_outputs: Number of splits. Note that this should evenly divide
the length of the axis.
:param axis: Axis along which to split.
:param squeeze_axis: If true, Removes the axis with length 1 from the shapes
of the output arrays.
:return: List of NDArrays resulting from the split.
"""
ndarray_or_list = data.split(num_outputs=num_outputs, axis=axis, squeeze_axis=squeeze_axis)
if num_outputs == 1:
return [ndarray_or_list]
return ndarray_or_list
10 changes: 9 additions & 1 deletion test/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,4 +338,12 @@ def test_metric_value_is_better(new, old, metric, result):
assert utils.metric_value_is_better(new, old, metric) == result



@pytest.mark.parametrize("num_factors", [1, 2, 3])
def test_split(num_factors):
batch_size = 4
bucket_key = 10
# Simulates splitting factored input
data = mx.nd.random.normal(shape=(batch_size, bucket_key, num_factors))
result = utils.split(data, num_outputs=num_factors, axis=2, squeeze_axis=True)
assert isinstance(result, list)
assert result[0].shape == (batch_size, bucket_key)

0 comments on commit c59361d

Please sign in to comment.