Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[v1.x Backport] Fix softmax, logsoftmax failed on empty ndarray (#18602
Browse files Browse the repository at this point in the history
…) (#18708)

* [v1.x] Backport of fix npx.softmax for 0-sized inputs (#18158)

Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>

* Fix softmax, logsoftmax failed on empty ndarray (#18602)

* Fix failing empty array (log_)softmax

* Modify test for npx (log_)softmax

* Fix softmax, logsoftmax backward failed on empty ndarray (#18710)

Co-authored-by: Yiyan66 <57363390+Yiyan66@users.noreply.github.com>
Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>
Co-authored-by: Bart Gawrych <gawrych.bartlomiej@intel.com>
  • Loading branch information
4 people committed Aug 3, 2020
1 parent 1a31cea commit 73d3a7b
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 29 deletions.
2 changes: 2 additions & 0 deletions src/operator/nn/log_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ static void LogSoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -57,6 +58,7 @@ static void LogSoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNLogSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
58 changes: 31 additions & 27 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ template<typename OP, bool negate, typename AType, typename DType, typename OTyp
inline void Softmax(Stream<cpu> *s, DType *in, OType *out, IType *length,
Shape<ndim> shape, int axis, const DType temperature) {
index_t M = shape[axis];
if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -186,6 +187,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape,
int axis, const DType temperature) {
index_t M = shape[axis];
if (M == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -402,6 +404,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out, IType *length,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -555,6 +558,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
if (M == 0 || shape.Size() == 0) return;
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
Expand Down Expand Up @@ -775,7 +779,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
if (req[0] == kNullOp || inputs[0].Size() == 0U) return;
CHECK_NE(req[0], kAddTo);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
int axis = CheckAxis(param.axis, inputs[0].ndim());
Expand All @@ -798,35 +802,35 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
type = inputs[1].type_flag_;
}
MXNET_INT32_INT64_TYPE_SWITCH(type, IType, {
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
IType* mask_ptr = nullptr;
if (param.use_length.value()) {
mask_ptr = inputs[1].dptr<IType>();
}
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
if (safe_acc) {
if (shape.ndim() == 2) {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
} else {
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
if (shape.ndim() == 2) {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<2>(),
axis, static_cast<DType>(temperature));
} else {
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
Softmax<OP, negate, DType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
outputs[0].dptr<OType>(), mask_ptr, shape.get<3>(),
axis, static_cast<DType>(temperature));
}
}
});
});
});
Expand Down
2 changes: 2 additions & 0 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ static void SoftmaxComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[0], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand All @@ -58,6 +59,7 @@ static void SoftmaxGradComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (inputs[0].shape().Size() == 0U) return;
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
if (SupportMKLDNNSoftmax(param, inputs[1], outputs[0])) {
MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
Expand Down
6 changes: 4 additions & 2 deletions src/operator/numpy/np_boolean_mask_assign.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,10 +221,9 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
// If there's no True in mask, return directly
if (valid_num == 0) return;

const TShape& vshape = inputs[2].shape_;

if (inputs.size() == 3U) {
// tensor case
const TShape& vshape = inputs.at(2).shape_;
if (inputs[2].shape_.Size() != 1) {
auto vndim = vshape.ndim();
auto dndim = dshape.ndim();
Expand Down Expand Up @@ -254,6 +253,8 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}

if (inputs.size() == 3U) {
// tensor case
const TShape& vshape = inputs.at(2).shape_;
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
if (inputs[2].shape_.Size() == 1) {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
Expand All @@ -269,6 +270,7 @@ void NumpyBooleanAssignForwardCPU(const nnvm::NodeAttrs& attrs,
}
});
} else {
// scalar case
CHECK(attrs.dict.find("value") != attrs.dict.end()) << "value needs be provided";
MSHADOW_TYPE_SWITCH_WITH_BOOL(data.type_flag_, DType, {
Kernel<BooleanAssignCPUKernel<true>, cpu>::Launch(
Expand Down
57 changes: 57 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,63 @@ def _test_batchnorm_impl(shape, fix_gamma, cudnn_off, output_mean_var,
data_grad_req,
gamma_grad_req, beta_grad_req)

@with_seed()
@use_np
def test_npx_softmax():
class TestSoftmax(HybridBlock):
def __init__(self, axis):
super(TestSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.softmax(a, axis=axis)

class TestLogSoftmax(HybridBlock):
def __init__(self, axis):
super(TestLogSoftmax, self).__init__()
self._axis = axis

def hybrid_forward(self, F, a):
return F.npx.log_softmax(a, axis=axis)

def np_softmax(x, axis=-1):
if (x.shape[axis] == 0):
return _np.sum(x, axis=axis, keepdims=True)
x = x - _np.max(x, axis=axis, keepdims=True)
x = _np.exp(x)
x /= _np.sum(x, axis=axis, keepdims=True)
return x

def np_log_softmax(x, axis=-1):
return _np.log(np_softmax(x, axis))

#(operator, function) tuples
tested_ops = [(TestSoftmax, np_softmax),
(TestLogSoftmax, np_log_softmax)]

# only testing 0-size shaped inputs here, other input cases have been tested in test_opeartor.py
for SoftmaxOp, softmax_function in tested_ops:
for hybridize in [True, False]:
for shape in [(3, 0, 4), (0, 0)]:
mx_a = np.random.uniform(size=shape)
mx_a.attach_grad()
for axis in range(-len(shape), len(shape)):
test_softmax_op = SoftmaxOp(axis)
if hybridize:
test_softmax_op.hybridize()

with mx.autograd.record():
mx_out = test_softmax_op(mx_a)

mx_out.wait_to_read()

np_out = softmax_function(mx_a.asnumpy(), axis)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5, equal_nan=True)

mx_out.backward()
mx_a.grad.wait_to_read()
assert_almost_equal(mx_a.grad.asnumpy(), _np.zeros(shape), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
Expand Down

0 comments on commit 73d3a7b

Please sign in to comment.