diff --git a/src/iceberg/CMakeLists.txt b/src/iceberg/CMakeLists.txt index 305e315e9..9c3e9c2a3 100644 --- a/src/iceberg/CMakeLists.txt +++ b/src/iceberg/CMakeLists.txt @@ -20,6 +20,7 @@ set(ICEBERG_INCLUDES "$" set(ICEBERG_SOURCES arrow_c_data_guard_internal.cc catalog/memory/in_memory_catalog.cc + expression/aggregate.cc expression/binder.cc expression/evaluator.cc expression/expression.cc diff --git a/src/iceberg/expression/aggregate.cc b/src/iceberg/expression/aggregate.cc new file mode 100644 index 000000000..a9c1a60bf --- /dev/null +++ b/src/iceberg/expression/aggregate.cc @@ -0,0 +1,348 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/aggregate.h" + +#include +#include +#include + +#include "iceberg/expression/literal.h" +#include "iceberg/row/struct_like.h" +#include "iceberg/type.h" +#include "iceberg/util/checked_cast.h" +#include "iceberg/util/macros.h" + +namespace iceberg { + +namespace { + +std::shared_ptr GetPrimitiveType(const BoundTerm& term) { + ICEBERG_DCHECK(term.type()->is_primitive(), "Value aggregate term should be primitive"); + return internal::checked_pointer_cast(term.type()); +} + +class CountAggregator : public BoundAggregate::Aggregator { + public: + explicit CountAggregator(const CountAggregate& aggregate) : aggregate_(aggregate) {} + + Status Update(const StructLike& row) override { + ICEBERG_ASSIGN_OR_RAISE(auto count, aggregate_.CountFor(row)); + count_ += count; + return {}; + } + + Literal GetResult() const override { return Literal::Long(count_); } + + private: + const CountAggregate& aggregate_; + int64_t count_ = 0; +}; + +class MaxAggregator : public BoundAggregate::Aggregator { + public: + explicit MaxAggregator(const MaxAggregate& aggregate) + : aggregate_(aggregate), + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} + + Status Update(const StructLike& data) override { + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::greater) { + current_ = std::move(value); + } + + return {}; + } + + Literal GetResult() const override { return current_; } + + private: + const MaxAggregate& aggregate_; + Literal current_; +}; + +class MinAggregator : public BoundAggregate::Aggregator { + public: + explicit MinAggregator(const MinAggregate& aggregate) + : aggregate_(aggregate), + current_(Literal::Null(GetPrimitiveType(*aggregate_.term()))) {} + + Status Update(const StructLike& data) override { + ICEBERG_ASSIGN_OR_RAISE(auto value, aggregate_.Evaluate(data)); + if (value.IsNull()) { + return {}; + } + if (current_.IsNull()) { + current_ = std::move(value); + return {}; + } + + if (auto ordering = value <=> current_; + ordering == std::partial_ordering::unordered) { + return InvalidArgument("Cannot compare literal {} with current value {}", + value.ToString(), current_.ToString()); + } else if (ordering == std::partial_ordering::less) { + current_ = std::move(value); + } + return {}; + } + + Literal GetResult() const override { return current_; } + + private: + const MinAggregate& aggregate_; + Literal current_; +}; + +} // namespace + +template +std::string Aggregate::ToString() const { + ICEBERG_DCHECK(IsSupportedOp(op()), "Unexpected aggregate operation"); + ICEBERG_DCHECK(op() == Expression::Operation::kCountStar || term() != nullptr, + "Aggregate term should not be null except for COUNT(*)"); + + switch (op()) { + case Expression::Operation::kCount: + return std::format("count({})", term()->ToString()); + case Expression::Operation::kCountNull: + return std::format("count_if({} is null)", term()->ToString()); + case Expression::Operation::kCountStar: + return "count(*)"; + case Expression::Operation::kMax: + return std::format("max({})", term()->ToString()); + case Expression::Operation::kMin: + return std::format("min({})", term()->ToString()); + default: + return std::format("Invalid aggregate: {}", ::iceberg::ToString(op())); + } +} + +// -------------------- CountAggregate -------------------- + +Result CountAggregate::Evaluate(const StructLike& data) const { + return CountFor(data).transform([](int64_t count) { return Literal::Long(count); }); +} + +std::unique_ptr CountAggregate::NewAggregator() const { + return std::unique_ptr(new CountAggregator(*this)); +} + +CountNonNullAggregate::CountNonNullAggregate(std::shared_ptr term) + : CountAggregate(Expression::Operation::kCount, std::move(term)) {} + +Result> CountNonNullAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound count aggregate requires non-null term"); + } + return std::unique_ptr( + new CountNonNullAggregate(std::move(term))); +} + +Result CountNonNullAggregate::CountFor(const StructLike& data) const { + return term()->Evaluate(data).transform( + [](const auto& val) { return val.IsNull() ? 0 : 1; }); +} + +CountNullAggregate::CountNullAggregate(std::shared_ptr term) + : CountAggregate(Expression::Operation::kCountNull, std::move(term)) {} + +Result> CountNullAggregate::Make( + std::shared_ptr term) { + if (!term) { + return InvalidExpression("Bound count aggregate requires non-null term"); + } + return std::unique_ptr(new CountNullAggregate(std::move(term))); +} + +Result CountNullAggregate::CountFor(const StructLike& data) const { + return term()->Evaluate(data).transform( + [](const auto& val) { return val.IsNull() ? 1 : 0; }); +} + +CountStarAggregate::CountStarAggregate() + : CountAggregate(Expression::Operation::kCountStar, nullptr) {} + +Result> CountStarAggregate::Make() { + return std::unique_ptr(new CountStarAggregate()); +} + +Result CountStarAggregate::CountFor(const StructLike& /*data*/) const { + return 1; +} + +MaxAggregate::MaxAggregate(std::shared_ptr term) + : BoundAggregate(Expression::Operation::kMax, std::move(term)) {} + +std::shared_ptr MaxAggregate::Make(std::shared_ptr term) { + return std::shared_ptr(new MaxAggregate(std::move(term))); +} + +Result MaxAggregate::Evaluate(const StructLike& data) const { + return term()->Evaluate(data); +} + +std::unique_ptr MaxAggregate::NewAggregator() const { + return std::unique_ptr(new MaxAggregator(*this)); +} + +MinAggregate::MinAggregate(std::shared_ptr term) + : BoundAggregate(Expression::Operation::kMin, std::move(term)) {} + +std::shared_ptr MinAggregate::Make(std::shared_ptr term) { + return std::shared_ptr(new MinAggregate(std::move(term))); +} + +Result MinAggregate::Evaluate(const StructLike& data) const { + return term()->Evaluate(data); +} + +std::unique_ptr MinAggregate::NewAggregator() const { + return std::unique_ptr(new MinAggregator(*this)); +} + +// -------------------- Unbound binding -------------------- + +template +Result> UnboundAggregateImpl::Bind( + const Schema& schema, bool case_sensitive) const { + ICEBERG_DCHECK(UnboundAggregateImpl::IsSupportedOp(this->op()), + "Unexpected aggregate operation"); + + std::shared_ptr bound_term; + if (this->term()) { + ICEBERG_ASSIGN_OR_RAISE(bound_term, this->term()->Bind(schema, case_sensitive)); + } + + switch (this->op()) { + case Expression::Operation::kCountStar: + return CountStarAggregate::Make(); + case Expression::Operation::kCount: + return CountNonNullAggregate::Make(std::move(bound_term)); + case Expression::Operation::kCountNull: + return CountNullAggregate::Make(std::move(bound_term)); + case Expression::Operation::kMax: + return MaxAggregate::Make(std::move(bound_term)); + case Expression::Operation::kMin: + return MinAggregate::Make(std::move(bound_term)); + default: + return NotSupported("Unsupported aggregate operation: {}", + ::iceberg::ToString(this->op())); + } +} + +template +Result>> UnboundAggregateImpl::Make( + Expression::Operation op, std::shared_ptr> term) { + if (!Aggregate>::IsSupportedOp(op)) { + return NotSupported("Unsupported aggregate operation: {}", ::iceberg::ToString(op)); + } + if (op != Expression::Operation::kCountStar && !term) { + return InvalidExpression("Aggregate term cannot be null unless COUNT(*)"); + } + + return std::shared_ptr>( + new UnboundAggregateImpl(op, std::move(term))); +} + +template class Aggregate>; +template class Aggregate; +template class UnboundAggregateImpl; + +// -------------------- AggregateEvaluator -------------------- + +namespace { + +class AggregateEvaluatorImpl : public AggregateEvaluator { + public: + AggregateEvaluatorImpl( + std::vector> aggregates, + std::vector> aggregators) + : aggregates_(std::move(aggregates)), aggregators_(std::move(aggregators)) {} + + Status Update(const StructLike& data) override { + for (auto& aggregator : aggregators_) { + ICEBERG_RETURN_UNEXPECTED(aggregator->Update(data)); + } + return {}; + } + + Result> GetResults() const override { + results_.clear(); + results_.reserve(aggregates_.size()); + for (const auto& aggregator : aggregators_) { + results_.emplace_back(aggregator->GetResult()); + } + return std::span(results_); + } + + Result GetResult() const override { + if (aggregates_.size() != 1) { + return InvalidArgument( + "GetResult() is only valid when evaluating a single aggregate"); + } + + ICEBERG_ASSIGN_OR_RAISE(auto all, GetResults()); + return all.front(); + } + + private: + std::vector> aggregates_; + std::vector> aggregators_; + mutable std::vector results_; +}; + +} // namespace + +Result> AggregateEvaluator::Make( + std::shared_ptr aggregate) { + std::vector> aggs; + aggs.push_back(std::move(aggregate)); + return Make(std::move(aggs)); +} + +Result> AggregateEvaluator::Make( + std::vector> aggregates) { + if (aggregates.empty()) { + return InvalidArgument("AggregateEvaluator requires at least one aggregate"); + } + std::vector> aggregators; + aggregators.reserve(aggregates.size()); + for (const auto& agg : aggregates) { + aggregators.push_back(agg->NewAggregator()); + } + + return std::unique_ptr( + new AggregateEvaluatorImpl(std::move(aggregates), std::move(aggregators))); +} + +} // namespace iceberg diff --git a/src/iceberg/expression/aggregate.h b/src/iceberg/expression/aggregate.h new file mode 100644 index 000000000..cde9e4583 --- /dev/null +++ b/src/iceberg/expression/aggregate.h @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#pragma once + +/// \file iceberg/expression/aggregate.h +/// Aggregate expression definitions. + +#include +#include +#include +#include + +#include "iceberg/expression/expression.h" +#include "iceberg/expression/term.h" +#include "iceberg/result.h" +#include "iceberg/type_fwd.h" + +namespace iceberg { + +/// \brief Base aggregate holding an operation and a term. +template +class ICEBERG_EXPORT Aggregate : public virtual Expression { + public: + ~Aggregate() override = default; + + Expression::Operation op() const override { return operation_; } + + const std::shared_ptr& term() const { return term_; } + + std::string ToString() const override; + + protected: + Aggregate(Expression::Operation op, std::shared_ptr term) + : operation_(op), term_(std::move(term)) {} + + static constexpr bool IsSupportedOp(Expression::Operation op) { + return op == Expression::Operation::kCount || + op == Expression::Operation::kCountNull || + op == Expression::Operation::kCountStar || op == Expression::Operation::kMax || + op == Expression::Operation::kMin; + } + + Expression::Operation operation_; + std::shared_ptr term_; +}; + +/// \brief Base class for unbound aggregates. +class ICEBERG_EXPORT UnboundAggregate : public virtual Expression, + public Unbound { + public: + ~UnboundAggregate() override = default; + + bool is_unbound_aggregate() const override { return true; } +}; + +/// \brief Template for unbound aggregates that carry a term and operation. +template +class ICEBERG_EXPORT UnboundAggregateImpl : public UnboundAggregate, + public Aggregate> { + using BASE = Aggregate>; + + public: + static Result>> Make( + Expression::Operation op, std::shared_ptr> term); + + std::shared_ptr reference() override { + return BASE::term() ? BASE::term()->reference() : nullptr; + } + + Result> Bind(const Schema& schema, + bool case_sensitive) const override; + + private: + UnboundAggregateImpl(Expression::Operation op, std::shared_ptr> term) + : BASE(op, std::move(term)) { + ICEBERG_DCHECK(BASE::IsSupportedOp(op), "Unexpected aggregate operation"); + ICEBERG_DCHECK(op == Expression::Operation::kCountStar || BASE::term() != nullptr, + "Aggregate term cannot be null except for COUNT(*)"); + } +}; + +/// \brief Base class for bound aggregates. +class ICEBERG_EXPORT BoundAggregate : public Aggregate, public Bound { + public: + using Aggregate::op; + using Aggregate::term; + + /// \brief Base class for aggregators. + class Aggregator { + public: + virtual ~Aggregator() = default; + + virtual Status Update(const StructLike& data) = 0; + + virtual Status Update(const DataFile& file) { + return NotImplemented("Update(DataFile) not implemented"); + } + + /// \brief Get the result of the aggregation. + /// \return The result of the aggregation. + /// \note It is an undefined behavior to call this method if any previous Update call + /// has returned an error. + virtual Literal GetResult() const = 0; + }; + + std::shared_ptr reference() override { + ICEBERG_DCHECK(term() != nullptr || op() == Expression::Operation::kCountStar, + "Bound aggregate term should not be null except for COUNT(*)"); + return term() ? term()->reference() : nullptr; + } + + Result Evaluate(const StructLike& data) const override = 0; + + bool is_bound_aggregate() const override { return true; } + + /// \brief Create a new aggregator for this aggregate. + /// \note The returned aggregator cannot outlive the BoundAggregate that creates it. + virtual std::unique_ptr NewAggregator() const = 0; + + protected: + BoundAggregate(Expression::Operation op, std::shared_ptr term) + : Aggregate(op, std::move(term)) {} +}; + +/// \brief Base class for COUNT aggregates. +class ICEBERG_EXPORT CountAggregate : public BoundAggregate { + public: + Result Evaluate(const StructLike& data) const final; + + std::unique_ptr NewAggregator() const override; + + /// \brief Count for a single row. Subclasses implement this. + virtual Result CountFor(const StructLike& data) const = 0; + + protected: + CountAggregate(Expression::Operation op, std::shared_ptr term) + : BoundAggregate(op, std::move(term)) {} +}; + +/// \brief COUNT(term) aggregate. +class ICEBERG_EXPORT CountNonNullAggregate : public CountAggregate { + public: + static Result> Make( + std::shared_ptr term); + + Result CountFor(const StructLike& data) const override; + + private: + explicit CountNonNullAggregate(std::shared_ptr term); +}; + +/// \brief COUNT_NULL(term) aggregate. +class ICEBERG_EXPORT CountNullAggregate : public CountAggregate { + public: + static Result> Make( + std::shared_ptr term); + + Result CountFor(const StructLike& data) const override; + + private: + explicit CountNullAggregate(std::shared_ptr term); +}; + +/// \brief COUNT(*) aggregate. +class ICEBERG_EXPORT CountStarAggregate : public CountAggregate { + public: + static Result> Make(); + + Result CountFor(const StructLike& data) const override; + + private: + CountStarAggregate(); +}; + +/// \brief Bound MAX aggregate. +class ICEBERG_EXPORT MaxAggregate : public BoundAggregate { + public: + static std::shared_ptr Make(std::shared_ptr term); + + Result Evaluate(const StructLike& data) const override; + + std::unique_ptr NewAggregator() const override; + + private: + explicit MaxAggregate(std::shared_ptr term); +}; + +/// \brief Bound MIN aggregate. +class ICEBERG_EXPORT MinAggregate : public BoundAggregate { + public: + static std::shared_ptr Make(std::shared_ptr term); + + Result Evaluate(const StructLike& data) const override; + + std::unique_ptr NewAggregator() const override; + + private: + explicit MinAggregate(std::shared_ptr term); +}; + +/// \brief Evaluates bound aggregates over StructLike data. +class ICEBERG_EXPORT AggregateEvaluator { + public: + virtual ~AggregateEvaluator() = default; + + /// \brief Create an evaluator for a single bound aggregate. + /// \param aggregate The bound aggregate to evaluate across rows. + static Result> Make( + std::shared_ptr aggregate); + + /// \brief Create an evaluator for multiple bound aggregates. + /// \param aggregates Aggregates to evaluate in one pass; order is preserved in + /// Results(). + static Result> Make( + std::vector> aggregates); + + /// \brief Update aggregates with a row. + virtual Status Update(const StructLike& data) = 0; + + /// \brief Final aggregated value. + virtual Result> GetResults() const = 0; + + /// \brief Convenience accessor when only one aggregate is evaluated. + virtual Result GetResult() const = 0; +}; + +} // namespace iceberg diff --git a/src/iceberg/expression/binder.cc b/src/iceberg/expression/binder.cc index 62c735308..43c3ebcdf 100644 --- a/src/iceberg/expression/binder.cc +++ b/src/iceberg/expression/binder.cc @@ -64,6 +64,18 @@ Result> Binder::Predicate( return InvalidExpression("Found already bound predicate: {}", pred->ToString()); } +Result> Binder::Aggregate( + const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + return InvalidExpression("Found already bound aggregate: {}", aggregate->ToString()); +} + +Result> Binder::Aggregate( + const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Aggregate cannot be null"); + return aggregate->Bind(schema_, case_sensitive_); +} + Result IsBoundVisitor::IsBound(const std::shared_ptr& expr) { ICEBERG_DCHECK(expr != nullptr, "Expression cannot be null"); IsBoundVisitor visitor; @@ -92,4 +104,13 @@ Result IsBoundVisitor::Predicate(const std::shared_ptr& return false; } +Result IsBoundVisitor::Aggregate(const std::shared_ptr& aggregate) { + return true; +} + +Result IsBoundVisitor::Aggregate( + const std::shared_ptr& aggregate) { + return false; +} + } // namespace iceberg diff --git a/src/iceberg/expression/binder.h b/src/iceberg/expression/binder.h index fcef0731d..a78b7a4bb 100644 --- a/src/iceberg/expression/binder.h +++ b/src/iceberg/expression/binder.h @@ -48,6 +48,10 @@ class ICEBERG_EXPORT Binder : public ExpressionVisitor& pred) override; Result> Predicate( const std::shared_ptr& pred) override; + Result> Aggregate( + const std::shared_ptr& aggregate) override; + Result> Aggregate( + const std::shared_ptr& aggregate) override; private: const Schema& schema_; @@ -65,6 +69,8 @@ class ICEBERG_EXPORT IsBoundVisitor : public ExpressionVisitor { Result Or(bool left_result, bool right_result) override; Result Predicate(const std::shared_ptr& pred) override; Result Predicate(const std::shared_ptr& pred) override; + Result Aggregate(const std::shared_ptr& aggregate) override; + Result Aggregate(const std::shared_ptr& aggregate) override; }; // TODO(gangwu): add the Java parity `ReferenceVisitor` diff --git a/src/iceberg/expression/expression.cc b/src/iceberg/expression/expression.cc index b40082bb4..9ee5dfc17 100644 --- a/src/iceberg/expression/expression.cc +++ b/src/iceberg/expression/expression.cc @@ -191,6 +191,8 @@ std::string_view ToString(Expression::Operation op) { return "NOT_STARTS_WITH"; case Expression::Operation::kCount: return "COUNT"; + case Expression::Operation::kCountNull: + return "COUNT_NULL"; case Expression::Operation::kNot: return "NOT"; case Expression::Operation::kCountStar: @@ -246,6 +248,7 @@ Result Negate(Expression::Operation op) { case Expression::Operation::kMax: case Expression::Operation::kMin: case Expression::Operation::kCount: + case Expression::Operation::kCountNull: return InvalidExpression("No negation for operation: {}", op); } std::unreachable(); diff --git a/src/iceberg/expression/expression.h b/src/iceberg/expression/expression.h index e6e19ea83..58b3fcc1a 100644 --- a/src/iceberg/expression/expression.h +++ b/src/iceberg/expression/expression.h @@ -57,6 +57,7 @@ class ICEBERG_EXPORT Expression : public util::Formattable { kStartsWith, kNotStartsWith, kCount, + kCountNull, kCountStar, kMax, kMin diff --git a/src/iceberg/expression/expression_visitor.h b/src/iceberg/expression/expression_visitor.h index aeafa9298..ed1e75d8e 100644 --- a/src/iceberg/expression/expression_visitor.h +++ b/src/iceberg/expression/expression_visitor.h @@ -24,7 +24,9 @@ #include #include +#include +#include "iceberg/expression/aggregate.h" #include "iceberg/expression/expression.h" #include "iceberg/expression/literal.h" #include "iceberg/expression/predicate.h" @@ -77,6 +79,22 @@ class ICEBERG_EXPORT ExpressionVisitor { /// \brief Visit an unbound predicate. /// \param pred The unbound predicate to visit virtual Result Predicate(const std::shared_ptr& pred) = 0; + + /// \brief Visit a bound aggregate. + /// \param aggregate The bound aggregate to visit. + virtual Result Aggregate(const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Bound aggregate cannot be null"); + return NotSupported("Visitor {} does not support bound aggregate", + typeid(*this).name()); + } + + /// \brief Visit an unbound aggregate. + /// \param aggregate The unbound aggregate to visit. + virtual Result Aggregate(const std::shared_ptr& aggregate) { + ICEBERG_DCHECK(aggregate != nullptr, "Unbound aggregate cannot be null"); + return NotSupported("Visitor {} does not support unbound aggregate", + typeid(*this).name()); + } }; /// \brief Visitor for bound expressions. @@ -275,7 +293,13 @@ Result Visit(const std::shared_ptr& expr, V& visitor) { return visitor.Predicate(std::dynamic_pointer_cast(expr)); } - // TODO(gangwu): handle aggregate expression + if (expr->is_bound_aggregate()) { + return visitor.Aggregate(std::dynamic_pointer_cast(expr)); + } + + if (expr->is_unbound_aggregate()) { + return visitor.Aggregate(std::dynamic_pointer_cast(expr)); + } switch (expr->op()) { case Expression::Operation::kTrue: diff --git a/src/iceberg/expression/expressions.cc b/src/iceberg/expression/expressions.cc index b3e88ff18..786cc0ab7 100644 --- a/src/iceberg/expression/expressions.cc +++ b/src/iceberg/expression/expressions.cc @@ -20,6 +20,7 @@ #include "iceberg/expression/expressions.h" #include "iceberg/exception.h" +#include "iceberg/expression/aggregate.h" #include "iceberg/transform.h" #include "iceberg/type.h" #include "iceberg/util/macros.h" @@ -81,6 +82,73 @@ std::shared_ptr Expressions::Transform( return unbound_transform; } +// Aggregates + +std::shared_ptr> Expressions::Count( + std::string name) { + return Count(Ref(std::move(name))); +} + +std::shared_ptr> Expressions::Count( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCount, std::move(expr))); + return agg; +} + +std::shared_ptr> Expressions::CountNull( + std::string name) { + return CountNull(Ref(std::move(name))); +} + +std::shared_ptr> Expressions::CountNull( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, + UnboundAggregateImpl::Make( + Expression::Operation::kCountNull, std::move(expr))); + return agg; +} + +std::shared_ptr> Expressions::CountNotNull( + std::string name) { + return CountNotNull(Ref(std::move(name))); +} + +std::shared_ptr> Expressions::CountNotNull( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCount, std::move(expr))); + return agg; +} + +std::shared_ptr> Expressions::CountStar() { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kCountStar, nullptr)); + return agg; +} + +std::shared_ptr> Expressions::Max(std::string name) { + return Max(Ref(std::move(name))); +} + +std::shared_ptr> Expressions::Max( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMax, std::move(expr))); + return agg; +} + +std::shared_ptr> Expressions::Min(std::string name) { + return Min(Ref(std::move(name))); +} + +std::shared_ptr> Expressions::Min( + std::shared_ptr> expr) { + ICEBERG_ASSIGN_OR_THROW(auto agg, UnboundAggregateImpl::Make( + Expression::Operation::kMin, std::move(expr))); + return agg; +} + // Template implementations for unary predicates std::shared_ptr> Expressions::IsNull( diff --git a/src/iceberg/expression/expressions.h b/src/iceberg/expression/expressions.h index 7331982dd..cb1d6df7e 100644 --- a/src/iceberg/expression/expressions.h +++ b/src/iceberg/expression/expressions.h @@ -27,6 +27,8 @@ #include #include +#include "iceberg/exception.h" +#include "iceberg/expression/aggregate.h" #include "iceberg/expression/literal.h" #include "iceberg/expression/predicate.h" #include "iceberg/expression/term.h" @@ -100,6 +102,48 @@ class ICEBERG_EXPORT Expressions { static std::shared_ptr Transform( std::string name, std::shared_ptr transform); + // Aggregates + + /// \brief Create a COUNT aggregate for a field name. + static std::shared_ptr> Count(std::string name); + + /// \brief Create a COUNT aggregate for an unbound term. + static std::shared_ptr> Count( + std::shared_ptr> expr); + + /// \brief Create a COUNT_NULL aggregate for a field name. + static std::shared_ptr> CountNull( + std::string name); + + /// \brief Create a COUNT_NULL aggregate for an unbound term. + static std::shared_ptr> CountNull( + std::shared_ptr> expr); + + /// \brief Create a COUNT_NOT_NULL aggregate for a field name. + static std::shared_ptr> CountNotNull( + std::string name); + + /// \brief Create a COUNT_NOT_NULL aggregate for an unbound term. + static std::shared_ptr> CountNotNull( + std::shared_ptr> expr); + + /// \brief Create a COUNT(*) aggregate. + static std::shared_ptr> CountStar(); + + /// \brief Create a MAX aggregate for a field name. + static std::shared_ptr> Max(std::string name); + + /// \brief Create a MAX aggregate for an unbound term. + static std::shared_ptr> Max( + std::shared_ptr> expr); + + /// \brief Create a MIN aggregate for a field name. + static std::shared_ptr> Min(std::string name); + + /// \brief Create a MIN aggregate for an unbound term. + static std::shared_ptr> Min( + std::shared_ptr> expr); + // Unary predicates /// \brief Create an IS NULL predicate for a field name. diff --git a/src/iceberg/expression/predicate.h b/src/iceberg/expression/predicate.h index dd837f286..b7ed21ff8 100644 --- a/src/iceberg/expression/predicate.h +++ b/src/iceberg/expression/predicate.h @@ -22,7 +22,6 @@ /// \file iceberg/expression/predicate.h /// Predicate interface for boolean expressions that test terms. -#include #include #include "iceberg/expression/expression.h" @@ -31,9 +30,6 @@ namespace iceberg { -template -concept TermType = std::derived_from; - /// \brief A predicate is a boolean expression that tests a term against some criteria. /// /// \tparam TermType The type of the term being tested diff --git a/src/iceberg/expression/term.h b/src/iceberg/expression/term.h index e2a378feb..c19b81324 100644 --- a/src/iceberg/expression/term.h +++ b/src/iceberg/expression/term.h @@ -22,6 +22,7 @@ /// \file iceberg/expression/term.h /// Term interface for Iceberg expressions - represents values that can be evaluated. +#include #include #include #include @@ -41,6 +42,9 @@ class ICEBERG_EXPORT Term : public util::Formattable { virtual Kind kind() const = 0; }; +template +concept TermType = std::derived_from; + /// \brief Interface for unbound expressions that need schema binding. /// /// Unbound expressions contain string-based references that must be resolved diff --git a/src/iceberg/meson.build b/src/iceberg/meson.build index 2df94cfc1..15905107b 100644 --- a/src/iceberg/meson.build +++ b/src/iceberg/meson.build @@ -42,6 +42,7 @@ iceberg_include_dir = include_directories('..') iceberg_sources = files( 'arrow_c_data_guard_internal.cc', 'catalog/memory/in_memory_catalog.cc', + 'expression/aggregate.cc', 'expression/binder.cc', 'expression/evaluator.cc', 'expression/expression.cc', diff --git a/src/iceberg/test/CMakeLists.txt b/src/iceberg/test/CMakeLists.txt index c36d33da6..41b22507d 100644 --- a/src/iceberg/test/CMakeLists.txt +++ b/src/iceberg/test/CMakeLists.txt @@ -90,6 +90,7 @@ add_iceberg_test(table_test add_iceberg_test(expression_test SOURCES + aggregate_test.cc expression_test.cc expression_visitor_test.cc literal_test.cc diff --git a/src/iceberg/test/aggregate_test.cc b/src/iceberg/test/aggregate_test.cc new file mode 100644 index 000000000..264e606f7 --- /dev/null +++ b/src/iceberg/test/aggregate_test.cc @@ -0,0 +1,239 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "iceberg/expression/aggregate.h" + +#include + +#include "iceberg/expression/binder.h" +#include "iceberg/expression/expressions.h" +#include "iceberg/row/struct_like.h" +#include "iceberg/schema.h" +#include "iceberg/test/matchers.h" +#include "iceberg/type.h" + +namespace iceberg { + +namespace { + +/// XXX: `Scalar` carries view semantics, so it is unsafe to use std::string_view variant. +class VectorStructLike : public StructLike { + public: + explicit VectorStructLike(std::vector fields) : fields_(std::move(fields)) {} + + Result GetField(size_t pos) const override { + if (pos >= fields_.size()) { + return InvalidArgument("Position {} out of range", pos); + } + return fields_[pos]; + } + + size_t num_fields() const override { return fields_.size(); } + + private: + std::vector fields_; +}; + +std::shared_ptr BindAggregate(const Schema& schema, + const std::shared_ptr& expr) { + auto result = Binder::Bind(schema, expr, /*case_sensitive=*/true); + EXPECT_TRUE(result.has_value()) + << "Failed to bind aggregate: " << result.error().message; + auto bound = std::dynamic_pointer_cast(std::move(result).value()); + EXPECT_NE(bound, nullptr); + return bound; +} + +} // namespace + +TEST(AggregateTest, CountVariants) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_expr = Expressions::Count("id"); + auto count_bound = BindAggregate(schema, count_expr); + auto count_evaluator = AggregateEvaluator::Make(count_bound).value(); + + auto count_null_expr = Expressions::CountNull("value"); + auto count_null_bound = BindAggregate(schema, count_null_expr); + auto count_null_evaluator = AggregateEvaluator::Make(count_null_bound).value(); + + auto count_star_expr = Expressions::CountStar(); + auto count_star_bound = BindAggregate(schema, count_star_expr); + auto count_star_evaluator = AggregateEvaluator::Make(count_star_bound).value(); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(count_evaluator->Update(row).has_value()); + ASSERT_TRUE(count_null_evaluator->Update(row).has_value()); + ASSERT_TRUE(count_star_evaluator->Update(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto count_result, count_evaluator->GetResult()); + EXPECT_EQ(std::get(count_result.value()), 2); + + ICEBERG_UNWRAP_OR_FAIL(auto count_null_result, count_null_evaluator->GetResult()); + EXPECT_EQ(std::get(count_null_result.value()), 1); + + ICEBERG_UNWRAP_OR_FAIL(auto count_star_result, count_star_evaluator->GetResult()); + EXPECT_EQ(std::get(count_star_result.value()), 3); +} + +TEST(AggregateTest, MaxMinAggregates) { + Schema schema({SchemaField::MakeOptional(1, "value", int32())}); + + auto max_expr = Expressions::Max("value"); + auto min_expr = Expressions::Min("value"); + + auto max_bound = BindAggregate(schema, max_expr); + auto min_bound = BindAggregate(schema, min_expr); + + auto max_eval = AggregateEvaluator::Make(max_bound).value(); + auto min_eval = AggregateEvaluator::Make(min_bound).value(); + + std::vector rows{VectorStructLike({Scalar{int32_t{5}}}), + VectorStructLike({Scalar{std::monostate{}}}), + VectorStructLike({Scalar{int32_t{2}}}), + VectorStructLike({Scalar{int32_t{12}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(max_eval->Update(row).has_value()); + ASSERT_TRUE(min_eval->Update(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto max_result, max_eval->GetResult()); + EXPECT_EQ(std::get(max_result.value()), 12); + + ICEBERG_UNWRAP_OR_FAIL(auto min_result, min_eval->GetResult()); + EXPECT_EQ(std::get(min_result.value()), 2); +} + +TEST(AggregateTest, UnboundAggregateCreationAndBinding) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count = Expressions::Count("id"); + auto count_null = Expressions::CountNull("id"); + auto count_star = Expressions::CountStar(); + auto max = Expressions::Max("value"); + auto min = Expressions::Min("value"); + + EXPECT_EQ(count->ToString(), "count(ref(name=\"id\"))"); + EXPECT_EQ(count_null->ToString(), "count_if(ref(name=\"id\") is null)"); + EXPECT_EQ(count_star->ToString(), "count(*)"); + EXPECT_EQ(max->ToString(), "max(ref(name=\"value\"))"); + EXPECT_EQ(min->ToString(), "min(ref(name=\"value\"))"); + + // Bind succeeds for existing columns + EXPECT_TRUE(Binder::Bind(schema, count, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, count_null, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, count_star, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, max, /*case_sensitive=*/true).has_value()); + EXPECT_TRUE(Binder::Bind(schema, min, /*case_sensitive=*/true).has_value()); + + // Binding fails when the reference is missing + auto missing_count = Expressions::Count("missing"); + auto missing_bind = Binder::Bind(schema, missing_count, /*case_sensitive=*/true); + EXPECT_THAT(missing_bind, IsError(ErrorKind::kInvalidExpression)); + + // Creating a value aggregate with null term should fail + auto invalid_unbound = + UnboundAggregateImpl::Make(Expression::Operation::kMax, nullptr); + EXPECT_THAT(invalid_unbound, IsError(ErrorKind::kInvalidExpression)); +} + +TEST(AggregateTest, BoundAggregateEvaluateDirectly) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_bound = BindAggregate(schema, Expressions::Count("id")); + auto count_null_bound = BindAggregate(schema, Expressions::CountNull("value")); + auto max_bound = BindAggregate(schema, Expressions::Max("value")); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}}), + VectorStructLike({Scalar{int32_t{3}}, Scalar{int32_t{2}}})}; + + int64_t count_sum = 0; + int64_t count_null_sum = 0; + int32_t max_val = std::numeric_limits::min(); + for (const auto& row : rows) { + ICEBERG_UNWRAP_OR_FAIL(auto c, count_bound->Evaluate(row)); + count_sum += std::get(c.value()); + + ICEBERG_UNWRAP_OR_FAIL(auto cn, count_null_bound->Evaluate(row)); + count_null_sum += std::get(cn.value()); + + ICEBERG_UNWRAP_OR_FAIL(auto mv, max_bound->Evaluate(row)); + if (!mv.IsNull()) { + max_val = std::max(max_val, std::get(mv.value())); + } + } + + EXPECT_EQ(count_sum, 3); + EXPECT_EQ(count_null_sum, 1); + EXPECT_EQ(max_val, 30); +} + +TEST(AggregateTest, MultipleAggregatesInEvaluator) { + Schema schema({SchemaField::MakeOptional(1, "id", int32()), + SchemaField::MakeOptional(2, "value", int32())}); + + auto count_expr = Expressions::Count("id"); + auto max_expr = Expressions::Max("value"); + auto min_expr = Expressions::Min("value"); + auto count_null_expr = Expressions::CountNull("value"); + auto count_star_expr = Expressions::CountStar(); + + auto count_bound = BindAggregate(schema, count_expr); + auto max_bound = BindAggregate(schema, max_expr); + auto min_bound = BindAggregate(schema, min_expr); + auto count_null_bound = BindAggregate(schema, count_null_expr); + auto count_star_bound = BindAggregate(schema, count_star_expr); + + std::vector> aggregates{ + count_bound, max_bound, min_bound, count_null_bound, count_star_bound}; + ICEBERG_UNWRAP_OR_FAIL(auto evaluator, AggregateEvaluator::Make(aggregates)); + + std::vector rows{ + VectorStructLike({Scalar{int32_t{1}}, Scalar{int32_t{10}}}), + VectorStructLike({Scalar{int32_t{2}}, Scalar{std::monostate{}}}), + VectorStructLike({Scalar{std::monostate{}}, Scalar{int32_t{30}}}), + VectorStructLike({Scalar{int32_t{3}}, Scalar{int32_t{2}}})}; + + for (const auto& row : rows) { + ASSERT_TRUE(evaluator->Update(row).has_value()); + } + + ICEBERG_UNWRAP_OR_FAIL(auto results, evaluator->GetResults()); + ASSERT_EQ(results.size(), 5); + EXPECT_EQ(std::get(results[0].value()), 3); // count + EXPECT_EQ(std::get(results[1].value()), 30); // max + EXPECT_EQ(std::get(results[2].value()), 2); // min + EXPECT_EQ(std::get(results[3].value()), 1); // count_null + EXPECT_EQ(std::get(results[4].value()), 4); // count_star +} + +} // namespace iceberg diff --git a/src/iceberg/test/meson.build b/src/iceberg/test/meson.build index 72b09a9ec..00cd649e5 100644 --- a/src/iceberg/test/meson.build +++ b/src/iceberg/test/meson.build @@ -56,6 +56,7 @@ iceberg_tests = { }, 'expression_test': { 'sources': files( + 'aggregate_test.cc', 'expression_test.cc', 'expression_visitor_test.cc', 'literal_test.cc',