-
Notifications
You must be signed in to change notification settings - Fork 3.9k
ARROW-14725: [C++][Compute] Extract Expression simplification pass registry #11716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -877,105 +877,11 @@ Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_cont | |
| [](Expression expr, ...) { return expr; }); | ||
| } | ||
|
|
||
| namespace { | ||
|
|
||
| Result<Expression> DirectComparisonSimplification(Expression expr, | ||
| const Expression::Call& guarantee) { | ||
| return Modify( | ||
| std::move(expr), [](Expression expr) { return expr; }, | ||
| [&guarantee](Expression expr, ...) -> Result<Expression> { | ||
| auto call = expr.call(); | ||
| if (!call) return expr; | ||
|
|
||
| // Ensure both calls are comparisons with equal LHS and scalar RHS | ||
| auto cmp = Comparison::Get(expr); | ||
| auto cmp_guarantee = Comparison::Get(guarantee.function_name); | ||
|
|
||
| if (!cmp) return expr; | ||
| if (!cmp_guarantee) return expr; | ||
|
|
||
| const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]); | ||
| const auto& guarantee_lhs = guarantee.arguments[0]; | ||
| if (lhs != guarantee_lhs) return expr; | ||
|
|
||
| auto rhs = call->arguments[1].literal(); | ||
| auto guarantee_rhs = guarantee.arguments[1].literal(); | ||
|
|
||
| if (!rhs) return expr; | ||
| if (!rhs->is_scalar()) return expr; | ||
|
|
||
| if (!guarantee_rhs) return expr; | ||
| if (!guarantee_rhs->is_scalar()) return expr; | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, | ||
| Comparison::Execute(*rhs, *guarantee_rhs)); | ||
| DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); | ||
|
|
||
| if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { | ||
| // RHS of filter is equal to RHS of guarantee | ||
|
|
||
| if ((*cmp & *cmp_guarantee) == *cmp_guarantee) { | ||
| // guarantee is a subset of filter, so all data will be included | ||
| // x > 1, x >= 1, x != 1 guaranteed by x > 1 | ||
| return literal(true); | ||
| } | ||
|
|
||
| if ((*cmp & *cmp_guarantee) == 0) { | ||
| // guarantee disjoint with filter, so all data will be excluded | ||
| // x > 1, x >= 1, x != 1 unsatisfiable if x == 1 | ||
| return literal(false); | ||
| } | ||
|
|
||
| return expr; | ||
| } | ||
|
|
||
| if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { | ||
| // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3 | ||
| return expr; | ||
| } | ||
|
|
||
| if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { | ||
| // x > 1, x >= 1, x != 1 guaranteed by x >= 3 | ||
| return literal(true); | ||
| } else { | ||
| // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 | ||
| return literal(false); | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| Result<Expression> SimplifyWithGuarantee(Expression expr, | ||
| const Expression& guaranteed_true_predicate) { | ||
| auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); | ||
|
|
||
| KnownFieldValues known_values; | ||
| RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(expr, | ||
| ReplaceFieldsWithKnownValues(known_values, std::move(expr))); | ||
|
|
||
| auto CanonicalizeAndFoldConstants = [&expr] { | ||
| ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr))); | ||
| ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr))); | ||
| return Status::OK(); | ||
| }; | ||
| RETURN_NOT_OK(CanonicalizeAndFoldConstants()); | ||
|
|
||
| for (const auto& guarantee : conjunction_members) { | ||
| if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) { | ||
| ARROW_ASSIGN_OR_RAISE( | ||
| auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee))); | ||
|
|
||
| if (Identical(simplified, expr)) continue; | ||
|
|
||
| expr = std::move(simplified); | ||
| RETURN_NOT_OK(CanonicalizeAndFoldConstants()); | ||
| } | ||
| } | ||
|
|
||
| return expr; | ||
| ExecContext ctx; | ||
| return default_expression_simplification_registry()->RunAllPasses( | ||
| std::move(expr), guaranteed_true_predicate, &ctx); | ||
| } | ||
|
|
||
| // Serialization is accomplished by converting expressions to KeyValueMetadata and storing | ||
|
|
@@ -1191,5 +1097,209 @@ Expression or_(const std::vector<Expression>& operands) { | |
|
|
||
| Expression not_(Expression operand) { return call("invert", {std::move(operand)}); } | ||
|
|
||
| ExpressionSimplificationPassRegistry* default_expression_simplification_registry() { | ||
| class DefaultRegistry : public ExpressionSimplificationPassRegistry { | ||
| public: | ||
| DefaultRegistry() { | ||
| Add([](Expression expr, ExecContext* ctx) -> Result<Expression> { | ||
| // if all arguments to a call are literal, we can evaluate this call *now* | ||
| auto call = CallNotNull(expr); | ||
| if (std::all_of(call->arguments.begin(), call->arguments.end(), | ||
| [](const Expression& argument) { return argument.literal(); })) { | ||
| static const ExecBatch ignored_input = ExecBatch{}; | ||
| ARROW_ASSIGN_OR_RAISE(Datum constant, | ||
| ExecuteScalarExpression(expr, ignored_input)); | ||
|
|
||
| return literal(std::move(constant)); | ||
| } | ||
| return expr; | ||
| }); | ||
|
|
||
| Add([](Expression expr, ExecContext* ctx) -> Result<Expression> { | ||
| // kernels which always produce intersected validity can be resolved | ||
| // to null *now* if any of their inputs is a null literal | ||
| auto call = CallNotNull(expr); | ||
| if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) { | ||
| for (const auto& argument : call->arguments) { | ||
| if (argument.IsNullLiteral()) { | ||
| return argument; | ||
| } | ||
| } | ||
| } | ||
| return expr; | ||
| }); | ||
|
|
||
| Add([](Expression expr, ExecContext* ctx) -> Result<Expression> { | ||
| auto call = CallNotNull(expr); | ||
| if (call->function_name == "and_kleene") { | ||
| // false and x == false | ||
| if (call->arguments[0] == literal(false)) return literal(false); | ||
| if (call->arguments[1] == literal(false)) return literal(false); | ||
|
|
||
| // true and x == x | ||
| if (call->arguments[0] == literal(true)) return call->arguments[1]; | ||
| if (call->arguments[1] == literal(true)) return call->arguments[0]; | ||
|
|
||
| // x and x == x | ||
| if (call->arguments[0] == call->arguments[1]) return call->arguments[0]; | ||
| } | ||
| return expr; | ||
| }); | ||
|
|
||
| Add([](Expression expr, ExecContext* ctx) -> Result<Expression> { | ||
| auto call = CallNotNull(expr); | ||
| if (call->function_name == "or_kleene") { | ||
| // true or x == true | ||
| if (call->arguments[0] == literal(true)) return literal(true); | ||
| if (call->arguments[1] == literal(true)) return literal(true); | ||
|
|
||
| // false or x == x | ||
| if (call->arguments[0] == literal(false)) return call->arguments[1]; | ||
| if (call->arguments[1] == literal(false)) return call->arguments[0]; | ||
|
|
||
| // x or x == x | ||
| if (call->arguments[0] == call->arguments[1]) return call->arguments[0]; | ||
| } | ||
| return expr; | ||
| }); | ||
|
|
||
| Add([](Expression expr, const Expression& guarantee_expr, | ||
| ExecContext* ctx) -> Result<Expression> { | ||
| // Ensure both calls are comparisons with equal LHS and scalar RHS | ||
| auto cmp = Comparison::Get(expr); | ||
| auto cmp_guarantee = Comparison::Get(guarantee_expr); | ||
|
|
||
| if (!cmp) return expr; | ||
| if (!cmp_guarantee) return expr; | ||
|
|
||
| const auto& args = CallNotNull(expr)->arguments; | ||
| const auto& guarantee_args = CallNotNull(guarantee_expr)->arguments; | ||
|
|
||
| const auto& lhs = Comparison::StripOrderPreservingCasts(args[0]); | ||
| const auto& guarantee_lhs = guarantee_args[0]; | ||
| if (lhs != guarantee_lhs) return expr; | ||
|
|
||
| auto rhs = args[1].literal(); | ||
| auto guarantee_rhs = guarantee_args[1].literal(); | ||
|
|
||
| if (!rhs) return expr; | ||
| if (!rhs->is_scalar()) return expr; | ||
|
|
||
| if (!guarantee_rhs) return expr; | ||
| if (!guarantee_rhs->is_scalar()) return expr; | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs, | ||
| Comparison::Execute(*rhs, *guarantee_rhs)); | ||
| DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA); | ||
|
|
||
| if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) { | ||
| // RHS of filter is equal to RHS of guarantee | ||
|
|
||
| if ((*cmp & *cmp_guarantee) == *cmp_guarantee) { | ||
| // guarantee is a subset of filter, so all data will be included | ||
| // x > 1, x >= 1, x != 1 guaranteed by x > 1 | ||
| return literal(true); | ||
| } | ||
|
|
||
| if ((*cmp & *cmp_guarantee) == 0) { | ||
| // guarantee disjoint with filter, so all data will be excluded | ||
| // x > 1, x >= 1, x != 1 unsatisfiable if x == 1 | ||
| return literal(false); | ||
| } | ||
|
|
||
| return expr; | ||
| } | ||
|
|
||
| if (*cmp_guarantee & cmp_rhs_guarantee_rhs) { | ||
| // x > 1, x >= 1, x != 1 cannot use guarantee x >= 3 | ||
| return expr; | ||
| } | ||
|
|
||
| if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) { | ||
| // x > 1, x >= 1, x != 1 guaranteed by x >= 3 | ||
| return literal(true); | ||
| } else { | ||
| // x < 1, x <= 1, x == 1 unsatisfiable if x >= 3 | ||
| return literal(false); | ||
| } | ||
| }); | ||
| } | ||
|
|
||
| void Add(IndependentPass p) override { independent_passes_.push_back(std::move(p)); } | ||
|
|
||
| void Add(GuaranteePass p) override { guarantee_passes_.push_back(std::move(p)); } | ||
|
|
||
| Result<Expression> RunIndependentPasses(Expression expr, ExecContext* ctx) override { | ||
| ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx)); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE( | ||
| auto simplified, | ||
| Modify( | ||
| canonicalized, [](Expression expr) { return expr; }, | ||
| [&](Expression expr, ...) -> Result<Expression> { | ||
| for (const auto& pass : independent_passes_) { | ||
| ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, ctx)); | ||
| if (Identical(simplified, expr)) continue; | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(simplified), ctx)); | ||
| if (!expr.call()) return expr; | ||
| } | ||
| return expr; | ||
| })); | ||
|
|
||
| if (Identical(simplified, canonicalized)) return expr; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this return
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also seems we don't need to canonicalize below before calling this. |
||
| return Canonicalize(simplified, ctx); | ||
| } | ||
|
|
||
| Result<Expression> RunAllPasses(Expression expr, | ||
| const Expression& guaranteed_true_predicate, | ||
| ExecContext* ctx) override { | ||
| auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate); | ||
|
|
||
| KnownFieldValues known_values; | ||
| RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map)); | ||
| ARROW_ASSIGN_OR_RAISE(expr, | ||
| ReplaceFieldsWithKnownValues(known_values, std::move(expr))); | ||
|
|
||
| // first run independent passes | ||
| ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx)); | ||
| ARROW_ASSIGN_OR_RAISE(auto simplified, RunIndependentPasses(canonicalized, ctx)); | ||
|
|
||
| ARROW_ASSIGN_OR_RAISE( | ||
| simplified, | ||
| Modify( | ||
| std::move(simplified), [](Expression expr) { return expr; }, | ||
| [&](Expression expr, Expression* old) -> Result<Expression> { | ||
| for (const auto& pass : guarantee_passes_) { | ||
| for (const auto& guarantee : conjunction_members) { | ||
| ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, guarantee, ctx)); | ||
| if (Identical(simplified, expr)) continue; | ||
|
|
||
| // independent passes may have been invalidated by this guarantee pass | ||
| ARROW_ASSIGN_OR_RAISE( | ||
| expr, RunIndependentPasses(std::move(simplified), ctx)); | ||
|
|
||
| if (!expr.call()) return expr; | ||
| } | ||
| } | ||
|
|
||
| if (old == nullptr) return expr; | ||
|
|
||
| return RunIndependentPasses(std::move(expr), ctx); | ||
| })); | ||
|
|
||
| if (Identical(simplified, canonicalized)) return expr; | ||
| return Canonicalize(simplified, ctx); | ||
| } | ||
|
|
||
| private: | ||
| std::vector<IndependentPass> independent_passes_; | ||
| std::vector<GuaranteePass> guarantee_passes_; | ||
| }; | ||
|
|
||
| static DefaultRegistry instance; | ||
| return &instance; | ||
| } | ||
|
|
||
| } // namespace compute | ||
| } // namespace arrow | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -184,13 +184,51 @@ Result<KnownFieldValues> ExtractKnownFieldValues( | |
| /// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS | ||
| /// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or | ||
| /// other semantically identical Expressions will not be recognized. | ||
| /// | ||
| /// For any simplification, if no changes could be made the identical expression will be | ||
| /// returned (`IsIdentical(old, new)` will be true). | ||
|
|
||
| /// Weak canonicalization which establishes guarantees for subsequent passes. Even | ||
| /// equivalent Expressions may result in different canonicalized expressions. | ||
| /// TODO this could be a strong canonicalization | ||
| ARROW_EXPORT | ||
| Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR); | ||
|
|
||
| /// An extensible registry for simplification passes over Expressions. | ||
| class ARROW_EXPORT ExpressionSimplificationPassRegistry { | ||
| public: | ||
| /// A pass which can operate on a bound Expression independently. | ||
| /// Independent passes need not recurse into Call::arguments; all independent | ||
| /// passes will be applied to each argument before any is applied to the call. | ||
| /// Expressions will be canonicalized before each pass is run. | ||
| using IndependentPass = std::function<Result<Expression>(Expression, ExecContext*)>; | ||
|
|
||
| /// A pass which utilizes a guaranteed true predicate. | ||
| /// Guarantee passes are allowed to invalidate independent passes; | ||
| /// all independent passes will be applied when any guarantee pass makes a change. | ||
| /// Guarantee passes need not decompose conjunctions; they will be run for | ||
| /// each member of a guarantee conjunction. | ||
| /// Guarantee passes need not recurse into Call::arguments; all guarantee | ||
| /// passes will be applied to each argument before any is applied to the call. | ||
| /// Expressions will be canonicalized before each pass is run. | ||
| using GuaranteePass = | ||
| std::function<Result<Expression>(Expression, const Expression&, ExecContext*)>; | ||
|
|
||
| virtual ~ExpressionSimplificationPassRegistry() = default; | ||
|
|
||
| virtual void Add(IndependentPass) = 0; | ||
| virtual void Add(GuaranteePass) = 0; | ||
|
|
||
| virtual Result<Expression> RunIndependentPasses(Expression, ExecContext*) = 0; | ||
| virtual Result<Expression> RunAllPasses(Expression, | ||
| const Expression& guaranteed_true_predicate, | ||
| ExecContext*) = 0; | ||
| }; | ||
|
|
||
| /// The default registry, which includes built-in simplification passes. | ||
| ARROW_EXPORT | ||
| ExpressionSimplificationPassRegistry* default_expression_simplification_registry(); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: DefaultExpressionSimplificationRegistry?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Though I see the exec node registry and function registry use different naming schemes (default_exec_factory_registry, GetFunctionRegistry) |
||
|
|
||
| /// Simplify Expressions based on literal arguments (for example, add(null, x) will always | ||
| /// be null so replace the call with a null literal). Includes early evaluation of all | ||
| /// calls whose arguments are entirely literal. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like FoldConstants should be removed above?