Skip to content

Commit

Permalink
support cast/try_cast in prune with signed integer and decimal (#3422)
Browse files Browse the repository at this point in the history
* support cast/try_cast in prune

* add bound for supported data type in the cast/try_cast prune
  • Loading branch information
liukun4515 committed Sep 12, 2022
1 parent 3b8a20a commit 69d05aa
Show file tree
Hide file tree
Showing 2 changed files with 278 additions and 22 deletions.
287 changes: 265 additions & 22 deletions datafusion/core/src/physical_optimizer/pruning.rs
Expand Up @@ -43,9 +43,9 @@ use arrow::{
datatypes::{DataType, Field, Schema, SchemaRef},
record_batch::RecordBatch,
};
use datafusion_expr::binary_expr;
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter};
use datafusion_expr::utils::expr_to_columns;
use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable};
use datafusion_physical_expr::create_physical_expr;

/// Interface to pass statistics information to [`PruningPredicate`]
Expand Down Expand Up @@ -429,11 +429,13 @@ impl<'a> PruningExpressionBuilder<'a> {
}
};

let (column_expr, correct_operator, scalar_expr) =
match rewrite_expr_to_prunable(column_expr, correct_operator, scalar_expr) {
Ok(ret) => ret,
Err(e) => return Err(e),
};
let df_schema = DFSchema::try_from(schema.clone())?;
let (column_expr, correct_operator, scalar_expr) = rewrite_expr_to_prunable(
column_expr,
correct_operator,
scalar_expr,
df_schema,
)?;
let column = columns.iter().next().unwrap().clone();
let field = match schema.column_with_name(&column.flat_name()) {
Some((_, f)) => f,
Expand Down Expand Up @@ -481,12 +483,15 @@ impl<'a> PruningExpressionBuilder<'a> {
/// 2. `-col > 10` should be rewritten to `col < -10`
/// 3. `!col = true` would be rewritten to `col = !true`
/// 4. `abs(a - 10) > 0` not supported
/// 5. `cast(can_prunable_expr) > 10`
/// 6. `try_cast(can_prunable_expr) > 10`
///
/// More rewrite rules are still in progress.
fn rewrite_expr_to_prunable(
column_expr: &Expr,
op: Operator,
scalar_expr: &Expr,
schema: DFSchema,
) -> Result<(Expr, Operator, Expr)> {
if !is_compare_op(op) {
return Err(DataFusionError::Plan(
Expand All @@ -495,22 +500,29 @@ fn rewrite_expr_to_prunable(
}

match column_expr {
// `col > lit()`
// `col op lit()`
Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())),

// `cast(col) op lit()`
Expr::Cast { expr, data_type } => {
let from_type = expr.get_type(&schema)?;
verify_support_type_for_prune(&from_type, data_type)?;
let (left, op, right) =
rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
Ok((cast(left, data_type.clone()), op, right))
}
// `try_cast(col) op lit()`
Expr::TryCast { expr, data_type } => {
let from_type = expr.get_type(&schema)?;
verify_support_type_for_prune(&from_type, data_type)?;
let (left, op, right) =
rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?;
Ok((try_cast(left, data_type.clone()), op, right))
}
// `-col > lit()` --> `col < -lit()`
Expr::Negative(c) => match c.as_ref() {
Expr::Column(_) => Ok((
c.as_ref().clone(),
reverse_operator(op),
Expr::Negative(Box::new(scalar_expr.clone())),
)),
_ => Err(DataFusionError::Plan(format!(
"negative with complex expression {:?} is not supported",
column_expr
))),
},

Expr::Negative(c) => {
let (left, op, right) = rewrite_expr_to_prunable(c, op, scalar_expr, schema)?;
Ok((left, reverse_operator(op), Expr::Negative(Box::new(right))))
}
// `!col = true` --> `col = !true`
Expr::Not(c) => {
if op != Operator::Eq && op != Operator::NotEq {
Expand Down Expand Up @@ -551,6 +563,32 @@ fn is_compare_op(op: Operator) -> bool {
)
}

// The pruning logic is based on the comparing the min/max bounds.
// Must make sure the two type has order.
// For example, casts from string to numbers is not correct.
// Because the "13" is less than "3" with UTF8 comparison order.
fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Result<()> {
// TODO: support other data type for prunable cast or try cast
if matches!(
from_type,
DataType::Int8
| DataType::Int16
| DataType::Int32
| DataType::Int64
| DataType::Decimal128(_, _)
) && matches!(
to_type,
DataType::Int8 | DataType::Int32 | DataType::Int64 | DataType::Decimal128(_, _)
) {
Ok(())
} else {
Err(DataFusionError::Plan(format!(
"Try Cast/Cast with from type {} to type {} is not supported",
from_type, to_type
)))
}
}

/// replaces a column with an old name with a new name in an expression
fn rewrite_column_expr(
e: Expr,
Expand Down Expand Up @@ -804,10 +842,10 @@ mod tests {
datatypes::{DataType, TimeUnit},
};
use datafusion_common::ScalarValue;
use datafusion_expr::{cast, is_null};
use std::collections::HashMap;

#[derive(Debug)]

/// Mock statistic provider for tests
///
/// Each row represents the statistics for a "container" (which
Expand Down Expand Up @@ -1508,6 +1546,78 @@ mod tests {
Ok(())
}

#[test]
fn row_group_predicate_cast() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
let expected_expr =
"CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(#c1_max AS Int64)";

// test column on the left
let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1))));
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

// test column on the right
let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64));
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

let expected_expr = "TRY_CAST(#c1_max AS Int64) > Int64(1)";

// test column on the left
let expr =
try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1))));
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

// test column on the right
let expr =
lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64));
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

Ok(())
}

#[test]
fn row_group_predicate_cast_list() -> Result<()> {
let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]);
// test cast(c1 as int64) in int64(1, 2, 3)
let expr = Expr::InList {
expr: Box::new(cast(col("c1"), DataType::Int64)),
list: vec![
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Int64(Some(2))),
lit(ScalarValue::Int64(Some(3))),
],
negated: false,
};
let expected_expr = "CAST(#c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(2) AND Int64(2) <= CAST(#c1_max AS Int64) OR CAST(#c1_min AS Int64) <= Int64(3) AND Int64(3) <= CAST(#c1_max AS Int64)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

let expr = Expr::InList {
expr: Box::new(cast(col("c1"), DataType::Int64)),
list: vec![
lit(ScalarValue::Int64(Some(1))),
lit(ScalarValue::Int64(Some(2))),
lit(ScalarValue::Int64(Some(3))),
],
negated: true,
};
let expected_expr = "CAST(#c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(#c1_max AS Int64) AND CAST(#c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(#c1_max AS Int64)";
let predicate_expr =
build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new())?;
assert_eq!(format!("{:?}", predicate_expr), expected_expr);

Ok(())
}

#[test]
fn prune_decimal_data() {
// decimal(9,2)
Expand All @@ -1527,6 +1637,36 @@ mod tests {
vec![Some(5), Some(6), Some(4), None], // max
),
);
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
assert_eq!(result, expected);

// with cast column to other type
let expr = cast(col("s1"), DataType::Decimal128(14, 3))
.gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
let statistics = TestStatistics::new().with(
"s1",
ContainerStats::new_i32(
vec![Some(0), Some(4), None, Some(3)], // min
vec![Some(5), Some(6), Some(4), None], // max
),
);
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
assert_eq!(result, expected);

// with try cast column to other type
let expr = try_cast(col("s1"), DataType::Decimal128(14, 3))
.gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3)));
let statistics = TestStatistics::new().with(
"s1",
ContainerStats::new_i32(
vec![Some(0), Some(4), None, Some(3)], // min
vec![Some(5), Some(6), Some(4), None], // max
),
);
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, false, true];
Expand Down Expand Up @@ -1576,6 +1716,7 @@ mod tests {
let expected = vec![false, true, false, true];
assert_eq!(result, expected);
}

#[test]
fn prune_api() {
let schema = Arc::new(Schema::new(vec![
Expand All @@ -1599,10 +1740,16 @@ mod tests {
// No stats for s2 ==> some rows could pass
// s2 [3, None] (null max) ==> some rows could pass

let p = PruningPredicate::try_new(expr, schema).unwrap();
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, true, true];
assert_eq!(result, expected);

// filter with cast
let expr = cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5))));
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
let expected = vec![false, true, true, true];
assert_eq!(result, expected);
}

Expand Down Expand Up @@ -1852,4 +1999,100 @@ mod tests {
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);
}

#[test]
fn prune_cast_column_scalar() {
// The data type of column i is INT32
let (schema, statistics) = int32_setup();
let expected_ret = vec![true, true, false, true, true];

// i > int64(0)
let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32));
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);

// cast(i as int64) > int64(0)
let expr = cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);

// try_cast(i as int64) > int64(0)
let expr =
try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0))));
let p = PruningPredicate::try_new(expr, schema.clone()).unwrap();
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);

// `-cast(i as int64) < 0` convert to `cast(i as int64) > -0`
let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64)))
.lt(lit(ScalarValue::Int64(Some(0))));
let p = PruningPredicate::try_new(expr, schema).unwrap();
let result = p.prune(&statistics).unwrap();
assert_eq!(result, expected_ret);
}

#[test]
fn test_rewrite_expr_to_prunable() {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]);
let df_schema = DFSchema::try_from(schema).unwrap();
// column op lit
let left_input = col("a");
let right_input = lit(ScalarValue::Int32(Some(12)));
let (result_left, _, result_right) = rewrite_expr_to_prunable(
&left_input,
Operator::Eq,
&right_input,
df_schema.clone(),
)
.unwrap();
assert_eq!(result_left, left_input);
assert_eq!(result_right, right_input);
// cast op lit
let left_input = cast(col("a"), DataType::Decimal128(20, 3));
let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3));
let (result_left, _, result_right) = rewrite_expr_to_prunable(
&left_input,
Operator::Gt,
&right_input,
df_schema.clone(),
)
.unwrap();
assert_eq!(result_left, left_input);
assert_eq!(result_right, right_input);
// try_cast op lit
let left_input = try_cast(col("a"), DataType::Int64);
let right_input = lit(ScalarValue::Int64(Some(12)));
let (result_left, _, result_right) =
rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema)
.unwrap();
assert_eq!(result_left, left_input);
assert_eq!(result_right, right_input);
// TODO: add test for other case and op
}

#[test]
fn test_rewrite_expr_to_prunable_error() {
// cast string value to numeric value
// this cast is not supported
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
let df_schema = DFSchema::try_from(schema).unwrap();
let left_input = cast(col("a"), DataType::Int64);
let right_input = lit(ScalarValue::Int64(Some(12)));
let result = rewrite_expr_to_prunable(
&left_input,
Operator::Gt,
&right_input,
df_schema.clone(),
);
assert!(result.is_err());
// other expr
let left_input = is_null(col("a"));
let right_input = lit(ScalarValue::Int64(Some(12)));
let result =
rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema);
assert!(result.is_err());
// TODO: add other negative test for other case and op
}
}
13 changes: 13 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Expand Up @@ -259,6 +259,19 @@ pub fn cast(expr: Expr, data_type: DataType) -> Expr {
}
}

/// Create a try cast expression
pub fn try_cast(expr: Expr, data_type: DataType) -> Expr {
Expr::TryCast {
expr: Box::new(expr),
data_type,
}
}

/// Create is null expression
pub fn is_null(expr: Expr) -> Expr {
Expr::IsNull(Box::new(expr))
}

/// Create an convenience function representing a unary scalar function
macro_rules! unary_scalar_expr {
($ENUM:ident, $FUNC:ident, $DOC:expr) => {
Expand Down

0 comments on commit 69d05aa

Please sign in to comment.