Skip to content

Commit

Permalink
ARROW-13016: [C++][Compute] Support Null type in Sum/Mean aggregation
Browse files Browse the repository at this point in the history
The Sum/Mean of a Null type array is a Null type scalar.

Closes #10486 from Crystrix/arrow-13016

Authored-by: crystrix <chenxi.li@live.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
Crystrix authored and lidavidm committed Dec 3, 2021
1 parent 6cd288e commit eee9df5
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 1 deletion.
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/kernels/aggregate_basic.cc
Expand Up @@ -259,7 +259,7 @@ Result<std::unique_ptr<KernelState>> SumInit(KernelContext* ctx,

Result<std::unique_ptr<KernelState>> MeanInit(KernelContext* ctx,
const KernelInitArgs& args) {
SumLikeInit<MeanImplDefault> visitor(
MeanKernelInit<MeanImplDefault> visitor(
ctx, args.inputs[0].type,
static_cast<const ScalarAggregateOptions&>(*args.options));
return visitor.Create();
Expand Down Expand Up @@ -929,6 +929,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddArrayScalarAggKernels(SumInit, SignedIntTypes(), int64(), func.get());
AddArrayScalarAggKernels(SumInit, UnsignedIntTypes(), uint64(), func.get());
AddArrayScalarAggKernels(SumInit, FloatingPointTypes(), float64(), func.get());
AddArrayScalarAggKernels(SumInit, {null()}, int64(), func.get());
// Add the SIMD variants for sum
#if defined(ARROW_HAVE_RUNTIME_AVX2) || defined(ARROW_HAVE_RUNTIME_AVX512)
auto cpu_info = arrow::internal::CpuInfo::GetInstance();
Expand All @@ -955,6 +956,7 @@ void RegisterScalarAggregateBasic(FunctionRegistry* registry) {
AddAggKernel(
KernelSignature::Make({InputType(Type::DECIMAL256)}, OutputType(ScalarFirstType)),
MeanInit, func.get(), SimdLevel::NONE);
AddArrayScalarAggKernels(MeanInit, {null()}, float64(), func.get());
// Add the SIMD variants for mean
#if defined(ARROW_HAVE_RUNTIME_AVX2)
if (cpu_info->IsSupported(arrow::internal::CpuInfo::AVX2)) {
Expand Down
51 changes: 51 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_basic_internal.h
Expand Up @@ -121,6 +121,40 @@ struct SumImpl : public ScalarAggregator {
ScalarAggregateOptions options;
};

template <typename ArrowType>
struct NullSumImpl : public ScalarAggregator {
using ScalarType = typename TypeTraits<ArrowType>::ScalarType;

explicit NullSumImpl(const ScalarAggregateOptions& options_) : options(options_) {}

Status Consume(KernelContext*, const ExecBatch& batch) override {
if (batch[0].is_scalar() || batch[0].array()->GetNullCount() > 0) {
// If the batch is a scalar or an array with elements, set is_empty to false
is_empty = false;
}
return Status::OK();
}

Status MergeFrom(KernelContext*, KernelState&& src) override {
const auto& other = checked_cast<const NullSumImpl&>(src);
this->is_empty &= other.is_empty;
return Status::OK();
}

Status Finalize(KernelContext*, Datum* out) override {
if ((options.skip_nulls || this->is_empty) && options.min_count == 0) {
// Return 0 if the remaining data is empty
out->value = std::make_shared<ScalarType>(0);
} else {
out->value = MakeNullScalar(TypeTraits<ArrowType>::type_singleton());
}
return Status::OK();
}

bool is_empty = true;
ScalarAggregateOptions options;
};

template <typename ArrowType, SimdLevel::type SimdLevel>
struct MeanImpl : public SumImpl<ArrowType, SimdLevel> {
using SumImpl<ArrowType, SimdLevel>::SumImpl;
Expand Down Expand Up @@ -200,12 +234,29 @@ struct SumLikeInit {
return Status::OK();
}

virtual Status Visit(const NullType&) {
state.reset(new NullSumImpl<Int64Type>(options));
return Status::OK();
}

Result<std::unique_ptr<KernelState>> Create() {
RETURN_NOT_OK(VisitTypeInline(*type, this));
return std::move(state);
}
};

template <template <typename> class KernelClass>
struct MeanKernelInit : public SumLikeInit<KernelClass> {
MeanKernelInit(KernelContext* ctx, const std::shared_ptr<DataType>& type,
const ScalarAggregateOptions& options)
: SumLikeInit<KernelClass>(ctx, type, options) {}

Status Visit(const NullType&) override {
this->state.reset(new NullSumImpl<DoubleType>(this->options));
return Status::OK();
}
};

// ----------------------------------------------------------------------
// MinMax implementation

Expand Down
52 changes: 52 additions & 0 deletions cpp/src/arrow/compute/kernels/aggregate_test.cc
Expand Up @@ -591,6 +591,32 @@ TEST(TestDecimalSumKernel, ScalarAggregateOptions) {
}
}

TEST(TestNullSumKernel, Basics) {
auto ty = null();
Datum null_result = std::make_shared<Int64Scalar>();
Datum zero_result = std::make_shared<Int64Scalar>(0);

EXPECT_THAT(Sum(ScalarFromJSON(ty, "null")), ResultWith(null_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]")), ResultWith(null_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]")), ResultWith(null_result));
EXPECT_THAT(Sum(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"})),
ResultWith(null_result));

ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
EXPECT_THAT(Sum(ScalarFromJSON(ty, "null"), options), ResultWith(zero_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"), options), ResultWith(zero_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"), options), ResultWith(zero_result));
EXPECT_THAT(Sum(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"}), options),
ResultWith(zero_result));

options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
EXPECT_THAT(Sum(ScalarFromJSON(ty, "null"), options), ResultWith(null_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[]"), options), ResultWith(zero_result));
EXPECT_THAT(Sum(ArrayFromJSON(ty, "[null]"), options), ResultWith(null_result));
EXPECT_THAT(Sum(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"}), options),
ResultWith(null_result));
}

//
// Product
//
Expand Down Expand Up @@ -1354,6 +1380,32 @@ TEST(TestDecimalMeanKernel, ScalarAggregateOptions) {
}
}

TEST(TestNullMeanKernel, Basics) {
auto ty = null();
Datum null_result = std::make_shared<DoubleScalar>();
Datum zero_result = std::make_shared<DoubleScalar>(0);

EXPECT_THAT(Mean(ScalarFromJSON(ty, "null")), ResultWith(null_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]")), ResultWith(null_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]")), ResultWith(null_result));
EXPECT_THAT(Mean(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"})),
ResultWith(null_result));

ScalarAggregateOptions options(/*skip_nulls=*/true, /*min_count=*/0);
EXPECT_THAT(Mean(ScalarFromJSON(ty, "null"), options), ResultWith(zero_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"), options), ResultWith(zero_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"), options), ResultWith(zero_result));
EXPECT_THAT(Mean(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"}), options),
ResultWith(zero_result));

options = ScalarAggregateOptions(/*skip_nulls=*/false, /*min_count=*/0);
EXPECT_THAT(Mean(ScalarFromJSON(ty, "null"), options), ResultWith(null_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[]"), options), ResultWith(zero_result));
EXPECT_THAT(Mean(ArrayFromJSON(ty, "[null]"), options), ResultWith(null_result));
EXPECT_THAT(Mean(ChunkedArrayFromJSON(ty, {"[null]", "[]", "[null, null]"}), options),
ResultWith(null_result));
}

//
// Min / Max
//
Expand Down

0 comments on commit eee9df5

Please sign in to comment.