diff --git a/cpp/src/arrow/compute/exec/expression.cc b/cpp/src/arrow/compute/exec/expression.cc index 4249179e1bf..b941ad4ca55 100644 --- a/cpp/src/arrow/compute/exec/expression.cc +++ b/cpp/src/arrow/compute/exec/expression.cc @@ -877,105 +877,11 @@ Result Canonicalize(Expression expr, compute::ExecContext* exec_cont [](Expression expr, ...) { return expr; }); } -namespace { - -Result DirectComparisonSimplification(Expression expr, - const Expression::Call& guarantee) { - return Modify( - std::move(expr), [](Expression expr) { return expr; }, - [&guarantee](Expression expr, ...) -> Result { - auto call = expr.call(); - if (!call) return expr; - - // Ensure both calls are comparisons with equal LHS and scalar RHS - auto cmp = Comparison::Get(expr); - auto cmp_guarantee = Comparison::Get(guarantee.function_name); - - if (!cmp) return expr; - if (!cmp_guarantee) return expr; - - const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); - const auto& guarantee_lhs = guarantee.arguments[0]; - if (lhs != guarantee_lhs) return expr; - - auto rhs = call->arguments[1].literal(); - auto guarantee_rhs = guarantee.arguments[1].literal(); - - if (!rhs) return expr; - if (!rhs->is_scalar()) return expr; - - if (!guarantee_rhs) return expr; - if (!guarantee_rhs->is_scalar()) return expr; - - ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, - Comparison::Execute(*rhs, *guarantee_rhs)); - DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); - - if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { - // RHS of filter is equal to RHS of guarantee - - if ((*cmp & *cmp_guarantee) == *cmp_guarantee) { - // guarantee is a subset of filter, so all data will be included - // x > 1, x >= 1, x != 1 guaranteed by x > 1 - return literal(true); - } - - if ((*cmp & *cmp_guarantee) == 0) { - // guarantee disjoint with filter, so all data will be excluded - // x > 1, x >= 1, x != 1 unsatisfiable if x == 1 - return literal(false); - } - - return expr; - } - - if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { - // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3 - return expr; - } - - if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { - // x > 1, x >= 1, x != 1 guaranteed by x >= 3 - return literal(true); - } else { - // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 - return literal(false); - } - }); -} - -} // namespace - Result SimplifyWithGuarantee(Expression expr, const Expression& guaranteed_true_predicate) { - auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); - - KnownFieldValues known_values; - RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); - - ARROW_ASSIGN_OR_RAISE(expr, - ReplaceFieldsWithKnownValues(known_values, std::move(expr))); - - auto CanonicalizeAndFoldConstants = [&expr] { - ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); - ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); - return Status::OK(); - }; - RETURN_NOT_OK(CanonicalizeAndFoldConstants()); - - for (const auto& guarantee : conjunction_members) { - if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { - ARROW_ASSIGN_OR_RAISE( - auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); - - if (Identical(simplified, expr)) continue; - - expr = std::move(simplified); - RETURN_NOT_OK(CanonicalizeAndFoldConstants()); - } - } - - return expr; + ExecContext ctx; + return default_expression_simplification_registry()->RunAllPasses( + std::move(expr), guaranteed_true_predicate, &ctx); } // Serialization is accomplished by converting expressions to KeyValueMetadata and storing @@ -1191,5 +1097,209 @@ Expression or_(const std::vector& operands) { Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } +ExpressionSimplificationPassRegistry* default_expression_simplification_registry() { + class DefaultRegistry : public ExpressionSimplificationPassRegistry { + public: + DefaultRegistry() { + Add([](Expression expr, ExecContext* ctx) -> Result { + // if all arguments to a call are literal, we can evaluate this call *now* + auto call = CallNotNull(expr); + if (std::all_of(call->arguments.begin(), call->arguments.end(), + [](const Expression& argument) { return argument.literal(); })) { + static const ExecBatch ignored_input = ExecBatch{}; + ARROW_ASSIGN_OR_RAISE(Datum constant, + ExecuteScalarExpression(expr, ignored_input)); + + return literal(std::move(constant)); + } + return expr; + }); + + Add([](Expression expr, ExecContext* ctx) -> Result { + // kernels which always produce intersected validity can be resolved + // to null *now* if any of their inputs is a null literal + auto call = CallNotNull(expr); + if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { + for (const auto& argument : call->arguments) { + if (argument.IsNullLiteral()) { + return argument; + } + } + } + return expr; + }); + + Add([](Expression expr, ExecContext* ctx) -> Result { + auto call = CallNotNull(expr); + if (call->function_name == "and_kleene") { + // false and x == false + if (call->arguments[0] == literal(false)) return literal(false); + if (call->arguments[1] == literal(false)) return literal(false); + + // true and x == x + if (call->arguments[0] == literal(true)) return call->arguments[1]; + if (call->arguments[1] == literal(true)) return call->arguments[0]; + + // x and x == x + if (call->arguments[0] == call->arguments[1]) return call->arguments[0]; + } + return expr; + }); + + Add([](Expression expr, ExecContext* ctx) -> Result { + auto call = CallNotNull(expr); + if (call->function_name == "or_kleene") { + // true or x == true + if (call->arguments[0] == literal(true)) return literal(true); + if (call->arguments[1] == literal(true)) return literal(true); + + // false or x == x + if (call->arguments[0] == literal(false)) return call->arguments[1]; + if (call->arguments[1] == literal(false)) return call->arguments[0]; + + // x or x == x + if (call->arguments[0] == call->arguments[1]) return call->arguments[0]; + } + return expr; + }); + + Add([](Expression expr, const Expression& guarantee_expr, + ExecContext* ctx) -> Result { + // Ensure both calls are comparisons with equal LHS and scalar RHS + auto cmp = Comparison::Get(expr); + auto cmp_guarantee = Comparison::Get(guarantee_expr); + + if (!cmp) return expr; + if (!cmp_guarantee) return expr; + + const auto& args = CallNotNull(expr)->arguments; + const auto& guarantee_args = CallNotNull(guarantee_expr)->arguments; + + const auto& lhs = Comparison::StripOrderPreservingCasts(args[0]); + const auto& guarantee_lhs = guarantee_args[0]; + if (lhs != guarantee_lhs) return expr; + + auto rhs = args[1].literal(); + auto guarantee_rhs = guarantee_args[1].literal(); + + if (!rhs) return expr; + if (!rhs->is_scalar()) return expr; + + if (!guarantee_rhs) return expr; + if (!guarantee_rhs->is_scalar()) return expr; + + ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, + Comparison::Execute(*rhs, *guarantee_rhs)); + DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); + + if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { + // RHS of filter is equal to RHS of guarantee + + if ((*cmp & *cmp_guarantee) == *cmp_guarantee) { + // guarantee is a subset of filter, so all data will be included + // x > 1, x >= 1, x != 1 guaranteed by x > 1 + return literal(true); + } + + if ((*cmp & *cmp_guarantee) == 0) { + // guarantee disjoint with filter, so all data will be excluded + // x > 1, x >= 1, x != 1 unsatisfiable if x == 1 + return literal(false); + } + + return expr; + } + + if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { + // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3 + return expr; + } + + if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { + // x > 1, x >= 1, x != 1 guaranteed by x >= 3 + return literal(true); + } else { + // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 + return literal(false); + } + }); + } + + void Add(IndependentPass p) override { independent_passes_.push_back(std::move(p)); } + + void Add(GuaranteePass p) override { guarantee_passes_.push_back(std::move(p)); } + + Result RunIndependentPasses(Expression expr, ExecContext* ctx) override { + ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx)); + + ARROW_ASSIGN_OR_RAISE( + auto simplified, + Modify( + canonicalized, [](Expression expr) { return expr; }, + [&](Expression expr, ...) -> Result { + for (const auto& pass : independent_passes_) { + ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, ctx)); + if (Identical(simplified, expr)) continue; + + ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(simplified), ctx)); + if (!expr.call()) return expr; + } + return expr; + })); + + if (Identical(simplified, canonicalized)) return expr; + return Canonicalize(simplified, ctx); + } + + Result RunAllPasses(Expression expr, + const Expression& guaranteed_true_predicate, + ExecContext* ctx) override { + auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); + + KnownFieldValues known_values; + RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); + ARROW_ASSIGN_OR_RAISE(expr, + ReplaceFieldsWithKnownValues(known_values, std::move(expr))); + + // first run independent passes + ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx)); + ARROW_ASSIGN_OR_RAISE(auto simplified, RunIndependentPasses(canonicalized, ctx)); + + ARROW_ASSIGN_OR_RAISE( + simplified, + Modify( + std::move(simplified), [](Expression expr) { return expr; }, + [&](Expression expr, Expression* old) -> Result { + for (const auto& pass : guarantee_passes_) { + for (const auto& guarantee : conjunction_members) { + ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, guarantee, ctx)); + if (Identical(simplified, expr)) continue; + + // independent passes may have been invalidated by this guarantee pass + ARROW_ASSIGN_OR_RAISE( + expr, RunIndependentPasses(std::move(simplified), ctx)); + + if (!expr.call()) return expr; + } + } + + if (old == nullptr) return expr; + + return RunIndependentPasses(std::move(expr), ctx); + })); + + if (Identical(simplified, canonicalized)) return expr; + return Canonicalize(simplified, ctx); + } + + private: + std::vector independent_passes_; + std::vector guarantee_passes_; + }; + + static DefaultRegistry instance; + return &instance; +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/exec/expression.h b/cpp/src/arrow/compute/exec/expression.h index 7c567cc8fc6..056998fca89 100644 --- a/cpp/src/arrow/compute/exec/expression.h +++ b/cpp/src/arrow/compute/exec/expression.h @@ -184,6 +184,9 @@ Result ExtractKnownFieldValues( /// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS /// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or /// other semantically identical Expressions will not be recognized. +/// +/// For any simplification, if no changes could be made the identical expression will be +/// returned (`IsIdentical(old, new)` will be true). /// Weak canonicalization which establishes guarantees for subsequent passes. Even /// equivalent Expressions may result in different canonicalized expressions. @@ -191,6 +194,41 @@ Result ExtractKnownFieldValues( ARROW_EXPORT Result Canonicalize(Expression, ExecContext* = NULLPTR); +/// An extensible registry for simplification passes over Expressions. +class ARROW_EXPORT ExpressionSimplificationPassRegistry { + public: + /// A pass which can operate on a bound Expression independently. + /// Independent passes need not recurse into Call::arguments; all independent + /// passes will be applied to each argument before any is applied to the call. + /// Expressions will be canonicalized before each pass is run. + using IndependentPass = std::function(Expression, ExecContext*)>; + + /// A pass which utilizes a guaranteed true predicate. + /// Guarantee passes are allowed to invalidate independent passes; + /// all independent passes will be applied when any guarantee pass makes a change. + /// Guarantee passes need not decompose conjunctions; they will be run for + /// each member of a guarantee conjunction. + /// Guarantee passes need not recurse into Call::arguments; all guarantee + /// passes will be applied to each argument before any is applied to the call. + /// Expressions will be canonicalized before each pass is run. + using GuaranteePass = + std::function(Expression, const Expression&, ExecContext*)>; + + virtual ~ExpressionSimplificationPassRegistry() = default; + + virtual void Add(IndependentPass) = 0; + virtual void Add(GuaranteePass) = 0; + + virtual Result RunIndependentPasses(Expression, ExecContext*) = 0; + virtual Result RunAllPasses(Expression, + const Expression& guaranteed_true_predicate, + ExecContext*) = 0; +}; + +/// The default registry, which includes built-in simplification passes. +ARROW_EXPORT +ExpressionSimplificationPassRegistry* default_expression_simplification_registry(); + /// Simplify Expressions based on literal arguments (for example, add(null, x) will always /// be null so replace the call with a null literal). Includes early evaluation of all /// calls whose arguments are entirely literal.