diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 7a752e5c003c..451205e4cb39 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -18,10 +18,11 @@ use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::config::ConfigOptions; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{DataFusionError, Result}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; use datafusion_expr::{ AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, }; +use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule}; use datafusion_optimizer::optimizer::Optimizer; use datafusion_optimizer::{utils, OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; @@ -46,10 +47,18 @@ pub fn main() -> Result<()> { logical_plan.display_indent() ); - // now run the optimizer with our custom rule - let optimizer = Optimizer::with_rules(vec![Arc::new(MyRule {})]); + // run the analyzer with our custom rule let config = OptimizerContext::default().with_skip_failing_rules(false); - let optimized_plan = optimizer.optimize(&logical_plan, &config, observe)?; + let analyzer = Analyzer::with_rules(vec![Arc::new(MyAnalyzerRule {})]); + let analyzed_plan = analyzer.execute_and_check(&logical_plan, config.options())?; + println!( + "Analyzed Logical Plan:\n\n{}\n", + analyzed_plan.display_indent() + ); + + // then run the optimizer with our custom rule + let optimizer = Optimizer::with_rules(vec![Arc::new(MyOptimizerRule {})]); + let optimized_plan = optimizer.optimize(&analyzed_plan, &config, observe)?; println!( "Optimized Logical Plan:\n\n{}\n", optimized_plan.display_indent() @@ -66,11 +75,57 @@ fn observe(plan: &LogicalPlan, rule: &dyn OptimizerRule) { ) } -struct MyRule {} +/// An example analyzer rule that changes Int64 literals to UInt64 +struct MyAnalyzerRule {} + +impl AnalyzerRule for MyAnalyzerRule { + fn analyze(&self, plan: LogicalPlan, _config: &ConfigOptions) -> Result { + Self::analyze_plan(plan) + } + + fn name(&self) -> &str { + "my_analyzer_rule" + } +} + +impl MyAnalyzerRule { + fn analyze_plan(plan: LogicalPlan) -> Result { + plan.transform(&|plan| { + Ok(match plan { + LogicalPlan::Filter(filter) => { + let predicate = Self::analyze_expr(filter.predicate.clone())?; + Transformed::Yes(LogicalPlan::Filter(Filter::try_new( + predicate, + filter.input, + )?)) + } + _ => Transformed::No(plan), + }) + }) + } + + fn analyze_expr(expr: Expr) -> Result { + expr.transform(&|expr| { + // closure is invoked for all sub expressions + Ok(match expr { + Expr::Literal(ScalarValue::Int64(i)) => { + // transform to UInt64 + Transformed::Yes(Expr::Literal(ScalarValue::UInt64( + i.map(|i| i as u64), + ))) + } + _ => Transformed::No(expr), + }) + }) + } +} + +/// An example optimizer rule that rewrite BETWEEN expression to binary compare expressions +struct MyOptimizerRule {} -impl OptimizerRule for MyRule { +impl OptimizerRule for MyOptimizerRule { fn name(&self) -> &str { - "my_rule" + "my_optimizer_rule" } fn try_optimize( diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index c3adb4cc74dd..2aa0ca95db71 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -100,7 +100,10 @@ use crate::physical_optimizer::global_sort_selection::GlobalSortSelection; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::pipeline_fixer::PipelineFixer; use crate::physical_optimizer::sort_enforcement::EnforceSorting; -use datafusion_optimizer::OptimizerConfig; +use datafusion_optimizer::{ + analyzer::{Analyzer, AnalyzerRule}, + OptimizerConfig, +}; use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; @@ -1198,6 +1201,8 @@ impl QueryPlanner for DefaultQueryPlanner { pub struct SessionState { /// UUID for the session session_id: String, + /// Responsible for analyzing and rewrite a logical plan before optimization + analyzer: Analyzer, /// Responsible for optimizing a logical plan optimizer: Optimizer, /// Responsible for optimizing a physical execution plan @@ -1336,6 +1341,7 @@ impl SessionState { SessionState { session_id, + analyzer: Analyzer::new(), optimizer: Optimizer::new(), physical_optimizers, query_planner: Arc::new(DefaultQueryPlanner {}), @@ -1448,6 +1454,15 @@ impl SessionState { self } + /// Replace the analyzer rules + pub fn with_analyzer_rules( + mut self, + rules: Vec>, + ) -> Self { + self.analyzer = Analyzer::with_rules(rules); + self + } + /// Replace the optimizer rules pub fn with_optimizer_rules( mut self, @@ -1466,6 +1481,15 @@ impl SessionState { self } + /// Adds a new [`AnalyzerRule`] + pub fn add_analyzer_rule( + mut self, + analyzer_rule: Arc, + ) -> Self { + self.analyzer.rules.push(analyzer_rule); + self + } + /// Adds a new [`OptimizerRule`] pub fn add_optimizer_rule( mut self, @@ -1639,9 +1663,12 @@ impl SessionState { if let LogicalPlan::Explain(e) = plan { let mut stringified_plans = e.stringified_plans.clone(); + let analyzed_plan = self + .analyzer + .execute_and_check(e.plan.as_ref(), self.options())?; // optimize the child plan, capturing the output of each optimizer let (plan, logical_optimization_succeeded) = match self.optimizer.optimize( - e.plan.as_ref(), + &analyzed_plan, self, |optimized_plan, optimizer| { let optimizer_name = optimizer.name().to_string(); @@ -1667,7 +1694,8 @@ impl SessionState { logical_optimization_succeeded, })) } else { - self.optimizer.optimize(plan, self, |_, _| {}) + let analyzed_plan = self.analyzer.execute_and_check(plan, self.options())?; + self.optimizer.optimize(&analyzed_plan, self, |_, _| {}) } } diff --git a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs index ba19108ceb5e..ecd00d7ac15c 100644 --- a/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs +++ b/datafusion/optimizer/src/analyzer/count_wildcard_rule.rs @@ -26,6 +26,7 @@ use crate::analyzer::AnalyzerRule; /// Rewrite `Count(Expr:Wildcard)` to `Count(Expr:Literal)`. /// Resolve issue: https://github.com/apache/arrow-datafusion/issues/5473. +#[derive(Default)] pub struct CountWildcardRule {} impl CountWildcardRule { diff --git a/datafusion/optimizer/src/analyzer/mod.rs b/datafusion/optimizer/src/analyzer/mod.rs index bb9b01c8593e..b5a29a287694 100644 --- a/datafusion/optimizer/src/analyzer/mod.rs +++ b/datafusion/optimizer/src/analyzer/mod.rs @@ -15,9 +15,9 @@ // specific language governing permissions and limitations // under the License. -mod count_wildcard_rule; -mod inline_table_scan; -pub(crate) mod type_coercion; +pub mod count_wildcard_rule; +pub mod inline_table_scan; +pub mod type_coercion; use crate::analyzer::count_wildcard_rule::CountWildcardRule; use crate::analyzer::inline_table_scan::InlineTableScan; diff --git a/datafusion/optimizer/src/optimizer.rs b/datafusion/optimizer/src/optimizer.rs index 6d02c46cc0f5..77b2312bc74a 100644 --- a/datafusion/optimizer/src/optimizer.rs +++ b/datafusion/optimizer/src/optimizer.rs @@ -17,7 +17,6 @@ //! Query optimizer traits -use crate::analyzer::Analyzer; use crate::common_subexpr_eliminate::CommonSubexprEliminate; use crate::decorrelate_where_exists::DecorrelateWhereExists; use crate::decorrelate_where_in::DecorrelateWhereIn; @@ -156,7 +155,7 @@ impl OptimizerConfig for OptimizerContext { /// A rule-based optimizer. #[derive(Clone)] pub struct Optimizer { - /// All rules to apply + /// All optimizer rules to apply pub rules: Vec>, } @@ -264,8 +263,7 @@ impl Optimizer { F: FnMut(&LogicalPlan, &dyn OptimizerRule), { let options = config.options(); - // execute_and_check has it's own timer - let mut new_plan = Analyzer::default().execute_and_check(plan, options)?; + let mut new_plan = plan.clone(); let start_time = Instant::now(); diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs index 0b9134c8b84f..e58a2aaa00c9 100644 --- a/datafusion/optimizer/tests/integration-test.rs +++ b/datafusion/optimizer/tests/integration-test.rs @@ -20,8 +20,9 @@ use chrono::{DateTime, NaiveDateTime, Utc}; use datafusion_common::config::ConfigOptions; use datafusion_common::{DataFusionError, Result}; use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; +use datafusion_optimizer::analyzer::Analyzer; use datafusion_optimizer::optimizer::Optimizer; -use datafusion_optimizer::{OptimizerContext, OptimizerRule}; +use datafusion_optimizer::{OptimizerConfig, OptimizerContext, OptimizerRule}; use datafusion_sql::planner::{ContextProvider, SqlToRel}; use datafusion_sql::sqlparser::ast::Statement; use datafusion_sql::sqlparser::dialect::GenericDialect; @@ -347,8 +348,10 @@ fn test_sql(sql: &str) -> Result { let config = OptimizerContext::new() .with_skip_failing_rules(false) .with_query_execution_start_time(now_time); + let analyzer = Analyzer::new(); let optimizer = Optimizer::new(); - // optimize the logical plan + // analyze and optimize the logical plan + let plan = analyzer.execute_and_check(&plan, config.options())?; optimizer.optimize(&plan, &config, &observe) }