Skip to content

Commit

Permalink
ARROW-8129: [C++][Compute] Refine compare sort kernel
Browse files Browse the repository at this point in the history
Sorting kernel implements two comparison functions, CompareValues uses
array.Value() for numerical data and CompareViews uses array.GetView()
for non-numerical ones. It can be simplified by using GetView() only
as all data types support GetView(). This patch also refines unit test.

This change improves about 40% performance.

Closes #6640 from cyb70289/sort-refine

Authored-by: Yibo Cai <yibo.cai@arm.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
cyb70289 authored and pitrou committed Mar 17, 2020
1 parent ec7fce5 commit 70db8ab
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 77 deletions.
10 changes: 0 additions & 10 deletions cpp/src/arrow/compute/kernels/nth_to_indices_test.cc
Expand Up @@ -39,16 +39,6 @@ template <typename ArrayType>
class Comparator {
public:
bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) {
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
return array.Value(lhs) <= array.Value(rhs);
}
};

template <>
class Comparator<StringArray> {
public:
bool operator()(const BinaryArray& array, uint64_t lhs, uint64_t rhs) {
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
return array.GetView(lhs) <= array.GetView(rhs);
Expand Down
66 changes: 23 additions & 43 deletions cpp/src/arrow/compute/kernels/sort_to_indices.cc
Expand Up @@ -63,23 +63,11 @@ class ARROW_EXPORT SortToIndicesKernel : public UnaryKernel {
std::unique_ptr<SortToIndicesKernel>* out);
};

template <typename ArrayType>
bool CompareValues(const ArrayType& array, uint64_t lhs, uint64_t rhs) {
return array.Value(lhs) < array.Value(rhs);
}

template <typename ArrayType>
bool CompareViews(const ArrayType& array, uint64_t lhs, uint64_t rhs) {
return array.GetView(lhs) < array.GetView(rhs);
}

template <typename ArrowType, typename Comparator>
template <typename ArrowType>
class CompareSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;

public:
explicit CompareSorter(Comparator compare) : compare_(compare) {}

void Sort(int64_t* indices_begin, int64_t* indices_end, const ArrayType& values) {
std::iota(indices_begin, indices_end, 0);

Expand All @@ -90,13 +78,10 @@ class CompareSorter {
[&values](uint64_t ind) { return !values.IsNull(ind); });
}
std::stable_sort(indices_begin, nulls_begin,
[&values, this](uint64_t left, uint64_t right) {
return compare_(values, left, right);
[&values](uint64_t left, uint64_t right) {
return values.GetView(left) < values.GetView(right);
});
}

private:
Comparator compare_;
};

template <typename ArrowType>
Expand Down Expand Up @@ -164,14 +149,12 @@ class CountSorter {
// Sort integers with counting sort or comparison based sorting algorithm
// - Use O(n) counting sort if values are in a small range
// - Use O(nlogn) std::stable_sort otherwise
template <typename ArrowType, typename Comparator>
template <typename ArrowType>
class CountOrCompareSorter {
using ArrayType = typename TypeTraits<ArrowType>::ArrayType;
using c_type = typename ArrowType::c_type;

public:
explicit CountOrCompareSorter(Comparator compare) : compare_sorter_(compare) {}

void Sort(int64_t* indices_begin, int64_t* indices_end, const ArrayType& values) {
if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) {
c_type min{std::numeric_limits<c_type>::max()};
Expand All @@ -198,7 +181,7 @@ class CountOrCompareSorter {
}

private:
CompareSorter<ArrowType, Comparator> compare_sorter_;
CompareSorter<ArrowType> compare_sorter_;
CountSorter<ArrowType> count_sorter_;

// Cross point to prefer counting sort than stl::stable_sort(merge sort)
Expand Down Expand Up @@ -257,22 +240,19 @@ class SortToIndicesKernelImpl : public SortToIndicesKernel {
}
};

template <typename ArrowType, typename Comparator,
typename Sorter = CompareSorter<ArrowType, Comparator>>
SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCompareKernel(Comparator comparator) {
return new SortToIndicesKernelImpl<ArrowType, Sorter>(Sorter(comparator));
template <typename ArrowType, typename Sorter = CompareSorter<ArrowType>>
static SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCompareKernel() {
return new SortToIndicesKernelImpl<ArrowType, Sorter>(Sorter());
}

template <typename ArrowType, typename Sorter = CountSorter<ArrowType>>
SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCountKernel(int min, int max) {
static SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCountKernel(int min, int max) {
return new SortToIndicesKernelImpl<ArrowType, Sorter>(Sorter(min, max));
}

template <typename ArrowType, typename Comparator,
typename Sorter = CountOrCompareSorter<ArrowType, Comparator>>
SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCountOrCompareKernel(
Comparator comparator) {
return new SortToIndicesKernelImpl<ArrowType, Sorter>(Sorter(comparator));
template <typename ArrowType, typename Sorter = CountOrCompareSorter<ArrowType>>
static SortToIndicesKernelImpl<ArrowType, Sorter>* MakeCountOrCompareKernel() {
return new SortToIndicesKernelImpl<ArrowType, Sorter>(Sorter());
}

Status SortToIndicesKernel::Make(const std::shared_ptr<DataType>& value_type,
Expand All @@ -286,34 +266,34 @@ Status SortToIndicesKernel::Make(const std::shared_ptr<DataType>& value_type,
kernel = MakeCountKernel<Int8Type>(-128, 127);
break;
case Type::UINT16:
kernel = MakeCountOrCompareKernel<UInt16Type>(CompareValues<UInt16Array>);
kernel = MakeCountOrCompareKernel<UInt16Type>();
break;
case Type::INT16:
kernel = MakeCountOrCompareKernel<Int16Type>(CompareValues<Int16Array>);
kernel = MakeCountOrCompareKernel<Int16Type>();
break;
case Type::UINT32:
kernel = MakeCountOrCompareKernel<UInt32Type>(CompareValues<UInt32Array>);
kernel = MakeCountOrCompareKernel<UInt32Type>();
break;
case Type::INT32:
kernel = MakeCountOrCompareKernel<Int32Type>(CompareValues<Int32Array>);
kernel = MakeCountOrCompareKernel<Int32Type>();
break;
case Type::UINT64:
kernel = MakeCountOrCompareKernel<UInt64Type>(CompareValues<UInt64Array>);
kernel = MakeCountOrCompareKernel<UInt64Type>();
break;
case Type::INT64:
kernel = MakeCountOrCompareKernel<Int64Type>(CompareValues<Int64Array>);
kernel = MakeCountOrCompareKernel<Int64Type>();
break;
case Type::FLOAT:
kernel = MakeCompareKernel<FloatType>(CompareValues<FloatArray>);
kernel = MakeCompareKernel<FloatType>();
break;
case Type::DOUBLE:
kernel = MakeCompareKernel<DoubleType>(CompareValues<DoubleArray>);
kernel = MakeCompareKernel<DoubleType>();
break;
case Type::BINARY:
kernel = MakeCompareKernel<BinaryType>(CompareViews<BinaryArray>);
kernel = MakeCompareKernel<BinaryType>();
break;
case Type::STRING:
kernel = MakeCompareKernel<StringType>(CompareViews<StringArray>);
kernel = MakeCompareKernel<StringType>();
break;
default:
return Status::NotImplemented("Sorting of ", *value_type, " arrays");
Expand All @@ -322,7 +302,7 @@ Status SortToIndicesKernel::Make(const std::shared_ptr<DataType>& value_type,
return Status::OK();
}

Status SortToIndices(FunctionContext* ctx, const Datum& values, Datum* offsets) {
static Status SortToIndices(FunctionContext* ctx, const Datum& values, Datum* offsets) {
std::unique_ptr<SortToIndicesKernel> kernel;
RETURN_NOT_OK(SortToIndicesKernel::Make(values.type(), &kernel));
return kernel->Call(ctx, values, offsets);
Expand Down
34 changes: 10 additions & 24 deletions cpp/src/arrow/compute/kernels/sort_to_indices_test.cc
Expand Up @@ -32,6 +32,8 @@
namespace arrow {
namespace compute {

using arrow::internal::checked_pointer_cast;

template <typename ArrowType>
class TestSortToIndicesKernel : public ComputeFixture, public TestBase {
private:
Expand Down Expand Up @@ -131,26 +133,10 @@ using SortToIndicesableTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, StringType>;

using SortToIndicesIntegerTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
Int32Type, Int64Type>;

template <typename ArrayType>
class Comparator {
public:
bool operator()(const ArrayType& array, uint64_t lhs, uint64_t rhs) {
if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs;
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
if (array.Value(lhs) == array.Value(rhs)) return lhs < rhs;
return array.Value(lhs) < array.Value(rhs);
}
};

template <>
class Comparator<StringArray> {
public:
bool operator()(const BinaryArray& array, uint64_t lhs, uint64_t rhs) {
if (array.IsNull(rhs) && array.IsNull(lhs)) return lhs < rhs;
if (array.IsNull(rhs)) return true;
if (array.IsNull(lhs)) return false;
Expand Down Expand Up @@ -230,16 +216,16 @@ TYPED_TEST(TestSortToIndicesKernelRandom, SortRandomValues) {
auto array = rand.Generate(length, null_probability);
std::shared_ptr<Array> offsets;
ASSERT_OK(arrow::compute::SortToIndices(&this->ctx_, *array, &offsets));
ValidateSorted<ArrayType>(*std::static_pointer_cast<ArrayType>(array),
*std::static_pointer_cast<UInt64Array>(offsets));
ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
*checked_pointer_cast<UInt64Array>(offsets));
}
}
}

// Long array with small value range: counting sort
// - length >= 1024(CountCompareSorter::countsort_min_len_)
// - range <= 4096(CountCompareSorter::countsort_max_range_)
TYPED_TEST_SUITE(TestSortToIndicesKernelRandomCount, SortToIndicesIntegerTypes);
TYPED_TEST_SUITE(TestSortToIndicesKernelRandomCount, IntegralArrowTypes);

TYPED_TEST(TestSortToIndicesKernelRandomCount, SortRandomValuesCount) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
Expand All @@ -253,14 +239,14 @@ TYPED_TEST(TestSortToIndicesKernelRandomCount, SortRandomValuesCount) {
auto array = rand.Generate(length, range, null_probability);
std::shared_ptr<Array> offsets;
ASSERT_OK(arrow::compute::SortToIndices(&this->ctx_, *array, &offsets));
ValidateSorted<ArrayType>(*std::static_pointer_cast<ArrayType>(array),
*std::static_pointer_cast<UInt64Array>(offsets));
ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
*checked_pointer_cast<UInt64Array>(offsets));
}
}
}

// Long array with big value range: std::stable_sort
TYPED_TEST_SUITE(TestSortToIndicesKernelRandomCompare, SortToIndicesIntegerTypes);
TYPED_TEST_SUITE(TestSortToIndicesKernelRandomCompare, IntegralArrowTypes);

TYPED_TEST(TestSortToIndicesKernelRandomCompare, SortRandomValuesCompare) {
using ArrayType = typename TypeTraits<TypeParam>::ArrayType;
Expand All @@ -273,8 +259,8 @@ TYPED_TEST(TestSortToIndicesKernelRandomCompare, SortRandomValuesCompare) {
auto array = rand.Generate(length, null_probability);
std::shared_ptr<Array> offsets;
ASSERT_OK(arrow::compute::SortToIndices(&this->ctx_, *array, &offsets));
ValidateSorted<ArrayType>(*std::static_pointer_cast<ArrayType>(array),
*std::static_pointer_cast<UInt64Array>(offsets));
ValidateSorted<ArrayType>(*checked_pointer_cast<ArrayType>(array),
*checked_pointer_cast<UInt64Array>(offsets));
}
}
}
Expand Down

0 comments on commit 70db8ab

Please sign in to comment.