Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Sequence Operator Improvements #9306

Merged
merged 10 commits into from Jan 16, 2018
Merged

Sequence Operator Improvements #9306

merged 10 commits into from Jan 16, 2018

Conversation

sbodenstein
Copy link
Contributor

@sbodenstein sbodenstein commented Jan 4, 2018

Description

This PR does the following:

  • Adds an optional argument axis to SequenceLast, SequenceMask and SequenceReverse. The default is axis=0. Values for axis=1 and axis=0 are implemented for SequenceLast and SequenceMask.
  • Rigorous tests are added for all sequence ops.
  • Bug in SequenceMask was fixed (incorrect behaviour for kAddTo).
  • Performance improvements: SequenceLast avoids any movement of data to the CPU, and uses a much faster implementation similar to pick. SequenceMask now supports inplace.

Axis support will be added to SequenceReverse in a separate PR.

Checklist

Essentials

  • Passed code style checking (make lint)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here

"to specify variable length sequence");
DMLC_DECLARE_FIELD(axis).set_default(0).describe(
"The sequence axis. Only values of 0 and 1 are current supported.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current -> currently

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

const DType *idx, int offset1, int offset2,
mshadow::Shape<2> oshape) {
auto opos = mxnet_op::unravel(i, oshape);
int seqpos = static_cast<int>(idx[opos[0]]) - 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add const wherever possible, same for all of the following code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -74,33 +145,32 @@ class SequenceLastOp : public Operator {
CHECK_EQ(out_data.size(), 1U);
Stream<xpu> *s = ctx.get_stream<xpu>();

// only support axis of 0 or 1 for now
bool axis = static_cast<bool>(param_.axis);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a bool value for axis is confusing in terms of readability.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Fixed.

}

struct SequenceLastParam : public dmlc::Parameter<SequenceLastParam> {
bool use_sequence_length;
uint32_t axis;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to support negative axis in the future? If so, it's better to define it as a signed integer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Fixed, thanks.

@@ -184,17 +252,22 @@ class SequenceLastProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U)
<< "Input:[data, sequence_length]";

if ((param_.axis != 0) && (param_.axis != 1)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can simplify this by using CHECK.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -221,6 +224,10 @@ class SequenceReverseProp : public OperatorProperty {
CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U)
<< "Input:[data, sequence_length]";

if ((param_.axis != 0)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use CHECK_EQ to simplify the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@sbodenstein
Copy link
Contributor Author

sbodenstein commented Jan 14, 2018

@reminisce, @cjolivier01: can this be merged?

Copy link
Contributor

@reminisce reminisce left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

return array2[:, -1]
lengths = list(lengths)
return np.array([array2[i, int(lengths[i]) - 1] for i in range(dims[0])])

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: surround top-level function with 2 lines, otherwise, IDEs such as PyCharm would issue a coding style warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@piiswrong piiswrong merged commit f960522 into apache:master Jan 16, 2018
larroy pushed a commit to larroy/mxnet that referenced this pull request Jan 18, 2018
* added axis parameter + refactored to use expression template

* refactor of sequence last

* add axis to sequence reverse

* add axis support to sequence mask, rewrite kernels, fix bug for kAddTo

* remove header

* add rigorous tests for sequence ops

* conflict

* remove conflict

* various sequence op fixes

* added 2 spaces for top-level python functions to avoid PyCharm lint warning
@asmushetzel
Copy link
Contributor

I just accidentally stumbled upon the bug with the kAddTo in SequenceMask (on a slightly older codebase). Nice to see it already fixed ;-)

yuxiangw pushed a commit to yuxiangw/incubator-mxnet that referenced this pull request Jan 25, 2018
* added axis parameter + refactored to use expression template

* refactor of sequence last

* add axis to sequence reverse

* add axis support to sequence mask, rewrite kernels, fix bug for kAddTo

* remove header

* add rigorous tests for sequence ops

* conflict

* remove conflict

* various sequence op fixes

* added 2 spaces for top-level python functions to avoid PyCharm lint warning
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* added axis parameter + refactored to use expression template

* refactor of sequence last

* add axis to sequence reverse

* add axis support to sequence mask, rewrite kernels, fix bug for kAddTo

* remove header

* add rigorous tests for sequence ops

* conflict

* remove conflict

* various sequence op fixes

* added 2 spaces for top-level python functions to avoid PyCharm lint warning
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* added axis parameter + refactored to use expression template

* refactor of sequence last

* add axis to sequence reverse

* add axis support to sequence mask, rewrite kernels, fix bug for kAddTo

* remove header

* add rigorous tests for sequence ops

* conflict

* remove conflict

* various sequence op fixes

* added 2 spaces for top-level python functions to avoid PyCharm lint warning
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants