From 7435dab75b045dc499d6c81bfcfc8cd038568879 Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Thu, 20 Nov 2025 16:25:30 -0800 Subject: [PATCH 01/12] Add aggregate expressions and evaluator (issue #330) --- src/iceberg/CMakeLists.txt | 1 + src/iceberg/expression/aggregate.cc | 289 ++++++++++++++++++++ src/iceberg/expression/aggregate.h | 186 +++++++++++++ src/iceberg/expression/binder.cc | 21 ++ src/iceberg/expression/binder.h | 6 + src/iceberg/expression/expression_visitor.h | 25 +- src/iceberg/expression/expressions.cc | 58 ++++ src/iceberg/expression/expressions.h | 41 +++ src/iceberg/meson.build | 1 + src/iceberg/test/CMakeLists.txt | 1 + src/iceberg/test/aggregate_test.cc | 132 +++++++++ src/iceberg/test/meson.build | 1 + 12 files changed, 761 insertions(+), 1 deletion(-) create mode 100644 src/iceberg/expression/aggregate.cc create mode 100644 src/iceberg/expression/aggregate.h create mode 100644 src/iceberg/test/aggregate_test.cc diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 22c222182..0b6000e4e 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -20,6 +20,7 @@ set(ICEBERG_INCLUDES "$" set(ICEBERG_SOURCES arrow_c_data_guard_internal.cc catalog/memory/in_memory_catalog.cc + expression/aggregate.cc expression/binder.cc expression/expression.cc expression/expressions.cc diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc new file mode 100644 index 000000000..2056d3312 --- /dev/null +++ b/src/iceberg/expression/aggregate.cc @@ -0,0 +1,289 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/aggregate.h" + +#include +#include + +#include "iceberg/exception.h" +#include "iceberg/expression/binder.h" +#include "iceberg/expression/expression.h" +#include "iceberg/row/struct_like.h" +#include "iceberg/type.h" +#include "iceberg/util/checked_cast.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +namespace { + +std::string OperationToPrefix(Expression::Operation op) { + switch (op) { + case Expression::Operation::kMax: + return "max"; + case Expression::Operation::kMin: + return "min"; + case Expression::Operation::kCount: + case Expression::Operation::kCountStar: + return "count"; + default: + break; + } + return "aggregate"; +} + +Result> GetPrimitiveType(const BoundTerm& term) { + auto primitive = std::dynamic_pointer_cast(term.type()); + if (primitive == nullptr) { + return InvalidExpression("Aggregate requires primitive type, got {}", + term.type()->ToString()); + } + return primitive; +} + +} // namespace + +CountAggregate::CountAggregate(Expression::Operation op, Mode mode, + std::shared_ptr> term, + std::shared_ptr reference) + : UnboundAggregate(op), + mode_(mode), + term_(std::move(term)), + reference_(std::move(reference)) {} + +Result> CountAggregate::Count( + std::shared_ptr> term) { + auto ref = term->reference(); + return std::unique_ptr(new CountAggregate( + Expression::Operation::kCount, Mode::kNonNull, std::move(term), std::move(ref))); +} + +Result> CountAggregate::CountNull( + std::shared_ptr> term) { + auto ref = term->reference(); + return std::unique_ptr(new CountAggregate( + Expression::Operation::kCount, Mode::kNull, std::move(term), std::move(ref))); +} + +std::unique_ptr CountAggregate::CountStar() { + return std::unique_ptr(new CountAggregate( + Expression::Operation::kCountStar, Mode::kStar, nullptr, nullptr)); +} + +std::string CountAggregate::ToString() const { + if (mode_ == Mode::kStar) { + return "count(*)"; + } + ICEBERG_DCHECK(reference_ != nullptr, "Count aggregate should have reference"); + switch (mode_) { + case Mode::kNull: + return std::format("count_null({})", reference_->name()); + case Mode::kNonNull: + return std::format("count({})", reference_->name()); + case Mode::kStar: + break; + } + std::unreachable(); +} + +Result> CountAggregate::Bind(const Schema& schema, + bool case_sensitive) const { + std::shared_ptr bound_term; + if (term_ != nullptr) { + ICEBERG_ASSIGN_OR_THROW(auto bound, term_->Bind(schema, case_sensitive)); + bound_term = std::move(bound); + } + auto aggregate = + std::make_shared(op(), mode_, std::move(bound_term)); + return aggregate; +} + +BoundCountAggregate::BoundCountAggregate(Expression::Operation op, + CountAggregate::Mode mode, + std::shared_ptr term) + : BoundAggregate(op, std::move(term)), mode_(mode) {} + +std::string BoundCountAggregate::ToString() const { + if (mode_ == CountAggregate::Mode::kStar) { + return "count(*)"; + } + ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); + switch (mode_) { + case CountAggregate::Mode::kNull: + return std::format("count_null({})", term()->reference()->name()); + case CountAggregate::Mode::kNonNull: + return std::format("count({})", term()->reference()->name()); + case CountAggregate::Mode::kStar: + break; + } + std::unreachable(); +} + +ValueAggregate::ValueAggregate(Expression::Operation op, + std::shared_ptr> term, + std::shared_ptr reference) + : UnboundAggregate(op), term_(std::move(term)), reference_(std::move(reference)) {} + +Result> ValueAggregate::Max( + std::shared_ptr> term) { + auto ref = term->reference(); + return std::unique_ptr( + new ValueAggregate(Expression::Operation::kMax, std::move(term), std::move(ref))); +} + +Result> ValueAggregate::Min( + std::shared_ptr> term) { + auto ref = term->reference(); + return std::unique_ptr( + new ValueAggregate(Expression::Operation::kMin, std::move(term), std::move(ref))); +} + +std::string ValueAggregate::ToString() const { + return std::format("{}({})", OperationToPrefix(op()), reference_->name()); +} + +Result> ValueAggregate::Bind(const Schema& schema, + bool case_sensitive) const { + ICEBERG_ASSIGN_OR_THROW(auto bound, term_->Bind(schema, case_sensitive)); + auto aggregate = std::make_shared( + op(), std::shared_ptr(std::move(bound))); + return aggregate; +} + +BoundValueAggregate::BoundValueAggregate(Expression::Operation op, + std::shared_ptr term) + : BoundAggregate(op, std::move(term)) {} + +std::string BoundValueAggregate::ToString() const { + ICEBERG_DCHECK(term() != nullptr, "Bound value aggregate should have term"); + return std::format("{}({})", OperationToPrefix(op()), term()->reference()->name()); +} + +namespace { + +class CountEvaluator : public AggregateEvaluator { + public: + CountEvaluator(CountAggregate::Mode mode, std::shared_ptr term) + : mode_(mode), term_(std::move(term)) {} + + Status Add(const StructLike& row) override { + switch (mode_) { + case CountAggregate::Mode::kStar: + ++count_; + return {}; + case CountAggregate::Mode::kNonNull: { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); + if (!literal.IsNull()) { + ++count_; + } + return {}; + } + case CountAggregate::Mode::kNull: { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); + if (literal.IsNull()) { + ++count_; + } + return {}; + } + } + std::unreachable(); + } + + Result ResultLiteral() const override { return Literal::Long(count_); } + + private: + CountAggregate::Mode mode_; + std::shared_ptr term_; + int64_t count_ = 0; +}; + +class ValueAggregateEvaluator : public AggregateEvaluator { + public: + ValueAggregateEvaluator(Expression::Operation op, std::shared_ptr term, + std::shared_ptr type) + : op_(op), term_(std::move(term)), type_(std::move(type)) {} + + Status Add(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); + if (literal.IsNull()) { + return {}; + } + + if (!current_) { + current_ = std::move(literal); + return {}; + } + + auto ordering = literal <=> *current_; + if (ordering == std::partial_ordering::unordered) { + return InvalidExpression("Cannot compare literals of type {}", + literal.type()->ToString()); + } + + if (op_ == Expression::Operation::kMax) { + if (ordering == std::partial_ordering::greater) { + current_ = std::move(literal); + } + } else { + if (ordering == std::partial_ordering::less) { + current_ = std::move(literal); + } + } + return {}; + } + + Result ResultLiteral() const override { + if (!current_) { + return Literal::Null(type_); + } + return *current_; + } + + private: + Expression::Operation op_; + std::shared_ptr term_; + std::shared_ptr type_; + std::optional current_; +}; + +} // namespace + +Result> AggregateEvaluator::Make( + std::shared_ptr aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + + if (auto count = std::dynamic_pointer_cast(aggregate)) { + if (count->mode() != CountAggregate::Mode::kStar && !count->term()) { + return InvalidExpression("Count aggregate requires a term"); + } + return std::unique_ptr( + new CountEvaluator(count->mode(), count->term())); + } + + if (auto value = std::dynamic_pointer_cast(aggregate)) { + ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*value->term())); + return std::unique_ptr( + new ValueAggregateEvaluator(value->op(), value->term(), std::move(type))); + } + + return NotSupported("Unsupported aggregate: {}", aggregate->ToString()); +} + +} // namespace iceberg diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h new file mode 100644 index 000000000..3c7d71f91 --- /dev/null +++ b/src/iceberg/expression/aggregate.h @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/aggregate.h +/// Aggregate expression definitions. + +#include + +#include "iceberg/expression/expression.h" +#include "iceberg/expression/term.h" +#include "iceberg/result.h" + +namespace iceberg { + +class AggregateEvaluator; + +/// \brief Base class for aggregate expressions. +class ICEBERG_EXPORT Aggregate : public Expression { + public: + ~Aggregate() override = default; + + Expression::Operation op() const override { return operation_; } + + bool is_unbound_aggregate() const override { return false; } + bool is_bound_aggregate() const override { return false; } + + protected: + explicit Aggregate(Expression::Operation op) : operation_(op) {} + + private: + Expression::Operation operation_; +}; + +/// \brief Unbound aggregate with an optional term. +class ICEBERG_EXPORT UnboundAggregate : public Aggregate, public Unbound { + public: + ~UnboundAggregate() override = default; + + bool is_unbound_aggregate() const override { return true; } + + /// \brief Returns the unbound reference if the aggregate has a term. + virtual std::shared_ptr reference() override = 0; + + protected: + explicit UnboundAggregate(Expression::Operation op) : Aggregate(op) {} +}; + +/// \brief Bound aggregate with an optional term. +class ICEBERG_EXPORT BoundAggregate : public Aggregate { + public: + ~BoundAggregate() override = default; + + bool is_bound_aggregate() const override { return true; } + + const std::shared_ptr& term() const { return term_; } + + protected: + BoundAggregate(Expression::Operation op, std::shared_ptr term) + : Aggregate(op), term_(std::move(term)) {} + + private: + std::shared_ptr term_; +}; + +/// \brief COUNT aggregate variants. +class ICEBERG_EXPORT CountAggregate : public UnboundAggregate { + public: + enum class Mode { kNonNull, kNull, kStar }; + + static Result> Count( + std::shared_ptr> term); + + static Result> CountNull( + std::shared_ptr> term); + + static std::unique_ptr CountStar(); + + ~CountAggregate() override = default; + + Mode mode() const { return mode_; } + + const std::shared_ptr>& term() const { return term_; } + + std::shared_ptr reference() override { return reference_; } + + std::string ToString() const override; + + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + private: + CountAggregate(Expression::Operation op, Mode mode, + std::shared_ptr> term, + std::shared_ptr reference); + + Mode mode_; + std::shared_ptr> term_; + std::shared_ptr reference_; +}; + +/// \brief Bound COUNT aggregate. +class ICEBERG_EXPORT BoundCountAggregate : public BoundAggregate { + public: + BoundCountAggregate(Expression::Operation op, CountAggregate::Mode mode, + std::shared_ptr term); + + CountAggregate::Mode mode() const { return mode_; } + + std::string ToString() const override; + + private: + CountAggregate::Mode mode_; +}; + +/// \brief MAX/MIN aggregate on a single term. +class ICEBERG_EXPORT ValueAggregate : public UnboundAggregate { + public: + static Result> Max( + std::shared_ptr> term); + + static Result> Min( + std::shared_ptr> term); + + ~ValueAggregate() override = default; + + std::shared_ptr reference() override { return reference_; } + + const std::shared_ptr>& term() const { return term_; } + + std::string ToString() const override; + + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + private: + ValueAggregate(Expression::Operation op, + std::shared_ptr> term, + std::shared_ptr reference); + + std::shared_ptr> term_; + std::shared_ptr reference_; +}; + +/// \brief Bound MAX/MIN aggregate. +class ICEBERG_EXPORT BoundValueAggregate : public BoundAggregate { + public: + BoundValueAggregate(Expression::Operation op, std::shared_ptr term); + + std::string ToString() const override; +}; + +/// \brief Evaluates bound aggregates over StructLike rows. +class ICEBERG_EXPORT AggregateEvaluator { + public: + virtual ~AggregateEvaluator() = default; + + /// \brief Create an evaluator for a bound aggregate. + static Result> Make( + std::shared_ptr aggregate); + + /// \brief Add a row to the aggregate. + virtual Status Add(const StructLike& row) = 0; + + /// \brief Final aggregated value. + virtual Result ResultLiteral() const = 0; +}; + +} // namespace iceberg diff --git a/src/iceberg/expression/binder.cc b/src/iceberg/expression/binder.cc index 62c735308..43c3ebcdf 100644 --- a/src/iceberg/expression/binder.cc +++ b/src/iceberg/expression/binder.cc @@ -64,6 +64,18 @@ Result> Binder::Predicate( return InvalidExpression("Found already bound predicate: {}", pred->ToString()); } +Result> Binder::Aggregate( + const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString()); +} + +Result> Binder::Aggregate( + const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + return aggregate->Bind(schema_, case_sensitive_); +} + Result IsBoundVisitor::IsBound(const std::shared_ptr& expr) { ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null"); IsBoundVisitor visitor; @@ -92,4 +104,13 @@ Result IsBoundVisitor::Predicate(const std::shared_ptr& return false; } +Result IsBoundVisitor::Aggregate(const std::shared_ptr& aggregate) { + return true; +} + +Result IsBoundVisitor::Aggregate( + const std::shared_ptr& aggregate) { + return false; +} + } // namespace iceberg diff --git a/src/iceberg/expression/binder.h b/src/iceberg/expression/binder.h index fcef0731d..a78b7a4bb 100644 --- a/src/iceberg/expression/binder.h +++ b/src/iceberg/expression/binder.h @@ -48,6 +48,10 @@ class ICEBERG_EXPORT Binder : public ExpressionVisitor& pred) override; Result> Predicate( const std::shared_ptr& pred) override; + Result> Aggregate( + const std::shared_ptr& aggregate) override; + Result> Aggregate( + const std::shared_ptr& aggregate) override; private: const Schema& schema_; @@ -65,6 +69,8 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor { Result Or(bool left_result, bool right_result) override; Result Predicate(const std::shared_ptr& pred) override; Result Predicate(const std::shared_ptr& pred) override; + Result Aggregate(const std::shared_ptr& aggregate) override; + Result Aggregate(const std::shared_ptr& aggregate) override; }; // TODO(gangwu): add the Java parity `ReferenceVisitor` diff --git a/src/iceberg/expression/expression_visitor.h b/src/iceberg/expression/expression_visitor.h index aeafa9298..edc99196d 100644 --- a/src/iceberg/expression/expression_visitor.h +++ b/src/iceberg/expression/expression_visitor.h @@ -25,6 +25,7 @@ #include #include +#include "iceberg/expression/aggregate.h" #include "iceberg/expression/expression.h" #include "iceberg/expression/literal.h" #include "iceberg/expression/predicate.h" @@ -77,6 +78,22 @@ class ICEBERG_EXPORT ExpressionVisitor { /// \brief Visit an unbound predicate. /// \param pred The unbound predicate to visit virtual Result Predicate(const std::shared_ptr& pred) = 0; + + /// \brief Visit a bound aggregate. + /// \param aggregate The bound aggregate to visit. + virtual Result Aggregate(const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Bound aggregate cannot be null"); + return NotSupported("Bound aggregate is not supported by this visitor: {}", + aggregate->ToString()); + } + + /// \brief Visit an unbound aggregate. + /// \param aggregate The unbound aggregate to visit. + virtual Result Aggregate(const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Unbound aggregate cannot be null"); + return NotSupported("Unbound aggregate is not supported by this visitor: {}", + aggregate->ToString()); + } }; /// \brief Visitor for bound expressions. @@ -275,7 +292,13 @@ Result Visit(const std::shared_ptr& expr, V& visitor) { return visitor.Predicate(std::dynamic_pointer_cast(expr)); } - // TODO(gangwu): handle aggregate expression + if (expr->is_bound_aggregate()) { + return visitor.Aggregate(std::dynamic_pointer_cast(expr)); + } + + if (expr->is_unbound_aggregate()) { + return visitor.Aggregate(std::dynamic_pointer_cast(expr)); + } switch (expr->op()) { case Expression::Operation::kTrue: diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index b3e88ff18..98230ce4e 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -20,6 +20,7 @@ #include "iceberg/expression/expressions.h" #include "iceberg/exception.h" +#include "iceberg/expression/aggregate.h" #include "iceberg/transform.h" #include "iceberg/type.h" #include "iceberg/util/macros.h" @@ -81,6 +82,63 @@ std::shared_ptr Expressions::Transform( return unbound_transform; } +// Aggregates + +std::shared_ptr Expressions::Count(std::string name) { + return Count(Ref(std::move(name))); +} + +std::shared_ptr Expressions::Count( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); + return std::shared_ptr(std::move(agg)); +} + +std::shared_ptr Expressions::CountNull(std::string name) { + return CountNull(Ref(std::move(name))); +} + +std::shared_ptr Expressions::CountNull( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::CountNull(std::move(expr))); + return std::shared_ptr(std::move(agg)); +} + +std::shared_ptr Expressions::CountNotNull(std::string name) { + return CountNotNull(Ref(std::move(name))); +} + +std::shared_ptr Expressions::CountNotNull( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); + return std::shared_ptr(std::move(agg)); +} + +std::shared_ptr Expressions::CountStar() { + auto agg = CountAggregate::CountStar(); + return std::shared_ptr(std::move(agg)); +} + +std::shared_ptr Expressions::Max(std::string name) { + return Max(Ref(std::move(name))); +} + +std::shared_ptr Expressions::Max( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Max(std::move(expr))); + return std::shared_ptr(std::move(agg)); +} + +std::shared_ptr Expressions::Min(std::string name) { + return Min(Ref(std::move(name))); +} + +std::shared_ptr Expressions::Min( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Min(std::move(expr))); + return std::shared_ptr(std::move(agg)); +} + // Template implementations for unary predicates std::shared_ptr> Expressions::IsNull( diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index cf7b6d20e..81ed4c877 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -28,6 +28,7 @@ #include #include "iceberg/exception.h" +#include "iceberg/expression/aggregate.h" #include "iceberg/expression/literal.h" #include "iceberg/expression/predicate.h" #include "iceberg/expression/term.h" @@ -101,6 +102,46 @@ class ICEBERG_EXPORT Expressions { static std::shared_ptr Transform( std::string name, std::shared_ptr transform); + // Aggregates + + /// \brief Create COUNT(col) aggregate. + static std::shared_ptr Count(std::string name); + + /// \brief Create COUNT(unbound term) aggregate. + static std::shared_ptr Count( + std::shared_ptr> expr); + + /// \brief Create COUNT_NULL(col) aggregate. + static std::shared_ptr CountNull(std::string name); + + /// \brief Create COUNT_NULL(unbound term) aggregate. + static std::shared_ptr CountNull( + std::shared_ptr> expr); + + /// \brief Create COUNT_NOT_NULL(col) aggregate. + static std::shared_ptr CountNotNull(std::string name); + + /// \brief Create COUNT_NOT_NULL(unbound term) aggregate. + static std::shared_ptr CountNotNull( + std::shared_ptr> expr); + + /// \brief Create COUNT(*) aggregate. + static std::shared_ptr CountStar(); + + /// \brief Create MAX(col) aggregate. + static std::shared_ptr Max(std::string name); + + /// \brief Create MAX(unbound term) aggregate. + static std::shared_ptr Max( + std::shared_ptr> expr); + + /// \brief Create MIN(col) aggregate. + static std::shared_ptr Min(std::string name); + + /// \brief Create MIN(unbound term) aggregate. + static std::shared_ptr Min( + std::shared_ptr> expr); + // Unary predicates /// \brief Create an IS NULL predicate for a field name. diff --git a/src/iceberg/meson.build b/src/iceberg/meson.build index ae5f8bac4..be5057cd5 100644 --- a/src/iceberg/meson.build +++ b/src/iceberg/meson.build @@ -42,6 +42,7 @@ iceberg_include_dir = include_directories('..') iceberg_sources = files( 'arrow_c_data_guard_internal.cc', 'catalog/memory/in_memory_catalog.cc', + 'expression/aggregate.cc', 'expression/binder.cc', 'expression/expression.cc', 'expression/expressions.cc', diff --git a/src/iceberg/test/CMakeLists.txt b/src/iceberg/test/CMakeLists.txt index d82fe17b8..524912455 100644 --- a/src/iceberg/test/CMakeLists.txt +++ b/src/iceberg/test/CMakeLists.txt @@ -87,6 +87,7 @@ add_iceberg_test(table_test add_iceberg_test(expression_test SOURCES + aggregate_test.cc expression_test.cc expression_visitor_test.cc literal_test.cc diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc new file mode 100644 index 000000000..23f57621a --- /dev/null +++ b/src/iceberg/test/aggregate_test.cc @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/aggregate.h" + +#include + +#include "iceberg/exception.h" +#include "iceberg/expression/binder.h" +#include "iceberg/expression/expressions.h" +#include "iceberg/row/struct_like.h" +#include "iceberg/schema.h" +#include "iceberg/test/matchers.h" +#include "iceberg/type.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +namespace { + +class VectorStructLike : public StructLike { + public: + explicit VectorStructLike(std::vector fields) : fields_(std::move(fields)) {} + + Result GetField(size_t pos) const override { + if (pos >= fields_.size()) { + return InvalidArgument("Position {} out of range", pos); + } + return fields_[pos]; + } + + size_t num_fields() const override { return fields_.size(); } + + private: + std::vector fields_; +}; + +std::shared_ptr BindAggregate(const Schema& schema, + const std::shared_ptr& expr) { + auto result = Binder::Bind(schema, expr, /*case_sensitive=*/true); + EXPECT_TRUE(result.has_value()) + << "Failed to bind aggregate: " << result.error().message; + auto bound = std::dynamic_pointer_cast(std::move(result).value()); + EXPECT_NE(bound, nullptr); + return bound; +} + +} // namespace + +TEST(AggregateTest, CountVariants) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_expr = Expressions::Count("id"); + auto count_bound = BindAggregate(schema, count_expr); + auto count_evaluator = AggregateEvaluator::Make(count_bound).value(); + + auto count_null_expr = Expressions::CountNull("value"); + auto count_null_bound = BindAggregate(schema, count_null_expr); + auto count_null_evaluator = AggregateEvaluator::Make(count_null_bound).value(); + + auto count_star_expr = Expressions::CountStar(); + auto count_star_bound = BindAggregate(schema, count_star_expr); + auto count_star_evaluator = AggregateEvaluator::Make(count_star_bound).value(); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(count_evaluator->Add(row).has_value()); + ASSERT_TRUE(count_null_evaluator->Add(row).has_value()); + ASSERT_TRUE(count_star_evaluator->Add(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto count_result, count_evaluator->ResultLiteral()); + EXPECT_EQ(std::get(count_result.value()), 2); + + ICEBERG_UNWRAP_OR_FAIL(auto count_null_result, count_null_evaluator->ResultLiteral()); + EXPECT_EQ(std::get(count_null_result.value()), 1); + + ICEBERG_UNWRAP_OR_FAIL(auto count_star_result, count_star_evaluator->ResultLiteral()); + EXPECT_EQ(std::get(count_star_result.value()), 3); +} + +TEST(AggregateTest, MaxMinAggregates) { + Schema schema({SchemaField::MakeOptional(1, "value", int32())}); + + auto max_expr = Expressions::Max("value"); + auto min_expr = Expressions::Min("value"); + + auto max_bound = BindAggregate(schema, max_expr); + auto min_bound = BindAggregate(schema, min_expr); + + auto max_eval = AggregateEvaluator::Make(max_bound).value(); + auto min_eval = AggregateEvaluator::Make(min_bound).value(); + + std::vector rows{VectorStructLike({Scalar{int32_t{5}}}), + VectorStructLike({Scalar{std::monostate{}}}), + VectorStructLike({Scalar{int32_t{2}}}), + VectorStructLike({Scalar{int32_t{12}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(max_eval->Add(row).has_value()); + ASSERT_TRUE(min_eval->Add(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto max_result, max_eval->ResultLiteral()); + EXPECT_EQ(std::get(max_result.value()), 12); + + ICEBERG_UNWRAP_OR_FAIL(auto min_result, min_eval->ResultLiteral()); + EXPECT_EQ(std::get(min_result.value()), 2); +} + +} // namespace iceberg diff --git a/src/iceberg/test/meson.build b/src/iceberg/test/meson.build index 6fbe82dfd..c31a4a14a 100644 --- a/src/iceberg/test/meson.build +++ b/src/iceberg/test/meson.build @@ -54,6 +54,7 @@ iceberg_tests = { }, 'expression_test': { 'sources': files( + 'aggregate_test.cc', 'expression_test.cc', 'expression_visitor_test.cc', 'literal_test.cc', From 0281987f4777ecb1b81d520f2d934ec3e782506a Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Thu, 20 Nov 2025 16:50:37 -0800 Subject: [PATCH 02/12] Address linter braced-init returns in expressions --- src/iceberg/expression/expressions.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index 98230ce4e..6f31f837d 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -91,7 +91,7 @@ std::shared_ptr Expressions::Count(std::string name) { std::shared_ptr Expressions::Count( std::shared_ptr> expr) { ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } std::shared_ptr Expressions::CountNull(std::string name) { @@ -101,7 +101,7 @@ std::shared_ptr Expressions::CountNull(std::string name) { std::shared_ptr Expressions::CountNull( std::shared_ptr> expr) { ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::CountNull(std::move(expr))); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } std::shared_ptr Expressions::CountNotNull(std::string name) { @@ -111,12 +111,12 @@ std::shared_ptr Expressions::CountNotNull(std::string name) { std::shared_ptr Expressions::CountNotNull( std::shared_ptr> expr) { ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } std::shared_ptr Expressions::CountStar() { auto agg = CountAggregate::CountStar(); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } std::shared_ptr Expressions::Max(std::string name) { @@ -126,7 +126,7 @@ std::shared_ptr Expressions::Max(std::string name) { std::shared_ptr Expressions::Max( std::shared_ptr> expr) { ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Max(std::move(expr))); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } std::shared_ptr Expressions::Min(std::string name) { @@ -136,7 +136,7 @@ std::shared_ptr Expressions::Min(std::string name) { std::shared_ptr Expressions::Min( std::shared_ptr> expr) { ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Min(std::move(expr))); - return std::shared_ptr(std::move(agg)); + return {std::move(agg)}; } // Template implementations for unary predicates From 098902656761382476fc54febd23da8473e3fc0b Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Thu, 20 Nov 2025 17:00:08 -0800 Subject: [PATCH 03/12] Remove redundant virtual in UnboundAggregate --- src/iceberg/expression/aggregate.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 3c7d71f91..26694e9ab 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -57,7 +57,7 @@ class ICEBERG_EXPORT UnboundAggregate : public Aggregate, public Unbound reference() override = 0; + std::shared_ptr reference() override = 0; protected: explicit UnboundAggregate(Expression::Operation op) : Aggregate(op) {} From cd387583951f919cdcbb5a0e64367bfc9e7abe54 Mon Sep 17 00:00:00 2001 From: Zhiyuan Liang <115799793+SuKi2cn@users.noreply.github.com> Date: Fri, 21 Nov 2025 01:13:30 -0800 Subject: [PATCH 04/12] Update src/iceberg/expression/aggregate.h Co-authored-by: Gang Wu --- src/iceberg/expression/aggregate.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 26694e9ab..2e50b4f84 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -39,9 +39,6 @@ class ICEBERG_EXPORT Aggregate : public Expression { Expression::Operation op() const override { return operation_; } - bool is_unbound_aggregate() const override { return false; } - bool is_bound_aggregate() const override { return false; } - protected: explicit Aggregate(Expression::Operation op) : operation_(op) {} From 3a0627b76f51a80088c5285a1c7d084aa9091af2 Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Fri, 21 Nov 2025 02:13:36 -0800 Subject: [PATCH 05/12] Refine aggregate execution and evaluator --- src/iceberg/expression/aggregate.cc | 331 ++++++++++++-------------- src/iceberg/expression/aggregate.h | 154 ++++++------ src/iceberg/expression/expressions.cc | 59 +++-- src/iceberg/expression/expressions.h | 24 +- src/iceberg/test/aggregate_test.cc | 28 +++ 5 files changed, 299 insertions(+), 297 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index 2056d3312..69e2fa56b 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -21,10 +21,10 @@ #include #include +#include #include "iceberg/exception.h" #include "iceberg/expression/binder.h" -#include "iceberg/expression/expression.h" #include "iceberg/row/struct_like.h" #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" @@ -34,21 +34,6 @@ namespace iceberg { namespace { -std::string OperationToPrefix(Expression::Operation op) { - switch (op) { - case Expression::Operation::kMax: - return "max"; - case Expression::Operation::kMin: - return "min"; - case Expression::Operation::kCount: - case Expression::Operation::kCountStar: - return "count"; - default: - break; - } - return "aggregate"; -} - Result> GetPrimitiveType(const BoundTerm& term) { auto primitive = std::dynamic_pointer_cast(term.type()); if (primitive == nullptr) { @@ -60,230 +45,212 @@ Result> GetPrimitiveType(const BoundTerm& term) { } // namespace -CountAggregate::CountAggregate(Expression::Operation op, Mode mode, - std::shared_ptr> term, - std::shared_ptr reference) - : UnboundAggregate(op), - mode_(mode), - term_(std::move(term)), - reference_(std::move(reference)) {} - -Result> CountAggregate::Count( - std::shared_ptr> term) { - auto ref = term->reference(); - return std::unique_ptr(new CountAggregate( - Expression::Operation::kCount, Mode::kNonNull, std::move(term), std::move(ref))); -} - -Result> CountAggregate::CountNull( - std::shared_ptr> term) { - auto ref = term->reference(); - return std::unique_ptr(new CountAggregate( - Expression::Operation::kCount, Mode::kNull, std::move(term), std::move(ref))); -} +// -------------------- Bound aggregates -------------------- -std::unique_ptr CountAggregate::CountStar() { - return std::unique_ptr(new CountAggregate( - Expression::Operation::kCountStar, Mode::kStar, nullptr, nullptr)); -} +BoundCountAggregate::BoundCountAggregate(Expression::Operation op, Mode mode, + std::shared_ptr term) + : BoundAggregate(op, std::move(term)), mode_(mode) {} -std::string CountAggregate::ToString() const { +std::string BoundCountAggregate::ToString() const { if (mode_ == Mode::kStar) { return "count(*)"; } - ICEBERG_DCHECK(reference_ != nullptr, "Count aggregate should have reference"); + ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); switch (mode_) { case Mode::kNull: - return std::format("count_null({})", reference_->name()); + return std::format("count_null({})", term()->reference()->name()); case Mode::kNonNull: - return std::format("count({})", reference_->name()); + return std::format("count({})", term()->reference()->name()); case Mode::kStar: break; } std::unreachable(); } -Result> CountAggregate::Bind(const Schema& schema, - bool case_sensitive) const { - std::shared_ptr bound_term; - if (term_ != nullptr) { - ICEBERG_ASSIGN_OR_THROW(auto bound, term_->Bind(schema, case_sensitive)); - bound_term = std::move(bound); - } - auto aggregate = - std::make_shared(op(), mode_, std::move(bound_term)); - return aggregate; -} - -BoundCountAggregate::BoundCountAggregate(Expression::Operation op, - CountAggregate::Mode mode, - std::shared_ptr term) - : BoundAggregate(op, std::move(term)), mode_(mode) {} - -std::string BoundCountAggregate::ToString() const { - if (mode_ == CountAggregate::Mode::kStar) { - return "count(*)"; - } - ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); +Result BoundCountAggregate::Evaluate(const StructLike& data) const { switch (mode_) { - case CountAggregate::Mode::kNull: - return std::format("count_null({})", term()->reference()->name()); - case CountAggregate::Mode::kNonNull: - return std::format("count({})", term()->reference()->name()); - case CountAggregate::Mode::kStar: - break; + case Mode::kStar: + return Literal::Long(1); + case Mode::kNonNull: { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return Literal::Long(literal.IsNull() ? 0 : 1); + } + case Mode::kNull: { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return Literal::Long(literal.IsNull() ? 1 : 0); + } } std::unreachable(); } -ValueAggregate::ValueAggregate(Expression::Operation op, - std::shared_ptr> term, - std::shared_ptr reference) - : UnboundAggregate(op), term_(std::move(term)), reference_(std::move(reference)) {} - -Result> ValueAggregate::Max( - std::shared_ptr> term) { - auto ref = term->reference(); - return std::unique_ptr( - new ValueAggregate(Expression::Operation::kMax, std::move(term), std::move(ref))); -} - -Result> ValueAggregate::Min( - std::shared_ptr> term) { - auto ref = term->reference(); - return std::unique_ptr( - new ValueAggregate(Expression::Operation::kMin, std::move(term), std::move(ref))); -} - -std::string ValueAggregate::ToString() const { - return std::format("{}({})", OperationToPrefix(op()), reference_->name()); -} - -Result> ValueAggregate::Bind(const Schema& schema, - bool case_sensitive) const { - ICEBERG_ASSIGN_OR_THROW(auto bound, term_->Bind(schema, case_sensitive)); - auto aggregate = std::make_shared( - op(), std::shared_ptr(std::move(bound))); - return aggregate; -} - BoundValueAggregate::BoundValueAggregate(Expression::Operation op, std::shared_ptr term) : BoundAggregate(op, std::move(term)) {} std::string BoundValueAggregate::ToString() const { ICEBERG_DCHECK(term() != nullptr, "Bound value aggregate should have term"); - return std::format("{}({})", OperationToPrefix(op()), term()->reference()->name()); + auto prefix = op() == Expression::Operation::kMax ? "max" : "min"; + return std::format("{}({})", prefix, term()->reference()->name()); } -namespace { +Result BoundValueAggregate::Evaluate(const StructLike& data) const { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return literal; +} -class CountEvaluator : public AggregateEvaluator { - public: - CountEvaluator(CountAggregate::Mode mode, std::shared_ptr term) - : mode_(mode), term_(std::move(term)) {} +// -------------------- Unbound binding -------------------- - Status Add(const StructLike& row) override { - switch (mode_) { - case CountAggregate::Mode::kStar: - ++count_; - return {}; - case CountAggregate::Mode::kNonNull: { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); - if (!literal.IsNull()) { - ++count_; - } - return {}; +template +Result> UnboundAggregateImpl::Bind( + const Schema& schema, bool case_sensitive) const { + std::shared_ptr bound_term; + if (this->term()) { + ICEBERG_ASSIGN_OR_THROW(bound_term, this->term()->Bind(schema, case_sensitive)); + } + + switch (count_mode_) { + case CountMode::kStar: + case CountMode::kNull: + case CountMode::kNonNull: { + auto op = this->op() == Expression::Operation::kCountStar + ? Expression::Operation::kCountStar + : Expression::Operation::kCount; + auto mode = + count_mode_ == CountMode::kNull + ? BoundCountAggregate::Mode::kNull + : (count_mode_ == CountMode::kStar ? BoundCountAggregate::Mode::kStar + : BoundCountAggregate::Mode::kNonNull); + auto aggregate = + std::make_shared(op, mode, std::move(bound_term)); + return aggregate; + } + case CountMode::kNone: { + if (this->op() != Expression::Operation::kMax && + this->op() != Expression::Operation::kMin) { + return NotSupported("Unsupported aggregate operation"); } - case CountAggregate::Mode::kNull: { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); - if (literal.IsNull()) { - ++count_; - } - return {}; + if (!bound_term) { + return InvalidExpression("Aggregate requires a term"); } + auto aggregate = + std::make_shared(this->op(), std::move(bound_term)); + return aggregate; } - std::unreachable(); } + std::unreachable(); +} - Result ResultLiteral() const override { return Literal::Long(count_); } +template class UnboundAggregateImpl; - private: - CountAggregate::Mode mode_; - std::shared_ptr term_; - int64_t count_ = 0; -}; +// -------------------- AggregateEvaluator -------------------- + +namespace { -class ValueAggregateEvaluator : public AggregateEvaluator { +class AggregateEvaluatorImpl : public AggregateEvaluator { public: - ValueAggregateEvaluator(Expression::Operation op, std::shared_ptr term, - std::shared_ptr type) - : op_(op), term_(std::move(term)), type_(std::move(type)) {} + explicit AggregateEvaluatorImpl(std::vector> aggregates) + : aggregates_(std::move(aggregates)), counts_(aggregates_.size(), 0) { + values_.resize(aggregates_.size()); + } Status Add(const StructLike& row) override { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); - if (literal.IsNull()) { - return {}; - } - - if (!current_) { - current_ = std::move(literal); - return {}; - } - - auto ordering = literal <=> *current_; - if (ordering == std::partial_ordering::unordered) { - return InvalidExpression("Cannot compare literals of type {}", - literal.type()->ToString()); + for (size_t i = 0; i < aggregates_.size(); ++i) { + const auto& agg = aggregates_[i]; + switch (agg->kind()) { + case BoundAggregate::Kind::kCount: { + auto count_agg = internal::checked_pointer_cast(agg); + ICEBERG_ASSIGN_OR_RAISE(auto contribution, count_agg->Evaluate(row)); + counts_[i] += std::get(contribution.value()); + break; + } + case BoundAggregate::Kind::kValue: { + auto value_agg = internal::checked_pointer_cast(agg); + ICEBERG_ASSIGN_OR_RAISE(auto val_literal, value_agg->Evaluate(row)); + if (val_literal.IsNull()) { + break; + } + auto& current = values_[i]; + if (!current) { + current = std::move(val_literal); + break; + } + auto ordering = val_literal <=> *current; + if (ordering == std::partial_ordering::unordered) { + return InvalidExpression("Cannot compare literals of type {}", + val_literal.type()->ToString()); + } + if (agg->op() == Expression::Operation::kMax) { + if (ordering == std::partial_ordering::greater) { + current = std::move(val_literal); + } + } else { + if (ordering == std::partial_ordering::less) { + current = std::move(val_literal); + } + } + break; + } + } } + return {}; + } - if (op_ == Expression::Operation::kMax) { - if (ordering == std::partial_ordering::greater) { - current_ = std::move(literal); - } - } else { - if (ordering == std::partial_ordering::less) { - current_ = std::move(literal); + Result> Results() const override { + std::vector out; + out.reserve(aggregates_.size()); + for (size_t i = 0; i < aggregates_.size(); ++i) { + switch (aggregates_[i]->kind()) { + case BoundAggregate::Kind::kCount: + out.emplace_back(Literal::Long(counts_[i])); + break; + case BoundAggregate::Kind::kValue: { + if (values_[i]) { + out.emplace_back(*values_[i]); + } else { + auto value_agg = + internal::checked_pointer_cast(aggregates_[i]); + ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*value_agg->term())); + out.emplace_back(Literal::Null(type)); + } + break; + } } } - return {}; + return out; } Result ResultLiteral() const override { - if (!current_) { - return Literal::Null(type_); + if (aggregates_.size() != 1) { + return InvalidArgument( + "ResultLiteral() is only valid when evaluating a single aggregate"); } - return *current_; + + ICEBERG_ASSIGN_OR_RAISE(auto all, Results()); + return all.front(); } private: - Expression::Operation op_; - std::shared_ptr term_; - std::shared_ptr type_; - std::optional current_; + std::vector> aggregates_; + std::vector counts_; + std::vector> values_; }; } // namespace Result> AggregateEvaluator::Make( std::shared_ptr aggregate) { - ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); - - if (auto count = std::dynamic_pointer_cast(aggregate)) { - if (count->mode() != CountAggregate::Mode::kStar && !count->term()) { - return InvalidExpression("Count aggregate requires a term"); - } - return std::unique_ptr( - new CountEvaluator(count->mode(), count->term())); - } + std::vector> aggs; + aggs.push_back(std::move(aggregate)); + return MakeList(std::move(aggs)); +} - if (auto value = std::dynamic_pointer_cast(aggregate)) { - ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*value->term())); - return std::unique_ptr( - new ValueAggregateEvaluator(value->op(), value->term(), std::move(type))); +Result> AggregateEvaluator::MakeList( + std::vector> aggregates) { + if (aggregates.empty()) { + return InvalidArgument("AggregateEvaluator requires at least one aggregate"); } - - return NotSupported("Unsupported aggregate: {}", aggregate->ToString()); + return std::unique_ptr( + new AggregateEvaluatorImpl(std::move(aggregates))); } } // namespace iceberg diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 2e50b4f84..6a780a4a2 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -22,7 +22,9 @@ /// \file iceberg/expression/aggregate.h /// Aggregate expression definitions. +#include #include +#include #include "iceberg/expression/expression.h" #include "iceberg/expression/term.h" @@ -30,130 +32,107 @@ namespace iceberg { -class AggregateEvaluator; +template +concept AggregateTermType = std::derived_from; -/// \brief Base class for aggregate expressions. -class ICEBERG_EXPORT Aggregate : public Expression { +/// \brief Base aggregate holding an operation and a term. +template +class ICEBERG_EXPORT Aggregate : public virtual Expression { public: ~Aggregate() override = default; Expression::Operation op() const override { return operation_; } + const std::shared_ptr& term() const { return term_; } + protected: - explicit Aggregate(Expression::Operation op) : operation_(op) {} + Aggregate(Expression::Operation op, std::shared_ptr term) + : operation_(op), term_(std::move(term)) {} - private: Expression::Operation operation_; + std::shared_ptr term_; }; -/// \brief Unbound aggregate with an optional term. -class ICEBERG_EXPORT UnboundAggregate : public Aggregate, public Unbound { +/// \brief Base class for unbound aggregates. +class ICEBERG_EXPORT UnboundAggregate : public virtual Expression, + public Unbound { public: ~UnboundAggregate() override = default; bool is_unbound_aggregate() const override { return true; } - - /// \brief Returns the unbound reference if the aggregate has a term. - std::shared_ptr reference() override = 0; - - protected: - explicit UnboundAggregate(Expression::Operation op) : Aggregate(op) {} }; -/// \brief Bound aggregate with an optional term. -class ICEBERG_EXPORT BoundAggregate : public Aggregate { +/// \brief Template for unbound aggregates that carry a term and operation. +template +class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, + public Aggregate> { + using BASE = Aggregate>; + public: - ~BoundAggregate() override = default; + enum class CountMode { kNonNull, kNull, kStar, kNone }; - bool is_bound_aggregate() const override { return true; } + UnboundAggregateImpl(Expression::Operation op, std::shared_ptr> term, + CountMode count_mode = CountMode::kNone) + : BASE(op, std::move(term)), count_mode_(count_mode) {} - const std::shared_ptr& term() const { return term_; } + std::shared_ptr reference() override { + return BASE::term() ? BASE::term()->reference() : nullptr; + } - protected: - BoundAggregate(Expression::Operation op, std::shared_ptr term) - : Aggregate(op), term_(std::move(term)) {} + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + CountMode count_mode() const { return count_mode_; } private: - std::shared_ptr term_; + CountMode count_mode_; }; -/// \brief COUNT aggregate variants. -class ICEBERG_EXPORT CountAggregate : public UnboundAggregate { +/// \brief Base class for bound aggregates. +class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound { public: - enum class Mode { kNonNull, kNull, kStar }; - - static Result> Count( - std::shared_ptr> term); - - static Result> CountNull( - std::shared_ptr> term); - - static std::unique_ptr CountStar(); - - ~CountAggregate() override = default; - - Mode mode() const { return mode_; } + using Aggregate::op; + using Aggregate::term; - const std::shared_ptr>& term() const { return term_; } + std::shared_ptr reference() override { + return term_ ? term_->reference() : nullptr; + } - std::shared_ptr reference() override { return reference_; } + virtual Result Evaluate(const StructLike& data) const override = 0; - std::string ToString() const override; + bool is_bound_aggregate() const override { return true; } - Result> Bind(const Schema& schema, - bool case_sensitive) const override; + enum class Kind : int8_t { + // Count aggregates (COUNT, COUNT_STAR, COUNT_NULL) + kCount = 0, + // Value aggregates (MIN, MAX) + kValue, + }; - private: - CountAggregate(Expression::Operation op, Mode mode, - std::shared_ptr> term, - std::shared_ptr reference); + virtual Kind kind() const = 0; - Mode mode_; - std::shared_ptr> term_; - std::shared_ptr reference_; + protected: + BoundAggregate(Expression::Operation op, std::shared_ptr term) + : Aggregate(op, std::move(term)) {} }; /// \brief Bound COUNT aggregate. class ICEBERG_EXPORT BoundCountAggregate : public BoundAggregate { public: - BoundCountAggregate(Expression::Operation op, CountAggregate::Mode mode, - std::shared_ptr term); - - CountAggregate::Mode mode() const { return mode_; } - - std::string ToString() const override; - - private: - CountAggregate::Mode mode_; -}; - -/// \brief MAX/MIN aggregate on a single term. -class ICEBERG_EXPORT ValueAggregate : public UnboundAggregate { - public: - static Result> Max( - std::shared_ptr> term); - - static Result> Min( - std::shared_ptr> term); + enum class Mode { kNonNull, kNull, kStar }; - ~ValueAggregate() override = default; + BoundCountAggregate(Expression::Operation op, Mode mode, + std::shared_ptr term); - std::shared_ptr reference() override { return reference_; } + Mode mode() const { return mode_; } - const std::shared_ptr>& term() const { return term_; } + Kind kind() const override { return Kind::kCount; } std::string ToString() const override; - - Result> Bind(const Schema& schema, - bool case_sensitive) const override; + Result Evaluate(const StructLike& data) const override; private: - ValueAggregate(Expression::Operation op, - std::shared_ptr> term, - std::shared_ptr reference); - - std::shared_ptr> term_; - std::shared_ptr reference_; + Mode mode_; }; /// \brief Bound MAX/MIN aggregate. @@ -161,7 +140,10 @@ class ICEBERG_EXPORT BoundValueAggregate : public BoundAggregate { public: BoundValueAggregate(Expression::Operation op, std::shared_ptr term); + Kind kind() const override { return Kind::kValue; } + std::string ToString() const override; + Result Evaluate(const StructLike& data) const override; }; /// \brief Evaluates bound aggregates over StructLike rows. @@ -169,14 +151,24 @@ class ICEBERG_EXPORT AggregateEvaluator { public: virtual ~AggregateEvaluator() = default; - /// \brief Create an evaluator for a bound aggregate. + /// \brief Create an evaluator for a single bound aggregate. + /// \param aggregate The bound aggregate to evaluate across rows. static Result> Make( std::shared_ptr aggregate); + /// \brief Create an evaluator for multiple bound aggregates. + /// \param aggregates Aggregates to evaluate in one pass; order is preserved in + /// Results(). + static Result> MakeList( + std::vector> aggregates); + /// \brief Add a row to the aggregate. virtual Status Add(const StructLike& row) = 0; /// \brief Final aggregated value. + virtual Result> Results() const = 0; + + /// \brief Convenience accessor when only one aggregate is evaluated. virtual Result ResultLiteral() const = 0; }; diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index 6f31f837d..4902b7629 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -84,59 +84,72 @@ std::shared_ptr Expressions::Transform( // Aggregates -std::shared_ptr Expressions::Count(std::string name) { +std::shared_ptr> Expressions::Count( + std::string name) { return Count(Ref(std::move(name))); } -std::shared_ptr Expressions::Count( +std::shared_ptr> Expressions::Count( std::shared_ptr> expr) { - ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); - return {std::move(agg)}; + auto agg = std::make_shared>( + Expression::Operation::kCount, std::move(expr), + UnboundAggregateImpl::CountMode::kNonNull); + return agg; } -std::shared_ptr Expressions::CountNull(std::string name) { +std::shared_ptr> Expressions::CountNull( + std::string name) { return CountNull(Ref(std::move(name))); } -std::shared_ptr Expressions::CountNull( +std::shared_ptr> Expressions::CountNull( std::shared_ptr> expr) { - ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::CountNull(std::move(expr))); - return {std::move(agg)}; + auto agg = std::make_shared>( + Expression::Operation::kCount, std::move(expr), + UnboundAggregateImpl::CountMode::kNull); + return agg; } -std::shared_ptr Expressions::CountNotNull(std::string name) { +std::shared_ptr> Expressions::CountNotNull( + std::string name) { return CountNotNull(Ref(std::move(name))); } -std::shared_ptr Expressions::CountNotNull( +std::shared_ptr> Expressions::CountNotNull( std::shared_ptr> expr) { - ICEBERG_ASSIGN_OR_THROW(auto agg, CountAggregate::Count(std::move(expr))); - return {std::move(agg)}; + auto agg = std::make_shared>( + Expression::Operation::kCount, std::move(expr), + UnboundAggregateImpl::CountMode::kNonNull); + return agg; } -std::shared_ptr Expressions::CountStar() { - auto agg = CountAggregate::CountStar(); - return {std::move(agg)}; +std::shared_ptr> Expressions::CountStar() { + auto agg = std::make_shared>( + Expression::Operation::kCountStar, nullptr, + UnboundAggregateImpl::CountMode::kStar); + return agg; } -std::shared_ptr Expressions::Max(std::string name) { +std::shared_ptr> Expressions::Max(std::string name) { return Max(Ref(std::move(name))); } -std::shared_ptr Expressions::Max( +std::shared_ptr> Expressions::Max( std::shared_ptr> expr) { - ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Max(std::move(expr))); - return {std::move(agg)}; + auto agg = std::make_shared>( + Expression::Operation::kMax, std::move(expr)); + return agg; } -std::shared_ptr Expressions::Min(std::string name) { +std::shared_ptr> Expressions::Min(std::string name) { return Min(Ref(std::move(name))); } -std::shared_ptr Expressions::Min( +std::shared_ptr> Expressions::Min( std::shared_ptr> expr) { - ICEBERG_ASSIGN_OR_THROW(auto agg, ValueAggregate::Min(std::move(expr))); - return {std::move(agg)}; + auto agg = std::make_shared>( + Expression::Operation::kMin, std::move(expr)); + return agg; } // Template implementations for unary predicates diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index 81ed4c877..a00f6622b 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -105,41 +105,43 @@ class ICEBERG_EXPORT Expressions { // Aggregates /// \brief Create COUNT(col) aggregate. - static std::shared_ptr Count(std::string name); + static std::shared_ptr> Count(std::string name); /// \brief Create COUNT(unbound term) aggregate. - static std::shared_ptr Count( + static std::shared_ptr> Count( std::shared_ptr> expr); /// \brief Create COUNT_NULL(col) aggregate. - static std::shared_ptr CountNull(std::string name); + static std::shared_ptr> CountNull( + std::string name); /// \brief Create COUNT_NULL(unbound term) aggregate. - static std::shared_ptr CountNull( + static std::shared_ptr> CountNull( std::shared_ptr> expr); /// \brief Create COUNT_NOT_NULL(col) aggregate. - static std::shared_ptr CountNotNull(std::string name); + static std::shared_ptr> CountNotNull( + std::string name); /// \brief Create COUNT_NOT_NULL(unbound term) aggregate. - static std::shared_ptr CountNotNull( + static std::shared_ptr> CountNotNull( std::shared_ptr> expr); /// \brief Create COUNT(*) aggregate. - static std::shared_ptr CountStar(); + static std::shared_ptr> CountStar(); /// \brief Create MAX(col) aggregate. - static std::shared_ptr Max(std::string name); + static std::shared_ptr> Max(std::string name); /// \brief Create MAX(unbound term) aggregate. - static std::shared_ptr Max( + static std::shared_ptr> Max( std::shared_ptr> expr); /// \brief Create MIN(col) aggregate. - static std::shared_ptr Min(std::string name); + static std::shared_ptr> Min(std::string name); /// \brief Create MIN(unbound term) aggregate. - static std::shared_ptr Min( + static std::shared_ptr> Min( std::shared_ptr> expr); // Unary predicates diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index 23f57621a..3a9e7d6dc 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -129,4 +129,32 @@ TEST(AggregateTest, MaxMinAggregates) { EXPECT_EQ(std::get(min_result.value()), 2); } +TEST(AggregateTest, MultipleAggregatesInEvaluator) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_expr = Expressions::Count("id"); + auto max_expr = Expressions::Max("value"); + + auto count_bound = BindAggregate(schema, count_expr); + auto max_bound = BindAggregate(schema, max_expr); + + std::vector> aggregates{count_bound, max_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::MakeList(aggregates)); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(evaluator->Add(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->Results()); + ASSERT_EQ(results.size(), 2); + EXPECT_EQ(std::get(results[0].value()), 2); + EXPECT_EQ(std::get(results[1].value()), 30); +} + } // namespace iceberg From 8625335d58941ca6210e8a84d2396530f1bfc128 Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Fri, 21 Nov 2025 02:28:54 -0800 Subject: [PATCH 06/12] style: remove redundant virtual on override method --- src/iceberg/expression/aggregate.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 6a780a4a2..51b88c898 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -98,7 +98,7 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound return term_ ? term_->reference() : nullptr; } - virtual Result Evaluate(const StructLike& data) const override = 0; + Result Evaluate(const StructLike& data) const override = 0; bool is_bound_aggregate() const override { return true; } From 31f50f1bb96e60077e24c7fa28bb054396eb426a Mon Sep 17 00:00:00 2001 From: Zhiyuan Liang <115799793+SuKi2cn@users.noreply.github.com> Date: Mon, 24 Nov 2025 07:48:43 -0800 Subject: [PATCH 07/12] Update src/iceberg/expression/aggregate.cc Co-authored-by: Gang Wu --- src/iceberg/expression/aggregate.cc | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index 69e2fa56b..eb351380e 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -35,12 +35,11 @@ namespace iceberg { namespace { Result> GetPrimitiveType(const BoundTerm& term) { - auto primitive = std::dynamic_pointer_cast(term.type()); - if (primitive == nullptr) { + if (!term.type().is_primitive()) { return InvalidExpression("Aggregate requires primitive type, got {}", term.type()->ToString()); } - return primitive; + return internal::checked_pointer_cast(term.type()); } } // namespace From ab0e9de798f5c6d5d8eea6c55b9d72ac55acbb7a Mon Sep 17 00:00:00 2001 From: Zhiyuan Liang <115799793+SuKi2cn@users.noreply.github.com> Date: Mon, 24 Nov 2025 07:50:15 -0800 Subject: [PATCH 08/12] Update src/iceberg/expression/aggregate.h Co-authored-by: Gang Wu --- src/iceberg/expression/aggregate.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 51b88c898..6c80439b1 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -117,7 +117,7 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound }; /// \brief Bound COUNT aggregate. -class ICEBERG_EXPORT BoundCountAggregate : public BoundAggregate { +class ICEBERG_EXPORT CountAggregate : public BoundAggregate { public: enum class Mode { kNonNull, kNull, kStar }; From 63f557af2413d3ea1a6577e3f85b2afe2da669e4 Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Mon, 24 Nov 2025 09:05:08 -0800 Subject: [PATCH 09/12] Refactor aggregate implementation to match Java design This PR refactors the C++ aggregate implementation to better align with the Java Iceberg design and existing Predicate patterns. The changes introduce a dedicated Aggregator hierarchy, simplify COUNT handling by splitting it into distinct classes, and improve evaluator extensibility for future input types (e.g. DataFile). --- src/iceberg/expression/aggregate.cc | 377 ++++++++++++++------ src/iceberg/expression/aggregate.h | 101 ++++-- src/iceberg/expression/expression.cc | 3 + src/iceberg/expression/expression.h | 1 + src/iceberg/expression/expression_visitor.h | 9 +- src/iceberg/expression/expressions.cc | 29 +- src/iceberg/expression/expressions.h | 22 +- src/iceberg/expression/predicate.h | 4 - src/iceberg/expression/term.h | 4 + src/iceberg/test/aggregate_test.cc | 12 +- 10 files changed, 381 insertions(+), 181 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index eb351380e..35efc9d3b 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -42,101 +42,287 @@ Result> GetPrimitiveType(const BoundTerm& term) { return internal::checked_pointer_cast(term.type()); } +class CountNonNullAggregator : public BoundAggregate::Aggregator { + public: + explicit CountNonNullAggregator(std::shared_ptr term) + : term_(std::move(term)) {} + + Status Update(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); + if (!literal.IsNull()) { + ++count_; + } + return Status::OK(); + } + + Result ResultLiteral() const override { return Literal::Long(count_); } + + private: + std::shared_ptr term_; + int64_t count_ = 0; +}; + +class CountNullAggregator : public BoundAggregate::Aggregator { + public: + explicit CountNullAggregator(std::shared_ptr term) + : term_(std::move(term)) {} + + Status Update(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); + if (literal.IsNull()) { + ++count_; + } + return Status::OK(); + } + + Result ResultLiteral() const override { return Literal::Long(count_); } + + private: + std::shared_ptr term_; + int64_t count_ = 0; +}; + +class CountStarAggregator : public BoundAggregate::Aggregator { + public: + Status Update(const StructLike& /*row*/) override { + ++count_; + return Status::OK(); + } + + Result ResultLiteral() const override { return Literal::Long(count_); } + + private: + int64_t count_ = 0; +}; + +class ValueAggregatorImpl : public BoundAggregate::Aggregator { + public: + ValueAggregatorImpl(bool is_max, std::shared_ptr term) + : is_max_(is_max), term_(std::move(term)) {} + + Status Update(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); + if (val_literal.IsNull()) { + return Status::OK(); + } + if (!current_) { + current_ = std::move(val_literal); + return Status::OK(); + } + + auto ordering = val_literal <=> *current_; + if (ordering == std::partial_ordering::unordered) { + return InvalidExpression("Cannot compare literals of type {}", + val_literal.type()->ToString()); + } + + if (is_max_) { + if (ordering == std::partial_ordering::greater) { + current_ = std::move(val_literal); + } + } else { + if (ordering == std::partial_ordering::less) { + current_ = std::move(val_literal); + } + } + return Status::OK(); + } + + Result ResultLiteral() const override { + if (current_) { + return *current_; + } + ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*term_)); + return Literal::Null(type); + } + + private: + bool is_max_; + std::shared_ptr term_; + std::optional current_; +}; + } // namespace // -------------------- Bound aggregates -------------------- -BoundCountAggregate::BoundCountAggregate(Expression::Operation op, Mode mode, - std::shared_ptr term) - : BoundAggregate(op, std::move(term)), mode_(mode) {} +CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr term) + : CountAggregate(Expression::Operation::kCount, std::move(term)) {} -std::string BoundCountAggregate::ToString() const { - if (mode_ == Mode::kStar) { - return "count(*)"; +Result> CountNonNullAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound count aggregate requires non-null term"); } + return std::shared_ptr( + new CountNonNullAggregate(std::move(term))); +} + +std::string CountNonNullAggregate::ToString() const { ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); - switch (mode_) { - case Mode::kNull: - return std::format("count_null({})", term()->reference()->name()); - case Mode::kNonNull: - return std::format("count({})", term()->reference()->name()); - case Mode::kStar: - break; - } - std::unreachable(); -} - -Result BoundCountAggregate::Evaluate(const StructLike& data) const { - switch (mode_) { - case Mode::kStar: - return Literal::Long(1); - case Mode::kNonNull: { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return Literal::Long(literal.IsNull() ? 0 : 1); - } - case Mode::kNull: { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return Literal::Long(literal.IsNull() ? 1 : 0); - } + return std::format("count({})", term()->reference()->name()); +} + +Result CountNonNullAggregate::Evaluate(const StructLike& data) const { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return Literal::Long(literal.IsNull() ? 0 : 1); +} + +Result> CountNonNullAggregate::NewAggregator() + const { + return std::unique_ptr(new CountNonNullAggregator(term())); +} + +CountNullAggregate::CountNullAggregate(std::shared_ptr term) + : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} + +Result> CountNullAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound count aggregate requires non-null term"); } - std::unreachable(); + return std::shared_ptr(new CountNullAggregate(std::move(term))); +} + +std::string CountNullAggregate::ToString() const { + ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); + return std::format("count_null({})", term()->reference()->name()); +} + +Result CountNullAggregate::Evaluate(const StructLike& data) const { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return Literal::Long(literal.IsNull() ? 1 : 0); +} + +Result> CountNullAggregate::NewAggregator() + const { + return std::unique_ptr(new CountNullAggregator(term())); } -BoundValueAggregate::BoundValueAggregate(Expression::Operation op, - std::shared_ptr term) +CountStarAggregate::CountStarAggregate() + : CountAggregate(Expression::Operation::kCountStar, nullptr) {} + +Result> CountStarAggregate::Make() { + return std::shared_ptr(new CountStarAggregate()); +} + +std::string CountStarAggregate::ToString() const { return "count(*)"; } + +Result CountStarAggregate::Evaluate(const StructLike& data) const { + return Literal::Long(1); +} + +Result> CountStarAggregate::NewAggregator() + const { + return std::unique_ptr(new CountStarAggregator()); +} + +ValueAggregate::ValueAggregate(Expression::Operation op, std::shared_ptr term) : BoundAggregate(op, std::move(term)) {} -std::string BoundValueAggregate::ToString() const { +std::string ValueAggregate::ToString() const { ICEBERG_DCHECK(term() != nullptr, "Bound value aggregate should have term"); auto prefix = op() == Expression::Operation::kMax ? "max" : "min"; return std::format("{}({})", prefix, term()->reference()->name()); } -Result BoundValueAggregate::Evaluate(const StructLike& data) const { +Result ValueAggregate::Evaluate(const StructLike& data) const { ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); return literal; } +Result> ValueAggregate::NewAggregator() + const { + bool is_max = op() == Expression::Operation::kMax; + return std::unique_ptr( + new ValueAggregatorImpl(is_max, term())); +} + // -------------------- Unbound binding -------------------- template Result> UnboundAggregateImpl::Bind( const Schema& schema, bool case_sensitive) const { + ICEBERG_DCHECK(UnboundAggregateImpl::IsSupportedOp(this->op()), + "Unexpected aggregate operation"); + std::shared_ptr bound_term; if (this->term()) { ICEBERG_ASSIGN_OR_THROW(bound_term, this->term()->Bind(schema, case_sensitive)); } - switch (count_mode_) { - case CountMode::kStar: - case CountMode::kNull: - case CountMode::kNonNull: { - auto op = this->op() == Expression::Operation::kCountStar - ? Expression::Operation::kCountStar - : Expression::Operation::kCount; - auto mode = - count_mode_ == CountMode::kNull - ? BoundCountAggregate::Mode::kNull - : (count_mode_ == CountMode::kStar ? BoundCountAggregate::Mode::kStar - : BoundCountAggregate::Mode::kNonNull); - auto aggregate = - std::make_shared(op, mode, std::move(bound_term)); + switch (this->op()) { + case Expression::Operation::kCountStar: { + ICEBERG_ASSIGN_OR_THROW(auto aggregate, CountStarAggregate::Make()); return aggregate; } - case CountMode::kNone: { - if (this->op() != Expression::Operation::kMax && - this->op() != Expression::Operation::kMin) { - return NotSupported("Unsupported aggregate operation"); + case Expression::Operation::kCount: { + if (!bound_term) { + return InvalidExpression("Aggregate requires a term"); } + ICEBERG_ASSIGN_OR_THROW(auto aggregate, + CountNonNullAggregate::Make(std::move(bound_term))); + return aggregate; + } + case Expression::Operation::kCountNull: { + if (!bound_term) { + return InvalidExpression("Aggregate requires a term"); + } + ICEBERG_ASSIGN_OR_THROW(auto aggregate, + CountNullAggregate::Make(std::move(bound_term))); + return aggregate; + } + case Expression::Operation::kMax: + case Expression::Operation::kMin: { if (!bound_term) { return InvalidExpression("Aggregate requires a term"); } auto aggregate = - std::make_shared(this->op(), std::move(bound_term)); + std::make_shared(this->op(), std::move(bound_term)); return aggregate; } + default: + return NotSupported("Unsupported aggregate operation"); + } +} + +template +Result>> UnboundAggregateImpl::Make( + Expression::Operation op, std::shared_ptr> term) { + if (!IsSupportedOp(op)) { + return NotSupported("Unsupported aggregate operation: {}", ::iceberg::ToString(op)); + } + if (op != Expression::Operation::kCountStar && !term) { + return InvalidExpression("Aggregate term cannot be null unless COUNT(*)"); + } + + return std::shared_ptr>( + new UnboundAggregateImpl(op, std::move(term))); +} + +template +std::string UnboundAggregateImpl::ToString() const { + ICEBERG_DCHECK(UnboundAggregateImpl::IsSupportedOp(this->op()), + "Unexpected aggregate operation"); + ICEBERG_DCHECK( + this->op() == Expression::Operation::kCountStar || this->term() != nullptr, + "Aggregate term should not be null unless COUNT(*)"); + + auto term_str = this->term() ? this->term()->ToString() : std::string{}; + switch (this->op()) { + case Expression::Operation::kCount: + return std::format("count({})", term_str); + case Expression::Operation::kCountNull: + return std::format("count_if({} is null)", term_str); + case Expression::Operation::kCountStar: + return "count(*)"; + case Expression::Operation::kMax: + return std::format("max({})", term_str); + case Expression::Operation::kMin: + return std::format("min({})", term_str); + default: + return "Aggregate"; } - std::unreachable(); } template class UnboundAggregateImpl; @@ -147,49 +333,14 @@ namespace { class AggregateEvaluatorImpl : public AggregateEvaluator { public: - explicit AggregateEvaluatorImpl(std::vector> aggregates) - : aggregates_(std::move(aggregates)), counts_(aggregates_.size(), 0) { - values_.resize(aggregates_.size()); - } - - Status Add(const StructLike& row) override { - for (size_t i = 0; i < aggregates_.size(); ++i) { - const auto& agg = aggregates_[i]; - switch (agg->kind()) { - case BoundAggregate::Kind::kCount: { - auto count_agg = internal::checked_pointer_cast(agg); - ICEBERG_ASSIGN_OR_RAISE(auto contribution, count_agg->Evaluate(row)); - counts_[i] += std::get(contribution.value()); - break; - } - case BoundAggregate::Kind::kValue: { - auto value_agg = internal::checked_pointer_cast(agg); - ICEBERG_ASSIGN_OR_RAISE(auto val_literal, value_agg->Evaluate(row)); - if (val_literal.IsNull()) { - break; - } - auto& current = values_[i]; - if (!current) { - current = std::move(val_literal); - break; - } - auto ordering = val_literal <=> *current; - if (ordering == std::partial_ordering::unordered) { - return InvalidExpression("Cannot compare literals of type {}", - val_literal.type()->ToString()); - } - if (agg->op() == Expression::Operation::kMax) { - if (ordering == std::partial_ordering::greater) { - current = std::move(val_literal); - } - } else { - if (ordering == std::partial_ordering::less) { - current = std::move(val_literal); - } - } - break; - } - } + AggregateEvaluatorImpl( + std::vector> aggregates, + std::vector> aggregators) + : aggregates_(std::move(aggregates)), aggregators_(std::move(aggregators)) {} + + Status Update(const StructLike& row) override { + for (auto& aggregator : aggregators_) { + ICEBERG_RETURN_NOT_OK(aggregator->Update(row)); } return {}; } @@ -197,23 +348,9 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { Result> Results() const override { std::vector out; out.reserve(aggregates_.size()); - for (size_t i = 0; i < aggregates_.size(); ++i) { - switch (aggregates_[i]->kind()) { - case BoundAggregate::Kind::kCount: - out.emplace_back(Literal::Long(counts_[i])); - break; - case BoundAggregate::Kind::kValue: { - if (values_[i]) { - out.emplace_back(*values_[i]); - } else { - auto value_agg = - internal::checked_pointer_cast(aggregates_[i]); - ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*value_agg->term())); - out.emplace_back(Literal::Null(type)); - } - break; - } - } + for (const auto& aggregator : aggregators_) { + ICEBERG_ASSIGN_OR_RAISE(auto literal, aggregator->ResultLiteral()); + out.emplace_back(std::move(literal)); } return out; } @@ -230,8 +367,7 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { private: std::vector> aggregates_; - std::vector counts_; - std::vector> values_; + std::vector> aggregators_; }; } // namespace @@ -248,8 +384,15 @@ Result> AggregateEvaluator::MakeList( if (aggregates.empty()) { return InvalidArgument("AggregateEvaluator requires at least one aggregate"); } + std::vector> aggregators; + aggregators.reserve(aggregates.size()); + for (const auto& agg : aggregates) { + ICEBERG_ASSIGN_OR_RAISE(auto aggregator, agg->NewAggregator()); + aggregators.push_back(std::move(aggregator)); + } + return std::unique_ptr( - new AggregateEvaluatorImpl(std::move(aggregates))); + new AggregateEvaluatorImpl(std::move(aggregates), std::move(aggregators))); } } // namespace iceberg diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 6c80439b1..21a201a00 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -22,21 +22,20 @@ /// \file iceberg/expression/aggregate.h /// Aggregate expression definitions. -#include #include +#include +#include #include #include "iceberg/expression/expression.h" #include "iceberg/expression/term.h" #include "iceberg/result.h" +#include "iceberg/type_fwd.h" namespace iceberg { -template -concept AggregateTermType = std::derived_from; - /// \brief Base aggregate holding an operation and a term. -template +template class ICEBERG_EXPORT Aggregate : public virtual Expression { public: ~Aggregate() override = default; @@ -69,11 +68,8 @@ class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, using BASE = Aggregate>; public: - enum class CountMode { kNonNull, kNull, kStar, kNone }; - - UnboundAggregateImpl(Expression::Operation op, std::shared_ptr> term, - CountMode count_mode = CountMode::kNone) - : BASE(op, std::move(term)), count_mode_(count_mode) {} + static Result>> Make( + Expression::Operation op, std::shared_ptr> term); std::shared_ptr reference() override { return BASE::term() ? BASE::term()->reference() : nullptr; @@ -82,10 +78,22 @@ class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, Result> Bind(const Schema& schema, bool case_sensitive) const override; - CountMode count_mode() const { return count_mode_; } + std::string ToString() const override; private: - CountMode count_mode_; + static constexpr bool IsSupportedOp(Expression::Operation op) { + return op == Expression::Operation::kCount || + op == Expression::Operation::kCountNull || + op == Expression::Operation::kCountStar || op == Expression::Operation::kMax || + op == Expression::Operation::kMin; + } + + UnboundAggregateImpl(Expression::Operation op, std::shared_ptr> term) + : BASE(op, std::move(term)) { + ICEBERG_DCHECK(IsSupportedOp(op), "Unexpected aggregate operation"); + ICEBERG_DCHECK(op == Expression::Operation::kCountStar || BASE::term() != nullptr, + "Aggregate term cannot be null unless COUNT(*)"); + } }; /// \brief Base class for bound aggregates. @@ -94,7 +102,20 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound using Aggregate::op; using Aggregate::term; + class ICEBERG_EXPORT Aggregator { + public: + virtual ~Aggregator() = default; + + virtual Status Update(const StructLike& row) = 0; + virtual Status Update(const DataFile& file) { + return NotSupported("Aggregating DataFile not supported"); + } + virtual Result ResultLiteral() const = 0; + }; + std::shared_ptr reference() override { + ICEBERG_DCHECK(term_ != nullptr || op() == Expression::Operation::kCountStar, + "Bound aggregate term should not be null unless COUNT(*)"); return term_ ? term_->reference() : nullptr; } @@ -110,40 +131,74 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound }; virtual Kind kind() const = 0; + virtual Result> NewAggregator() const = 0; protected: BoundAggregate(Expression::Operation op, std::shared_ptr term) : Aggregate(op, std::move(term)) {} }; -/// \brief Bound COUNT aggregate. +/// \brief Base class for COUNT aggregates. class ICEBERG_EXPORT CountAggregate : public BoundAggregate { public: - enum class Mode { kNonNull, kNull, kStar }; + Kind kind() const override { return Kind::kCount; } - BoundCountAggregate(Expression::Operation op, Mode mode, - std::shared_ptr term); + protected: + CountAggregate(Expression::Operation op, std::shared_ptr term) + : BoundAggregate(op, std::move(term)) {} +}; - Mode mode() const { return mode_; } +/// \brief COUNT(term) aggregate. +class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { + public: + static Result> Make( + std::shared_ptr term); - Kind kind() const override { return Kind::kCount; } + std::string ToString() const override; + Result Evaluate(const StructLike& data) const override; + Result> NewAggregator() const override; + + private: + explicit CountNonNullAggregate(std::shared_ptr term); +}; + +/// \brief COUNT_NULL(term) aggregate. +class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { + public: + static Result> Make( + std::shared_ptr term); + + std::string ToString() const override; + Result Evaluate(const StructLike& data) const override; + Result> NewAggregator() const override; + + private: + explicit CountNullAggregate(std::shared_ptr term); +}; + +/// \brief COUNT(*) aggregate. +class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { + public: + static Result> Make(); std::string ToString() const override; Result Evaluate(const StructLike& data) const override; + Result> NewAggregator() const override; private: - Mode mode_; + CountStarAggregate(); }; /// \brief Bound MAX/MIN aggregate. -class ICEBERG_EXPORT BoundValueAggregate : public BoundAggregate { +class ICEBERG_EXPORT ValueAggregate : public BoundAggregate { public: - BoundValueAggregate(Expression::Operation op, std::shared_ptr term); + ValueAggregate(Expression::Operation op, std::shared_ptr term); Kind kind() const override { return Kind::kValue; } std::string ToString() const override; Result Evaluate(const StructLike& data) const override; + Result> NewAggregator() const override; }; /// \brief Evaluates bound aggregates over StructLike rows. @@ -162,8 +217,8 @@ class ICEBERG_EXPORT AggregateEvaluator { static Result> MakeList( std::vector> aggregates); - /// \brief Add a row to the aggregate. - virtual Status Add(const StructLike& row) = 0; + /// \brief Update aggregates with a row. + virtual Status Update(const StructLike& row) = 0; /// \brief Final aggregated value. virtual Result> Results() const = 0; diff --git a/src/iceberg/expression/expression.cc b/src/iceberg/expression/expression.cc index b40082bb4..9ee5dfc17 100644 --- a/src/iceberg/expression/expression.cc +++ b/src/iceberg/expression/expression.cc @@ -191,6 +191,8 @@ std::string_view ToString(Expression::Operation op) { return "NOT_STARTS_WITH"; case Expression::Operation::kCount: return "COUNT"; + case Expression::Operation::kCountNull: + return "COUNT_NULL"; case Expression::Operation::kNot: return "NOT"; case Expression::Operation::kCountStar: @@ -246,6 +248,7 @@ Result Negate(Expression::Operation op) { case Expression::Operation::kMax: case Expression::Operation::kMin: case Expression::Operation::kCount: + case Expression::Operation::kCountNull: return InvalidExpression("No negation for operation: {}", op); } std::unreachable(); diff --git a/src/iceberg/expression/expression.h b/src/iceberg/expression/expression.h index e6e19ea83..58b3fcc1a 100644 --- a/src/iceberg/expression/expression.h +++ b/src/iceberg/expression/expression.h @@ -57,6 +57,7 @@ class ICEBERG_EXPORT Expression : public util::Formattable { kStartsWith, kNotStartsWith, kCount, + kCountNull, kCountStar, kMax, kMin diff --git a/src/iceberg/expression/expression_visitor.h b/src/iceberg/expression/expression_visitor.h index edc99196d..c54da9324 100644 --- a/src/iceberg/expression/expression_visitor.h +++ b/src/iceberg/expression/expression_visitor.h @@ -24,6 +24,7 @@ #include #include +#include #include "iceberg/expression/aggregate.h" #include "iceberg/expression/expression.h" @@ -83,16 +84,16 @@ class ICEBERG_EXPORT ExpressionVisitor { /// \param aggregate The bound aggregate to visit. virtual Result Aggregate(const std::shared_ptr& aggregate) { ICEBERG_DCHECK(aggregate != nullptr, "Bound aggregate cannot be null"); - return NotSupported("Bound aggregate is not supported by this visitor: {}", - aggregate->ToString()); + return NotSupported("Visitor {} does not support bound aggregate: {}", + typeid(*this).name(), aggregate->ToString()); } /// \brief Visit an unbound aggregate. /// \param aggregate The unbound aggregate to visit. virtual Result Aggregate(const std::shared_ptr& aggregate) { ICEBERG_DCHECK(aggregate != nullptr, "Unbound aggregate cannot be null"); - return NotSupported("Unbound aggregate is not supported by this visitor: {}", - aggregate->ToString()); + return NotSupported("Visitor {} does not support unbound aggregate: {}", + typeid(*this).name(), aggregate->ToString()); } }; diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index 4902b7629..786cc0ab7 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -91,9 +91,8 @@ std::shared_ptr> Expressions::Count( std::shared_ptr> Expressions::Count( std::shared_ptr> expr) { - auto agg = std::make_shared>( - Expression::Operation::kCount, std::move(expr), - UnboundAggregateImpl::CountMode::kNonNull); + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCount, std::move(expr))); return agg; } @@ -104,9 +103,9 @@ std::shared_ptr> Expressions::CountNull( std::shared_ptr> Expressions::CountNull( std::shared_ptr> expr) { - auto agg = std::make_shared>( - Expression::Operation::kCount, std::move(expr), - UnboundAggregateImpl::CountMode::kNull); + ICEBERG_ASSIGN_OR_THROW(auto agg, + UnboundAggregateImpl::Make( + Expression::Operation::kCountNull, std::move(expr))); return agg; } @@ -117,16 +116,14 @@ std::shared_ptr> Expressions::CountNotNull( std::shared_ptr> Expressions::CountNotNull( std::shared_ptr> expr) { - auto agg = std::make_shared>( - Expression::Operation::kCount, std::move(expr), - UnboundAggregateImpl::CountMode::kNonNull); + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCount, std::move(expr))); return agg; } std::shared_ptr> Expressions::CountStar() { - auto agg = std::make_shared>( - Expression::Operation::kCountStar, nullptr, - UnboundAggregateImpl::CountMode::kStar); + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCountStar, nullptr)); return agg; } @@ -136,8 +133,8 @@ std::shared_ptr> Expressions::Max(std::stri std::shared_ptr> Expressions::Max( std::shared_ptr> expr) { - auto agg = std::make_shared>( - Expression::Operation::kMax, std::move(expr)); + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMax, std::move(expr))); return agg; } @@ -147,8 +144,8 @@ std::shared_ptr> Expressions::Min(std::stri std::shared_ptr> Expressions::Min( std::shared_ptr> expr) { - auto agg = std::make_shared>( - Expression::Operation::kMin, std::move(expr)); + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMin, std::move(expr))); return agg; } diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index a00f6622b..cb1d6df7e 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -104,43 +104,43 @@ class ICEBERG_EXPORT Expressions { // Aggregates - /// \brief Create COUNT(col) aggregate. + /// \brief Create a COUNT aggregate for a field name. static std::shared_ptr> Count(std::string name); - /// \brief Create COUNT(unbound term) aggregate. + /// \brief Create a COUNT aggregate for an unbound term. static std::shared_ptr> Count( std::shared_ptr> expr); - /// \brief Create COUNT_NULL(col) aggregate. + /// \brief Create a COUNT_NULL aggregate for a field name. static std::shared_ptr> CountNull( std::string name); - /// \brief Create COUNT_NULL(unbound term) aggregate. + /// \brief Create a COUNT_NULL aggregate for an unbound term. static std::shared_ptr> CountNull( std::shared_ptr> expr); - /// \brief Create COUNT_NOT_NULL(col) aggregate. + /// \brief Create a COUNT_NOT_NULL aggregate for a field name. static std::shared_ptr> CountNotNull( std::string name); - /// \brief Create COUNT_NOT_NULL(unbound term) aggregate. + /// \brief Create a COUNT_NOT_NULL aggregate for an unbound term. static std::shared_ptr> CountNotNull( std::shared_ptr> expr); - /// \brief Create COUNT(*) aggregate. + /// \brief Create a COUNT(*) aggregate. static std::shared_ptr> CountStar(); - /// \brief Create MAX(col) aggregate. + /// \brief Create a MAX aggregate for a field name. static std::shared_ptr> Max(std::string name); - /// \brief Create MAX(unbound term) aggregate. + /// \brief Create a MAX aggregate for an unbound term. static std::shared_ptr> Max( std::shared_ptr> expr); - /// \brief Create MIN(col) aggregate. + /// \brief Create a MIN aggregate for a field name. static std::shared_ptr> Min(std::string name); - /// \brief Create MIN(unbound term) aggregate. + /// \brief Create a MIN aggregate for an unbound term. static std::shared_ptr> Min( std::shared_ptr> expr); diff --git a/src/iceberg/expression/predicate.h b/src/iceberg/expression/predicate.h index 29393766b..c2e09e667 100644 --- a/src/iceberg/expression/predicate.h +++ b/src/iceberg/expression/predicate.h @@ -22,7 +22,6 @@ /// \file iceberg/expression/predicate.h /// Predicate interface for boolean expressions that test terms. -#include #include #include "iceberg/expression/expression.h" @@ -31,9 +30,6 @@ namespace iceberg { -template -concept TermType = std::derived_from; - /// \brief A predicate is a boolean expression that tests a term against some criteria. /// /// \tparam TermType The type of the term being tested diff --git a/src/iceberg/expression/term.h b/src/iceberg/expression/term.h index e2a378feb..c19b81324 100644 --- a/src/iceberg/expression/term.h +++ b/src/iceberg/expression/term.h @@ -22,6 +22,7 @@ /// \file iceberg/expression/term.h /// Term interface for Iceberg expressions - represents values that can be evaluated. +#include #include #include #include @@ -41,6 +42,9 @@ class ICEBERG_EXPORT Term : public util::Formattable { virtual Kind kind() const = 0; }; +template +concept TermType = std::derived_from; + /// \brief Interface for unbound expressions that need schema binding. /// /// Unbound expressions contain string-based references that must be resolved diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index 3a9e7d6dc..139524129 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -85,9 +85,9 @@ TEST(AggregateTest, CountVariants) { VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; for (const auto& row : rows) { - ASSERT_TRUE(count_evaluator->Add(row).has_value()); - ASSERT_TRUE(count_null_evaluator->Add(row).has_value()); - ASSERT_TRUE(count_star_evaluator->Add(row).has_value()); + ASSERT_TRUE(count_evaluator->Update(row).has_value()); + ASSERT_TRUE(count_null_evaluator->Update(row).has_value()); + ASSERT_TRUE(count_star_evaluator->Update(row).has_value()); } ICEBERG_UNWRAP_OR_FAIL(auto count_result, count_evaluator->ResultLiteral()); @@ -118,8 +118,8 @@ TEST(AggregateTest, MaxMinAggregates) { VectorStructLike({Scalar{int32_t{12}}})}; for (const auto& row : rows) { - ASSERT_TRUE(max_eval->Add(row).has_value()); - ASSERT_TRUE(min_eval->Add(row).has_value()); + ASSERT_TRUE(max_eval->Update(row).has_value()); + ASSERT_TRUE(min_eval->Update(row).has_value()); } ICEBERG_UNWRAP_OR_FAIL(auto max_result, max_eval->ResultLiteral()); @@ -148,7 +148,7 @@ TEST(AggregateTest, MultipleAggregatesInEvaluator) { VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; for (const auto& row : rows) { - ASSERT_TRUE(evaluator->Add(row).has_value()); + ASSERT_TRUE(evaluator->Update(row).has_value()); } ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->Results()); From 032c7666aa87765f9d5d0e7ba01d3bd1c73d3d5b Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Mon, 24 Nov 2025 09:14:07 -0800 Subject: [PATCH 10/12] Fix aggregate status handling --- src/iceberg/expression/aggregate.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index 35efc9d3b..b51f33f8a 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -35,7 +35,7 @@ namespace iceberg { namespace { Result> GetPrimitiveType(const BoundTerm& term) { - if (!term.type().is_primitive()) { + if (!term.type()->is_primitive()) { return InvalidExpression("Aggregate requires primitive type, got {}", term.type()->ToString()); } @@ -52,7 +52,7 @@ class CountNonNullAggregator : public BoundAggregate::Aggregator { if (!literal.IsNull()) { ++count_; } - return Status::OK(); + return {}; } Result ResultLiteral() const override { return Literal::Long(count_); } @@ -72,7 +72,7 @@ class CountNullAggregator : public BoundAggregate::Aggregator { if (literal.IsNull()) { ++count_; } - return Status::OK(); + return {}; } Result ResultLiteral() const override { return Literal::Long(count_); } @@ -86,7 +86,7 @@ class CountStarAggregator : public BoundAggregate::Aggregator { public: Status Update(const StructLike& /*row*/) override { ++count_; - return Status::OK(); + return {}; } Result ResultLiteral() const override { return Literal::Long(count_); } @@ -103,11 +103,11 @@ class ValueAggregatorImpl : public BoundAggregate::Aggregator { Status Update(const StructLike& row) override { ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); if (val_literal.IsNull()) { - return Status::OK(); + return {}; } if (!current_) { current_ = std::move(val_literal); - return Status::OK(); + return {}; } auto ordering = val_literal <=> *current_; @@ -125,7 +125,7 @@ class ValueAggregatorImpl : public BoundAggregate::Aggregator { current_ = std::move(val_literal); } } - return Status::OK(); + return {}; } Result ResultLiteral() const override { @@ -340,7 +340,7 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { Status Update(const StructLike& row) override { for (auto& aggregator : aggregators_) { - ICEBERG_RETURN_NOT_OK(aggregator->Update(row)); + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(row)); } return {}; } From 04a1a12c72d6b20c6bb83e28700a24f3ca30f39a Mon Sep 17 00:00:00 2001 From: Zhiyuan Date: Wed, 26 Nov 2025 01:34:51 -0800 Subject: [PATCH 11/12] refactor: aggregate framework and improve evaluator design This change refactors the aggregate framework to better match the Java implementation and improves API clarity, performance, and test coverage. --- src/iceberg/expression/aggregate.cc | 233 ++++++++++++-------- src/iceberg/expression/aggregate.h | 61 +++-- src/iceberg/expression/expression_visitor.h | 8 +- src/iceberg/test/aggregate_test.cc | 104 ++++++++- 4 files changed, 268 insertions(+), 138 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index b51f33f8a..d1ac7a9f5 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -34,11 +34,8 @@ namespace iceberg { namespace { -Result> GetPrimitiveType(const BoundTerm& term) { - if (!term.type()->is_primitive()) { - return InvalidExpression("Aggregate requires primitive type, got {}", - term.type()->ToString()); - } +std::shared_ptr GetPrimitiveType(const BoundTerm& term) { + ICEBERG_DCHECK(term.type()->is_primitive(), "Value aggregate term should be primitive"); return internal::checked_pointer_cast(term.type()); } @@ -55,7 +52,7 @@ class CountNonNullAggregator : public BoundAggregate::Aggregator { return {}; } - Result ResultLiteral() const override { return Literal::Long(count_); } + Literal GetResult() const override { return Literal::Long(count_); } private: std::shared_ptr term_; @@ -75,7 +72,7 @@ class CountNullAggregator : public BoundAggregate::Aggregator { return {}; } - Result ResultLiteral() const override { return Literal::Long(count_); } + Literal GetResult() const override { return Literal::Long(count_); } private: std::shared_ptr term_; @@ -89,16 +86,15 @@ class CountStarAggregator : public BoundAggregate::Aggregator { return {}; } - Result ResultLiteral() const override { return Literal::Long(count_); } + Literal GetResult() const override { return Literal::Long(count_); } private: int64_t count_ = 0; }; -class ValueAggregatorImpl : public BoundAggregate::Aggregator { +class MaxAggregator : public BoundAggregate::Aggregator { public: - ValueAggregatorImpl(bool is_max, std::shared_ptr term) - : is_max_(is_max), term_(std::move(term)) {} + explicit MaxAggregator(std::shared_ptr term) : term_(std::move(term)) {} Status Update(const StructLike& row) override { ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); @@ -116,28 +112,52 @@ class ValueAggregatorImpl : public BoundAggregate::Aggregator { val_literal.type()->ToString()); } - if (is_max_) { - if (ordering == std::partial_ordering::greater) { - current_ = std::move(val_literal); - } - } else { - if (ordering == std::partial_ordering::less) { - current_ = std::move(val_literal); - } + if (ordering == std::partial_ordering::greater) { + current_ = std::move(val_literal); } return {}; } - Result ResultLiteral() const override { - if (current_) { - return *current_; + Literal GetResult() const override { + return current_.value_or(Literal::Null(GetPrimitiveType(*term_))); + } + + private: + std::shared_ptr term_; + std::optional current_; +}; + +class MinAggregator : public BoundAggregate::Aggregator { + public: + explicit MinAggregator(std::shared_ptr term) : term_(std::move(term)) {} + + Status Update(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); + if (val_literal.IsNull()) { + return {}; + } + if (!current_) { + current_ = std::move(val_literal); + return {}; + } + + auto ordering = val_literal <=> *current_; + if (ordering == std::partial_ordering::unordered) { + return InvalidExpression("Cannot compare literals of type {}", + val_literal.type()->ToString()); } - ICEBERG_ASSIGN_OR_RAISE(auto type, GetPrimitiveType(*term_)); - return Literal::Null(type); + + if (ordering == std::partial_ordering::less) { + current_ = std::move(val_literal); + } + return {}; + } + + Literal GetResult() const override { + return current_.value_or(Literal::Null(GetPrimitiveType(*term_))); } private: - bool is_max_; std::shared_ptr term_; std::optional current_; }; @@ -146,47 +166,59 @@ class ValueAggregatorImpl : public BoundAggregate::Aggregator { // -------------------- Bound aggregates -------------------- +std::string BoundAggregate::ToString() const { + std::string term_name; + if (op() != Expression::Operation::kCountStar) { + ICEBERG_DCHECK(term() != nullptr, "Bound aggregate should have term unless COUNT(*)"); + term_name = term()->reference()->name(); + } + + switch (op()) { + case Expression::Operation::kCount: + return std::format("count({})", term_name); + case Expression::Operation::kCountNull: + return std::format("count_null({})", term_name); + case Expression::Operation::kCountStar: + return "count(*)"; + case Expression::Operation::kMax: + return std::format("max({})", term_name); + case Expression::Operation::kMin: + return std::format("min({})", term_name); + default: + return "Aggregate"; + } +} + CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr term) : CountAggregate(Expression::Operation::kCount, std::move(term)) {} -Result> CountNonNullAggregate::Make( +Result> CountNonNullAggregate::Make( std::shared_ptr term) { if (!term) { return InvalidExpression("Bound count aggregate requires non-null term"); } - return std::shared_ptr( + return std::unique_ptr( new CountNonNullAggregate(std::move(term))); } -std::string CountNonNullAggregate::ToString() const { - ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); - return std::format("count({})", term()->reference()->name()); -} - Result CountNonNullAggregate::Evaluate(const StructLike& data) const { ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); return Literal::Long(literal.IsNull() ? 0 : 1); } -Result> CountNonNullAggregate::NewAggregator() - const { +std::unique_ptr CountNonNullAggregate::NewAggregator() const { return std::unique_ptr(new CountNonNullAggregator(term())); } CountNullAggregate::CountNullAggregate(std::shared_ptr term) : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} -Result> CountNullAggregate::Make( +Result> CountNullAggregate::Make( std::shared_ptr term) { if (!term) { return InvalidExpression("Bound count aggregate requires non-null term"); } - return std::shared_ptr(new CountNullAggregate(std::move(term))); -} - -std::string CountNullAggregate::ToString() const { - ICEBERG_DCHECK(term() != nullptr, "Bound count aggregate should have term"); - return std::format("count_null({})", term()->reference()->name()); + return std::unique_ptr(new CountNullAggregate(std::move(term))); } Result CountNullAggregate::Evaluate(const StructLike& data) const { @@ -194,48 +226,55 @@ Result CountNullAggregate::Evaluate(const StructLike& data) const { return Literal::Long(literal.IsNull() ? 1 : 0); } -Result> CountNullAggregate::NewAggregator() - const { +std::unique_ptr CountNullAggregate::NewAggregator() const { return std::unique_ptr(new CountNullAggregator(term())); } CountStarAggregate::CountStarAggregate() : CountAggregate(Expression::Operation::kCountStar, nullptr) {} -Result> CountStarAggregate::Make() { - return std::shared_ptr(new CountStarAggregate()); +Result> CountStarAggregate::Make() { + return std::unique_ptr(new CountStarAggregate()); } -std::string CountStarAggregate::ToString() const { return "count(*)"; } - Result CountStarAggregate::Evaluate(const StructLike& data) const { return Literal::Long(1); } -Result> CountStarAggregate::NewAggregator() - const { +std::unique_ptr CountStarAggregate::NewAggregator() const { return std::unique_ptr(new CountStarAggregator()); } -ValueAggregate::ValueAggregate(Expression::Operation op, std::shared_ptr term) - : BoundAggregate(op, std::move(term)) {} +MaxAggregate::MaxAggregate(std::shared_ptr term) + : BoundAggregate(Expression::Operation::kMax, std::move(term)) {} -std::string ValueAggregate::ToString() const { - ICEBERG_DCHECK(term() != nullptr, "Bound value aggregate should have term"); - auto prefix = op() == Expression::Operation::kMax ? "max" : "min"; - return std::format("{}({})", prefix, term()->reference()->name()); +std::shared_ptr MaxAggregate::Make(std::shared_ptr term) { + return std::shared_ptr(new MaxAggregate(std::move(term))); } -Result ValueAggregate::Evaluate(const StructLike& data) const { +Result MaxAggregate::Evaluate(const StructLike& data) const { ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); return literal; } -Result> ValueAggregate::NewAggregator() - const { - bool is_max = op() == Expression::Operation::kMax; - return std::unique_ptr( - new ValueAggregatorImpl(is_max, term())); +std::unique_ptr MaxAggregate::NewAggregator() const { + return std::unique_ptr(new MaxAggregator(term())); +} + +MinAggregate::MinAggregate(std::shared_ptr term) + : BoundAggregate(Expression::Operation::kMin, std::move(term)) {} + +std::shared_ptr MinAggregate::Make(std::shared_ptr term) { + return std::shared_ptr(new MinAggregate(std::move(term))); +} + +Result MinAggregate::Evaluate(const StructLike& data) const { + ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); + return literal; +} + +std::unique_ptr MinAggregate::NewAggregator() const { + return std::unique_ptr(new MinAggregator(term())); } // -------------------- Unbound binding -------------------- @@ -248,38 +287,37 @@ Result> UnboundAggregateImpl::Bind( std::shared_ptr bound_term; if (this->term()) { - ICEBERG_ASSIGN_OR_THROW(bound_term, this->term()->Bind(schema, case_sensitive)); + ICEBERG_ASSIGN_OR_RAISE(bound_term, this->term()->Bind(schema, case_sensitive)); } switch (this->op()) { - case Expression::Operation::kCountStar: { - ICEBERG_ASSIGN_OR_THROW(auto aggregate, CountStarAggregate::Make()); - return aggregate; - } - case Expression::Operation::kCount: { - if (!bound_term) { - return InvalidExpression("Aggregate requires a term"); - } - ICEBERG_ASSIGN_OR_THROW(auto aggregate, - CountNonNullAggregate::Make(std::move(bound_term))); - return aggregate; - } - case Expression::Operation::kCountNull: { - if (!bound_term) { - return InvalidExpression("Aggregate requires a term"); - } - ICEBERG_ASSIGN_OR_THROW(auto aggregate, - CountNullAggregate::Make(std::move(bound_term))); - return aggregate; - } + case Expression::Operation::kCountStar: + return CountStarAggregate::Make().transform([](auto aggregate) { + return std::shared_ptr(std::move(aggregate)); + }); + case Expression::Operation::kCount: + return CountNonNullAggregate::Make(std::move(bound_term)) + .transform([](auto aggregate) { + return std::shared_ptr(std::move(aggregate)); + }); + case Expression::Operation::kCountNull: + return CountNullAggregate::Make(std::move(bound_term)) + .transform([](auto aggregate) { + return std::shared_ptr(std::move(aggregate)); + }); case Expression::Operation::kMax: case Expression::Operation::kMin: { if (!bound_term) { return InvalidExpression("Aggregate requires a term"); } - auto aggregate = - std::make_shared(this->op(), std::move(bound_term)); - return aggregate; + if (!bound_term->type()->is_primitive()) { + return InvalidExpression("Aggregate requires primitive type, got {}", + bound_term->type()->ToString()); + } + if (this->op() == Expression::Operation::kMax) { + return MaxAggregate::Make(std::move(bound_term)); + } + return MinAggregate::Make(std::move(bound_term)); } default: return NotSupported("Unsupported aggregate operation"); @@ -345,29 +383,29 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { return {}; } - Result> Results() const override { - std::vector out; - out.reserve(aggregates_.size()); + Result> GetResults() const override { + results_.clear(); + results_.reserve(aggregates_.size()); for (const auto& aggregator : aggregators_) { - ICEBERG_ASSIGN_OR_RAISE(auto literal, aggregator->ResultLiteral()); - out.emplace_back(std::move(literal)); + results_.emplace_back(aggregator->GetResult()); } - return out; + return std::span(results_); } - Result ResultLiteral() const override { + Result GetResult() const override { if (aggregates_.size() != 1) { return InvalidArgument( - "ResultLiteral() is only valid when evaluating a single aggregate"); + "GetResult() is only valid when evaluating a single aggregate"); } - ICEBERG_ASSIGN_OR_RAISE(auto all, Results()); + ICEBERG_ASSIGN_OR_RAISE(auto all, GetResults()); return all.front(); } private: std::vector> aggregates_; std::vector> aggregators_; + mutable std::vector results_; }; } // namespace @@ -376,10 +414,10 @@ Result> AggregateEvaluator::Make( std::shared_ptr aggregate) { std::vector> aggs; aggs.push_back(std::move(aggregate)); - return MakeList(std::move(aggs)); + return Make(std::move(aggs)); } -Result> AggregateEvaluator::MakeList( +Result> AggregateEvaluator::Make( std::vector> aggregates) { if (aggregates.empty()) { return InvalidArgument("AggregateEvaluator requires at least one aggregate"); @@ -387,8 +425,7 @@ Result> AggregateEvaluator::MakeList( std::vector> aggregators; aggregators.reserve(aggregates.size()); for (const auto& agg : aggregates) { - ICEBERG_ASSIGN_OR_RAISE(auto aggregator, agg->NewAggregator()); - aggregators.push_back(std::move(aggregator)); + aggregators.push_back(agg->NewAggregator()); } return std::unique_ptr( diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index 21a201a00..e5dadcbd8 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -24,6 +24,7 @@ #include #include +#include #include #include @@ -92,7 +93,7 @@ class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, : BASE(op, std::move(term)) { ICEBERG_DCHECK(IsSupportedOp(op), "Unexpected aggregate operation"); ICEBERG_DCHECK(op == Expression::Operation::kCountStar || BASE::term() != nullptr, - "Aggregate term cannot be null unless COUNT(*)"); + "Aggregate term cannot be null except for COUNT(*)"); } }; @@ -102,7 +103,7 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound using Aggregate::op; using Aggregate::term; - class ICEBERG_EXPORT Aggregator { + class Aggregator { public: virtual ~Aggregator() = default; @@ -110,28 +111,27 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound virtual Status Update(const DataFile& file) { return NotSupported("Aggregating DataFile not supported"); } - virtual Result ResultLiteral() const = 0; + virtual Literal GetResult() const = 0; }; std::shared_ptr reference() override { ICEBERG_DCHECK(term_ != nullptr || op() == Expression::Operation::kCountStar, - "Bound aggregate term should not be null unless COUNT(*)"); + "Bound aggregate term should not be null except for COUNT(*)"); return term_ ? term_->reference() : nullptr; } + std::string ToString() const override; Result Evaluate(const StructLike& data) const override = 0; bool is_bound_aggregate() const override { return true; } enum class Kind : int8_t { - // Count aggregates (COUNT, COUNT_STAR, COUNT_NULL) kCount = 0, - // Value aggregates (MIN, MAX) kValue, }; virtual Kind kind() const = 0; - virtual Result> NewAggregator() const = 0; + virtual std::unique_ptr NewAggregator() const = 0; protected: BoundAggregate(Expression::Operation op, std::shared_ptr term) @@ -151,12 +151,11 @@ class ICEBERG_EXPORT CountAggregate : public BoundAggregate { /// \brief COUNT(term) aggregate. class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { public: - static Result> Make( + static Result> Make( std::shared_ptr term); - std::string ToString() const override; Result Evaluate(const StructLike& data) const override; - Result> NewAggregator() const override; + std::unique_ptr NewAggregator() const override; private: explicit CountNonNullAggregate(std::shared_ptr term); @@ -165,12 +164,11 @@ class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { /// \brief COUNT_NULL(term) aggregate. class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { public: - static Result> Make( + static Result> Make( std::shared_ptr term); - std::string ToString() const override; Result Evaluate(const StructLike& data) const override; - Result> NewAggregator() const override; + std::unique_ptr NewAggregator() const override; private: explicit CountNullAggregate(std::shared_ptr term); @@ -179,26 +177,41 @@ class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { /// \brief COUNT(*) aggregate. class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { public: - static Result> Make(); + static Result> Make(); - std::string ToString() const override; Result Evaluate(const StructLike& data) const override; - Result> NewAggregator() const override; + std::unique_ptr NewAggregator() const override; private: CountStarAggregate(); }; -/// \brief Bound MAX/MIN aggregate. -class ICEBERG_EXPORT ValueAggregate : public BoundAggregate { +/// \brief Bound MAX aggregate. +class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { public: - ValueAggregate(Expression::Operation op, std::shared_ptr term); + static std::shared_ptr Make(std::shared_ptr term); Kind kind() const override { return Kind::kValue; } - std::string ToString() const override; Result Evaluate(const StructLike& data) const override; - Result> NewAggregator() const override; + std::unique_ptr NewAggregator() const override; + + private: + explicit MaxAggregate(std::shared_ptr term); +}; + +/// \brief Bound MIN aggregate. +class ICEBERG_EXPORT MinAggregate : public BoundAggregate { + public: + static std::shared_ptr Make(std::shared_ptr term); + + Kind kind() const override { return Kind::kValue; } + + Result Evaluate(const StructLike& data) const override; + std::unique_ptr NewAggregator() const override; + + private: + explicit MinAggregate(std::shared_ptr term); }; /// \brief Evaluates bound aggregates over StructLike rows. @@ -214,17 +227,17 @@ class ICEBERG_EXPORT AggregateEvaluator { /// \brief Create an evaluator for multiple bound aggregates. /// \param aggregates Aggregates to evaluate in one pass; order is preserved in /// Results(). - static Result> MakeList( + static Result> Make( std::vector> aggregates); /// \brief Update aggregates with a row. virtual Status Update(const StructLike& row) = 0; /// \brief Final aggregated value. - virtual Result> Results() const = 0; + virtual Result> GetResults() const = 0; /// \brief Convenience accessor when only one aggregate is evaluated. - virtual Result ResultLiteral() const = 0; + virtual Result GetResult() const = 0; }; } // namespace iceberg diff --git a/src/iceberg/expression/expression_visitor.h b/src/iceberg/expression/expression_visitor.h index c54da9324..ed1e75d8e 100644 --- a/src/iceberg/expression/expression_visitor.h +++ b/src/iceberg/expression/expression_visitor.h @@ -84,16 +84,16 @@ class ICEBERG_EXPORT ExpressionVisitor { /// \param aggregate The bound aggregate to visit. virtual Result Aggregate(const std::shared_ptr& aggregate) { ICEBERG_DCHECK(aggregate != nullptr, "Bound aggregate cannot be null"); - return NotSupported("Visitor {} does not support bound aggregate: {}", - typeid(*this).name(), aggregate->ToString()); + return NotSupported("Visitor {} does not support bound aggregate", + typeid(*this).name()); } /// \brief Visit an unbound aggregate. /// \param aggregate The unbound aggregate to visit. virtual Result Aggregate(const std::shared_ptr& aggregate) { ICEBERG_DCHECK(aggregate != nullptr, "Unbound aggregate cannot be null"); - return NotSupported("Visitor {} does not support unbound aggregate: {}", - typeid(*this).name(), aggregate->ToString()); + return NotSupported("Visitor {} does not support unbound aggregate", + typeid(*this).name()); } }; diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index 139524129..adbdc984d 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -90,13 +90,13 @@ TEST(AggregateTest, CountVariants) { ASSERT_TRUE(count_star_evaluator->Update(row).has_value()); } - ICEBERG_UNWRAP_OR_FAIL(auto count_result, count_evaluator->ResultLiteral()); + ICEBERG_UNWRAP_OR_FAIL(auto count_result, count_evaluator->GetResult()); EXPECT_EQ(std::get(count_result.value()), 2); - ICEBERG_UNWRAP_OR_FAIL(auto count_null_result, count_null_evaluator->ResultLiteral()); + ICEBERG_UNWRAP_OR_FAIL(auto count_null_result, count_null_evaluator->GetResult()); EXPECT_EQ(std::get(count_null_result.value()), 1); - ICEBERG_UNWRAP_OR_FAIL(auto count_star_result, count_star_evaluator->ResultLiteral()); + ICEBERG_UNWRAP_OR_FAIL(auto count_star_result, count_star_evaluator->GetResult()); EXPECT_EQ(std::get(count_star_result.value()), 3); } @@ -122,39 +122,119 @@ TEST(AggregateTest, MaxMinAggregates) { ASSERT_TRUE(min_eval->Update(row).has_value()); } - ICEBERG_UNWRAP_OR_FAIL(auto max_result, max_eval->ResultLiteral()); + ICEBERG_UNWRAP_OR_FAIL(auto max_result, max_eval->GetResult()); EXPECT_EQ(std::get(max_result.value()), 12); - ICEBERG_UNWRAP_OR_FAIL(auto min_result, min_eval->ResultLiteral()); + ICEBERG_UNWRAP_OR_FAIL(auto min_result, min_eval->GetResult()); EXPECT_EQ(std::get(min_result.value()), 2); } +TEST(AggregateTest, UnboundAggregateCreationAndBinding) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count = Expressions::Count("id"); + auto count_null = Expressions::CountNull("id"); + auto count_star = Expressions::CountStar(); + auto max = Expressions::Max("value"); + auto min = Expressions::Min("value"); + + EXPECT_EQ(count->ToString(), "count(ref(name=\"id\"))"); + EXPECT_EQ(count_null->ToString(), "count_if(ref(name=\"id\") is null)"); + EXPECT_EQ(count_star->ToString(), "count(*)"); + EXPECT_EQ(max->ToString(), "max(ref(name=\"value\"))"); + EXPECT_EQ(min->ToString(), "min(ref(name=\"value\"))"); + + // Bind succeeds for existing columns + EXPECT_TRUE(Binder::Bind(schema, count, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, count_null, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, count_star, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, max, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, min, /*case_sensitive=*/true).has_value()); + + // Binding fails when the reference is missing + auto missing_count = Expressions::Count("missing"); + auto missing_bind = Binder::Bind(schema, missing_count, /*case_sensitive=*/true); + EXPECT_THAT(missing_bind, IsError(ErrorKind::kInvalidExpression)); + + // Creating a value aggregate with null term should fail + auto invalid_unbound = + UnboundAggregateImpl::Make(Expression::Operation::kMax, nullptr); + EXPECT_THAT(invalid_unbound, IsError(ErrorKind::kInvalidExpression)); +} + +TEST(AggregateTest, BoundAggregateEvaluateDirectly) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("value")); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}}), + VectorStructLike({Scalar{int32_t{3}}, Scalar{int32_t{2}}})}; + + int64_t count_sum = 0; + int64_t count_null_sum = 0; + int32_t max_val = std::numeric_limits::min(); + for (const auto& row : rows) { + ICEBERG_UNWRAP_OR_FAIL(auto c, count_bound->Evaluate(row)); + count_sum += std::get(c.value()); + + ICEBERG_UNWRAP_OR_FAIL(auto cn, count_null_bound->Evaluate(row)); + count_null_sum += std::get(cn.value()); + + ICEBERG_UNWRAP_OR_FAIL(auto mv, max_bound->Evaluate(row)); + if (!mv.IsNull()) { + max_val = std::max(max_val, std::get(mv.value())); + } + } + + EXPECT_EQ(count_sum, 3); + EXPECT_EQ(count_null_sum, 1); + EXPECT_EQ(max_val, 30); +} + TEST(AggregateTest, MultipleAggregatesInEvaluator) { Schema schema({SchemaField::MakeOptional(1, "id", int32()), SchemaField::MakeOptional(2, "value", int32())}); auto count_expr = Expressions::Count("id"); auto max_expr = Expressions::Max("value"); + auto min_expr = Expressions::Min("value"); + auto count_null_expr = Expressions::CountNull("value"); + auto count_star_expr = Expressions::CountStar(); auto count_bound = BindAggregate(schema, count_expr); auto max_bound = BindAggregate(schema, max_expr); + auto min_bound = BindAggregate(schema, min_expr); + auto count_null_bound = BindAggregate(schema, count_null_expr); + auto count_star_bound = BindAggregate(schema, count_star_expr); - std::vector> aggregates{count_bound, max_bound}; - ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::MakeList(aggregates)); + std::vector> aggregates{ + count_bound, max_bound, min_bound, count_null_bound, count_star_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); std::vector rows{ VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), - VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}}), + VectorStructLike({Scalar{int32_t{3}}, Scalar{int32_t{2}}})}; for (const auto& row : rows) { ASSERT_TRUE(evaluator->Update(row).has_value()); } - ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->Results()); - ASSERT_EQ(results.size(), 2); - EXPECT_EQ(std::get(results[0].value()), 2); - EXPECT_EQ(std::get(results[1].value()), 30); + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), 5); + EXPECT_EQ(std::get(results[0].value()), 3); // count + EXPECT_EQ(std::get(results[1].value()), 30); // max + EXPECT_EQ(std::get(results[2].value()), 2); // min + EXPECT_EQ(std::get(results[3].value()), 1); // count_null + EXPECT_EQ(std::get(results[4].value()), 4); // count_star } } // namespace iceberg From 94853eeeb1aad8820d6ac631b2302932e4fe7bdc Mon Sep 17 00:00:00 2001 From: Gang Wu Date: Fri, 28 Nov 2025 10:52:54 +0800 Subject: [PATCH 12/12] polish --- src/iceberg/expression/aggregate.cc | 255 +++++++++------------------- src/iceberg/expression/aggregate.h | 71 ++++---- src/iceberg/test/aggregate_test.cc | 3 +- 3 files changed, 121 insertions(+), 208 deletions(-) diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc index d1ac7a9f5..a9c1a60bf 100644 --- a/src/iceberg/expression/aggregate.cc +++ b/src/iceberg/expression/aggregate.cc @@ -23,8 +23,7 @@ #include #include -#include "iceberg/exception.h" -#include "iceberg/expression/binder.h" +#include "iceberg/expression/literal.h" #include "iceberg/row/struct_like.h" #include "iceberg/type.h" #include "iceberg/util/checked_cast.h" @@ -39,156 +38,124 @@ std::shared_ptr GetPrimitiveType(const BoundTerm& term) { return internal::checked_pointer_cast(term.type()); } -class CountNonNullAggregator : public BoundAggregate::Aggregator { +class CountAggregator : public BoundAggregate::Aggregator { public: - explicit CountNonNullAggregator(std::shared_ptr term) - : term_(std::move(term)) {} + explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {} Status Update(const StructLike& row) override { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); - if (!literal.IsNull()) { - ++count_; - } - return {}; - } - - Literal GetResult() const override { return Literal::Long(count_); } - - private: - std::shared_ptr term_; - int64_t count_ = 0; -}; - -class CountNullAggregator : public BoundAggregate::Aggregator { - public: - explicit CountNullAggregator(std::shared_ptr term) - : term_(std::move(term)) {} - - Status Update(const StructLike& row) override { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term_->Evaluate(row)); - if (literal.IsNull()) { - ++count_; - } - return {}; - } - - Literal GetResult() const override { return Literal::Long(count_); } - - private: - std::shared_ptr term_; - int64_t count_ = 0; -}; - -class CountStarAggregator : public BoundAggregate::Aggregator { - public: - Status Update(const StructLike& /*row*/) override { - ++count_; + ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(row)); + count_ += count; return {}; } Literal GetResult() const override { return Literal::Long(count_); } private: + const CountAggregate& aggregate_; int64_t count_ = 0; }; class MaxAggregator : public BoundAggregate::Aggregator { public: - explicit MaxAggregator(std::shared_ptr term) : term_(std::move(term)) {} + explicit MaxAggregator(const MaxAggregate& aggregate) + : aggregate_(aggregate), + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} - Status Update(const StructLike& row) override { - ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); - if (val_literal.IsNull()) { + Status Update(const StructLike& data) override { + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); + if (value.IsNull()) { return {}; } - if (!current_) { - current_ = std::move(val_literal); + if (current_.IsNull()) { + current_ = std::move(value); return {}; } - auto ordering = val_literal <=> *current_; - if (ordering == std::partial_ordering::unordered) { - return InvalidExpression("Cannot compare literals of type {}", - val_literal.type()->ToString()); + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::greater) { + current_ = std::move(value); } - if (ordering == std::partial_ordering::greater) { - current_ = std::move(val_literal); - } return {}; } - Literal GetResult() const override { - return current_.value_or(Literal::Null(GetPrimitiveType(*term_))); - } + Literal GetResult() const override { return current_; } private: - std::shared_ptr term_; - std::optional current_; + const MaxAggregate& aggregate_; + Literal current_; }; class MinAggregator : public BoundAggregate::Aggregator { public: - explicit MinAggregator(std::shared_ptr term) : term_(std::move(term)) {} + explicit MinAggregator(const MinAggregate& aggregate) + : aggregate_(aggregate), + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} - Status Update(const StructLike& row) override { - ICEBERG_ASSIGN_OR_RAISE(auto val_literal, term_->Evaluate(row)); - if (val_literal.IsNull()) { + Status Update(const StructLike& data) override { + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); + if (value.IsNull()) { return {}; } - if (!current_) { - current_ = std::move(val_literal); + if (current_.IsNull()) { + current_ = std::move(value); return {}; } - auto ordering = val_literal <=> *current_; - if (ordering == std::partial_ordering::unordered) { - return InvalidExpression("Cannot compare literals of type {}", - val_literal.type()->ToString()); - } - - if (ordering == std::partial_ordering::less) { - current_ = std::move(val_literal); + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::less) { + current_ = std::move(value); } return {}; } - Literal GetResult() const override { - return current_.value_or(Literal::Null(GetPrimitiveType(*term_))); - } + Literal GetResult() const override { return current_; } private: - std::shared_ptr term_; - std::optional current_; + const MinAggregate& aggregate_; + Literal current_; }; } // namespace -// -------------------- Bound aggregates -------------------- - -std::string BoundAggregate::ToString() const { - std::string term_name; - if (op() != Expression::Operation::kCountStar) { - ICEBERG_DCHECK(term() != nullptr, "Bound aggregate should have term unless COUNT(*)"); - term_name = term()->reference()->name(); - } +template +std::string Aggregate::ToString() const { + ICEBERG_DCHECK(IsSupportedOp(op()), "Unexpected aggregate operation"); + ICEBERG_DCHECK(op() == Expression::Operation::kCountStar || term() != nullptr, + "Aggregate term should not be null except for COUNT(*)"); switch (op()) { case Expression::Operation::kCount: - return std::format("count({})", term_name); + return std::format("count({})", term()->ToString()); case Expression::Operation::kCountNull: - return std::format("count_null({})", term_name); + return std::format("count_if({} is null)", term()->ToString()); case Expression::Operation::kCountStar: return "count(*)"; case Expression::Operation::kMax: - return std::format("max({})", term_name); + return std::format("max({})", term()->ToString()); case Expression::Operation::kMin: - return std::format("min({})", term_name); + return std::format("min({})", term()->ToString()); default: - return "Aggregate"; + return std::format("Invalid aggregate: {}", ::iceberg::ToString(op())); } } +// -------------------- CountAggregate -------------------- + +Result CountAggregate::Evaluate(const StructLike& data) const { + return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); +} + +std::unique_ptr CountAggregate::NewAggregator() const { + return std::unique_ptr(new CountAggregator(*this)); +} + CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr term) : CountAggregate(Expression::Operation::kCount, std::move(term)) {} @@ -201,13 +168,9 @@ Result> CountNonNullAggregate::Make( new CountNonNullAggregate(std::move(term))); } -Result CountNonNullAggregate::Evaluate(const StructLike& data) const { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return Literal::Long(literal.IsNull() ? 0 : 1); -} - -std::unique_ptr CountNonNullAggregate::NewAggregator() const { - return std::unique_ptr(new CountNonNullAggregator(term())); +Result CountNonNullAggregate::CountFor(const StructLike& data) const { + return term()->Evaluate(data).transform( + [](const auto& val) { return val.IsNull() ? 0 : 1; }); } CountNullAggregate::CountNullAggregate(std::shared_ptr term) @@ -221,13 +184,9 @@ Result> CountNullAggregate::Make( return std::unique_ptr(new CountNullAggregate(std::move(term))); } -Result CountNullAggregate::Evaluate(const StructLike& data) const { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return Literal::Long(literal.IsNull() ? 1 : 0); -} - -std::unique_ptr CountNullAggregate::NewAggregator() const { - return std::unique_ptr(new CountNullAggregator(term())); +Result CountNullAggregate::CountFor(const StructLike& data) const { + return term()->Evaluate(data).transform( + [](const auto& val) { return val.IsNull() ? 1 : 0; }); } CountStarAggregate::CountStarAggregate() @@ -237,12 +196,8 @@ Result> CountStarAggregate::Make() { return std::unique_ptr(new CountStarAggregate()); } -Result CountStarAggregate::Evaluate(const StructLike& data) const { - return Literal::Long(1); -} - -std::unique_ptr CountStarAggregate::NewAggregator() const { - return std::unique_ptr(new CountStarAggregator()); +Result CountStarAggregate::CountFor(const StructLike& /*data*/) const { + return 1; } MaxAggregate::MaxAggregate(std::shared_ptr term) @@ -253,12 +208,11 @@ std::shared_ptr MaxAggregate::Make(std::shared_ptr term } Result MaxAggregate::Evaluate(const StructLike& data) const { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return literal; + return term()->Evaluate(data); } std::unique_ptr MaxAggregate::NewAggregator() const { - return std::unique_ptr(new MaxAggregator(term())); + return std::unique_ptr(new MaxAggregator(*this)); } MinAggregate::MinAggregate(std::shared_ptr term) @@ -269,12 +223,11 @@ std::shared_ptr MinAggregate::Make(std::shared_ptr term } Result MinAggregate::Evaluate(const StructLike& data) const { - ICEBERG_ASSIGN_OR_RAISE(auto literal, term()->Evaluate(data)); - return literal; + return term()->Evaluate(data); } std::unique_ptr MinAggregate::NewAggregator() const { - return std::unique_ptr(new MinAggregator(term())); + return std::unique_ptr(new MinAggregator(*this)); } // -------------------- Unbound binding -------------------- @@ -292,42 +245,25 @@ Result> UnboundAggregateImpl::Bind( switch (this->op()) { case Expression::Operation::kCountStar: - return CountStarAggregate::Make().transform([](auto aggregate) { - return std::shared_ptr(std::move(aggregate)); - }); + return CountStarAggregate::Make(); case Expression::Operation::kCount: - return CountNonNullAggregate::Make(std::move(bound_term)) - .transform([](auto aggregate) { - return std::shared_ptr(std::move(aggregate)); - }); + return CountNonNullAggregate::Make(std::move(bound_term)); case Expression::Operation::kCountNull: - return CountNullAggregate::Make(std::move(bound_term)) - .transform([](auto aggregate) { - return std::shared_ptr(std::move(aggregate)); - }); + return CountNullAggregate::Make(std::move(bound_term)); case Expression::Operation::kMax: - case Expression::Operation::kMin: { - if (!bound_term) { - return InvalidExpression("Aggregate requires a term"); - } - if (!bound_term->type()->is_primitive()) { - return InvalidExpression("Aggregate requires primitive type, got {}", - bound_term->type()->ToString()); - } - if (this->op() == Expression::Operation::kMax) { - return MaxAggregate::Make(std::move(bound_term)); - } + return MaxAggregate::Make(std::move(bound_term)); + case Expression::Operation::kMin: return MinAggregate::Make(std::move(bound_term)); - } default: - return NotSupported("Unsupported aggregate operation"); + return NotSupported("Unsupported aggregate operation: {}", + ::iceberg::ToString(this->op())); } } template Result>> UnboundAggregateImpl::Make( Expression::Operation op, std::shared_ptr> term) { - if (!IsSupportedOp(op)) { + if (!Aggregate>::IsSupportedOp(op)) { return NotSupported("Unsupported aggregate operation: {}", ::iceberg::ToString(op)); } if (op != Expression::Operation::kCountStar && !term) { @@ -338,31 +274,8 @@ Result>> UnboundAggregateImpl::Make( new UnboundAggregateImpl(op, std::move(term))); } -template -std::string UnboundAggregateImpl::ToString() const { - ICEBERG_DCHECK(UnboundAggregateImpl::IsSupportedOp(this->op()), - "Unexpected aggregate operation"); - ICEBERG_DCHECK( - this->op() == Expression::Operation::kCountStar || this->term() != nullptr, - "Aggregate term should not be null unless COUNT(*)"); - - auto term_str = this->term() ? this->term()->ToString() : std::string{}; - switch (this->op()) { - case Expression::Operation::kCount: - return std::format("count({})", term_str); - case Expression::Operation::kCountNull: - return std::format("count_if({} is null)", term_str); - case Expression::Operation::kCountStar: - return "count(*)"; - case Expression::Operation::kMax: - return std::format("max({})", term_str); - case Expression::Operation::kMin: - return std::format("min({})", term_str); - default: - return "Aggregate"; - } -} - +template class Aggregate>; +template class Aggregate; template class UnboundAggregateImpl; // -------------------- AggregateEvaluator -------------------- @@ -376,9 +289,9 @@ class AggregateEvaluatorImpl : public AggregateEvaluator { std::vector> aggregators) : aggregates_(std::move(aggregates)), aggregators_(std::move(aggregators)) {} - Status Update(const StructLike& row) override { + Status Update(const StructLike& data) override { for (auto& aggregator : aggregators_) { - ICEBERG_RETURN_UNEXPECTED(aggregator->Update(row)); + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(data)); } return {}; } diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h index e5dadcbd8..cde9e4583 100644 --- a/src/iceberg/expression/aggregate.h +++ b/src/iceberg/expression/aggregate.h @@ -23,7 +23,6 @@ /// Aggregate expression definitions. #include -#include #include #include #include @@ -45,10 +44,19 @@ class ICEBERG_EXPORT Aggregate : public virtual Expression { const std::shared_ptr& term() const { return term_; } + std::string ToString() const override; + protected: Aggregate(Expression::Operation op, std::shared_ptr term) : operation_(op), term_(std::move(term)) {} + static constexpr bool IsSupportedOp(Expression::Operation op) { + return op == Expression::Operation::kCount || + op == Expression::Operation::kCountNull || + op == Expression::Operation::kCountStar || op == Expression::Operation::kMax || + op == Expression::Operation::kMin; + } + Expression::Operation operation_; std::shared_ptr term_; }; @@ -79,19 +87,10 @@ class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, Result> Bind(const Schema& schema, bool case_sensitive) const override; - std::string ToString() const override; - private: - static constexpr bool IsSupportedOp(Expression::Operation op) { - return op == Expression::Operation::kCount || - op == Expression::Operation::kCountNull || - op == Expression::Operation::kCountStar || op == Expression::Operation::kMax || - op == Expression::Operation::kMin; - } - UnboundAggregateImpl(Expression::Operation op, std::shared_ptr> term) : BASE(op, std::move(term)) { - ICEBERG_DCHECK(IsSupportedOp(op), "Unexpected aggregate operation"); + ICEBERG_DCHECK(BASE::IsSupportedOp(op), "Unexpected aggregate operation"); ICEBERG_DCHECK(op == Expression::Operation::kCountStar || BASE::term() != nullptr, "Aggregate term cannot be null except for COUNT(*)"); } @@ -103,34 +102,36 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound using Aggregate::op; using Aggregate::term; + /// \brief Base class for aggregators. class Aggregator { public: virtual ~Aggregator() = default; - virtual Status Update(const StructLike& row) = 0; + virtual Status Update(const StructLike& data) = 0; + virtual Status Update(const DataFile& file) { - return NotSupported("Aggregating DataFile not supported"); + return NotImplemented("Update(DataFile) not implemented"); } + + /// \brief Get the result of the aggregation. + /// \return The result of the aggregation. + /// \note It is an undefined behavior to call this method if any previous Update call + /// has returned an error. virtual Literal GetResult() const = 0; }; std::shared_ptr reference() override { - ICEBERG_DCHECK(term_ != nullptr || op() == Expression::Operation::kCountStar, + ICEBERG_DCHECK(term() != nullptr || op() == Expression::Operation::kCountStar, "Bound aggregate term should not be null except for COUNT(*)"); - return term_ ? term_->reference() : nullptr; + return term() ? term()->reference() : nullptr; } - std::string ToString() const override; Result Evaluate(const StructLike& data) const override = 0; bool is_bound_aggregate() const override { return true; } - enum class Kind : int8_t { - kCount = 0, - kValue, - }; - - virtual Kind kind() const = 0; + /// \brief Create a new aggregator for this aggregate. + /// \note The returned aggregator cannot outlive the BoundAggregate that creates it. virtual std::unique_ptr NewAggregator() const = 0; protected: @@ -141,7 +142,12 @@ class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound /// \brief Base class for COUNT aggregates. class ICEBERG_EXPORT CountAggregate : public BoundAggregate { public: - Kind kind() const override { return Kind::kCount; } + Result Evaluate(const StructLike& data) const final; + + std::unique_ptr NewAggregator() const override; + + /// \brief Count for a single row. Subclasses implement this. + virtual Result CountFor(const StructLike& data) const = 0; protected: CountAggregate(Expression::Operation op, std::shared_ptr term) @@ -154,8 +160,7 @@ class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { static Result> Make( std::shared_ptr term); - Result Evaluate(const StructLike& data) const override; - std::unique_ptr NewAggregator() const override; + Result CountFor(const StructLike& data) const override; private: explicit CountNonNullAggregate(std::shared_ptr term); @@ -167,8 +172,7 @@ class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { static Result> Make( std::shared_ptr term); - Result Evaluate(const StructLike& data) const override; - std::unique_ptr NewAggregator() const override; + Result CountFor(const StructLike& data) const override; private: explicit CountNullAggregate(std::shared_ptr term); @@ -179,8 +183,7 @@ class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { public: static Result> Make(); - Result Evaluate(const StructLike& data) const override; - std::unique_ptr NewAggregator() const override; + Result CountFor(const StructLike& data) const override; private: CountStarAggregate(); @@ -191,9 +194,8 @@ class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { public: static std::shared_ptr Make(std::shared_ptr term); - Kind kind() const override { return Kind::kValue; } - Result Evaluate(const StructLike& data) const override; + std::unique_ptr NewAggregator() const override; private: @@ -205,16 +207,15 @@ class ICEBERG_EXPORT MinAggregate : public BoundAggregate { public: static std::shared_ptr Make(std::shared_ptr term); - Kind kind() const override { return Kind::kValue; } - Result Evaluate(const StructLike& data) const override; + std::unique_ptr NewAggregator() const override; private: explicit MinAggregate(std::shared_ptr term); }; -/// \brief Evaluates bound aggregates over StructLike rows. +/// \brief Evaluates bound aggregates over StructLike data. class ICEBERG_EXPORT AggregateEvaluator { public: virtual ~AggregateEvaluator() = default; @@ -231,7 +232,7 @@ class ICEBERG_EXPORT AggregateEvaluator { std::vector> aggregates); /// \brief Update aggregates with a row. - virtual Status Update(const StructLike& row) = 0; + virtual Status Update(const StructLike& data) = 0; /// \brief Final aggregated value. virtual Result> GetResults() const = 0; diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc index adbdc984d..264e606f7 100644 --- a/src/iceberg/test/aggregate_test.cc +++ b/src/iceberg/test/aggregate_test.cc @@ -21,19 +21,18 @@ #include -#include "iceberg/exception.h" #include "iceberg/expression/binder.h" #include "iceberg/expression/expressions.h" #include "iceberg/row/struct_like.h" #include "iceberg/schema.h" #include "iceberg/test/matchers.h" #include "iceberg/type.h" -#include "iceberg/util/macros.h" namespace iceberg { namespace { +/// XXX: `Scalar` carries view semantics, so it is unsafe to use std::string_view variant. class VectorStructLike : public StructLike { public: explicit VectorStructLike(std::vector fields) : fields_(std::move(fields)) {}