Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Implement Weibull backward (#17590)
Browse files Browse the repository at this point in the history
  • Loading branch information
D-Roberts committed Feb 15, 2020
1 parent 149975c commit b6b1de0
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 28 deletions.
8 changes: 5 additions & 3 deletions python/mxnet/ndarray/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None):
return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
def weibull(a, size=None, ctx=None, out=None):
r"""Draw samples from a 1-parameter Weibull distribution with given
parameter a, via inversion.
Expand Down Expand Up @@ -614,13 +614,15 @@ def weibull(a, size=None):
"""
from ...numpy import ndarray as np_ndarray
tensor_type_name = np_ndarray
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.weibull(a, a=None, size=size)
return _npi.weibull(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.weibull(a=a, size=size)
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None):
return _mx_nd_np.random.exponential(scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
def weibull(a, size=None, ctx=None, out=None):
r"""Draw samples from a 1-parameter Weibull distribution with given parameter a
via inversion.
Expand Down Expand Up @@ -653,7 +653,7 @@ def weibull(a, size=None):
model time to failure, in modeling particle sizes, in information retrieval
to model dwell time on pages, in quantitative finance to model risk etc.
"""
return _mx_nd_np.random.weibull(a, size)
return _mx_nd_np.random.weibull(a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
Expand Down
8 changes: 5 additions & 3 deletions python/mxnet/symbol/numpy/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,7 @@ def exponential(scale=1.0, size=None, ctx=None, out=None):
return _npi.exponential(scale=scale, size=size, ctx=ctx, out=out)


def weibull(a, size=None):
def weibull(a, size=None, ctx=None, out=None):
r"""Draw samples from a 1-parameter Weibull distribution with given parameter a
via inversion.
Expand Down Expand Up @@ -684,13 +684,15 @@ def weibull(a, size=None):
"""
from ..numpy import _Symbol as np_symbol
tensor_type_name = np_symbol
if ctx is None:
ctx = current_context()
if size == ():
size = None
is_tensor = isinstance(a, tensor_type_name)
if is_tensor:
return _npi.weibull(a, a=None, size=size)
return _npi.weibull(a, a=None, size=size, ctx=ctx, out=out)
else:
return _npi.weibull(a=a, size=size)
return _npi.weibull(a=a, size=size, ctx=ctx, out=out)


def pareto(a, size=None):
Expand Down
38 changes: 35 additions & 3 deletions src/operator/numpy/random/np_weibull_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace op {
DMLC_REGISTER_PARAMETER(NumpyWeibullParam);

NNVM_REGISTER_OP(_npi_weibull)
.describe("Numpy behavior Weibull")
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs) {
const NumpyWeibullParam& param = nnvm::get<NumpyWeibullParam>(attrs.parsed);
Expand All @@ -41,7 +42,11 @@ NNVM_REGISTER_OP(_npi_weibull)
}
return num_inputs;
})
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs",
[](const NodeAttrs& attrs){
return 1;
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
const NumpyWeibullParam& param = nnvm::get<NumpyWeibullParam>(attrs.parsed);
Expand All @@ -52,10 +57,11 @@ NNVM_REGISTER_OP(_npi_weibull)
return (num_inputs == 0) ? std::vector<std::string>() : std::vector<std::string>{"input1"};
})
.set_attr_parser(ParamParser<NumpyWeibullParam>)
.set_attr<mxnet::FInferShape>("FInferShape", UnaryDistOpShape<NumpyWeibullParam>)
.set_attr<mxnet::FInferShape>("FInferShape", TwoparamsDistOpShape<NumpyWeibullParam>)
.set_attr<nnvm::FInferType>("FInferType",
[](const nnvm::NodeAttrs &attrs, std::vector<int> *in_attrs, std::vector<int> *out_attrs) {
(*out_attrs)[0] = mshadow::kFloat32;
(*out_attrs)[1] = mshadow::kFloat32;
return true;
})
.set_attr<FResourceRequest>("FResourceRequest",
Expand All @@ -64,9 +70,35 @@ NNVM_REGISTER_OP(_npi_weibull)
ResourceRequest::kRandom, ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyWeibullForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseInOut{"_backward_broadcast_weibull"})
.add_argument("input1", "NDArray-or-Symbol", "Source input")
.add_arguments(NumpyWeibullParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_broadcast_weibull)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<NumpyWeibullParam>)
.set_num_inputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyWeibullParam& param = nnvm::get<NumpyWeibullParam>(attrs.parsed);
int num_inputs = 5;
if (param.a.has_value()) num_inputs -= 1;
return num_inputs;
}
)
.set_num_outputs(
[](const nnvm::NodeAttrs& attrs){
const NumpyWeibullParam& param = nnvm::get<NumpyWeibullParam>(attrs.parsed);
int num_outputs = 1;
if (param.a.has_value()) num_outputs -= 1;
return num_outputs;
}
)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs){
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", WeibullReparamBackward<cpu>)
.add_arguments(NumpyWeibullParam::__FIELDS__());

} // namespace op
} // namespace mxnet
3 changes: 3 additions & 0 deletions src/operator/numpy/random/np_weibull_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,8 @@ namespace op {
NNVM_REGISTER_OP(_npi_weibull)
.set_attr<FCompute>("FCompute<gpu>", NumpyWeibullForward<gpu>);

NNVM_REGISTER_OP(_backward_broadcast_weibull)
.set_attr<FCompute>("FCompute<gpu>", WeibullReparamBackward<gpu>);

} // namespace op
} // namespace mxnet
94 changes: 77 additions & 17 deletions src/operator/numpy/random/np_weibull_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ namespace op {
struct NumpyWeibullParam : public dmlc::Parameter<NumpyWeibullParam> {
dmlc::optional<float> a;
dmlc::optional<mxnet::Tuple<int>> size;
std::string ctx;
DMLC_DECLARE_PARAMETER(NumpyWeibullParam) {
DMLC_DECLARE_FIELD(a)
.set_default(dmlc::optional<float>());
Expand All @@ -52,23 +53,26 @@ struct NumpyWeibullParam : public dmlc::Parameter<NumpyWeibullParam> {
.describe("Output shape. If the given shape is, "
"e.g., (m, n, k), then m * n * k samples are drawn. "
"Default is None, in which case a single value is returned.");
DMLC_DECLARE_FIELD(ctx).set_default("cpu").describe(
"Context of output, in format [cpu|gpu|cpu_pinned](n)."
" Only used for imperative calls.");
}
};

template <typename DType>
struct scalar_weibull_kernel {
MSHADOW_XINLINE static void Map(index_t i, float a, float *threshold,
MSHADOW_XINLINE static void Map(index_t i, float a, float *noise,
DType *out) {
out[i] = powf(-log(threshold[i]), DType(1.0/a));
out[i] = powf(-log(noise[i]), DType(1.0/a));
}
};

namespace mxnet_op {

template <typename IType>
struct check_legal_a_kernel {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float* flag) {
if (a[i] < 0.0) {
MSHADOW_XINLINE static void Map(index_t i, IType *a, float *flag) {
if (a[i] <= 0.0) {
flag[0] = -1.0;
}
}
Expand All @@ -80,37 +84,39 @@ struct weibull_kernel {
MSHADOW_XINLINE static void Map(index_t i,
const Shape<ndim> &stride,
const Shape<ndim> &oshape,
IType *aparams, float* threshold, OType *out) {
IType *aparams, float *noise, OType *out) {
Shape<ndim> coord = unravel(i, oshape);
auto idx = static_cast<index_t>(dot(coord, stride));
out[i] = powf(-log(threshold[i]), IType(1.0/aparams[idx]));
noise[i] = -log(noise[i]);
out[i] = powf(noise[i], IType(1.0/aparams[idx]));
// get grad
noise[i] = -log(noise[i]) * out[i] * (1.0/(aparams[idx] * aparams[idx]));
}
};

} // namespace mxnet_op

template <typename xpu>
void NumpyWeibullForward(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mxnet_op;
const NumpyWeibullParam &param = nnvm::get<NumpyWeibullParam>(attrs.parsed);
Stream<xpu> *s = ctx.get_stream<xpu>();
index_t output_len = outputs[0].Size();
Random<xpu, float> *prnd = ctx.requested[0].get_random<xpu, float>(s);
Tensor<xpu, 1, float> workspace =
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(output_len + 1), s);
Tensor<xpu, 1, float> uniform_tensor = workspace.Slice(0, output_len);
Tensor<xpu, 1, float> indicator_device = workspace.Slice(output_len, output_len + 1);
ctx.requested[1].get_space_typed<xpu, 1, float>(Shape1(1), s);
Tensor<xpu, 1, float> uniform_tensor = outputs[1].FlatTo1D<xpu, float>(s);
Tensor<xpu, 1, float> indicator_device = workspace;
float indicator_host = 1.0;
float *indicator_device_ptr = indicator_device.dptr_;
Kernel<set_zero, xpu>::Launch(s, 1, indicator_device_ptr);
prnd->SampleUniform(&workspace, 0.0, 1.0);
prnd->SampleUniform(&uniform_tensor, 0.0, 1.0);
if (param.a.has_value()) {
CHECK_GE(param.a.value(), 0.0) << "ValueError: expect a >= 0";
CHECK_GT(param.a.value(), 0.0) << "ValueError: expect a > 0";
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<scalar_weibull_kernel<DType>, xpu>::Launch(
s, outputs[0].Size(), param.a.value(),
Expand All @@ -122,7 +128,7 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs,
s, inputs[0].Size(), inputs[0].dptr<IType>(), indicator_device_ptr);
});
_copy<xpu>(s, &indicator_host, indicator_device_ptr);
CHECK_GE(indicator_host, 0.0) << "ValueError: expect a >= 0";
CHECK_GE(indicator_host, 0.0) << "ValueError: expect a > 0";
mxnet::TShape new_lshape, new_oshape;
int ndim = FillShape(inputs[0].shape_, inputs[0].shape_, outputs[0].shape_,
&new_lshape, &new_lshape, &new_oshape);
Expand All @@ -140,6 +146,60 @@ void NumpyWeibullForward(const nnvm::NodeAttrs &attrs,
}
}

template<typename xpu, int ndim, typename DType>
inline void ScalarWeibullReparamBackwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs,
const mxnet::TShape& new_ishape,
const mxnet::TShape& new_oshape) {
using namespace mshadow;
using namespace mshadow::expr;
using namespace broadcast;
Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob igrad = outputs[0].reshape(new_ishape);
// inputs: [grad_from_samples, grad_from_noise(invisible), input_tensor,
// samples, noise]
const TBlob ograd = inputs[0].reshape(new_oshape);
const TBlob itensor = inputs[2].reshape(new_ishape);
const TBlob samples = inputs[3].reshape(new_oshape);
const TBlob noise = inputs[4].reshape(new_oshape);
size_t workspace_size =
ReduceWorkspaceSize<ndim, DType>(s, igrad.shape_, req[0], ograd.shape_);
Tensor<xpu, 1, char> workspace =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(workspace_size), s);
Reduce<red::sum, ndim, DType, op::mshadow_op::mul, op::mshadow_op::left>(
s, igrad, req[0], workspace, ograd, noise, noise);
}

template<typename xpu>
void WeibullReparamBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& reqs,
const std::vector<TBlob>& outputs) {
// skip kernel launch for zero-size tensors
if (inputs[0].shape_.Size() == 0U) {
return;
}
// [scalar] case
if (outputs.size() == 0U) {
return;
}
// [tensor] case
if (inputs.size() == 5U) {
mxnet::TShape new_ishape, new_oshape;
int ndim = FillShape(outputs[0].shape_, outputs[0].shape_, inputs[0].shape_,
&new_ishape, &new_ishape, &new_oshape);
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
BROADCAST_NDIM_SWITCH(ndim, NDim, {
ScalarWeibullReparamBackwardImpl<xpu, NDim, DType>(
ctx, inputs, reqs, outputs, new_ishape, new_oshape);
});
});
}
}

} // namespace op
} // namespace mxnet

Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -4046,6 +4046,39 @@ def _test_exception(a):
assertRaises(ValueError, _test_exception, 0)


@with_seed()
@use_np
def test_np_weibull_grad():
class TestRandomW(HybridBlock):
def __init__(self, shape):
super(TestRandomW, self).__init__()
self._shape = shape

def hybrid_forward(self, F, a):
return F.np.random.weibull(a, self._shape)

output_shapes = [
(3, 2),
(4, 3, 2, 2),
(3, 4, 5)
]
for hybridize in [False, True]:
for out_shape in output_shapes:
test_w_grad = TestRandomW(out_shape)
if hybridize:
test_w_grad.hybridize()
a = np.ones(out_shape)
a.attach_grad()
with mx.autograd.record():
mx_out = test_w_grad(a)
mx_out.backward()

# gradient formula calculus (a=1)
formula_grad = - mx_out * np.log(mx_out)
assert a.grad.shape == out_shape
assert_almost_equal(a.grad.asnumpy().sum(), formula_grad.asnumpy().sum(), rtol=1e-3, atol=1e-5)


@with_seed()
@use_np
def test_np_randn():
Expand Down

0 comments on commit b6b1de0

Please sign in to comment.