diff --git a/datafusion/core/tests/sql/select.rs b/datafusion/core/tests/sql/select.rs index 5a56247e4f59..c82eee7d0f7c 100644 --- a/datafusion/core/tests/sql/select.rs +++ b/datafusion/core/tests/sql/select.rs @@ -20,6 +20,7 @@ use datafusion::{ datasource::empty::EmptyTable, from_slice::FromSlice, physical_plan::collect_partitioned, }; +use datafusion_common::ScalarValue; use tempfile::TempDir; #[tokio::test] @@ -1257,6 +1258,73 @@ async fn csv_join_unaliased_subqueries() -> Result<()> { Ok(()) } +// Test prepare statement from sql to final result +// This test is equivalent with the test parallel_query_with_filter below but using prepare statement +#[tokio::test] +async fn test_prepare_statement() -> Result<()> { + let tmp_dir = TempDir::new()?; + let partition_count = 4; + let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; + + // sql to statement then to prepare logical plan with parameters + // c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64 + let logical_plan = + ctx.create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")?; + + // prepare logical plan to logical plan without parameters + let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))]; + let logical_plan = logical_plan.with_param_values(param_values)?; + + // logical plan to optimized logical plan + let logical_plan = ctx.optimize(&logical_plan)?; + + // optimized logical plan to physical plan + let physical_plan = ctx.create_physical_plan(&logical_plan).await?; + + let task_ctx = ctx.task_ctx(); + let results = collect_partitioned(physical_plan, task_ctx).await?; + + // note that the order of partitions is not deterministic + let mut num_rows = 0; + for partition in &results { + for batch in partition { + num_rows += batch.num_rows(); + } + } + assert_eq!(20, num_rows); + + let results: Vec = results.into_iter().flatten().collect(); + let expected = vec![ + "+----+----+", + "| c1 | c2 |", + "+----+----+", + "| 1 | 1 |", + "| 1 | 10 |", + "| 1 | 2 |", + "| 1 | 3 |", + "| 1 | 4 |", + "| 1 | 5 |", + "| 1 | 6 |", + "| 1 | 7 |", + "| 1 | 8 |", + "| 1 | 9 |", + "| 2 | 1 |", + "| 2 | 10 |", + "| 2 | 2 |", + "| 2 | 3 |", + "| 2 | 4 |", + "| 2 | 5 |", + "| 2 | 6 |", + "| 2 | 7 |", + "| 2 | 8 |", + "| 2 | 9 |", + "+----+----+", + ]; + assert_batches_sorted_eq!(expected, &results); + + Ok(()) +} + #[tokio::test] async fn parallel_query_with_filter() -> Result<()> { let tmp_dir = TempDir::new()?; diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 7f38e7dbb2ef..43e615e14d95 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -16,17 +16,20 @@ // under the License. use crate::expr::BinaryExpr; +use crate::expr_rewriter::{ExprRewritable, ExprRewriter}; ///! Logical plan types use crate::logical_plan::builder::validate_unique_names; use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor}; use crate::logical_plan::extension::UserDefinedLogicalNode; use crate::utils::{ - exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist, + self, exprlist_to_fields, from_plan, grouping_set_expr_count, + grouping_set_to_exprlist, }; use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion_common::{ plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference, + ScalarValue, }; use std::collections::{HashMap, HashSet}; use std::fmt::{self, Debug, Display, Formatter}; @@ -364,6 +367,42 @@ impl LogicalPlan { ) -> Result { from_plan(self, &self.expressions(), inputs) } + + /// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values + pub fn with_param_values( + self, + param_values: Vec, + ) -> Result { + match self { + LogicalPlan::Prepare(prepare_lp) => { + // Verify if the number of params matches the number of values + if prepare_lp.data_types.len() != param_values.len() { + return Err(DataFusionError::Internal(format!( + "Expected {} parameters, got {}", + prepare_lp.data_types.len(), + param_values.len() + ))); + } + + // Verify if the types of the params matches the types of the values + let iter = prepare_lp.data_types.iter().zip(param_values.iter()); + for (i, (param_type, value)) in iter.enumerate() { + if *param_type != value.get_datatype() { + return Err(DataFusionError::Internal(format!( + "Expected parameter of type {:?}, got {:?} at index {}", + param_type, + value.get_datatype(), + i + ))); + } + } + + let input_plan = prepare_lp.input; + input_plan.replace_params_with_values(¶m_values) + } + _ => Ok(self), + } + } } /// Trait that implements the [Visitor @@ -534,6 +573,72 @@ impl LogicalPlan { _ => {} } } + + /// Return a logical plan with all placeholders/params (e.g $1 $2, ...) replaced with corresponding values provided in the prams_values + pub fn replace_params_with_values( + &self, + param_values: &Vec, + ) -> Result { + let exprs = self.expressions(); + let mut new_exprs = vec![]; + for expr in exprs { + new_exprs.push(Self::replace_placeholders_with_values(expr, param_values)?); + } + + let new_inputs = self.inputs(); + let mut new_inputs_with_values = vec![]; + for input in new_inputs { + new_inputs_with_values.push(input.replace_params_with_values(param_values)?); + } + + let new_plan = utils::from_plan(self, &new_exprs, &new_inputs_with_values)?; + Ok(new_plan) + } + + /// Return an Expr with all placeholders replaced with their corresponding values provided in the prams_values + fn replace_placeholders_with_values( + expr: Expr, + param_values: &Vec, + ) -> Result { + struct PlaceholderReplacer<'a> { + param_values: &'a Vec, + } + + impl<'a> ExprRewriter for PlaceholderReplacer<'a> { + fn mutate(&mut self, expr: Expr) -> Result { + if let Expr::Placeholder { id, data_type } = &expr { + // convert id (in format $1, $2, ..) to idx (0, 1, ..) + let idx = id[1..].parse::().map_err(|e| { + DataFusionError::Internal(format!( + "Failed to parse placeholder id: {}", + e + )) + })? - 1; + // value at the idx-th position in param_values should be the value for the placeholder + let value = self.param_values.get(idx).ok_or_else(|| { + DataFusionError::Internal(format!( + "No value found for placeholder with id {}", + id + )) + })?; + // check if the data type of the value matches the data type of the placeholder + if value.get_datatype() != *data_type { + return Err(DataFusionError::Internal(format!( + "Placeholder value type mismatch: expected {:?}, got {:?}", + data_type, + value.get_datatype() + ))); + } + // Replace the placeholder with the value + Ok(Expr::Literal(value.clone())) + } else { + Ok(expr) + } + } + } + + expr.rewrite(&mut PlaceholderReplacer { param_values }) + } } // Various implementations for printing out LogicalPlans diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 82d8c3834294..dd54f7b51ec4 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -5421,15 +5421,32 @@ mod tests { sql: &str, expected_plan: &str, expected_data_types: &str, - ) { + ) -> LogicalPlan { let plan = logical_plan(sql).unwrap(); + + let assert_plan = plan.clone(); // verify plan - assert_eq!(format!("{:?}", plan), expected_plan); + assert_eq!(format!("{:?}", assert_plan), expected_plan); + // verify data types - if let LogicalPlan::Prepare(Prepare { data_types, .. }) = plan { + if let LogicalPlan::Prepare(Prepare { data_types, .. }) = assert_plan { let dt = format!("{:?}", data_types); assert_eq!(dt, expected_data_types); } + + plan + } + + fn prepare_stmt_replace_params_quick_test( + plan: LogicalPlan, + param_values: Vec, + expected_plan: &str, + ) -> LogicalPlan { + // replace params + let plan = plan.with_param_values(param_values).unwrap(); + assert_eq!(format!("{:?}", plan), expected_plan); + + plan } struct MockContextProvider {} @@ -6239,11 +6256,7 @@ mod tests { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6252,11 +6265,7 @@ mod tests { // param is not number following the $ sign // panic due to error returned from the parser let sql = "PREPARE AS SELECT id, age FROM person WHERE age = $foo"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6265,11 +6274,7 @@ mod tests { )] fn test_prepare_statement_to_plan_panic_no_relation_and_constant_param() { let sql = "PREPARE my_plan(INT) AS SELECT id + $1"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6279,11 +6284,7 @@ mod tests { fn test_prepare_statement_to_plan_panic_no_data_types() { // only provide 1 data type while using 2 params let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1 + $2"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6292,11 +6293,7 @@ mod tests { )] fn test_prepare_statement_to_plan_panic_is_param() { let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age is $1"; - - let expected_plan = "whatever"; - let expected_dt = "whatever"; - - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + logical_plan(sql).unwrap(); } #[test] @@ -6311,9 +6308,18 @@ mod tests { let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; - ///////////////////////// + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + ////////////////////////////////////////// // no embedded parameter and no declare it let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; @@ -6324,7 +6330,54 @@ mod tests { let expected_dt = "[]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int64(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic(expected = "value: Internal(\"Expected 1 parameters, got 0\")")] + fn test_prepare_statement_to_plan_one_param_no_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic( + expected = "value: Internal(\"Expected parameter of type Int32, got Float64 at index 0\")" + )] + fn test_prepare_statement_to_plan_one_param_one_value_different_type_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Float64(Some(20.0))]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + #[should_panic(expected = "value: Internal(\"Expected 0 parameters, got 1\")")] + fn test_prepare_statement_to_plan_no_param_on_value_panic() { + // no embedded parameter but still declare it + let sql = "PREPARE my_plan AS SELECT id, age FROM person WHERE age = 10"; + let plan = logical_plan(sql).unwrap(); + // declare 1 param but provide 0 + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "whatever"; + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6335,25 +6388,50 @@ mod tests { \n Projection: $1\n EmptyRelation"; let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); - ///////////////////////// + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: Int32(10)\n EmptyRelation"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + /////////////////////////////////////// let sql = "PREPARE my_plan(INT) AS SELECT 1 + $1"; let expected_plan = "Prepare: \"my_plan\" [Int32] \ \n Projection: Int64(1) + $1\n EmptyRelation"; let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: Int64(1) + Int32(10)\n EmptyRelation"; - ///////////////////////// + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + + /////////////////////////////////////// let sql = "PREPARE my_plan(INT, DOUBLE) AS SELECT 1 + $1 + $2"; let expected_plan = "Prepare: \"my_plan\" [Int32, Float64] \ \n Projection: Int64(1) + $1 + $2\n EmptyRelation"; let expected_dt = "[Int32, Float64]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(10.0)), + ]; + let expected_plan = + "Projection: Int64(1) + Int32(10) + Float64(10)\n EmptyRelation"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6367,7 +6445,41 @@ mod tests { let expected_dt = "[Int32]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ScalarValue::Int32(Some(10))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Int32(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); + } + + #[test] + fn test_prepare_statement_to_plan_data_type() { + let sql = "PREPARE my_plan(DOUBLE) AS SELECT id, age FROM person WHERE age = $1"; + + // age is defined as Int32 but prepare statement declares it as DOUBLE/Float64 + // Prepare statement and its logical plan should be created successfully + let expected_plan = "Prepare: \"my_plan\" [Float64] \ + \n Projection: person.id, person.age\ + \n Filter: person.age = $1\ + \n TableScan: person"; + + let expected_dt = "[Float64]"; + + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values still succeed and use Float64 + let param_values = vec![ScalarValue::Float64(Some(10.0))]; + let expected_plan = "Projection: person.id, person.age\ + \n Filter: person.age = Float64(10)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6384,7 +6496,24 @@ mod tests { let expected_dt = "[Int32, Utf8, Float64, Int32, Float64, Utf8]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Utf8(Some("abc".to_string())), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Int32(Some(20)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Utf8(Some("xyz".to_string())), + ]; + let expected_plan = + "Projection: person.id, person.age, Utf8(\"xyz\")\ + \n Filter: person.age IN ([Int32(10), Int32(20)]) AND person.salary > Float64(100) AND person.salary < Float64(200) OR person.first_name < Utf8(\"abc\")\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6406,7 +6535,24 @@ mod tests { let expected_dt = "[Int32, Float64, Float64, Float64]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Int32(Some(10)), + ScalarValue::Float64(Some(100.0)), + ScalarValue::Float64(Some(200.0)), + ScalarValue::Float64(Some(300.0)), + ]; + let expected_plan = + "Projection: person.id, SUM(person.age)\ + \n Filter: SUM(person.age) < Int32(10) AND SUM(person.age) > Int64(10) OR SUM(person.age) IN ([Float64(200), Float64(300)])\ + \n Aggregate: groupBy=[[person.id]], aggr=[[SUM(person.age)]]\ + \n Filter: person.salary > Float64(100)\ + \n TableScan: person"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test] @@ -6421,7 +6567,20 @@ mod tests { let expected_dt = "[Utf8, Utf8]"; - prepare_stmt_quick_test(sql, expected_plan, expected_dt); + let plan = prepare_stmt_quick_test(sql, expected_plan, expected_dt); + + /////////////////// + // replace params with values + let param_values = vec![ + ScalarValue::Utf8(Some("a".to_string())), + ScalarValue::Utf8(Some("b".to_string())), + ]; + let expected_plan = "Projection: num, letter\ + \n Projection: t.column1 AS num, t.column2 AS letter\ + \n SubqueryAlias: t\ + \n Values: (Int64(1), Utf8(\"a\")), (Int64(2), Utf8(\"b\"))"; + + prepare_stmt_replace_params_quick_test(plan, param_values, expected_plan); } #[test]