diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 94b24e2adfee..2d5eb46804df 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -220,7 +220,8 @@ impl LogicalPlan { .. }) => match partitioning_scheme { Partitioning::Hash(expr, _) => expr.clone(), - _ => vec![], + Partitioning::DistributeBy(expr) => expr.clone(), + Partitioning::RoundRobinBatch(_) => vec![], }, LogicalPlan::Window(Window { window_expr, .. }) => window_expr.clone(), LogicalPlan::Aggregate(Aggregate { diff --git a/datafusion/optimizer/Cargo.toml b/datafusion/optimizer/Cargo.toml index b1a64384c83b..277f2d95f9cb 100644 --- a/datafusion/optimizer/Cargo.toml +++ b/datafusion/optimizer/Cargo.toml @@ -48,5 +48,6 @@ log = "^0.4" [dev-dependencies] ctor = "0.1.22" +datafusion-sql = { path = "../sql", version = "11.0.0" } env_logger = "0.9.0" diff --git a/datafusion/optimizer/tests/integration-test.rs b/datafusion/optimizer/tests/integration-test.rs new file mode 100644 index 000000000000..b9d4d3b6333c --- /dev/null +++ b/datafusion/optimizer/tests/integration-test.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use datafusion_common::{DataFusionError, Result}; +use datafusion_expr::{AggregateUDF, LogicalPlan, ScalarUDF, TableSource}; +use datafusion_optimizer::common_subexpr_eliminate::CommonSubexprEliminate; +use datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery; +use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists; +use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn; +use datafusion_optimizer::eliminate_filter::EliminateFilter; +use datafusion_optimizer::eliminate_limit::EliminateLimit; +use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys; +use datafusion_optimizer::filter_push_down::FilterPushDown; +use datafusion_optimizer::limit_push_down::LimitPushDown; +use datafusion_optimizer::optimizer::Optimizer; +use datafusion_optimizer::projection_push_down::ProjectionPushDown; +use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin; +use datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate; +use datafusion_optimizer::simplify_expressions::SimplifyExpressions; +use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy; +use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin; +use datafusion_optimizer::{OptimizerConfig, OptimizerRule}; +use datafusion_sql::planner::{ContextProvider, SqlToRel}; +use datafusion_sql::sqlparser::ast::Statement; +use datafusion_sql::sqlparser::dialect::GenericDialect; +use datafusion_sql::sqlparser::parser::Parser; +use datafusion_sql::TableReference; +use std::any::Any; +use std::collections::HashMap; +use std::sync::Arc; + +#[test] +fn distribute_by() -> Result<()> { + // regression test for https://github.com/apache/arrow-datafusion/issues/3234 + let sql = "SELECT col_int32, col_utf8 FROM test DISTRIBUTE BY (col_utf8)"; + let plan = test_sql(sql)?; + let expected = "Repartition: DistributeBy(#col_utf8)\ + \n Projection: #test.col_int32, #test.col_utf8\ + \n TableScan: test projection=[col_int32, col_utf8]"; + assert_eq!(expected, format!("{:?}", plan)); + Ok(()) +} + +fn test_sql(sql: &str) -> Result { + let rules: Vec> = vec![ + // Simplify expressions first to maximize the chance + // of applying other optimizations + Arc::new(SimplifyExpressions::new()), + Arc::new(DecorrelateWhereExists::new()), + Arc::new(DecorrelateWhereIn::new()), + Arc::new(DecorrelateScalarSubquery::new()), + Arc::new(SubqueryFilterToJoin::new()), + Arc::new(EliminateFilter::new()), + Arc::new(CommonSubexprEliminate::new()), + Arc::new(EliminateLimit::new()), + Arc::new(ProjectionPushDown::new()), + Arc::new(RewriteDisjunctivePredicate::new()), + Arc::new(FilterNullJoinKeys::default()), + Arc::new(ReduceOuterJoin::new()), + Arc::new(FilterPushDown::new()), + Arc::new(LimitPushDown::new()), + Arc::new(SingleDistinctToGroupBy::new()), + ]; + + let optimizer = Optimizer::new(rules); + + // parse the SQL + let dialect = GenericDialect {}; // or AnsiDialect, or your own dialect ... + let ast: Vec = Parser::parse_sql(&dialect, sql).unwrap(); + let statement = &ast[0]; + + // create a logical query plan + let schema_provider = MySchemaProvider {}; + let sql_to_rel = SqlToRel::new(&schema_provider); + let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap(); + + // optimize the logical plan + let mut config = OptimizerConfig::new().with_skip_failing_rules(false); + optimizer.optimize(&plan, &mut config, &observe) +} + +struct MySchemaProvider {} + +impl ContextProvider for MySchemaProvider { + fn get_table_provider( + &self, + name: TableReference, + ) -> datafusion_common::Result> { + let table_name = name.table(); + if table_name.starts_with("test") { + let schema = Schema::new_with_metadata( + vec![ + Field::new("col_int32", DataType::Int32, true), + Field::new("col_utf8", DataType::Utf8, true), + ], + HashMap::new(), + ); + + Ok(Arc::new(MyTableSource { + schema: Arc::new(schema), + })) + } else { + Err(DataFusionError::Plan("table does not exist".to_string())) + } + } + + fn get_function_meta(&self, _name: &str) -> Option> { + None + } + + fn get_aggregate_meta(&self, _name: &str) -> Option> { + None + } + + fn get_variable_type(&self, _variable_names: &[String]) -> Option { + None + } +} + +fn observe(_plan: &LogicalPlan, _rule: &dyn OptimizerRule) {} + +struct MyTableSource { + schema: SchemaRef, +} + +impl TableSource for MyTableSource { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + self.schema.clone() + } +}