Skip to content

Commit

Permalink
feat: run expression simplifier in a loop until a fixedpoint or 3 cyc…
Browse files Browse the repository at this point in the history
…les (#10358)

* feat: run expression simplifier in a loop

* change max_simplifier_iterations to u32

* use simplify_inner to explicitly test iteration count

* refactor simplify_inner loop

* const evaluator should return transformed=false on literals

* update tests

* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>

* run shorten_in_list_simplifier once at the end of the loop

* move UDF test case to core integration tests

* documentation and naming updates

* documentation and naming updates

* remove unused import and minor doc formatting change

* Update datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs

---------

Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
  • Loading branch information
erratic-pattern and alamb committed May 7, 2024
1 parent 9fd697c commit f0e96c6
Show file tree
Hide file tree
Showing 2 changed files with 182 additions and 24 deletions.
31 changes: 31 additions & 0 deletions datafusion/core/tests/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,29 @@ fn test_simplify(input_expr: Expr, expected_expr: Expr) {
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
}
fn test_simplify_with_cycle_count(
input_expr: Expr,
expected_expr: Expr,
expected_count: u32,
) {
let info: MyInfo = MyInfo {
schema: expr_test_schema(),
execution_props: ExecutionProps::new(),
};
let simplifier = ExprSimplifier::new(info);
let (simplified_expr, count) = simplifier
.simplify_with_cycle_count(input_expr.clone())
.expect("successfully evaluated");

assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
);
assert_eq!(
count, expected_count,
"Mismatch simplifier cycle count\n Expected: {expected_count}\n Got:{count}"
);
}

#[test]
fn test_simplify_log() {
Expand Down Expand Up @@ -658,3 +681,11 @@ fn test_simplify_concat() {
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
test_simplify(expr, expected)
}
#[test]
fn test_simplify_cycles() {
// cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
let expr = cast(now(), DataType::Int64)
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
let expected = lit(true);
test_simplify_with_cycle_count(expr, expected, 3);
}
175 changes: 151 additions & 24 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ pub struct ExprSimplifier<S> {
/// Should expressions be canonicalized before simplification? Defaults to
/// true
canonicalize: bool,
/// Maximum number of simplifier cycles
max_simplifier_cycles: u32,
}

pub const THRESHOLD_INLINE_INLIST: usize = 3;
pub const DEFAULT_MAX_SIMPLIFIER_CYCLES: u32 = 3;

impl<S: SimplifyInfo> ExprSimplifier<S> {
/// Create a new `ExprSimplifier` with the given `info` such as an
Expand All @@ -107,10 +110,11 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
info,
guarantees: vec![],
canonicalize: true,
max_simplifier_cycles: DEFAULT_MAX_SIMPLIFIER_CYCLES,
}
}

/// Simplifies this [`Expr`]`s as much as possible, evaluating
/// Simplifies this [`Expr`] as much as possible, evaluating
/// constants and applying algebraic simplifications.
///
/// The types of the expression must match what operators expect,
Expand Down Expand Up @@ -171,7 +175,18 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// let expr = simplifier.simplify(expr).unwrap();
/// assert_eq!(expr, b_lt_2);
/// ```
pub fn simplify(&self, mut expr: Expr) -> Result<Expr> {
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
Ok(self.simplify_with_cycle_count(expr)?.0)
}

/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
/// constants and applying algebraic simplifications. Additionally returns a `u32`
/// representing the number of simplification cycles performed, which can be useful for testing
/// optimizations.
///
/// See [Self::simplify] for details and usage examples.
///
pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
Expand All @@ -181,24 +196,26 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
expr = expr.rewrite(&mut Canonicalizer::new()).data()?
}

// TODO iterate until no changes are made during rewrite
// (evaluating constants can enable new simplifications and
// simplifications can enable new constant evaluation)
// https://github.com/apache/datafusion/issues/1160
expr.rewrite(&mut const_evaluator)
.data()?
.rewrite(&mut simplifier)
.data()?
.rewrite(&mut guarantee_rewriter)
.data()?
// run both passes twice to try an minimize simplifications that we missed
.rewrite(&mut const_evaluator)
.data()?
.rewrite(&mut simplifier)
.data()?
// shorten inlist should be started after other inlist rules are applied
.rewrite(&mut shorten_in_list_simplifier)
.data()
// Evaluating constants can enable new simplifications and
// simplifications can enable new constant evaluation
// see `Self::with_max_cycles`
let mut num_cycles = 0;
loop {
let Transformed {
data, transformed, ..
} = expr
.rewrite(&mut const_evaluator)?
.transform_data(|expr| expr.rewrite(&mut simplifier))?
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}
// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
Ok((expr, num_cycles))
}

/// Apply type coercion to an [`Expr`] so that it can be
Expand Down Expand Up @@ -323,6 +340,63 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
self.canonicalize = canonicalize;
self
}

/// Specifies the maximum number of simplification cycles to run.
///
/// The simplifier can perform multiple passes of simplification. This is
/// because the output of one simplification step can allow more optimizations
/// in another simplification step. For example, constant evaluation can allow more
/// expression simplifications, and expression simplifications can allow more constant
/// evaluations.
///
/// This method specifies the maximum number of allowed iteration cycles before the simplifier
/// returns an [Expr] output. However, it does not always perform the maximum number of cycles.
/// The simplifier will attempt to detect when an [Expr] is unchanged by all the simplification
/// passes, and return early. This avoids wasting time on unnecessary [Expr] tree traversals.
///
/// If no maximum is specified, the value of [DEFAULT_MAX_SIMPLIFIER_CYCLES] is used
/// instead.
///
/// ```rust
/// use arrow::datatypes::{DataType, Field, Schema};
/// use datafusion_expr::{col, lit, Expr};
/// use datafusion_common::{Result, ScalarValue, ToDFSchema};
/// use datafusion_expr::execution_props::ExecutionProps;
/// use datafusion_expr::simplify::SimplifyContext;
/// use datafusion_optimizer::simplify_expressions::ExprSimplifier;
///
/// let schema = Schema::new(vec![
/// Field::new("a", DataType::Int64, false),
/// ])
/// .to_dfschema_ref().unwrap();
///
/// // Create the simplifier
/// let props = ExecutionProps::new();
/// let context = SimplifyContext::new(&props)
/// .with_schema(schema);
/// let simplifier = ExprSimplifier::new(context);
///
/// // Expression: a IS NOT NULL
/// let expr = col("a").is_not_null();
///
/// // When using default maximum cycles, 2 cycles will be performed.
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
/// assert_eq!(simplified_expr, lit(true));
/// // 2 cycles were executed, but only 1 was needed
/// assert_eq!(count, 2);
///
/// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
/// // Expression has been rewritten to: (c = a AND b = 1)
/// assert_eq!(simplified_expr, lit(true));
/// // Only 1 cycle was executed
/// assert_eq!(count, 1);
///
/// ```
pub fn with_max_cycles(mut self, max_simplifier_cycles: u32) -> Self {
self.max_simplifier_cycles = max_simplifier_cycles;
self
}
}

/// Canonicalize any BinaryExprs that are not in canonical form
Expand Down Expand Up @@ -404,6 +478,8 @@ struct ConstEvaluator<'a> {
enum ConstSimplifyResult {
// Expr was simplifed and contains the new expression
Simplified(ScalarValue),
// Expr was not simplified and original value is returned
NotSimplified(ScalarValue),
// Evaluation encountered an error, contains the original expression
SimplifyRuntimeError(DataFusionError, Expr),
}
Expand Down Expand Up @@ -450,6 +526,9 @@ impl<'a> TreeNodeRewriter for ConstEvaluator<'a> {
ConstSimplifyResult::Simplified(s) => {
Ok(Transformed::yes(Expr::Literal(s)))
}
ConstSimplifyResult::NotSimplified(s) => {
Ok(Transformed::no(Expr::Literal(s)))
}
ConstSimplifyResult::SimplifyRuntimeError(_, expr) => {
Ok(Transformed::yes(expr))
}
Expand Down Expand Up @@ -548,7 +627,7 @@ impl<'a> ConstEvaluator<'a> {
/// Internal helper to evaluates an Expr
pub(crate) fn evaluate_to_scalar(&mut self, expr: Expr) -> ConstSimplifyResult {
if let Expr::Literal(s) = expr {
return ConstSimplifyResult::Simplified(s);
return ConstSimplifyResult::NotSimplified(s);
}

let phys_expr =
Expand Down Expand Up @@ -1672,15 +1751,14 @@ fn inlist_except(mut l1: InList, l2: InList) -> Result<Expr> {

#[cfg(test)]
mod tests {
use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use std::{
collections::HashMap,
ops::{BitAnd, BitOr, BitXor},
sync::Arc,
};

use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};

use crate::simplify_expressions::SimplifyContext;
use crate::test::test_table_scan_with_name;

Expand Down Expand Up @@ -2868,6 +2946,19 @@ mod tests {
try_simplify(expr).unwrap()
}

fn try_simplify_with_cycle_count(expr: Expr) -> Result<(Expr, u32)> {
let schema = expr_test_schema();
let execution_props = ExecutionProps::new();
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
simplifier.simplify_with_cycle_count(expr)
}

fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
try_simplify_with_cycle_count(expr).unwrap()
}

fn simplify_with_guarantee(
expr: Expr,
guarantees: Vec<(Expr, NullableInterval)>,
Expand Down Expand Up @@ -3575,4 +3666,40 @@ mod tests {

assert_eq!(simplify(expr), expected);
}

#[test]
fn test_simplify_cycles() {
// TRUE
let expr = lit(true);
let expected = lit(true);
let (expr, num_iter) = simplify_with_cycle_count(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);

// (true != NULL) OR (5 > 10)
let expr = lit(true).not_eq(lit_bool_null()).or(lit(5).gt(lit(10)));
let expected = lit_bool_null();
let (expr, num_iter) = simplify_with_cycle_count(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);

// NOTE: this currently does not simplify
// (((c4 - 10) + 10) *100) / 100
let expr = (((col("c4") - lit(10)) + lit(10)) * lit(100)) / lit(100);
let expected = expr.clone();
let (expr, num_iter) = simplify_with_cycle_count(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 1);

// ((c4<1 or c3<2) and c3_non_null<3) and false
let expr = col("c4")
.lt(lit(1))
.or(col("c3").lt(lit(2)))
.and(col("c3_non_null").lt(lit(3)))
.and(lit(false));
let expected = lit(false);
let (expr, num_iter) = simplify_with_cycle_count(expr);
assert_eq!(expr, expected);
assert_eq!(num_iter, 2);
}
}

0 comments on commit f0e96c6

Please sign in to comment.