From a40d9badbf990019bbc2e66725d01b5d41308814 Mon Sep 17 00:00:00 2001 From: jackwener Date: Fri, 5 May 2023 22:51:38 +0800 Subject: [PATCH] fix: `projection_push_down` don't consider VarProvider in columns. --- datafusion/core/tests/dataframe.rs | 41 +++++++++++++++++++++++++++++- datafusion/expr/src/utils.rs | 6 ++--- 2 files changed, 42 insertions(+), 5 deletions(-) diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 93a35f21f4c0..7c1a31b6cca9 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -28,13 +28,15 @@ use datafusion::from_slice::FromSlice; use std::sync::Arc; use datafusion::dataframe::DataFrame; +use datafusion::datasource::MemTable; use datafusion::error::Result; use datafusion::execution::context::SessionContext; use datafusion::prelude::JoinType; use datafusion::prelude::{CsvReadOptions, ParquetReadOptions}; use datafusion::test_util::parquet_test_data; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; -use datafusion_common::ScalarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_execution::config::SessionConfig; use datafusion_expr::expr::{GroupingSet, Sort}; use datafusion_expr::utils::COUNT_STAR_EXPANSION; use datafusion_expr::Expr::Wildcard; @@ -43,6 +45,7 @@ use datafusion_expr::{ sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, }; +use datafusion_physical_expr::var_provider::{VarProvider, VarType}; #[tokio::test] async fn test_count_wildcard_on_sort() -> Result<()> { @@ -1230,3 +1233,39 @@ pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Resul .await?; Ok(()) } +#[derive(Debug)] +struct HardcodedIntProvider {} + +impl VarProvider for HardcodedIntProvider { + fn get_value(&self, _var_names: Vec) -> Result { + Ok(ScalarValue::Int64(Some(1234))) + } + + fn get_type(&self, _: &[String]) -> Option { + Some(DataType::Int64) + } +} + +#[tokio::test] +async fn use_var_provider() -> Result<()> { + let schema = Arc::new(Schema::new(vec![ + Field::new("foo", DataType::Int64, false), + Field::new("bar", DataType::Int64, false), + ])); + + let mem_table = Arc::new(MemTable::try_new(schema, vec![])?); + + let config = SessionConfig::new() + .with_target_partitions(4) + .set_bool("datafusion.optimizer.skip_failed_rules", false); + let ctx = SessionContext::with_config(config); + + ctx.register_table("csv_table", mem_table)?; + ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {})); + + let dataframe = ctx + .sql("SELECT foo FROM csv_table WHERE bar > @var") + .await?; + dataframe.collect().await?; + Ok(()) +} diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index 7babd659e7ef..00e1d0769359 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -270,13 +270,11 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { Expr::Column(qc) => { accum.insert(qc.clone()); } - Expr::ScalarVariable(_, var_names) => { - accum.insert(Column::from_name(var_names.join("."))); - } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds // new Expr types, they will check here as well - Expr::Alias(_, _) + Expr::ScalarVariable(_, _) + | Expr::Alias(_, _) | Expr::Literal(_) | Expr::BinaryExpr { .. } | Expr::Like { .. }