Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 207 additions & 97 deletions cpp/src/arrow/compute/exec/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -877,105 +877,11 @@ Result<Expression> Canonicalize(Expression expr, compute::ExecContext* exec_cont
[](Expression expr, ...) { return expr; });
}

namespace {

Result<Expression> DirectComparisonSimplification(Expression expr,
const Expression::Call& guarantee) {
return Modify(
std::move(expr), [](Expression expr) { return expr; },
[&guarantee](Expression expr, ...) -> Result<Expression> {
auto call = expr.call();
if (!call) return expr;

// Ensure both calls are comparisons with equal LHS and scalar RHS
auto cmp = Comparison::Get(expr);
auto cmp_guarantee = Comparison::Get(guarantee.function_name);

if (!cmp) return expr;
if (!cmp_guarantee) return expr;

const auto& lhs = Comparison::StripOrderPreservingCasts(call->arguments[0]);
const auto& guarantee_lhs = guarantee.arguments[0];
if (lhs != guarantee_lhs) return expr;

auto rhs = call->arguments[1].literal();
auto guarantee_rhs = guarantee.arguments[1].literal();

if (!rhs) return expr;
if (!rhs->is_scalar()) return expr;

if (!guarantee_rhs) return expr;
if (!guarantee_rhs->is_scalar()) return expr;

ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
Comparison::Execute(*rhs, *guarantee_rhs));
DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);

if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
// RHS of filter is equal to RHS of guarantee

if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
// guarantee is a subset of filter, so all data will be included
// x > 1, x >= 1, x != 1 guaranteed by x > 1
return literal(true);
}

if ((*cmp & *cmp_guarantee) == 0) {
// guarantee disjoint with filter, so all data will be excluded
// x > 1, x >= 1, x != 1 unsatisfiable if x == 1
return literal(false);
}

return expr;
}

if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
// x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
return expr;
}

if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
// x > 1, x >= 1, x != 1 guaranteed by x >= 3
return literal(true);
} else {
// x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
return literal(false);
}
});
}

} // namespace

Result<Expression> SimplifyWithGuarantee(Expression expr,
const Expression& guaranteed_true_predicate) {
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);

KnownFieldValues known_values;
RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));

ARROW_ASSIGN_OR_RAISE(expr,
ReplaceFieldsWithKnownValues(known_values, std::move(expr)));

auto CanonicalizeAndFoldConstants = [&expr] {
ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(expr)));
ARROW_ASSIGN_OR_RAISE(expr, FoldConstants(std::move(expr)));
return Status::OK();
};
RETURN_NOT_OK(CanonicalizeAndFoldConstants());

for (const auto& guarantee : conjunction_members) {
if (Comparison::Get(guarantee) && guarantee.call()->arguments[1].literal()) {
ARROW_ASSIGN_OR_RAISE(
auto simplified, DirectComparisonSimplification(expr, *CallNotNull(guarantee)));

if (Identical(simplified, expr)) continue;

expr = std::move(simplified);
RETURN_NOT_OK(CanonicalizeAndFoldConstants());
}
}

return expr;
ExecContext ctx;
return default_expression_simplification_registry()->RunAllPasses(
std::move(expr), guaranteed_true_predicate, &ctx);
}

// Serialization is accomplished by converting expressions to KeyValueMetadata and storing
Expand Down Expand Up @@ -1191,5 +1097,209 @@ Expression or_(const std::vector<Expression>& operands) {

Expression not_(Expression operand) { return call("invert", {std::move(operand)}); }

ExpressionSimplificationPassRegistry* default_expression_simplification_registry() {
class DefaultRegistry : public ExpressionSimplificationPassRegistry {
public:
DefaultRegistry() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like FoldConstants should be removed above?

Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
// if all arguments to a call are literal, we can evaluate this call *now*
auto call = CallNotNull(expr);
if (std::all_of(call->arguments.begin(), call->arguments.end(),
[](const Expression& argument) { return argument.literal(); })) {
static const ExecBatch ignored_input = ExecBatch{};
ARROW_ASSIGN_OR_RAISE(Datum constant,
ExecuteScalarExpression(expr, ignored_input));

return literal(std::move(constant));
}
return expr;
});

Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
// kernels which always produce intersected validity can be resolved
// to null *now* if any of their inputs is a null literal
auto call = CallNotNull(expr);
if (GetNullHandling(*call) == compute::NullHandling::INTERSECTION) {
for (const auto& argument : call->arguments) {
if (argument.IsNullLiteral()) {
return argument;
}
}
}
return expr;
});

Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
auto call = CallNotNull(expr);
if (call->function_name == "and_kleene") {
// false and x == false
if (call->arguments[0] == literal(false)) return literal(false);
if (call->arguments[1] == literal(false)) return literal(false);

// true and x == x
if (call->arguments[0] == literal(true)) return call->arguments[1];
if (call->arguments[1] == literal(true)) return call->arguments[0];

// x and x == x
if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
}
return expr;
});

Add([](Expression expr, ExecContext* ctx) -> Result<Expression> {
auto call = CallNotNull(expr);
if (call->function_name == "or_kleene") {
// true or x == true
if (call->arguments[0] == literal(true)) return literal(true);
if (call->arguments[1] == literal(true)) return literal(true);

// false or x == x
if (call->arguments[0] == literal(false)) return call->arguments[1];
if (call->arguments[1] == literal(false)) return call->arguments[0];

// x or x == x
if (call->arguments[0] == call->arguments[1]) return call->arguments[0];
}
return expr;
});

Add([](Expression expr, const Expression& guarantee_expr,
ExecContext* ctx) -> Result<Expression> {
// Ensure both calls are comparisons with equal LHS and scalar RHS
auto cmp = Comparison::Get(expr);
auto cmp_guarantee = Comparison::Get(guarantee_expr);

if (!cmp) return expr;
if (!cmp_guarantee) return expr;

const auto& args = CallNotNull(expr)->arguments;
const auto& guarantee_args = CallNotNull(guarantee_expr)->arguments;

const auto& lhs = Comparison::StripOrderPreservingCasts(args[0]);
const auto& guarantee_lhs = guarantee_args[0];
if (lhs != guarantee_lhs) return expr;

auto rhs = args[1].literal();
auto guarantee_rhs = guarantee_args[1].literal();

if (!rhs) return expr;
if (!rhs->is_scalar()) return expr;

if (!guarantee_rhs) return expr;
if (!guarantee_rhs->is_scalar()) return expr;

ARROW_ASSIGN_OR_RAISE(auto cmp_rhs_guarantee_rhs,
Comparison::Execute(*rhs, *guarantee_rhs));
DCHECK_NE(cmp_rhs_guarantee_rhs, Comparison::NA);

if (cmp_rhs_guarantee_rhs == Comparison::EQUAL) {
// RHS of filter is equal to RHS of guarantee

if ((*cmp & *cmp_guarantee) == *cmp_guarantee) {
// guarantee is a subset of filter, so all data will be included
// x > 1, x >= 1, x != 1 guaranteed by x > 1
return literal(true);
}

if ((*cmp & *cmp_guarantee) == 0) {
// guarantee disjoint with filter, so all data will be excluded
// x > 1, x >= 1, x != 1 unsatisfiable if x == 1
return literal(false);
}

return expr;
}

if (*cmp_guarantee & cmp_rhs_guarantee_rhs) {
// x > 1, x >= 1, x != 1 cannot use guarantee x >= 3
return expr;
}

if (*cmp & Comparison::GetFlipped(cmp_rhs_guarantee_rhs)) {
// x > 1, x >= 1, x != 1 guaranteed by x >= 3
return literal(true);
} else {
// x < 1, x <= 1, x == 1 unsatisfiable if x >= 3
return literal(false);
}
});
}

void Add(IndependentPass p) override { independent_passes_.push_back(std::move(p)); }

void Add(GuaranteePass p) override { guarantee_passes_.push_back(std::move(p)); }

Result<Expression> RunIndependentPasses(Expression expr, ExecContext* ctx) override {
ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx));

ARROW_ASSIGN_OR_RAISE(
auto simplified,
Modify(
canonicalized, [](Expression expr) { return expr; },
[&](Expression expr, ...) -> Result<Expression> {
for (const auto& pass : independent_passes_) {
ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, ctx));
if (Identical(simplified, expr)) continue;

ARROW_ASSIGN_OR_RAISE(expr, Canonicalize(std::move(simplified), ctx));
if (!expr.call()) return expr;
}
return expr;
}));

if (Identical(simplified, canonicalized)) return expr;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this return canonicalized?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also seems we don't need to canonicalize below before calling this.

return Canonicalize(simplified, ctx);
}

Result<Expression> RunAllPasses(Expression expr,
const Expression& guaranteed_true_predicate,
ExecContext* ctx) override {
auto conjunction_members = GuaranteeConjunctionMembers(guaranteed_true_predicate);

KnownFieldValues known_values;
RETURN_NOT_OK(ExtractKnownFieldValuesImpl(&conjunction_members, &known_values.map));
ARROW_ASSIGN_OR_RAISE(expr,
ReplaceFieldsWithKnownValues(known_values, std::move(expr)));

// first run independent passes
ARROW_ASSIGN_OR_RAISE(auto canonicalized, Canonicalize(expr, ctx));
ARROW_ASSIGN_OR_RAISE(auto simplified, RunIndependentPasses(canonicalized, ctx));

ARROW_ASSIGN_OR_RAISE(
simplified,
Modify(
std::move(simplified), [](Expression expr) { return expr; },
[&](Expression expr, Expression* old) -> Result<Expression> {
for (const auto& pass : guarantee_passes_) {
for (const auto& guarantee : conjunction_members) {
ARROW_ASSIGN_OR_RAISE(auto simplified, pass(expr, guarantee, ctx));
if (Identical(simplified, expr)) continue;

// independent passes may have been invalidated by this guarantee pass
ARROW_ASSIGN_OR_RAISE(
expr, RunIndependentPasses(std::move(simplified), ctx));

if (!expr.call()) return expr;
}
}

if (old == nullptr) return expr;

return RunIndependentPasses(std::move(expr), ctx);
}));

if (Identical(simplified, canonicalized)) return expr;
return Canonicalize(simplified, ctx);
}

private:
std::vector<IndependentPass> independent_passes_;
std::vector<GuaranteePass> guarantee_passes_;
};

static DefaultRegistry instance;
return &instance;
}

} // namespace compute
} // namespace arrow
38 changes: 38 additions & 0 deletions cpp/src/arrow/compute/exec/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,51 @@ Result<KnownFieldValues> ExtractKnownFieldValues(
/// guarantee on a field value, an Expression must be a call to "equal" with field_ref LHS
/// and literal RHS. Flipping the arguments, "is_in" with a one-long value_set, ... or
/// other semantically identical Expressions will not be recognized.
///
/// For any simplification, if no changes could be made the identical expression will be
/// returned (`IsIdentical(old, new)` will be true).

/// Weak canonicalization which establishes guarantees for subsequent passes. Even
/// equivalent Expressions may result in different canonicalized expressions.
/// TODO this could be a strong canonicalization
ARROW_EXPORT
Result<Expression> Canonicalize(Expression, ExecContext* = NULLPTR);

/// An extensible registry for simplification passes over Expressions.
class ARROW_EXPORT ExpressionSimplificationPassRegistry {
public:
/// A pass which can operate on a bound Expression independently.
/// Independent passes need not recurse into Call::arguments; all independent
/// passes will be applied to each argument before any is applied to the call.
/// Expressions will be canonicalized before each pass is run.
using IndependentPass = std::function<Result<Expression>(Expression, ExecContext*)>;

/// A pass which utilizes a guaranteed true predicate.
/// Guarantee passes are allowed to invalidate independent passes;
/// all independent passes will be applied when any guarantee pass makes a change.
/// Guarantee passes need not decompose conjunctions; they will be run for
/// each member of a guarantee conjunction.
/// Guarantee passes need not recurse into Call::arguments; all guarantee
/// passes will be applied to each argument before any is applied to the call.
/// Expressions will be canonicalized before each pass is run.
using GuaranteePass =
std::function<Result<Expression>(Expression, const Expression&, ExecContext*)>;

virtual ~ExpressionSimplificationPassRegistry() = default;

virtual void Add(IndependentPass) = 0;
virtual void Add(GuaranteePass) = 0;

virtual Result<Expression> RunIndependentPasses(Expression, ExecContext*) = 0;
virtual Result<Expression> RunAllPasses(Expression,
const Expression& guaranteed_true_predicate,
ExecContext*) = 0;
};

/// The default registry, which includes built-in simplification passes.
ARROW_EXPORT
ExpressionSimplificationPassRegistry* default_expression_simplification_registry();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: DefaultExpressionSimplificationRegistry?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though I see the exec node registry and function registry use different naming schemes (default_exec_factory_registry, GetFunctionRegistry)


/// Simplify Expressions based on literal arguments (for example, add(null, x) will always
/// be null so replace the call with a null literal). Includes early evaluation of all
/// calls whose arguments are entirely literal.
Expand Down