From 561c545b30004333a3b101d06ca464f517c69405 Mon Sep 17 00:00:00 2001 From: NGA-TRAN Date: Thu, 8 Dec 2022 21:37:54 -0500 Subject: [PATCH 1/4] feat: prepare logical plan to logicl plan without params/placeholders --- datafusion/core/tests/sql/select.rs | 68 +++ datafusion/expr/src/logical_plan/plan.rs | 691 ++++++++++++++++++++++- datafusion/sql/src/planner.rs | 239 ++++++-- 3 files changed, 957 insertions(+), 41 deletions(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 5a56247e4f59..85d03832c8b2 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -20,6 +20,7 @@ use datafusion::{ datasource::empty::EmptyTable, from_slice::FromSlice, physical_plan::collect_partitioned, }; +use datafusion_common::ScalarValue; use tempfile::TempDir; #[tokio::test] @@ -1257,6 +1258,73 @@ async fn csv_join_unaliased_subqueries() -> Result<()> { Ok(()) } +// Test prepare statement from sql to final result +// This test is equivalent with the test parallel_query_with_filter below but using prepare statement +#[tokio::test] +async fn test_prepare_statement() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to prepare logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 + let logical_plan = + ctx.create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")?; + + // prepare logical plan to logical plan without parameters + let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; + let logical_plan = LogicalPlan::execute(logical_plan, param_values)?; + + // logical plan to optimized logical plan + let logical_plan = ctx.optimize(&logical_plan)?; + + // optimized logical plan to physical plan + let physical_plan = ctx.create_physical_plan(&logical_plan).await?; + + let task_ctx = ctx.task_ctx(); + let results = collect_partitioned(physical_plan, task_ctx).await?; + + // note that the order of partitions is not deterministic + let mut num_rows = 0; + for partition in &results { + for batch in partition { + num_rows += batch.num_rows(); + } + } + assert_eq!(20, num_rows); + + let results: Vec = results.into_iter().flatten().collect(); + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 10 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 2 | 1 |", + "| 2 | 10 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7f38e7dbb2ef..e664a3a4c49b 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -23,10 +23,14 @@ use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::{ exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, }; -use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; +use crate::{ + Between, Case, Cast, Expr, ExprSchemable, GetIndexedField, GroupingSet, Like, + TableProviderFilterPushDown, TableSource, +}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + ScalarValue, }; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; @@ -364,6 +368,42 @@ impl LogicalPlan { ) -> Result { from_plan(self, &self.expressions(), inputs) } + + /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values + pub fn execute( + logical_plan: LogicalPlan, + param_values: Vec, + ) -> Result { + match logical_plan { + LogicalPlan::Prepare(prepare_lp) => { + // Verify if the number of params matches the number of values + if prepare_lp.data_types.len() != param_values.len() { + return Err(DataFusionError::Internal(format!( + "Expected {} parameters, got {}", + prepare_lp.data_types.len(), + param_values.len() + ))); + } + + // Verify if the types of the params matches the types of the values + for (param_type, value) in + prepare_lp.data_types.iter().zip(param_values.iter()) + { + if *param_type != value.get_datatype() { + return Err(DataFusionError::Internal(format!( + "Expected parameter of type {:?}, got {:?}", + param_type, + value.get_datatype() + ))); + } + } + + let input_plan = prepare_lp.input; + input_plan.replace_params_with_values(¶m_values) + } + _ => Ok(logical_plan), + } + } } /// Trait that implements the [Visitor @@ -534,6 +574,655 @@ impl LogicalPlan { _ => {} } } + + /// recursively to replace the params (e.g $1 $2, ...) wit corresponding values provided in the prams_values + pub fn replace_params_with_values( + &self, + param_values: &Vec, + ) -> Result { + match self { + LogicalPlan::Projection(Projection { + expr, + input, + schema, + }) => { + let expr = &mut expr + .iter() + .map(|e| { + Self::replace_placeholders_with_values(e.clone(), param_values) + }) + .collect::, _>>()?; + + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Projection(Projection { + expr: expr.clone(), + input: Arc::new(input), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Filter(Filter { predicate, input }) => { + let predicate = Self::replace_placeholders_with_values( + predicate.clone(), + param_values, + )?; + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Filter(Filter { + predicate, + input: Arc::new(input), + })) + } + LogicalPlan::Repartition(Repartition { + input, + partitioning_scheme, + }) => { + let input = input.replace_params_with_values(param_values)?; + // Even though the `partitioning` member of Repartition include expresions , they are internal ones and should not include params + // Hence no need to look for placeholders and replace them + Ok(LogicalPlan::Repartition(Repartition { + input: Arc::new(input), + partitioning_scheme: partitioning_scheme.clone(), + })) + } + LogicalPlan::Window(Window { + input, + window_expr, + schema, + }) => { + let input = input.replace_params_with_values(param_values)?; + let window_expr = &mut window_expr + .iter() + .map(|e| { + Self::replace_placeholders_with_values(e.clone(), param_values) + }) + .collect::, _>>()?; + Ok(LogicalPlan::Window(Window { + input: Arc::new(input), + window_expr: window_expr.clone(), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Aggregate(Aggregate { + input, + group_expr, + aggr_expr, + schema, + }) => { + let input = input.replace_params_with_values(param_values)?; + let group_expr = &mut group_expr + .iter() + .map(|e| { + Self::replace_placeholders_with_values(e.clone(), param_values) + }) + .collect::, _>>()?; + let aggr_expr = &mut aggr_expr + .iter() + .map(|e| { + Self::replace_placeholders_with_values(e.clone(), param_values) + }) + .collect::, _>>()?; + Ok(LogicalPlan::Aggregate(Aggregate { + input: Arc::new(input), + group_expr: group_expr.clone(), + aggr_expr: aggr_expr.clone(), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Sort(Sort { input, expr, fetch }) => { + let input = input.replace_params_with_values(param_values)?; + let expr = &mut expr + .iter() + .map(|e| { + Self::replace_placeholders_with_values(e.clone(), param_values) + }) + .collect::, _>>()?; + Ok(LogicalPlan::Sort(Sort { + input: Arc::new(input), + expr: expr.clone(), + fetch: *fetch, + })) + } + LogicalPlan::Join(Join { + left, + right, + filter, + on, + join_type, + join_constraint, + schema, + null_equals_null, + }) => { + let left = left.replace_params_with_values(param_values)?; + let fright = right.replace_params_with_values(param_values)?; + let filter = filter.clone().map(|f| { + Self::replace_placeholders_with_values(f, param_values) + .expect("Failed to replace params in join filter") + }); + Ok(LogicalPlan::Join(Join { + left: Arc::new(left), + right: Arc::new(fright), + filter, + on: on.clone(), + join_type: *join_type, + join_constraint: *join_constraint, + schema: Arc::clone(schema), + null_equals_null: *null_equals_null, + })) + } + LogicalPlan::CrossJoin(CrossJoin { + left, + right, + schema, + }) => { + let left = left.replace_params_with_values(param_values)?; + let right = right.replace_params_with_values(param_values)?; + Ok(LogicalPlan::CrossJoin(CrossJoin { + left: Arc::new(left), + right: Arc::new(right), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Limit(Limit { input, skip, fetch }) => { + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Limit(Limit { + input: Arc::new(input), + skip: *skip, + fetch: *fetch, + })) + } + LogicalPlan::Subquery(Subquery { subquery }) => { + let subquery = subquery.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Subquery(Subquery { + subquery: Arc::new(subquery), + })) + } + LogicalPlan::SubqueryAlias(SubqueryAlias { + input, + alias, + schema, + }) => { + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { + input: Arc::new(input), + alias: alias.clone(), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Extension(Extension { node }) => { + // Currently only support params in standard SQL + // and extesion should not have any params + Ok(LogicalPlan::Extension(Extension { node: node.clone() })) + } + LogicalPlan::Union(Union { inputs, schema }) => { + let inputs = inputs + .iter() + .map(|input| input.replace_params_with_values(param_values)) + .collect::, _>>()?; + Ok(LogicalPlan::Union(Union { + inputs: inputs.into_iter().map(Arc::new).collect(), + schema: Arc::clone(schema), + })) + } + LogicalPlan::Distinct(Distinct { input }) => { + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Distinct(Distinct { + input: Arc::new(input), + })) + } + LogicalPlan::Prepare(Prepare { + name, + data_types, + input, + }) => { + let input = input.replace_params_with_values(param_values)?; + Ok(LogicalPlan::Prepare(Prepare { + name: name.clone(), + data_types: data_types.clone(), + input: Arc::new(input), + })) + } + // plans without inputs + LogicalPlan::TableScan(TableScan { + table_name, + source, + projection, + projected_schema, + filters, + fetch, + }) => { + let filters = filters + .iter() + .map(|f| { + Self::replace_placeholders_with_values(f.clone(), param_values) + .expect("Failed to replace params in table scan filter") + }) + .collect(); + Ok(LogicalPlan::TableScan(TableScan { + table_name: table_name.clone(), + source: Arc::clone(source), + projection: projection.clone(), + projected_schema: Arc::clone(projected_schema), + filters, + fetch: *fetch, + })) + } + LogicalPlan::Values(Values { values, schema }) => { + let values = values + .iter() + .map(|row| { + row.iter() + .map(|expr| { + Self::replace_placeholders_with_values( + expr.clone(), + param_values, + ) + .expect("Failed to replace params in values") + }) + .collect() + }) + .collect(); + Ok(LogicalPlan::Values(Values { + values, + schema: Arc::clone(schema), + })) + } + LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row, + schema, + }) => Ok(LogicalPlan::EmptyRelation(EmptyRelation { + produce_one_row: *produce_one_row, + schema: Arc::clone(schema), + })), + LogicalPlan::Explain(_) + | LogicalPlan::Analyze(_) + | LogicalPlan::CreateMemoryTable(_) + | LogicalPlan::CreateView(_) + | LogicalPlan::CreateExternalTable(_) + | LogicalPlan::CreateCatalogSchema(_) + | LogicalPlan::CreateCatalog(_) + | LogicalPlan::DropTable(_) + | LogicalPlan::SetVariable(_) + | LogicalPlan::DropView(_) => Err::( + DataFusionError::NotImplemented(format!( + "This logical plan should not contain parameters/placeholder: {}", + self.display() + )), + ), + } + } + + /// Recrusively to walk the expression and convert a placeholder into a literal value + fn replace_placeholders_with_values( + expr: Expr, + param_values: &Vec, + ) -> Result { + match expr { + Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => Ok(expr), + Expr::Alias(expr, name) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::Alias(Box::new(expr), name)) + } + Expr::BinaryExpr(BinaryExpr { left, op, right }) => { + let left = Self::replace_placeholders_with_values(*left, param_values)?; + let right = Self::replace_placeholders_with_values(*right, param_values)?; + Ok(Expr::BinaryExpr(BinaryExpr { + left: Box::new(left), + op, + right: Box::new(right), + })) + } + Expr::Case(case) => { + let expr = match case.expr { + Some(expr) => Some(Box::new(Self::replace_placeholders_with_values( + *expr, + param_values, + )?)), + None => None, + }; + let mut when_then_expr = vec![]; + for (w, t) in case.when_then_expr { + let w = Self::replace_placeholders_with_values(*w, param_values)?; + let t = Self::replace_placeholders_with_values(*t, param_values)?; + when_then_expr.push((Box::new(w), Box::new(t))); + } + let else_expr = match case.else_expr { + Some(expr) => Some(Box::new(Self::replace_placeholders_with_values( + *expr, + param_values, + )?)), + None => None, + }; + Ok(Expr::Case(Case { + expr, + when_then_expr, + else_expr, + })) + } + Expr::Cast(Cast { expr, data_type }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::Cast(Cast { + expr: Box::new(expr), + data_type, + })) + } + Expr::TryCast { expr, data_type } => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::TryCast { + expr: Box::new(expr), + data_type, + }) + } + Expr::Not(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::Not(Box::new(expr))) + } + Expr::Negative(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::Negative(Box::new(expr))) + } + Expr::IsNull(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsNull(Box::new(expr))) + } + Expr::IsNotNull(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsNotNull(Box::new(expr))) + } + Expr::IsTrue(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsTrue(Box::new(expr))) + } + Expr::IsFalse(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsFalse(Box::new(expr))) + } + Expr::IsUnknown(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsUnknown(Box::new(expr))) + } + Expr::IsNotTrue(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsNotTrue(Box::new(expr))) + } + Expr::IsNotFalse(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsNotFalse(Box::new(expr))) + } + Expr::IsNotUnknown(expr) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::IsNotUnknown(Box::new(expr))) + } + Expr::GetIndexedField(GetIndexedField { key, expr }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::GetIndexedField(GetIndexedField { + key, + expr: Box::new(expr), + })) + } + Expr::ScalarFunction { fun, args, .. } => { + let new_args = args + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + Ok(Expr::ScalarFunction { + fun, + args: new_args, + }) + } + Expr::ScalarUDF { fun, args } => { + let new_args = args + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + Ok(Expr::ScalarUDF { + fun, + args: new_args, + }) + } + Expr::WindowFunction { + fun, + args, + partition_by, + order_by, + window_frame, + } => { + let new_args = args + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + let new_partition_by = partition_by + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + let new_order_by = order_by + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + Ok(Expr::WindowFunction { + fun, + args: new_args, + partition_by: new_partition_by, + order_by: new_order_by, + window_frame, + }) + } + Expr::AggregateFunction { + fun, + distinct, + args, + filter, + } => { + let new_args = args + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + let new_filter = match filter { + Some(filter) => Some(Box::new( + Self::replace_placeholders_with_values(*filter, param_values)?, + )), + None => None, + }; + Ok(Expr::AggregateFunction { + fun, + distinct, + args: new_args, + filter: new_filter, + }) + } + Expr::AggregateUDF { fun, args, filter } => { + let new_args = args + .into_iter() + .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) + .collect::, _>>()?; + let new_filter = match filter { + Some(filter) => Some(Box::new( + Self::replace_placeholders_with_values(*filter, param_values)?, + )), + None => None, + }; + Ok(Expr::AggregateUDF { + fun, + args: new_args, + filter: new_filter, + }) + } + Expr::GroupingSet(grouping_set) => match grouping_set { + GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( + exprs + .into_iter() + .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .collect::, _>>()?, + ))), + GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube( + exprs + .into_iter() + .map(|e| Self::replace_placeholders_with_values(e, param_values)) + .collect::, _>>()?, + ))), + GroupingSet::GroupingSets(exprs) => { + Ok(Expr::GroupingSet(GroupingSet::GroupingSets( + exprs + .into_iter() + .map(|e| { + e.into_iter() + .map(|e| { + Self::replace_placeholders_with_values( + e, + param_values, + ) + }) + .collect::, _>>() + }) + .collect::, _>>()?, + ))) + } + }, + Expr::InList { + expr, + list, + negated, + } => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + let list = list + .into_iter() + .map(|expr| { + Self::replace_placeholders_with_values(expr, param_values) + }) + .collect::, _>>()?; + Ok(Expr::InList { + expr: Box::new(expr), + list, + negated, + }) + } + Expr::Exists { subquery, negated } => { + subquery.subquery.replace_params_with_values(param_values)?; + Ok(Expr::Exists { subquery, negated }) + } + Expr::InSubquery { + expr, + subquery, + negated, + } => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + subquery.subquery.replace_params_with_values(param_values)?; + Ok(Expr::InSubquery { + expr: Box::new(expr), + subquery, + negated, + }) + } + Expr::ScalarSubquery(_) => Err(DataFusionError::NotImplemented( + "Scalar subqueries are not yet supported in the physical plan" + .to_string(), + )), + Expr::Between(Between { + expr, + negated, + low, + high, + }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + let low = Self::replace_placeholders_with_values(*low, param_values)?; + let high = Self::replace_placeholders_with_values(*high, param_values)?; + Ok(Expr::Between(Between { + expr: Box::new(expr), + negated, + low: Box::new(low), + high: Box::new(high), + })) + } + Expr::Like(Like { + negated, + expr, + pattern, + escape_char, + }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + let pattern = + Self::replace_placeholders_with_values(*pattern, param_values)?; + Ok(Expr::Like(Like { + negated, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + })) + } + Expr::ILike(Like { + negated, + expr, + pattern, + escape_char, + }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + let pattern = + Self::replace_placeholders_with_values(*pattern, param_values)?; + Ok(Expr::ILike(Like { + negated, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + })) + } + Expr::SimilarTo(Like { + negated, + expr, + pattern, + escape_char, + }) => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + let pattern = + Self::replace_placeholders_with_values(*pattern, param_values)?; + Ok(Expr::SimilarTo(Like { + negated, + expr: Box::new(expr), + pattern: Box::new(pattern), + escape_char, + })) + } + Expr::Sort { + expr, + asc, + nulls_first, + } => { + let expr = Self::replace_placeholders_with_values(*expr, param_values)?; + Ok(Expr::Sort { + expr: Box::new(expr), + asc, + nulls_first, + }) + } + Expr::Wildcard => Ok(Expr::Wildcard), + Expr::QualifiedWildcard { qualifier } => { + Ok(Expr::QualifiedWildcard { qualifier }) + } + Expr::Placeholder { id, data_type } => { + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {}", + e + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = param_values.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {}", + id + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if value.get_datatype() != data_type { + return Err(DataFusionError::Internal(format!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.get_datatype() + ))); + } + // Replace the placeholder with the value + Ok(Expr::Literal(value.clone())) + } + } + } } // Various implementations for printing out LogicalPlans diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 82d8c3834294..2d7506db2fbd 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5421,15 +5421,32 @@ mod tests { sql: &str, expected_plan: &str, expected_data_types: &str, - ) { + ) -> LogicalPlan { let plan = logical_plan(sql).unwrap(); + + let assert_plan = plan.clone(); // verify plan - assert_eq!(format!("{:?}", plan), expected_plan); + assert_eq!(format!("{:?}", assert_plan), expected_plan); + // verify data types - if let LogicalPlan::Prepare(Prepare { data_types, .. }) = plan { + if let LogicalPlan::Prepare(Prepare { data_types, .. }) = assert_plan { let dt = format!("{:?}", data_types); assert_eq!(dt, expected_data_types); } + + plan + } + + fn prepare_stmt_replace_params_quick_test( + plan: LogicalPlan, + param_values: Vec, + expected_plan: &str, + ) -> LogicalPlan { + // replace params + let plan = LogicalPlan::execute(plan, param_values).unwrap(); + assert_eq!(format!("{:?}", plan), expected_plan); + + plan } struct MockContextProvider {} @@ -6239,11 +6256,7 @@ mod tests { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6252,11 +6265,7 @@ mod tests { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6265,11 +6274,7 @@ mod tests { )] fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6279,11 +6284,7 @@ mod tests { fn test_prepare_statement_to_plan_panic_no_data_types() { // only provide 1 data type while using 2 params let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6292,11 +6293,7 @@ mod tests { )] fn test_prepare_statement_to_plan_panic_is_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6311,9 +6308,18 @@ mod tests { let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; - ///////////////////////// + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + ////////////////////////////////////////// // no embedded parameter and no declare it let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; @@ -6324,7 +6330,54 @@ mod tests { let expected_dt = "[]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic(expected = "value: Internal(\"Expected 1 parameters, got 0\")")] + fn test_prepare_statement_to_plan_one_param_no_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic( + expected = "value: Internal(\"Expected parameter of type Int32, got Float64\")" + )] + fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Float64(Some(20.0))]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic(expected = "value: Internal(\"Expected 0 parameters, got 1\")")] + fn test_prepare_statement_to_plan_no_param_on_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6335,25 +6388,50 @@ mod tests { \n Projection: $1\n EmptyRelation"; let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - ///////////////////////// + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: Int32(10)\n EmptyRelation"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + /////////////////////////////////////// let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; let expected_plan = "Prepare: \"my_plan\" [Int32] \ \n Projection: Int64(1) + $1\n EmptyRelation"; let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation"; - ///////////////////////// + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + /////////////////////////////////////// let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \ \n Projection: Int64(1) + $1 + $2\n EmptyRelation"; let expected_dt = "[Int32, Float64]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(10.0)), + ]; + let expected_plan = + "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6367,7 +6445,41 @@ mod tests { let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int32(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + fn test_prepare_statement_to_plan_data_type() { + let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; + + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + let expected_plan = "Prepare: \"my_plan\" [Float64] \ + \n Projection: person.id, person.age\ + \n Filter: person.age = $1\ + \n TableScan: person"; + + let expected_dt = "[Float64]"; + + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values still succeed and use Float64 + let param_values = vec![ScalarValue::Float64(Some(10.0))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Float64(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6384,7 +6496,24 @@ mod tests { let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Int32(Some(20)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Utf8(Some("xyz".to_string())), + ]; + let expected_plan = + "Projection: person.id, person.age, Utf8(\"xyz\")\ + \n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6406,7 +6535,24 @@ mod tests { let expected_dt = "[Int32, Float64, Float64, Float64]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Float64(Some(300.0)), + ]; + let expected_plan = + "Projection: person.id, SUM(person.age)\ + \n Filter: SUM(person.age) < Int32(10) AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\ + \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\ + \n Filter: person.salary > Float64(100)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6421,7 +6567,20 @@ mod tests { let expected_dt = "[Utf8, Utf8]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Utf8(Some("a".to_string())), + ScalarValue::Utf8(Some("b".to_string())), + ]; + let expected_plan = "Projection: num, letter\ + \n Projection: t.column1 AS num, t.column2 AS letter\ + \n SubqueryAlias: t\ + \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] From 98f71623c9da65db9ad4d6778a94a5c96a445835 Mon Sep 17 00:00:00 2001 From: NGA-TRAN Date: Thu, 8 Dec 2022 21:55:12 -0500 Subject: [PATCH 2/4] fix: typo --- datafusion/expr/src/logical_plan/plan.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index e664a3a4c49b..55f1bda07d04 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -575,7 +575,7 @@ impl LogicalPlan { } } - /// recursively to replace the params (e.g $1 $2, ...) wit corresponding values provided in the prams_values + /// recursively to replace the params (e.g $1 $2, ...) with corresponding values provided in the prams_values pub fn replace_params_with_values( &self, param_values: &Vec, From 9800b38f0884825a3d22f4fdb66c926fa52dfb2a Mon Sep 17 00:00:00 2001 From: NGA-TRAN Date: Fri, 9 Dec 2022 14:37:40 -0500 Subject: [PATCH 3/4] refactor: address review comments --- datafusion/core/tests/sql/select.rs | 2 +- datafusion/expr/src/logical_plan/plan.rs | 700 ++--------------------- datafusion/sql/src/planner.rs | 2 +- 3 files changed, 60 insertions(+), 644 deletions(-) diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 85d03832c8b2..c82eee7d0f7c 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -1273,7 +1273,7 @@ async fn test_prepare_statement() -> Result<()> { // prepare logical plan to logical plan without parameters let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; - let logical_plan = LogicalPlan::execute(logical_plan, param_values)?; + let logical_plan = logical_plan.with_param_values(param_values)?; // logical plan to optimized logical plan let logical_plan = ctx.optimize(&logical_plan)?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 55f1bda07d04..8b4f96809953 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -16,17 +16,16 @@ // under the License. use crate::expr::BinaryExpr; +use crate::expr_rewriter::{ExprRewritable, ExprRewriter}; ///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::{ - exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, -}; -use crate::{ - Between, Case, Cast, Expr, ExprSchemable, GetIndexedField, GroupingSet, Like, - TableProviderFilterPushDown, TableSource, + self, exprlist_to_fields, from_plan, grouping_set_expr_count, + grouping_set_to_exprlist, }; +use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, @@ -370,11 +369,11 @@ impl LogicalPlan { } /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values - pub fn execute( - logical_plan: LogicalPlan, + pub fn with_param_values( + self, param_values: Vec, ) -> Result { - match logical_plan { + match self { LogicalPlan::Prepare(prepare_lp) => { // Verify if the number of params matches the number of values if prepare_lp.data_types.len() != param_values.len() { @@ -401,7 +400,7 @@ impl LogicalPlan { let input_plan = prepare_lp.input; input_plan.replace_params_with_values(¶m_values) } - _ => Ok(logical_plan), + _ => Ok(self), } } } @@ -575,653 +574,70 @@ impl LogicalPlan { } } - /// recursively to replace the params (e.g $1 $2, ...) with corresponding values provided in the prams_values + /// Return a logical plan with all placeholders/params (e.g $1 $2, ...) replaced with corresponding values provided in the prams_values pub fn replace_params_with_values( &self, param_values: &Vec, ) -> Result { - match self { - LogicalPlan::Projection(Projection { - expr, - input, - schema, - }) => { - let expr = &mut expr - .iter() - .map(|e| { - Self::replace_placeholders_with_values(e.clone(), param_values) - }) - .collect::, _>>()?; - - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Projection(Projection { - expr: expr.clone(), - input: Arc::new(input), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Filter(Filter { predicate, input }) => { - let predicate = Self::replace_placeholders_with_values( - predicate.clone(), - param_values, - )?; - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Filter(Filter { - predicate, - input: Arc::new(input), - })) - } - LogicalPlan::Repartition(Repartition { - input, - partitioning_scheme, - }) => { - let input = input.replace_params_with_values(param_values)?; - // Even though the `partitioning` member of Repartition include expresions , they are internal ones and should not include params - // Hence no need to look for placeholders and replace them - Ok(LogicalPlan::Repartition(Repartition { - input: Arc::new(input), - partitioning_scheme: partitioning_scheme.clone(), - })) - } - LogicalPlan::Window(Window { - input, - window_expr, - schema, - }) => { - let input = input.replace_params_with_values(param_values)?; - let window_expr = &mut window_expr - .iter() - .map(|e| { - Self::replace_placeholders_with_values(e.clone(), param_values) - }) - .collect::, _>>()?; - Ok(LogicalPlan::Window(Window { - input: Arc::new(input), - window_expr: window_expr.clone(), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Aggregate(Aggregate { - input, - group_expr, - aggr_expr, - schema, - }) => { - let input = input.replace_params_with_values(param_values)?; - let group_expr = &mut group_expr - .iter() - .map(|e| { - Self::replace_placeholders_with_values(e.clone(), param_values) - }) - .collect::, _>>()?; - let aggr_expr = &mut aggr_expr - .iter() - .map(|e| { - Self::replace_placeholders_with_values(e.clone(), param_values) - }) - .collect::, _>>()?; - Ok(LogicalPlan::Aggregate(Aggregate { - input: Arc::new(input), - group_expr: group_expr.clone(), - aggr_expr: aggr_expr.clone(), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Sort(Sort { input, expr, fetch }) => { - let input = input.replace_params_with_values(param_values)?; - let expr = &mut expr - .iter() - .map(|e| { - Self::replace_placeholders_with_values(e.clone(), param_values) - }) - .collect::, _>>()?; - Ok(LogicalPlan::Sort(Sort { - input: Arc::new(input), - expr: expr.clone(), - fetch: *fetch, - })) - } - LogicalPlan::Join(Join { - left, - right, - filter, - on, - join_type, - join_constraint, - schema, - null_equals_null, - }) => { - let left = left.replace_params_with_values(param_values)?; - let fright = right.replace_params_with_values(param_values)?; - let filter = filter.clone().map(|f| { - Self::replace_placeholders_with_values(f, param_values) - .expect("Failed to replace params in join filter") - }); - Ok(LogicalPlan::Join(Join { - left: Arc::new(left), - right: Arc::new(fright), - filter, - on: on.clone(), - join_type: *join_type, - join_constraint: *join_constraint, - schema: Arc::clone(schema), - null_equals_null: *null_equals_null, - })) - } - LogicalPlan::CrossJoin(CrossJoin { - left, - right, - schema, - }) => { - let left = left.replace_params_with_values(param_values)?; - let right = right.replace_params_with_values(param_values)?; - Ok(LogicalPlan::CrossJoin(CrossJoin { - left: Arc::new(left), - right: Arc::new(right), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Limit(Limit { input, skip, fetch }) => { - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Limit(Limit { - input: Arc::new(input), - skip: *skip, - fetch: *fetch, - })) - } - LogicalPlan::Subquery(Subquery { subquery }) => { - let subquery = subquery.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Subquery(Subquery { - subquery: Arc::new(subquery), - })) - } - LogicalPlan::SubqueryAlias(SubqueryAlias { - input, - alias, - schema, - }) => { - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::SubqueryAlias(SubqueryAlias { - input: Arc::new(input), - alias: alias.clone(), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Extension(Extension { node }) => { - // Currently only support params in standard SQL - // and extesion should not have any params - Ok(LogicalPlan::Extension(Extension { node: node.clone() })) - } - LogicalPlan::Union(Union { inputs, schema }) => { - let inputs = inputs - .iter() - .map(|input| input.replace_params_with_values(param_values)) - .collect::, _>>()?; - Ok(LogicalPlan::Union(Union { - inputs: inputs.into_iter().map(Arc::new).collect(), - schema: Arc::clone(schema), - })) - } - LogicalPlan::Distinct(Distinct { input }) => { - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Distinct(Distinct { - input: Arc::new(input), - })) - } - LogicalPlan::Prepare(Prepare { - name, - data_types, - input, - }) => { - let input = input.replace_params_with_values(param_values)?; - Ok(LogicalPlan::Prepare(Prepare { - name: name.clone(), - data_types: data_types.clone(), - input: Arc::new(input), - })) - } - // plans without inputs - LogicalPlan::TableScan(TableScan { - table_name, - source, - projection, - projected_schema, - filters, - fetch, - }) => { - let filters = filters - .iter() - .map(|f| { - Self::replace_placeholders_with_values(f.clone(), param_values) - .expect("Failed to replace params in table scan filter") - }) - .collect(); - Ok(LogicalPlan::TableScan(TableScan { - table_name: table_name.clone(), - source: Arc::clone(source), - projection: projection.clone(), - projected_schema: Arc::clone(projected_schema), - filters, - fetch: *fetch, - })) - } - LogicalPlan::Values(Values { values, schema }) => { - let values = values - .iter() - .map(|row| { - row.iter() - .map(|expr| { - Self::replace_placeholders_with_values( - expr.clone(), - param_values, - ) - .expect("Failed to replace params in values") - }) - .collect() - }) - .collect(); - Ok(LogicalPlan::Values(Values { - values, - schema: Arc::clone(schema), - })) - } - LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row, - schema, - }) => Ok(LogicalPlan::EmptyRelation(EmptyRelation { - produce_one_row: *produce_one_row, - schema: Arc::clone(schema), - })), - LogicalPlan::Explain(_) - | LogicalPlan::Analyze(_) - | LogicalPlan::CreateMemoryTable(_) - | LogicalPlan::CreateView(_) - | LogicalPlan::CreateExternalTable(_) - | LogicalPlan::CreateCatalogSchema(_) - | LogicalPlan::CreateCatalog(_) - | LogicalPlan::DropTable(_) - | LogicalPlan::SetVariable(_) - | LogicalPlan::DropView(_) => Err::( - DataFusionError::NotImplemented(format!( - "This logical plan should not contain parameters/placeholder: {}", - self.display() - )), - ), + let exprs = self.expressions(); + let mut new_exprs = vec![]; + for expr in exprs { + new_exprs.push(Self::replace_placeholders_with_values(expr, param_values)?); } + + let new_inputs = self.inputs(); + let mut new_inputs_with_values = vec![]; + for input in new_inputs { + new_inputs_with_values.push(input.replace_params_with_values(param_values)?); + } + + let new_plan = utils::from_plan(self, &new_exprs, &new_inputs_with_values)?; + Ok(new_plan) } - /// Recrusively to walk the expression and convert a placeholder into a literal value + /// Return an Expr with all placeholders replaced with their corresponding values provided in the prams_values fn replace_placeholders_with_values( expr: Expr, param_values: &Vec, ) -> Result { - match expr { - Expr::Column(..) | Expr::ScalarVariable(..) | Expr::Literal(..) => Ok(expr), - Expr::Alias(expr, name) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::Alias(Box::new(expr), name)) - } - Expr::BinaryExpr(BinaryExpr { left, op, right }) => { - let left = Self::replace_placeholders_with_values(*left, param_values)?; - let right = Self::replace_placeholders_with_values(*right, param_values)?; - Ok(Expr::BinaryExpr(BinaryExpr { - left: Box::new(left), - op, - right: Box::new(right), - })) - } - Expr::Case(case) => { - let expr = match case.expr { - Some(expr) => Some(Box::new(Self::replace_placeholders_with_values( - *expr, - param_values, - )?)), - None => None, - }; - let mut when_then_expr = vec![]; - for (w, t) in case.when_then_expr { - let w = Self::replace_placeholders_with_values(*w, param_values)?; - let t = Self::replace_placeholders_with_values(*t, param_values)?; - when_then_expr.push((Box::new(w), Box::new(t))); - } - let else_expr = match case.else_expr { - Some(expr) => Some(Box::new(Self::replace_placeholders_with_values( - *expr, - param_values, - )?)), - None => None, - }; - Ok(Expr::Case(Case { - expr, - when_then_expr, - else_expr, - })) - } - Expr::Cast(Cast { expr, data_type }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::Cast(Cast { - expr: Box::new(expr), - data_type, - })) - } - Expr::TryCast { expr, data_type } => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::TryCast { - expr: Box::new(expr), - data_type, - }) - } - Expr::Not(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::Not(Box::new(expr))) - } - Expr::Negative(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::Negative(Box::new(expr))) - } - Expr::IsNull(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsNull(Box::new(expr))) - } - Expr::IsNotNull(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsNotNull(Box::new(expr))) - } - Expr::IsTrue(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsTrue(Box::new(expr))) - } - Expr::IsFalse(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsFalse(Box::new(expr))) - } - Expr::IsUnknown(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsUnknown(Box::new(expr))) - } - Expr::IsNotTrue(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsNotTrue(Box::new(expr))) - } - Expr::IsNotFalse(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsNotFalse(Box::new(expr))) - } - Expr::IsNotUnknown(expr) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::IsNotUnknown(Box::new(expr))) - } - Expr::GetIndexedField(GetIndexedField { key, expr }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::GetIndexedField(GetIndexedField { - key, - expr: Box::new(expr), - })) - } - Expr::ScalarFunction { fun, args, .. } => { - let new_args = args - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - Ok(Expr::ScalarFunction { - fun, - args: new_args, - }) - } - Expr::ScalarUDF { fun, args } => { - let new_args = args - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - Ok(Expr::ScalarUDF { - fun, - args: new_args, - }) - } - Expr::WindowFunction { - fun, - args, - partition_by, - order_by, - window_frame, - } => { - let new_args = args - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - let new_partition_by = partition_by - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - let new_order_by = order_by - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - Ok(Expr::WindowFunction { - fun, - args: new_args, - partition_by: new_partition_by, - order_by: new_order_by, - window_frame, - }) - } - Expr::AggregateFunction { - fun, - distinct, - args, - filter, - } => { - let new_args = args - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - let new_filter = match filter { - Some(filter) => Some(Box::new( - Self::replace_placeholders_with_values(*filter, param_values)?, - )), - None => None, - }; - Ok(Expr::AggregateFunction { - fun, - distinct, - args: new_args, - filter: new_filter, - }) - } - Expr::AggregateUDF { fun, args, filter } => { - let new_args = args - .into_iter() - .map(|arg| Self::replace_placeholders_with_values(arg, param_values)) - .collect::, _>>()?; - let new_filter = match filter { - Some(filter) => Some(Box::new( - Self::replace_placeholders_with_values(*filter, param_values)?, - )), - None => None, - }; - Ok(Expr::AggregateUDF { - fun, - args: new_args, - filter: new_filter, - }) - } - Expr::GroupingSet(grouping_set) => match grouping_set { - GroupingSet::Rollup(exprs) => Ok(Expr::GroupingSet(GroupingSet::Rollup( - exprs - .into_iter() - .map(|e| Self::replace_placeholders_with_values(e, param_values)) - .collect::, _>>()?, - ))), - GroupingSet::Cube(exprs) => Ok(Expr::GroupingSet(GroupingSet::Cube( - exprs - .into_iter() - .map(|e| Self::replace_placeholders_with_values(e, param_values)) - .collect::, _>>()?, - ))), - GroupingSet::GroupingSets(exprs) => { - Ok(Expr::GroupingSet(GroupingSet::GroupingSets( - exprs - .into_iter() - .map(|e| { - e.into_iter() - .map(|e| { - Self::replace_placeholders_with_values( - e, - param_values, - ) - }) - .collect::, _>>() - }) - .collect::, _>>()?, - ))) - } - }, - Expr::InList { - expr, - list, - negated, - } => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - let list = list - .into_iter() - .map(|expr| { - Self::replace_placeholders_with_values(expr, param_values) - }) - .collect::, _>>()?; - Ok(Expr::InList { - expr: Box::new(expr), - list, - negated, - }) - } - Expr::Exists { subquery, negated } => { - subquery.subquery.replace_params_with_values(param_values)?; - Ok(Expr::Exists { subquery, negated }) - } - Expr::InSubquery { - expr, - subquery, - negated, - } => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - subquery.subquery.replace_params_with_values(param_values)?; - Ok(Expr::InSubquery { - expr: Box::new(expr), - subquery, - negated, - }) - } - Expr::ScalarSubquery(_) => Err(DataFusionError::NotImplemented( - "Scalar subqueries are not yet supported in the physical plan" - .to_string(), - )), - Expr::Between(Between { - expr, - negated, - low, - high, - }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - let low = Self::replace_placeholders_with_values(*low, param_values)?; - let high = Self::replace_placeholders_with_values(*high, param_values)?; - Ok(Expr::Between(Between { - expr: Box::new(expr), - negated, - low: Box::new(low), - high: Box::new(high), - })) - } - Expr::Like(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - let pattern = - Self::replace_placeholders_with_values(*pattern, param_values)?; - Ok(Expr::Like(Like { - negated, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - })) - } - Expr::ILike(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - let pattern = - Self::replace_placeholders_with_values(*pattern, param_values)?; - Ok(Expr::ILike(Like { - negated, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - })) - } - Expr::SimilarTo(Like { - negated, - expr, - pattern, - escape_char, - }) => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - let pattern = - Self::replace_placeholders_with_values(*pattern, param_values)?; - Ok(Expr::SimilarTo(Like { - negated, - expr: Box::new(expr), - pattern: Box::new(pattern), - escape_char, - })) - } - Expr::Sort { - expr, - asc, - nulls_first, - } => { - let expr = Self::replace_placeholders_with_values(*expr, param_values)?; - Ok(Expr::Sort { - expr: Box::new(expr), - asc, - nulls_first, - }) - } - Expr::Wildcard => Ok(Expr::Wildcard), - Expr::QualifiedWildcard { qualifier } => { - Ok(Expr::QualifiedWildcard { qualifier }) - } - Expr::Placeholder { id, data_type } => { - // convert id (in format $1, $2, ..) to idx (0, 1, ..) - let idx = id[1..].parse::().map_err(|e| { - DataFusionError::Internal(format!( - "Failed to parse placeholder id: {}", - e - )) - })? - 1; - // value at the idx-th position in param_values should be the value for the placeholder - let value = param_values.get(idx).ok_or_else(|| { - DataFusionError::Internal(format!( - "No value found for placeholder with id {}", - id - )) - })?; - // check if the data type of the value matches the data type of the placeholder - if value.get_datatype() != data_type { - return Err(DataFusionError::Internal(format!( - "Placeholder value type mismatch: expected {:?}, got {:?}", - data_type, - value.get_datatype() - ))); + struct PlaceholderReplacer<'a> { + param_values: &'a Vec, + } + + impl<'a> ExprRewriter for PlaceholderReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Placeholder { id, data_type } = &expr { + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {}", + e + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = self.param_values.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {}", + id + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if value.get_datatype() != *data_type { + return Err(DataFusionError::Internal(format!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.get_datatype() + ))); + } + // Replace the placeholder with the value + Ok(Expr::Literal(value.clone())) + } else { + Ok(expr) } - // Replace the placeholder with the value - Ok(Expr::Literal(value.clone())) } } + + expr.rewrite(&mut PlaceholderReplacer { param_values }) } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2d7506db2fbd..cbd7a83f7889 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5443,7 +5443,7 @@ mod tests { expected_plan: &str, ) -> LogicalPlan { // replace params - let plan = LogicalPlan::execute(plan, param_values).unwrap(); + let plan = plan.with_param_values(param_values).unwrap(); assert_eq!(format!("{:?}", plan), expected_plan); plan From 92443433cedefdf3e95a7ab8bda1ebe6cc2a30ae Mon Sep 17 00:00:00 2001 From: NGA-TRAN Date: Fri, 9 Dec 2022 16:53:49 -0500 Subject: [PATCH 4/4] refactor: add index of the params/values into the error message --- datafusion/expr/src/logical_plan/plan.rs | 10 +++++----- datafusion/sql/src/planner.rs | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 8b4f96809953..43e615e14d95 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -385,14 +385,14 @@ impl LogicalPlan { } // Verify if the types of the params matches the types of the values - for (param_type, value) in - prepare_lp.data_types.iter().zip(param_values.iter()) - { + let iter = prepare_lp.data_types.iter().zip(param_values.iter()); + for (i, (param_type, value)) in iter.enumerate() { if *param_type != value.get_datatype() { return Err(DataFusionError::Internal(format!( - "Expected parameter of type {:?}, got {:?}", + "Expected parameter of type {:?}, got {:?} at index {}", param_type, - value.get_datatype() + value.get_datatype(), + i ))); } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index cbd7a83f7889..dd54f7b51ec4 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -6356,7 +6356,7 @@ mod tests { #[test] #[should_panic( - expected = "value: Internal(\"Expected parameter of type Int32, got Float64\")" + expected = "value: Internal(\"Expected parameter of type Int32, got Float64 at index 0\")" )] fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { // no embedded parameter but still declare it