Skip to content

Commit

Permalink
incorporate Francois' suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
bkietz committed Apr 8, 2019
1 parent 8be7df1 commit 198320d
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 47 deletions.
19 changes: 10 additions & 9 deletions cpp/src/arrow/array/builder_binary.h
Expand Up @@ -195,9 +195,7 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
}

void UnsafeAppend(util::string_view value) {
#ifndef NDEBUG
CheckValueSize(static_cast<int64_t>(value.size()));
#endif
CheckValueSize(value);
UnsafeAppend(reinterpret_cast<const uint8_t*>(value.data()));
}

Expand All @@ -206,16 +204,12 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
}

Status Append(const util::string_view& view) {
#ifndef NDEBUG
CheckValueSize(static_cast<int64_t>(view.size()));
#endif
CheckValueSize(view);
return Append(reinterpret_cast<const uint8_t*>(view.data()));
}

Status Append(const std::string& s) {
#ifndef NDEBUG
CheckValueSize(static_cast<int64_t>(s.size()));
#endif
CheckValueSize(s);
return Append(reinterpret_cast<const uint8_t*>(s.data()));
}

Expand Down Expand Up @@ -258,6 +252,13 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {
int32_t byte_width_;
BufferBuilder byte_builder_;

template <typename Sized>
void CheckValueSize(const Sized& s, decltype(s.size())* = nullptr) {
#ifndef NDEBUG
CheckValueSize(static_cast<size_t>(s.size()));
#endif
}

#ifndef NDEBUG
void CheckValueSize(int64_t size);
#endif
Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/compute/kernels/take-test.cc
Expand Up @@ -92,7 +92,7 @@ TEST_F(TestTakeKernelWithBoolean, TakeBoolean) {
ASSERT_RAISES(Invalid,
this->Take(boolean(), "[true, false, true]", "[0, 9, 0]", options, &arr));

options.out_of_bounds = TakeOptions::TONULL;
options.out_of_bounds = TakeOptions::TO_NULL;
this->AssertTake("[true, false, true]", "[0, 9, 0]", options, "[true, null, true]");
}

Expand Down Expand Up @@ -120,7 +120,7 @@ TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", "[0, 9, 0]",
options, &arr));

options.out_of_bounds = TakeOptions::TONULL;
options.out_of_bounds = TakeOptions::TO_NULL;
this->AssertTake("[7, 8, 9]", "[0, 9, 0]", options, "[7, null, 7]");
}

Expand Down Expand Up @@ -155,7 +155,7 @@ TEST_F(TestTakeKernelWithString, TakeString) {
ASSERT_RAISES(Invalid,
this->Take(utf8(), R"(["a", "b", "c"])", "[0, 9, 0]", options, &arr));

options.out_of_bounds = TakeOptions::TONULL;
options.out_of_bounds = TakeOptions::TO_NULL;
this->AssertTake(R"(["a", "b", "c"])", "[0, 9, 0]", options, R"(["a", null, "a"])");
}

Expand All @@ -167,7 +167,7 @@ TEST_F(TestTakeKernelWithString, TakeDictionary) {
"[null, 1, null]");
this->AssertTakeDictionary(dict, "[0, 1, 4]", "[null, 1, 0]", options, "[null, 1, 0]");

options.out_of_bounds = TakeOptions::TONULL;
options.out_of_bounds = TakeOptions::TO_NULL;
this->AssertTakeDictionary(dict, "[0, 1, 4]", "[0, 9, 0]", options, "[0, null, 0]");
}

Expand Down
75 changes: 46 additions & 29 deletions cpp/src/arrow/compute/kernels/take.cc
Expand Up @@ -29,13 +29,20 @@ namespace compute {

Status Take(FunctionContext* context, const Array& values, const Array& indices,
const TakeOptions& options, std::shared_ptr<Array>* out) {
TakeKernel kernel(values.type(), options);
Datum out_datum;
RETURN_NOT_OK(kernel.Call(context, values.data(), indices.data(), &out_datum));
RETURN_NOT_OK(
Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum));
*out = out_datum.make_array();
return Status::OK();
}

Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
const TakeOptions& options, Datum* out) {
TakeKernel kernel(values.type(), options);
RETURN_NOT_OK(kernel.Call(context, values, indices, out));
return Status::OK();
}

struct TakeParameters {
FunctionContext* context;
std::shared_ptr<Array> values, indices;
Expand All @@ -61,22 +68,21 @@ Status UnsafeAppend(StringBuilder* builder, util::string_view value) {
return Status::OK();
}

template <int OutOfBounds, bool AllValuesValid, bool AllIndicesValid, typename ValueArray,
typename IndexArray, typename OutBuilder>
template <TakeOptions::OutOfBoundsBehavior B, bool AllValuesValid, bool AllIndicesValid,
typename ValueArray, typename IndexArray, typename OutBuilder>
Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices,
OutBuilder* builder) {
for (int64_t i = 0; i != indices.length(); ++i) {
auto raw_indices = indices.raw_values();
for (int64_t i = 0; i < indices.length(); ++i) {
if (!AllIndicesValid && indices.IsNull(i)) {
builder->UnsafeAppendNull();
continue;
}
auto index = indices.raw_values()[i];
if (OutOfBounds == TakeOptions::ERROR &&
static_cast<int64_t>(index) >= values.length()) {
auto index = static_cast<int64_t>(raw_indices[i]);
if (B == TakeOptions::ERROR && (index < 0 || index >= values.length())) {
return Status::Invalid("take index out of bounds");
}
if (OutOfBounds == TakeOptions::TONULL &&
static_cast<int64_t>(index) >= values.length()) {
if (B == TakeOptions::TO_NULL && (index < 0 || index >= values.length())) {
builder->UnsafeAppendNull();
continue;
}
Expand All @@ -89,23 +95,24 @@ Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& in
return Status::OK();
}

template <int OutOfBounds, bool AllValuesValid, typename ValueArray, typename IndexArray,
typename OutBuilder>
template <TakeOptions::OutOfBoundsBehavior B, bool AllValuesValid, typename ValueArray,
typename IndexArray, typename OutBuilder>
Status TakeImpl(FunctionContext* context, const ValueArray& values,
const IndexArray& indices, OutBuilder* builder) {
if (indices.null_count() == 0) {
return TakeImpl<OutOfBounds, AllValuesValid, true>(context, values, indices, builder);
return TakeImpl<B, AllValuesValid, true>(context, values, indices, builder);
}
return TakeImpl<OutOfBounds, AllValuesValid, false>(context, values, indices, builder);
return TakeImpl<B, AllValuesValid, false>(context, values, indices, builder);
}

template <int OutOfBounds, typename ValueArray, typename IndexArray, typename OutBuilder>
template <TakeOptions::OutOfBoundsBehavior B, typename ValueArray, typename IndexArray,
typename OutBuilder>
Status TakeImpl(FunctionContext* context, const ValueArray& values,
const IndexArray& indices, OutBuilder* builder) {
if (values.null_count() == 0) {
return TakeImpl<OutOfBounds, true>(context, values, indices, builder);
return TakeImpl<B, true>(context, values, indices, builder);
}
return TakeImpl<OutOfBounds, false>(context, values, indices, builder);
return TakeImpl<B, false>(context, values, indices, builder);
}

template <typename ValueArray, typename IndexArray, typename OutBuilder>
Expand All @@ -115,8 +122,8 @@ Status TakeImpl(FunctionContext* context, const ValueArray& values,
switch (options.out_of_bounds) {
case TakeOptions::ERROR:
return TakeImpl<TakeOptions::ERROR>(context, values, indices, builder);
case TakeOptions::TONULL:
return TakeImpl<TakeOptions::TONULL>(context, values, indices, builder);
case TakeOptions::TO_NULL:
return TakeImpl<TakeOptions::TO_NULL>(context, values, indices, builder);
case TakeOptions::UNSAFE:
return TakeImpl<TakeOptions::UNSAFE>(context, values, indices, builder);
default:
Expand Down Expand Up @@ -147,8 +154,10 @@ struct UnpackValues {
auto indices_length = params_.indices->length();
if (params_.options.out_of_bounds == TakeOptions::ERROR && indices_length != 0) {
auto indices = static_cast<IndexArrayRef>(*params_.indices).raw_values();
auto max = *std::max_element(indices, indices + indices_length);
if (static_cast<int64_t>(max) > params_.values->length()) {
auto minmax = std::minmax_element(indices, indices + indices_length);
auto min = static_cast<int64_t>(*minmax.first);
auto max = static_cast<int64_t>(*minmax.second);
if (min < 0 || max >= params_.values->length()) {
return Status::Invalid("out of bounds index");
}
}
Expand All @@ -157,14 +166,20 @@ struct UnpackValues {
}

Status Visit(const DictionaryType& t) {
auto dictionary_indices = params_.values->data()->Copy();
dictionary_indices->type = t.index_type();
TakeParameters params = params_;
params.values = MakeArray(dictionary_indices);
UnpackValues<IndexType> unpack = {params};
RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack));
(*params_.out)->data()->type = dictionary(t.index_type(), t.dictionary());
return Status::OK();
std::shared_ptr<Array> taken_indices;
{
// To take from a dictionary, apply the current kernel to the dictionary's
// indices. (Use UnpackValues<IndexType> since IndexType is already unpacked)
auto indices = static_cast<const DictionaryArray*>(params_.values.get())->indices();
TakeParameters params = params_;
params.values = indices;
params.out = &taken_indices;
UnpackValues<IndexType> unpack = {params};
RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack));
}
// create output dictionary from taken indices
return DictionaryArray::FromArrays(dictionary(t.index_type(), t.dictionary()),
taken_indices, params_.out);
}

Status Visit(const ExtensionType& t) {
Expand Down Expand Up @@ -193,9 +208,11 @@ struct UnpackIndices {
UnpackValues<IndexType> unpack = {params_};
return VisitTypeInline(*params_.values->type(), &unpack);
}

Status Visit(const DataType& other) {
return Status::Invalid("index type not supported: ", other);
}

const TakeParameters& params_;
};

Expand Down
23 changes: 18 additions & 5 deletions cpp/src/arrow/compute/kernels/take.h
Expand Up @@ -32,12 +32,14 @@ namespace compute {
class FunctionContext;

struct ARROW_EXPORT TakeOptions {
enum {
// indices out of bounds will raise an error
enum OutOfBoundsBehavior {
// Out of bounds indices will raise an error
ERROR,
// indices out of bounds will result in a null value
TONULL,
// indices out of bounds is undefined behavior
// Out of bounds indices will result in a null value
TO_NULL,
// Bounds checking will be skipped, which is faster.
// Only use this if indices are known to be within bounds;
// out of bounds indices will result in undefined behavior
UNSAFE
} out_of_bounds = ERROR;
};
Expand All @@ -62,6 +64,17 @@ ARROW_EXPORT
Status Take(FunctionContext* context, const Array& values, const Array& indices,
const TakeOptions& options, std::shared_ptr<Array>* out);

/// \brief Take from an array of values at indices in another array
///
/// \param[in] context the FunctionContext
/// \param[in] values datum from which to take
/// \param[in] indices which values to take
/// \param[in] options options
/// \param[out] out resulting datum
ARROW_EXPORT
Status Take(FunctionContext* context, const Datum& values, const Datum& indices,
const TakeOptions& options, Datum* out);

/// \brief BinaryKernel implementing Take operation
class ARROW_EXPORT TakeKernel : public BinaryKernel {
public:
Expand Down

0 comments on commit 198320d

Please sign in to comment.