Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 68 additions & 0 deletions datafusion/core/tests/sql/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use datafusion::{
datasource::empty::EmptyTable, from_slice::FromSlice,
physical_plan::collect_partitioned,
};
use datafusion_common::ScalarValue;
use tempfile::TempDir;

#[tokio::test]
Expand Down Expand Up @@ -1257,6 +1258,73 @@ async fn csv_join_unaliased_subqueries() -> Result<()> {
Ok(())
}

// Test prepare statement from sql to final result
// This test is equivalent with the test parallel_query_with_filter below but using prepare statement
#[tokio::test]
async fn test_prepare_statement() -> Result<()> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to have this test to verify my logical plan works correctly

let tmp_dir = TempDir::new()?;
let partition_count = 4;
let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?;

// sql to statement then to prepare logical plan with parameters
// c1 defined as UINT32, c2 defined as UInt64 but the params are Int32 and Float64
let logical_plan =
ctx.create_logical_plan("PREPARE my_plan(INT, DOUBLE) AS SELECT c1, c2 FROM test WHERE c1 > $2 AND c1 < $1")?;

// prepare logical plan to logical plan without parameters
let param_values = vec![ScalarValue::Int32(Some(3)), ScalarValue::Float64(Some(0.0))];
let logical_plan = logical_plan.with_param_values(param_values)?;

// logical plan to optimized logical plan
let logical_plan = ctx.optimize(&logical_plan)?;

// optimized logical plan to physical plan
let physical_plan = ctx.create_physical_plan(&logical_plan).await?;

let task_ctx = ctx.task_ctx();
let results = collect_partitioned(physical_plan, task_ctx).await?;

// note that the order of partitions is not deterministic
let mut num_rows = 0;
for partition in &results {
for batch in partition {
num_rows += batch.num_rows();
}
}
assert_eq!(20, num_rows);

let results: Vec<RecordBatch> = results.into_iter().flatten().collect();
let expected = vec![
"+----+----+",
"| c1 | c2 |",
"+----+----+",
"| 1 | 1 |",
"| 1 | 10 |",
"| 1 | 2 |",
"| 1 | 3 |",
"| 1 | 4 |",
"| 1 | 5 |",
"| 1 | 6 |",
"| 1 | 7 |",
"| 1 | 8 |",
"| 1 | 9 |",
"| 2 | 1 |",
"| 2 | 10 |",
"| 2 | 2 |",
"| 2 | 3 |",
"| 2 | 4 |",
"| 2 | 5 |",
"| 2 | 6 |",
"| 2 | 7 |",
"| 2 | 8 |",
"| 2 | 9 |",
"+----+----+",
];
assert_batches_sorted_eq!(expected, &results);

Ok(())
}

#[tokio::test]
async fn parallel_query_with_filter() -> Result<()> {
let tmp_dir = TempDir::new()?;
Expand Down
107 changes: 106 additions & 1 deletion datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@
// under the License.

use crate::expr::BinaryExpr;
use crate::expr_rewriter::{ExprRewritable, ExprRewriter};
///! Logical plan types
use crate::logical_plan::builder::validate_unique_names;
use crate::logical_plan::display::{GraphvizVisitor, IndentVisitor};
use crate::logical_plan::extension::UserDefinedLogicalNode;
use crate::utils::{
exprlist_to_fields, from_plan, grouping_set_expr_count, grouping_set_to_exprlist,
self, exprlist_to_fields, from_plan, grouping_set_expr_count,
grouping_set_to_exprlist,
};
use crate::{Expr, ExprSchemable, TableProviderFilterPushDown, TableSource};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use datafusion_common::{
plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, OwnedTableReference,
ScalarValue,
};
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug, Display, Formatter};
Expand Down Expand Up @@ -364,6 +367,42 @@ impl LogicalPlan {
) -> Result<LogicalPlan, DataFusionError> {
from_plan(self, &self.expressions(), inputs)
}

/// Convert a prepare logical plan into its inner logical plan with all params replaced with their corresponding values
pub fn with_param_values(
self,
param_values: Vec<ScalarValue>,
) -> Result<LogicalPlan, DataFusionError> {
match self {
LogicalPlan::Prepare(prepare_lp) => {
// Verify if the number of params matches the number of values
if prepare_lp.data_types.len() != param_values.len() {
return Err(DataFusionError::Internal(format!(
"Expected {} parameters, got {}",
prepare_lp.data_types.len(),
param_values.len()
)));
}

// Verify if the types of the params matches the types of the values
let iter = prepare_lp.data_types.iter().zip(param_values.iter());
for (i, (param_type, value)) in iter.enumerate() {
if *param_type != value.get_datatype() {
return Err(DataFusionError::Internal(format!(
"Expected parameter of type {:?}, got {:?} at index {}",
param_type,
value.get_datatype(),
i
)));
}
}

let input_plan = prepare_lp.input;
input_plan.replace_params_with_values(&param_values)
}
_ => Ok(self),
}
}
}

/// Trait that implements the [Visitor
Expand Down Expand Up @@ -534,6 +573,72 @@ impl LogicalPlan {
_ => {}
}
}

/// Return a logical plan with all placeholders/params (e.g $1 $2, ...) replaced with corresponding values provided in the prams_values
pub fn replace_params_with_values(
&self,
param_values: &Vec<ScalarValue>,
) -> Result<LogicalPlan, DataFusionError> {
let exprs = self.expressions();
let mut new_exprs = vec![];
for expr in exprs {
new_exprs.push(Self::replace_placeholders_with_values(expr, param_values)?);
}

let new_inputs = self.inputs();
let mut new_inputs_with_values = vec![];
for input in new_inputs {
new_inputs_with_values.push(input.replace_params_with_values(param_values)?);
}

let new_plan = utils::from_plan(self, &new_exprs, &new_inputs_with_values)?;
Ok(new_plan)
}

/// Return an Expr with all placeholders replaced with their corresponding values provided in the prams_values
fn replace_placeholders_with_values(
expr: Expr,
param_values: &Vec<ScalarValue>,
) -> Result<Expr, DataFusionError> {
struct PlaceholderReplacer<'a> {
param_values: &'a Vec<ScalarValue>,
}

impl<'a> ExprRewriter for PlaceholderReplacer<'a> {
fn mutate(&mut self, expr: Expr) -> Result<Expr, DataFusionError> {
if let Expr::Placeholder { id, data_type } = &expr {
// convert id (in format $1, $2, ..) to idx (0, 1, ..)
let idx = id[1..].parse::<usize>().map_err(|e| {
DataFusionError::Internal(format!(
"Failed to parse placeholder id: {}",
e
))
})? - 1;
// value at the idx-th position in param_values should be the value for the placeholder
let value = self.param_values.get(idx).ok_or_else(|| {
DataFusionError::Internal(format!(
"No value found for placeholder with id {}",
id
))
})?;
// check if the data type of the value matches the data type of the placeholder
if value.get_datatype() != *data_type {
return Err(DataFusionError::Internal(format!(
"Placeholder value type mismatch: expected {:?}, got {:?}",
data_type,
value.get_datatype()
)));
}
// Replace the placeholder with the value
Ok(Expr::Literal(value.clone()))
} else {
Ok(expr)
}
}
}

expr.rewrite(&mut PlaceholderReplacer { param_values })
}
}

// Various implementations for printing out LogicalPlans
Expand Down
Loading