Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address code reviews and code re-design
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jun 1, 2018
1 parent 0fae7e5 commit 240eecc
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 137 deletions.
16 changes: 10 additions & 6 deletions python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -3742,7 +3742,8 @@ def empty(shape, ctx=None, dtype=None):
return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))


def histogram(a, bins=10, range_=None):
# pylint: disable= redefined-builtin
def histogram(a, bins=10, range=None):
"""Compute the histogram of the input data.
Parameters
Expand All @@ -3763,9 +3764,12 @@ def histogram(a, bins=10, range_=None):
# pylint: disable= no-member, protected-access
if isinstance(bins, NDArray):
return _internal._histogram(data=a, bins=bins)
elif isinstance(bins, int):
if range_ is None:
range_ = (float(a.min().asnumpy()[0]), float(a.max().asnumpy()[0]))
return _internal._histogram(data=a, bins=array([]), bin_cnt=bins, range=range_)
elif isinstance(bins, integer_types):
if range is None:
warnings.warn("range_ is not specified, using numpy's result "
"to ensure consistency with numpy")
res, bin_bounds = np.histogram(a.asnumpy(), bins=bins)
return array(res), array(bin_bounds)
return _internal._histogram(data=a, bin_cnt=bins, range=range)
return None
# pylint: enable= no-member, protected-access
# pylint: enable= no-member, protected-access, redefined-builtin
87 changes: 43 additions & 44 deletions src/operator/tensor/histogram-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,13 @@ namespace op {

struct HistogramParam : public dmlc::Parameter<HistogramParam> {
dmlc::optional<int> bin_cnt;
dmlc::optional<nnvm::Tuple<float>> range;
dmlc::optional<nnvm::Tuple<double>> range;
DMLC_DECLARE_PARAMETER(HistogramParam) {
DMLC_DECLARE_FIELD(bin_cnt)
.set_default(dmlc::optional<int>())
.describe("Number of bins for uniform case");
DMLC_DECLARE_FIELD(range)
.set_default(dmlc::optional<nnvm::Tuple<float>>())
.set_default(dmlc::optional<nnvm::Tuple<double>>())
.describe("The lower and upper range of the bins. if not provided, "
"range is simply (a.min(), a.max()). values outside the "
"range are ignored. the first element of the range must be "
Expand All @@ -61,23 +61,24 @@ struct HistogramParam : public dmlc::Parameter<HistogramParam> {

struct FillBinBoundsKernel {
template<typename DType>
static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, float min, float max) {
static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, double min, double max) {
if (i <= bin_cnt) {
bin_bounds[i] = DType((i * max + (bin_cnt - i) * min) / bin_cnt);
bin_bounds[i] = DType((max * i + (bin_cnt - i) * min) / bin_cnt);
}
}
};

inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 2U);
HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
const bool has_cnt = param.bin_cnt.has_value();
const bool has_range = param.range.has_value();
const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK_EQ(in_attrs->size(), has_cnt ? 1U : 2U);
CHECK_EQ(out_attrs->size(), 2U);
CHECK(legal_param) << "cnt and range should both or neither specified";

if (has_cnt) {
// if cnt is specified, the output histogram has shape (cnt,)
// while output bins has shape (cnt+1,)
Expand All @@ -96,21 +97,18 @@ inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1));
}

return out_attrs->at(0).ndim() == 1U && out_attrs->at(0).Size() != 0U &&
out_attrs->at(1).ndim() == 1U && out_attrs->at(1).Size() != 0U &&
return !shape_is_none(out_attrs->at(0)) && !shape_is_none(out_attrs->at(1)) &&
out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1;
}

inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(in_attrs->at(0), in_attrs->at(1));
CHECK_EQ(out_attrs->size(), 2U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
return out_attrs->at(0) != -1 && out_attrs->at(1) != -1;
return !type_is_none(out_attrs->at(0)) && !type_is_none(out_attrs->at(1));
}

template<typename xpu>
Expand All @@ -122,54 +120,55 @@ void HistogramForwardImpl(mshadow::Stream<xpu>* s,
const TBlob& out_data,
const TBlob& out_bins);

template<typename xpu>
void HistogramForwardImpl(mshadow::Stream<xpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max);

template<typename xpu>
void HistogramOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req.size(), 2U);
CHECK_EQ(req[0], kWriteTo);
CHECK_EQ(req[1], kWriteTo);
const HistogramParam& param = nnvm::get<HistogramParam>(attrs.parsed);
const bool has_cnt = param.bin_cnt.has_value();
const bool has_range = param.range.has_value();
const bool legal_params = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK(legal_params) << "width and range should both or neither be specified";

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& bin_bounds = inputs[1];
const TBlob& out_data = outputs[0];
const TBlob& out_bins = outputs[1];

HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, bin_bounds, out_data, out_bins);
}

template<typename xpu>
void HistogramBackwardImpl(const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& out_grad,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& in_grad);

template<typename xpu>
void HistogramOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 4U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo);

const TBlob& out_grad = inputs[0];
const TBlob& in_data = inputs[1];
const TBlob& bin_bounds = inputs[2];
const TBlob& out_data = inputs[3];
const TBlob& in_grad = outputs[0];

HistogramBackwardImpl<xpu>(ctx, attrs, out_grad, in_data, bin_bounds, out_data, in_grad);
if (has_cnt) {
CHECK((param.range.value().ndim() == 2U)) << "range should be a tuple with only 2 elements";
CHECK(param.range.value()[0] <= param.range.value()[1])
<< "left hand side of range(" << param.range.value()[0]
<< ")should be less than or equal to right hand side(" << param.range.value()[1] << ")";
double max = param.range.value()[1];
double min = param.range.value()[0];
const int bin_cnt = param.bin_cnt.value();
if (min == max) {
min -= 0.5f;
max += 0.5f;
LOG(INFO) << min << " " << max;
}
HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, out_data, out_bins, bin_cnt, min, max);
} else {
const TBlob& bin_bounds = inputs[1];
HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, bin_bounds, out_data, out_bins);
}
}

} // namespace op
Expand Down
90 changes: 51 additions & 39 deletions src/operator/tensor/histogram.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,30 +24,32 @@ namespace op {

struct ComputeBinKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* bin_indices,
int bin_cnt, float width, float min, float max) {
MSHADOW_XINLINE static void Map(int i, const DType* in_data, const DType* bin_bounds,
int* bin_indices, int bin_cnt, double min, double max) {
DType data = in_data[i];
int target = -1;
if (data >= min && data <= max) {
bin_indices[i] = mshadow_op::floor::Map((in_data[i] - min) / width);
bin_indices[i] = mshadow_op::minimum::Map(bin_cnt - 1, bin_indices[i]);
} else {
bin_indices[i] = -1;
target = (data - min) * bin_cnt / (max - min);
target = mshadow_op::minimum::Map(bin_cnt - 1, target);
target -= (data < bin_bounds[target]) ? 1 : 0;
target += ((data >= bin_bounds[target + 1]) && (target != bin_cnt - 1)) ? 1 : 0;
}
bin_indices[i] = target;
}

template<typename DType>
MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* bin_indices,
const DType* bin_bounds, int num_bins) {
DType data = in_data[i];
int target_idx = -1;
int target = -1;
if (data >= bin_bounds[0] && data <= bin_bounds[num_bins]) {
target_idx = 0;
while ((data - bin_bounds[target_idx]) >= 0) {
target_idx += 1;
target = 0;
while ((data - bin_bounds[target]) >= 0) {
target += 1;
}
target_idx = mshadow_op::minimum::Map(target_idx - 1, num_bins - 1);
target = mshadow_op::minimum::Map(target - 1, num_bins - 1);
}
bin_indices[i] = target_idx;
bin_indices[i] = target;
}
};

Expand All @@ -71,37 +73,44 @@ void HistogramForwardImpl(mshadow::Stream<cpu>* s,
const TBlob& out_bins) {
using namespace mshadow;
using namespace mxnet_op;
HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
const bool has_cnt = param.bin_cnt.has_value();
const bool has_range = param.range.has_value();
const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK(legal_param) << "width and range should both or neither be specified";
CHECK(!has_range || (has_range && (param.range.value().ndim() == 2U)))
<< "range should be a tuple with only 2 elements";
CHECK(!has_range || (has_range && (param.range.value()[0] < param.range.value()[1])))
<< "left hand side of range should be less than or equal to right hand side";

Tensor<cpu, 1, int> bin_indices =
ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
const int bin_cnt = out_data.Size();

MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
if (has_cnt) {
float max = param.range.value()[1];
float min = param.range.value()[0];
float width = (max - min) / bin_cnt;
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_,
bin_cnt, width, min, max);
Kernel<FillBinBoundsKernel, cpu>::Launch(
s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
} else {
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_, bin_bounds.dptr<DType>(),
bin_cnt);
Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
s, bin_bounds.Size(), out_bins.dptr<DType>(), bin_bounds.dptr<DType>());
}
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_, bin_bounds.dptr<DType>(),
bin_cnt);
Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
s, bin_bounds.Size(), out_bins.dptr<DType>(), bin_bounds.dptr<DType>());
});
MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
ComputeHistogram(bin_indices.dptr_, out_data.dptr<CType>(), in_data.Size());
});
}

template<typename cpu>
void HistogramForwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& out_data,
const TBlob& out_bins,
const int bin_cnt,
const double min,
const double max) {
using namespace mshadow;
using namespace mxnet_op;
Tensor<cpu, 1, int> bin_indices =
ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);

MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
Kernel<FillBinBoundsKernel, cpu>::Launch(
s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), out_bins.dptr<DType>(), bin_indices.dptr_,
bin_cnt, min, max);
});
MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
Expand All @@ -124,7 +133,10 @@ Example::
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<HistogramParam>)
.set_num_inputs(2)
.set_num_inputs([](const NodeAttrs& attrs) {
const HistogramParam& params = nnvm::get<HistogramParam>(attrs.parsed);
return params.bin_cnt.has_value() ? 1 : 2;
})
.set_num_outputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
Expand Down
Loading

0 comments on commit 240eecc

Please sign in to comment.