Skip to content

Commit

Permalink
change panic to result
Browse files Browse the repository at this point in the history
  • Loading branch information
liukun4515 committed Aug 23, 2022
1 parent eae5133 commit 8455e3c
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 44 deletions.
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::pre_cast_lit_in_binary_comparison::PreCastLitInBinaryComparisonExpressions;
use datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
Expand Down Expand Up @@ -1361,7 +1361,7 @@ impl SessionState {
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
Arc::new(PreCastLitInBinaryComparisonExpressions::new()),
Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(DecorrelateScalarSubquery::new()),
Expand Down
22 changes: 19 additions & 3 deletions datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion::physical_plan::{
};
use datafusion::prelude::*;
use datafusion::scalar::ScalarValue;
use datafusion_common::DataFusionError;
use std::ops::Deref;
use std::sync::Arc;

Expand Down Expand Up @@ -157,11 +158,26 @@ impl TableProvider for CustomProvider {
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v,
_ => unimplemented!(),
other_value => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support value {:?}",
other_value
)))
}
},
_ => unimplemented!(),
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
},
_ => unimplemented!(),
other_expr => {
return Err(DataFusionError::NotImplemented(format!(
"Do not support expr {:?}",
other_expr
)))
}
};

Ok(Arc::new(CustomPlan {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/optimizer/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;

pub mod pre_cast_lit_in_binary_comparison;
pub mod pre_cast_lit_in_comparison;
pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
//! It can reduce adding the `Expr::Cast` to the expr instead of adding the `Expr::Cast` to literal expr.
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
use datafusion_common::{DFSchemaRef, Result, ScalarValue};
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operator};

Expand All @@ -44,15 +44,15 @@ use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, Operat
/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType of c1 is INT32.
///
#[derive(Default)]
pub struct PreCastLitInBinaryComparisonExpressions {}
pub struct PreCastLitInComparisonExpressions {}

impl PreCastLitInBinaryComparisonExpressions {
impl PreCastLitInComparisonExpressions {
pub fn new() -> Self {
Self::default()
}
}

impl OptimizerRule for PreCastLitInBinaryComparisonExpressions {
impl OptimizerRule for PreCastLitInComparisonExpressions {
fn optimize(
&self,
plan: &LogicalPlan,
Expand All @@ -62,7 +62,7 @@ impl OptimizerRule for PreCastLitInBinaryComparisonExpressions {
}

fn name(&self) -> &str {
"pre_cast_lit_in_binary_comparison"
"pre_cast_lit_in_comparison"
}
}

Expand All @@ -78,24 +78,24 @@ fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
.expressions()
.into_iter()
.map(|expr| visit_expr(expr, schema))
.collect::<Vec<_>>();
.collect::<Result<Vec<_>>>()?;

from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
}

// Visit all type of expr, if the current has child expr, the child expr needed to visit first.
fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr {
fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Result<Expr> {
// traverse the expr by dfs
match &expr {
Expr::BinaryExpr { left, op, right } => {
// dfs visit the left and right expr
let left = visit_expr(*left.clone(), schema);
let right = visit_expr(*right.clone(), schema);
let left = visit_expr(*left.clone(), schema)?;
let right = visit_expr(*right.clone(), schema)?;
let left_type = left.get_type(schema);
let right_type = right.get_type(schema);
// can't get the data type, just return the expr
if left_type.is_err() || right_type.is_err() {
return expr.clone();
return Ok(expr.clone());
}
let left_type = left_type.unwrap();
let right_type = right_type.unwrap();
Expand All @@ -112,69 +112,79 @@ fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr {
if can_integer_literal_cast_to_type(
left_lit_value,
&right_type,
) =>
)? =>
{
// cast the left literal to the right type
return binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type),
return Ok(binary_expr(
cast_to_other_scalar_expr(left_lit_value, &right_type)?,
*op,
right,
);
));
}
(_, Expr::Literal(right_lit_value))
if can_integer_literal_cast_to_type(
right_lit_value,
&left_type,
) =>
)
.unwrap() =>
{
// cast the right literal to the left type
return binary_expr(
return Ok(binary_expr(
left,
*op,
cast_to_other_scalar_expr(right_lit_value, &left_type),
);
cast_to_other_scalar_expr(right_lit_value, &left_type)?,
));
}
(_, _) => {
// do nothing
}
};
}
// return the new binary op
binary_expr(left, *op, right)
Ok(binary_expr(left, *op, right))
}
// TODO: optimize in list
// Expr::InList { .. } => {}
// TODO: handle other expr type and dfs visit them
_ => expr,
_ => Ok(expr),
}
}

fn cast_to_other_scalar_expr(origin_value: &ScalarValue, target_type: &DataType) -> Expr {
fn cast_to_other_scalar_expr(
origin_value: &ScalarValue,
target_type: &DataType,
) -> Result<Expr> {
// null case
if origin_value.is_null() {
// if the origin value is null, just convert to another type of null value
// The target type must be satisfied `is_support_data_type` method, we can unwrap safely
return lit(ScalarValue::try_from(target_type).unwrap());
return Ok(lit(ScalarValue::try_from(target_type).unwrap()));
}
// no null case
let value: i64 = match origin_value {
ScalarValue::Int8(Some(v)) => *v as i64,
ScalarValue::Int16(Some(v)) => *v as i64,
ScalarValue::Int32(Some(v)) => *v as i64,
ScalarValue::Int64(Some(v)) => *v as i64,
other_type => {
panic!("Invalid type and value {:?}", other_type);
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid type and value {}",
other_value
)))
}
};
lit(match target_type {
Ok(lit(match target_type {
DataType::Int8 => ScalarValue::Int8(Some(value as i8)),
DataType::Int16 => ScalarValue::Int16(Some(value as i16)),
DataType::Int32 => ScalarValue::Int32(Some(value as i32)),
DataType::Int64 => ScalarValue::Int64(Some(value)),
other_type => {
panic!("Invalid target data type {:?}", other_type);
return Err(DataFusionError::Internal(format!(
"Invalid target data type {:?}",
other_type
)))
}
})
}))
}

fn is_comparison_op(op: &Operator) -> bool {
Expand All @@ -200,36 +210,42 @@ fn is_support_data_type(data_type: &DataType) -> bool {
fn can_integer_literal_cast_to_type(
integer_lit_value: &ScalarValue,
target_type: &DataType,
) -> bool {
) -> Result<bool> {
if integer_lit_value.is_null() {
// null value can be cast to any type of null value
return true;
return Ok(true);
}
let (target_min, target_max) = match target_type {
DataType::Int8 => (i8::MIN as i128, i8::MAX as i128),
DataType::Int16 => (i16::MIN as i128, i16::MAX as i128),
DataType::Int32 => (i32::MIN as i128, i32::MAX as i128),
DataType::Int64 => (i64::MIN as i128, i64::MAX as i128),
_ => panic!("Error target data type {:?}", target_type),
other_type => {
return Err(DataFusionError::Internal(format!(
"Error target data type {:?}",
other_type
)))
}
};
let lit_value = match integer_lit_value {
ScalarValue::Int8(Some(v)) => *v as i128,
ScalarValue::Int16(Some(v)) => *v as i128,
ScalarValue::Int32(Some(v)) => *v as i128,
ScalarValue::Int64(Some(v)) => *v as i128,
_ => {
panic!("Invalid literal value {:?}", integer_lit_value)
other_value => {
return Err(DataFusionError::Internal(format!(
"Invalid literal value {:?}",
other_value
)))
}
};
if lit_value >= target_min && lit_value <= target_max {
return true;
}
false

Ok(lit_value >= target_min && lit_value <= target_max)
}

#[cfg(test)]
mod tests {
use crate::pre_cast_lit_in_binary_comparison::visit_expr;
use crate::pre_cast_lit_in_comparison::visit_expr;
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, DFSchemaRef, ScalarValue};
use datafusion_expr::{col, lit, Expr};
Expand Down Expand Up @@ -277,7 +293,7 @@ mod tests {
}

fn optimize_test(expr: Expr, schema: &DFSchemaRef) -> Expr {
visit_expr(expr, schema)
visit_expr(expr, schema).unwrap()
}

fn expr_test_schema() -> DFSchemaRef {
Expand Down

0 comments on commit 8455e3c

Please sign in to comment.