Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
implementation of histogram operator
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed May 14, 2018
1 parent 89ffab9 commit 7c44e7f
Show file tree
Hide file tree
Showing 7 changed files with 517 additions and 2 deletions.
28 changes: 27 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
"ones", "add", "arange", "eye", "divide", "equal", "full", "greater", "greater_equal",
"imdecode", "lesser", "lesser_equal", "logical_and", "logical_or", "logical_xor",
"maximum", "minimum", "moveaxis", "modulo", "multiply", "not_equal", "onehot_encode",
"power", "subtract", "true_divide", "waitall", "_new_empty_handle"]
"power", "subtract", "true_divide", "waitall", "_new_empty_handle", "histogram"]

_STORAGE_TYPE_UNDEFINED = -1
_STORAGE_TYPE_DEFAULT = 0
Expand Down Expand Up @@ -3740,3 +3740,29 @@ def empty(shape, ctx=None, dtype=None):
if dtype is None:
dtype = mx_real_t
return NDArray(handle=_new_alloc_handle(shape, ctx, False, dtype))


def histogram(a, bins=10, range_=None):
"""Compute the histogram of the input data.
Parameters
----------
a : NDArray
Input data. The histogram is computed over the flattened array.
bins : int or sequence of scalars
If bins is an int, it defines the number of equal-width bins in the
given range (10, by default). If bins is a sequence, it defines the bin edges,
including the rightmost edge, allowing for non-uniform bin widths.
range_ : (float, float), optional
The lower and upper range of the bins. If not provided, range is simply (a.min(), a.max()).
Values outside the range are ignored. The first element of the range must be less than or
equal to the second. range affects the automatic bin computation as well, the range will
be equally divided by the number of bins.
"""

if isinstance(bins, NDArray):
return _internal._histogram(data=a, bins=bins)
elif isinstance(bins, int):
if range_ is None:
range_ = (float(a.min().asnumpy()[0]), float(a.max().asnumpy()[0]))
return _internal._histogram(data=a, bins=array([]), bin_cnt=bins, range=range_)
30 changes: 30 additions & 0 deletions src/common/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,36 @@ static inline __device__ void atomicAdd(mshadow::half::half_t *address,
} while (assumed != old);
}

static inline __device__ void atomicAdd(uint8_t *address, uint8_t val) {
unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
unsigned int old = *address_as_ui;
unsigned int shift = (((size_t)address & 0x3) << 3);
unsigned int sum;
unsigned int assumed;

do {
assumed = old;
sum = val + static_cast<uint8_t>((old >> shift) & 0xff);
old = (old & ~(0x000000ff << shift)) | (sum << shift);
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}

static inline __device__ void atomicAdd(int8_t *address, int8_t val) {
unsigned int * address_as_ui = (unsigned int *) (address - ((size_t)address & 0x3));
unsigned int old = *address_as_ui;
unsigned int shift = (((size_t)address & 0x3) << 3);
unsigned int sum;
unsigned int assumed;

do {
assumed = old;
sum = val + static_cast<int8_t>((old >> shift) & 0xff);
old = (old & ~(0x000000ff << shift)) | (sum << shift);
old = atomicCAS(address_as_ui, assumed, old);
} while (assumed != old);
}

// Overload atomicAdd to work for signed int64 on all architectures
static inline __device__ void atomicAdd(int64_t *address, int64_t val) {
atomicAdd(reinterpret_cast<unsigned long long*>(address), static_cast<unsigned long long>(val)); // NOLINT
Expand Down
178 changes: 178 additions & 0 deletions src/operator/tensor/histogram-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#ifndef MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
#define MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_

#include <dmlc/logging.h>
#include <dmlc/parameter.h>
#include <mxnet/operator.h>
#include <mxnet/operator_util.h>
#include <dmlc/optional.h>
#include <mshadow/tensor.h>
#include <nnvm/op.h>
#include <nnvm/node.h>
#include <nnvm/op_attr_types.h>
#include <vector>
#include <type_traits>
#include "./util/tensor_util-inl.h"
#include "../elemwise_op_common.h"
#include "../mshadow_op.h"
#include "../mxnet_op.h"
#include "../operator_common.h"

namespace mxnet {
namespace op {

struct HistogramParam : public dmlc::Parameter<HistogramParam> {
dmlc::optional<int> bin_cnt;
dmlc::optional<nnvm::Tuple<float>> range;
DMLC_DECLARE_PARAMETER(HistogramParam) {
DMLC_DECLARE_FIELD(bin_cnt)
.set_default(dmlc::optional<int>())
.describe("Number of bins for uniform case");
DMLC_DECLARE_FIELD(range)
.set_default(dmlc::optional<nnvm::Tuple<float>>())
.describe("The lower and upper range of the bins. if not provided, "
"range is simply (a.min(), a.max()). values outside the "
"range are ignored. the first element of the range must be "
"less than or equal to the second. range affects the automatic "
"bin computation as well. while bin width is computed to be "
"optimal based on the actual data within range, the bin count "
"will fill the entire range including portions containing no data.");
}
};

struct FillBinBoundsKernel {
template<typename DType>
static MSHADOW_XINLINE void Map(int i, DType* bin_bounds, int bin_cnt, float min, float max) {
if (i <= bin_cnt) {
bin_bounds[i] = DType((i * max + (bin_cnt - i) * min) / bin_cnt);
}
}
};

inline bool HistogramOpShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape>* in_attrs,
std::vector<TShape>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 2U);
HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
const bool has_cnt = param.bin_cnt.has_value();
const bool has_range = param.range.has_value();
const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK(legal_param) << "cnt and range should both or neither specified";
if (has_cnt) {
// if cnt is specified, the output histogram has shape (cnt,)
// while output bins has shape (cnt+1,)
const int bin_cnt = param.bin_cnt.value();
SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({bin_cnt}));
SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape({bin_cnt + 1}));
} else {
// if cnt is not specified, the output histogram has shape (bins.Size() - 1)
// while output bins has same shape as input bins
TShape oshape = (*in_attrs)[1];

CHECK_EQ(oshape.ndim(), 1U) << "bins argument should be an 1D vector";
CHECK_GE(oshape.Size(), 2U) << "number of bounds should be >= 2";

SHAPE_ASSIGN_CHECK(*out_attrs, 0, TShape({(oshape[0] - 1)}));
SHAPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(1));
}

return out_attrs->at(0).ndim() == 1U && out_attrs->at(0).Size() != 0U &&
out_attrs->at(1).ndim() == 1U && out_attrs->at(1).Size() != 0U &&
out_attrs->at(0).Size() == out_attrs->at(1).Size() - 1;
}

inline bool HistogramOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(in_attrs->at(0), in_attrs->at(1));
CHECK_EQ(out_attrs->size(), 2U);

TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));
return out_attrs->at(0) != -1 && out_attrs->at(1) != -1;
}

template<typename xpu>
void HistogramForwardImpl(mshadow::Stream<xpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins);

template<typename xpu>
void HistogramOpForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 2U);
CHECK_EQ(req.size(), 2U);
CHECK_EQ(req[0], kWriteTo);
CHECK_EQ(req[1], kWriteTo);

mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
const TBlob& in_data = inputs[0];
const TBlob& bin_bounds = inputs[1];
const TBlob& out_data = outputs[0];
const TBlob& out_bins = outputs[1];

HistogramForwardImpl<xpu>(s, ctx, attrs, in_data, bin_bounds, out_data, out_bins);
}

template<typename xpu>
void HistogramBackwardImpl(const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& out_grad,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& in_grad);

template<typename xpu>
void HistogramOpBackward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
CHECK_EQ(inputs.size(), 4U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req.size(), 1U);
CHECK_EQ(req[0], kWriteTo);

const TBlob& out_grad = inputs[0];
const TBlob& in_data = inputs[1];
const TBlob& bin_bounds = inputs[2];
const TBlob& out_data = inputs[3];
const TBlob& in_grad = outputs[0];

HistogramBackwardImpl<xpu>(ctx, attrs, out_grad, in_data, bin_bounds, out_data, in_grad);
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_TENSOR_HISTOGRAM_INL_H_
144 changes: 144 additions & 0 deletions src/operator/tensor/histogram.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include "./histogram-inl.h"

namespace mxnet {
namespace op {

struct ComputeBinKernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* bin_indices,
int bin_cnt, float width, float min, float max) {
DType data = in_data[i];
if (data >= min && data <= max) {
bin_indices[i] = mshadow_op::floor::Map((in_data[i] - min) / width);
bin_indices[i] = mshadow_op::minimum::Map(bin_cnt - 1, bin_indices[i]);
} else {
bin_indices[i] = -1;
}
}

template<typename DType>
MSHADOW_XINLINE static void Map(int i, const DType* in_data, int* bin_indices,
const DType* bin_bounds, int num_bins) {
DType data = in_data[i];
int target_idx = -1;
if (data >= bin_bounds[0] && data <= bin_bounds[num_bins]) {
target_idx = 0;
while ((data - bin_bounds[target_idx]) >= 0) {
target_idx += 1;
}
target_idx = mshadow_op::minimum::Map(target_idx - 1, num_bins - 1);
}
bin_indices[i] = target_idx;
}
};

template<typename CType>
void ComputeHistogram(const int* bin_indices, CType* out_data, size_t input_size) {
for (size_t i = 0; i < input_size; ++i) {
int target = bin_indices[i];
if (target >= 0) {
out_data[target] += 1;
}
}
}

template<typename cpu>
void HistogramForwardImpl(mshadow::Stream<cpu>* s,
const OpContext& ctx,
const nnvm::NodeAttrs& attrs,
const TBlob& in_data,
const TBlob& bin_bounds,
const TBlob& out_data,
const TBlob& out_bins) {
using namespace mshadow;
using namespace mxnet_op;
HistogramParam param = nnvm::get<HistogramParam>(attrs.parsed);
const bool has_cnt = param.bin_cnt.has_value();
const bool has_range = param.range.has_value();
const bool legal_param = (has_cnt && has_range) || (!has_cnt && !has_range);
CHECK(legal_param) << "width and range should both or neither be specified";

CHECK(!has_range || (has_range && (param.range.value().ndim() == 2U)));
CHECK(!has_range || (has_range && (param.range.value()[0] < param.range.value()[1])));

Tensor<cpu, 1, int> bin_indices =
ctx.requested[0].get_space_typed<cpu, 1, int>(Shape1(in_data.Size()), s);
const int bin_cnt = out_data.Size();
MSHADOW_TYPE_SWITCH(in_data.type_flag_, DType, {
if (has_cnt) {
float max = param.range.value()[1];
float min = param.range.value()[0];
float width = (max - min) / bin_cnt;
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_,
bin_cnt, width, min, max);
Kernel<FillBinBoundsKernel, cpu>::Launch(
s, bin_cnt+1, out_bins.dptr<DType>(), bin_cnt, min, max);
} else {
Kernel<ComputeBinKernel, cpu>::Launch(
s, in_data.Size(), in_data.dptr<DType>(), bin_indices.dptr_, bin_bounds.dptr<DType>(),
bin_cnt);
Kernel<op_with_req<mshadow_op::identity, kWriteTo>, cpu>::Launch(
s, bin_bounds.Size(), out_bins.dptr<DType>(), bin_bounds.dptr<DType>());
}
});
MSHADOW_TYPE_SWITCH(out_data.type_flag_, CType, {
Kernel<set_zero, cpu>::Launch(s, bin_cnt, out_data.dptr<CType>());
ComputeHistogram(bin_indices.dptr_, out_data.dptr<CType>(), in_data.Size());
});
}

DMLC_REGISTER_PARAMETER(HistogramParam);

NNVM_REGISTER_OP(_histogram)
.describe(R"code(This operators implements the histogram function.
Example::
x = [[0, 1], [2, 2], [3, 4]]
histo, bin_edges = histogram(data=x, bin_bounds=[], bin_cnt=5, range=(0,5))
histo = [1, 1, 2, 1, 1]
bin_edges = [0., 1., 2., 3., 4.]
histo, bin_edges = histogram(data=x, bin_bounds=[0., 2.1, 3.])
histo = [4, 1]
)code" ADD_FILELINE)
.set_attr_parser(ParamParser<HistogramParam>)
.set_num_inputs(2)
.set_num_outputs(2)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"data", "bins"};
})
.set_attr<nnvm::FInferShape>("FInferShape", HistogramOpShape)
.set_attr<nnvm::FInferType>("FInferType", HistogramOpType)
.set_attr<FCompute>("FCompute<cpu>", HistogramOpForward<cpu>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption",
[](const NodeAttrs& attrs) {
return std::vector<std::pair<int, int> >{};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_argument("bins", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(HistogramParam::__FIELDS__());

} // namespace op
} // namespace mxnet

Loading

0 comments on commit 7c44e7f

Please sign in to comment.