Skip to content

Commit

Permalink
apacheGH-35468: [C++] Fix Acero var/std for multiple batches (apache#…
Browse files Browse the repository at this point in the history
…35469)

See apache#35468

**This PR contains a "Critical Fix".**

The current result of Acero var/std for multiple batches is incorrect.
* Closes: apache#35468

Authored-by: Yaron Gvili <rtpsw@hotmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
rtpsw authored and liujiacheng777 committed May 11, 2023
1 parent 8b95025 commit eff04ca
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
27 changes: 27 additions & 0 deletions cpp/src/arrow/acero/groupby_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <memory>

#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"

namespace arrow {
Expand Down Expand Up @@ -123,5 +124,31 @@ TEST(GroupByConvenienceFunc, Invalid) {
TableGroupBy(in_table, {{"add", {"value"}, "value_add"}}, {}));
}

void TestVarStdMultiBatch(const std::string& var_std_func_name) {
std::shared_ptr<Schema> in_schema = schema({field("value", float64())});
std::shared_ptr<Table> in_table = TableFromJSON(in_schema, {R"([
[1],
[2],
[3]
])",
R"([
[4],
[4],
[4]
])"});

ASSERT_OK_AND_ASSIGN(std::shared_ptr<Table> actual,
TableGroupBy(in_table, {{var_std_func_name, {"value"}, "x"}}, {},
/*use_threads=*/false));

ASSERT_OK_AND_ASSIGN(auto var_scalar, actual->column(0)->GetScalar(0));
// the next assertion will fail if only the second batch affects the result
ASSERT_NE(0, std::dynamic_pointer_cast<DoubleScalar>(var_scalar)->value);
}

TEST(GroupByConvenienceFunc, VarianceMultiBatch) { TestVarStdMultiBatch("variance"); }

TEST(GroupByConvenienceFunc, StdDevMultiBatch) { TestVarStdMultiBatch("stddev"); }

} // namespace acero
} // namespace arrow
8 changes: 5 additions & 3 deletions cpp/src/arrow/compute/kernels/aggregate_var_std.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ struct VarStdState {
return (v - mean) * (v - mean);
});

this->count = count;
this->mean = mean;
this->m2 = m2;
ThisType state(decimal_scale, options);
state.count = count;
state.mean = mean;
state.m2 = m2;
this->MergeFrom(state);
}

// int32/16/8: textbook one pass algorithm with integer arithmetic
Expand Down

0 comments on commit eff04ca

Please sign in to comment.