Conversation
src/operator/sequence_last-inl.h
Outdated
"to specify variable length sequence"); | ||
DMLC_DECLARE_FIELD(axis).set_default(0).describe( | ||
"The sequence axis. Only values of 0 and 1 are current supported."); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
current -> currently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
src/operator/sequence_last-inl.h
Outdated
const DType *idx, int offset1, int offset2, | ||
mshadow::Shape<2> oshape) { | ||
auto opos = mxnet_op::unravel(i, oshape); | ||
int seqpos = static_cast<int>(idx[opos[0]]) - 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add const
wherever possible, same for all of the following code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/sequence_last-inl.h
Outdated
@@ -74,33 +145,32 @@ class SequenceLastOp : public Operator { | |||
CHECK_EQ(out_data.size(), 1U); | |||
Stream<xpu> *s = ctx.get_stream<xpu>(); | |||
|
|||
// only support axis of 0 or 1 for now | |||
bool axis = static_cast<bool>(param_.axis); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defining a bool value for axis is confusing in terms of readability.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Fixed.
src/operator/sequence_last-inl.h
Outdated
} | ||
|
||
struct SequenceLastParam : public dmlc::Parameter<SequenceLastParam> { | ||
bool use_sequence_length; | ||
uint32_t axis; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it possible to support negative axis in the future? If so, it's better to define it as a signed integer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. Fixed, thanks.
src/operator/sequence_last-inl.h
Outdated
@@ -184,17 +252,22 @@ class SequenceLastProp : public OperatorProperty { | |||
CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U) | |||
<< "Input:[data, sequence_length]"; | |||
|
|||
if ((param_.axis != 0) && (param_.axis != 1)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can simplify this by using CHECK.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
src/operator/sequence_reverse-inl.h
Outdated
@@ -221,6 +224,10 @@ class SequenceReverseProp : public OperatorProperty { | |||
CHECK_EQ(in_shape->size(), param_.use_sequence_length ? 2U : 1U) | |||
<< "Input:[data, sequence_length]"; | |||
|
|||
if ((param_.axis != 0)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use CHECK_EQ to simplify the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
@reminisce, @cjolivier01: can this be merged? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
return array2[:, -1] | ||
lengths = list(lengths) | ||
return np.array([array2[i, int(lengths[i]) - 1] for i in range(dims[0])]) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: surround top-level function with 2 lines, otherwise, IDEs such as PyCharm would issue a coding style warning.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed.
* added axis parameter + refactored to use expression template * refactor of sequence last * add axis to sequence reverse * add axis support to sequence mask, rewrite kernels, fix bug for kAddTo * remove header * add rigorous tests for sequence ops * conflict * remove conflict * various sequence op fixes * added 2 spaces for top-level python functions to avoid PyCharm lint warning
I just accidentally stumbled upon the bug with the kAddTo in SequenceMask (on a slightly older codebase). Nice to see it already fixed ;-) |
* added axis parameter + refactored to use expression template * refactor of sequence last * add axis to sequence reverse * add axis support to sequence mask, rewrite kernels, fix bug for kAddTo * remove header * add rigorous tests for sequence ops * conflict * remove conflict * various sequence op fixes * added 2 spaces for top-level python functions to avoid PyCharm lint warning
* added axis parameter + refactored to use expression template * refactor of sequence last * add axis to sequence reverse * add axis support to sequence mask, rewrite kernels, fix bug for kAddTo * remove header * add rigorous tests for sequence ops * conflict * remove conflict * various sequence op fixes * added 2 spaces for top-level python functions to avoid PyCharm lint warning
* added axis parameter + refactored to use expression template * refactor of sequence last * add axis to sequence reverse * add axis support to sequence mask, rewrite kernels, fix bug for kAddTo * remove header * add rigorous tests for sequence ops * conflict * remove conflict * various sequence op fixes * added 2 spaces for top-level python functions to avoid PyCharm lint warning
Description
This PR does the following:
axis
toSequenceLast
,SequenceMask
andSequenceReverse
. The default isaxis=0
. Values foraxis=1
andaxis=0
are implemented forSequenceLast
andSequenceMask
.SequenceMask
was fixed (incorrect behaviour forkAddTo
).SequenceLast
avoids any movement of data to the CPU, and uses a much faster implementation similar topick
.SequenceMask
now supports inplace.Axis support will be added to
SequenceReverse
in a separate PR.Checklist
Essentials
make lint
)Changes
Comments