Skip to content

Commit

Permalink
GH-36053: [C++] summarizing a variable results in NA at random, while…
Browse files Browse the repository at this point in the history
… there is no NA in the subset of data (#36368)

### Rationale for this change

When merging two aggregate states we were failing to use the correct `no_nulls` field.  This field tells us whether we should return `null` if `skip_nulls=False` (if `no_nulls` is false then we return null).

Since we were reading the wrong field we would sometimes emit null even when a column didn't actually have any nulls.

### What changes are included in this PR?

Fixed the bug.

### Are these changes tested?

Yes, I added a new unit test that reproduced this failure quite reliably.

### Are there any user-facing changes?

No.
* Closes: #36053

Authored-by: Weston Pace <weston.pace@gmail.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
westonpace committed Jun 29, 2023
1 parent 3fa7d60 commit 0cea12f
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
48 changes: 48 additions & 0 deletions cpp/src/arrow/acero/aggregate_node_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@

#include <memory>

#include "arrow/acero/test_util_internal.h"
#include "arrow/compute/api_aggregate.h"
#include "arrow/result.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/string.h"

namespace arrow {

Expand Down Expand Up @@ -162,5 +166,49 @@ TEST(GroupByConvenienceFunc, VarianceMultiBatch) { TestVarStdMultiBatch("varianc

TEST(GroupByConvenienceFunc, StdDevMultiBatch) { TestVarStdMultiBatch("stddev"); }

TEST(GroupByNode, NoSkipNulls) {
constexpr int kNumBatches = 128;

std::shared_ptr<Schema> in_schema =
schema({field("key", int32()), field("value", int32())});

// This regresses GH-36053. Some groups have nulls and other groups do not. The
// "does this group have nulls" field needs to merge correctly between different
// aggregate states. We use 128 batches to encourage multiple thread states to
// be used.
ExecBatch nulls_batch =
ExecBatchFromJSON({int32(), int32()}, "[[1, null], [1, null], [1, null]]");
ExecBatch no_nulls_batch =
ExecBatchFromJSON({int32(), int32()}, "[[2, 1], [2, 1], [2, 1]]");

std::vector<ExecBatch> batches;
batches.reserve(kNumBatches);
for (int i = 0; i < kNumBatches; i += 2) {
batches.push_back(nulls_batch);
batches.push_back(no_nulls_batch);
}

std::vector<Aggregate> aggregates = {Aggregate(
"hash_sum", std::make_shared<compute::ScalarAggregateOptions>(/*skip_nulls=*/false),
FieldRef("value"))};
std::vector<FieldRef> keys = {"key"};

Declaration plan = Declaration::Sequence(
{{"exec_batch_source", ExecBatchSourceNodeOptions(in_schema, std::move(batches))},
{"aggregate", AggregateNodeOptions(aggregates, keys)}});

ASSERT_OK_AND_ASSIGN(BatchesWithCommonSchema out_batches,
DeclarationToExecBatches(plan));

std::shared_ptr<Schema> out_schema =
schema({field("key", int32()), field("sum_value", int64())});
int32_t expected_sum = static_cast<int32_t>(no_nulls_batch.length) *
static_cast<int32_t>(bit_util::CeilDiv(kNumBatches, 2));
ExecBatch expected_batch = ExecBatchFromJSON(
{int32(), int64()}, "[[1, null], [2, " + internal::ToChars(expected_sum) + "]]");

AssertExecBatchesEqualIgnoringOrder(out_schema, {expected_batch}, out_batches.batches);
}

} // namespace acero
} // namespace arrow
2 changes: 1 addition & 1 deletion cpp/src/arrow/compute/kernels/hash_aggregate.cc
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ struct GroupedReducingAggregator : public GroupedAggregator {

const CType* other_reduced = other->reduced_.data();
const int64_t* other_counts = other->counts_.data();
const uint8_t* other_no_nulls = no_nulls_.mutable_data();
const uint8_t* other_no_nulls = other->no_nulls_.data();

auto g = group_id_mapping.GetValues<uint32_t>(1);
for (int64_t other_g = 0; other_g < group_id_mapping.length; ++other_g, ++g) {
Expand Down

0 comments on commit 0cea12f

Please sign in to comment.