From dbf2aa5ad9b61bec17c6f6010359383f8707b5ba Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:29:54 -0300 Subject: [PATCH 1/2] add lambda support --- Cargo.lock | 1 + .../examples/custom_file_casts.rs | 11 +- .../examples/default_column_values.rs | 14 +- datafusion-examples/examples/expr_api.rs | 8 +- .../examples/json_shredding.rs | 14 +- datafusion/catalog-listing/src/helpers.rs | 9 +- datafusion/common/src/column.rs | 6 + datafusion/common/src/cse.rs | 22 +- datafusion/common/src/dfschema.rs | 16 +- datafusion/common/src/lib.rs | 2 + datafusion/common/src/utils/mod.rs | 127 +++- .../core/src/execution/session_state.rs | 5 +- datafusion/core/tests/parquet/mod.rs | 2 +- .../core/tests/parquet/schema_adapter.rs | 8 +- .../datasource-parquet/src/row_filter.rs | 16 +- datafusion/expr/src/expr.rs | 69 +- datafusion/expr/src/expr_rewriter/mod.rs | 49 +- datafusion/expr/src/expr_rewriter/order_by.rs | 4 + datafusion/expr/src/expr_schema.rs | 60 +- datafusion/expr/src/lib.rs | 7 +- datafusion/expr/src/tree_node.rs | 702 +++++++++++++++++- datafusion/expr/src/udf.rs | 564 +++++++++++++- datafusion/expr/src/utils.rs | 41 +- datafusion/ffi/src/udf/mod.rs | 8 +- datafusion/ffi/src/udf/return_type_args.rs | 9 +- .../functions-nested/src/array_transform.rs | 266 +++++++ .../src/analyzer/function_rewrite.rs | 21 +- .../optimizer/src/analyzer/type_coercion.rs | 98 +-- .../optimizer/src/common_subexpr_eliminate.rs | 23 +- datafusion/optimizer/src/decorrelate.rs | 20 +- .../optimizer/src/optimize_projections/mod.rs | 4 +- datafusion/optimizer/src/push_down_filter.rs | 37 +- .../optimizer/src/scalar_subquery_to_join.rs | 67 +- .../simplify_expressions/expr_simplifier.rs | 105 ++- datafusion/optimizer/src/utils.rs | 4 +- .../src/schema_rewriter.rs | 29 +- datafusion/physical-expr/Cargo.toml | 4 + .../src/async_scalar_function.rs | 2 + .../physical-expr/src/expressions/column.rs | 21 +- .../physical-expr/src/expressions/lambda.rs | 139 ++++ .../physical-expr/src/expressions/mod.rs | 2 + datafusion/physical-expr/src/lib.rs | 2 + datafusion/physical-expr/src/physical_expr.rs | 10 +- datafusion/physical-expr/src/planner.rs | 29 +- datafusion/physical-expr/src/projection.rs | 53 +- .../physical-expr/src/scalar_function.rs | 701 ++++++++++++++++- .../physical-expr/src/simplifier/mod.rs | 20 +- .../src/simplifier/unwrap_cast.rs | 12 +- datafusion/physical-expr/src/utils/mod.rs | 21 +- .../src/enforce_sorting/sort_pushdown.rs | 60 +- .../src/projection_pushdown.rs | 55 +- datafusion/physical-plan/src/async_func.rs | 6 +- .../src/joins/stream_join_utils.rs | 29 +- datafusion/physical-plan/src/projection.rs | 61 +- datafusion/proto/src/logical_plan/to_proto.rs | 5 + datafusion/pruning/src/pruning_predicate.rs | 11 +- datafusion/sql/src/expr/function.rs | 28 +- datafusion/sql/src/expr/identifier.rs | 13 + datafusion/sql/src/planner.rs | 30 +- datafusion/sql/src/select.rs | 6 +- datafusion/sql/src/unparser/expr.rs | 17 +- datafusion/sql/src/unparser/plan.rs | 6 +- datafusion/sql/src/unparser/rewrite.rs | 14 +- datafusion/sql/src/unparser/utils.rs | 44 +- datafusion/sql/src/utils.rs | 32 +- datafusion/sqllogictest/test_files/array.slt | 8 +- datafusion/sqllogictest/test_files/lambda.slt | 166 +++++ .../src/logical_plan/producer/expr/mod.rs | 1 + 68 files changed, 3573 insertions(+), 483 deletions(-) create mode 100644 datafusion/functions-nested/src/array_transform.rs create mode 100644 datafusion/physical-expr/src/expressions/lambda.rs create mode 100644 datafusion/sqllogictest/test_files/lambda.slt diff --git a/Cargo.lock b/Cargo.lock index f500265108ff..4a315ff38f2a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2489,6 +2489,7 @@ dependencies = [ "paste", "petgraph 0.8.3", "rand 0.9.2", + "recursive", "rstest", ] diff --git a/datafusion-examples/examples/custom_file_casts.rs b/datafusion-examples/examples/custom_file_casts.rs index 4d97ecd91dc6..d8db97d1e044 100644 --- a/datafusion-examples/examples/custom_file_casts.rs +++ b/datafusion-examples/examples/custom_file_casts.rs @@ -22,7 +22,7 @@ use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; use datafusion::common::not_impl_err; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion::common::tree_node::{Transformed, TransformedResult}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, @@ -31,7 +31,7 @@ use datafusion::execution::context::SessionContext; use datafusion::execution::object_store::ObjectStoreUrl; use datafusion::parquet::arrow::ArrowWriter; use datafusion::physical_expr::expressions::CastExpr; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,11 +181,10 @@ impl PhysicalExprAdapter for CustomCastsPhysicalExprAdapter { expr = self.inner.rewrite(expr)?; // Now we can apply custom casting rules or even swap out all CastExprs for a custom cast kernel / expression // For example, [DataFusion Comet](https://github.com/apache/datafusion-comet) has a [custom cast kernel](https://github.com/apache/datafusion-comet/blob/b4ac876ab420ed403ac7fc8e1b29f42f1f442566/native/spark-expr/src/conversion_funcs/cast.rs#L133-L138). - expr.transform(|expr| { + expr.transform_with_schema(&self.physical_file_schema, |expr, schema| { if let Some(cast) = expr.as_any().downcast_ref::() { - let input_data_type = - cast.expr().data_type(&self.physical_file_schema)?; - let output_data_type = cast.data_type(&self.physical_file_schema)?; + let input_data_type = cast.expr().data_type(schema)?; + let output_data_type = cast.data_type(schema)?; if !cast.is_bigger_cast(&input_data_type) { return not_impl_err!( "Unsupported CAST from {input_data_type} to {output_data_type}" diff --git a/datafusion-examples/examples/default_column_values.rs b/datafusion-examples/examples/default_column_values.rs index d3a7d2ec67f3..0d00d2c3af82 100644 --- a/datafusion-examples/examples/default_column_values.rs +++ b/datafusion-examples/examples/default_column_values.rs @@ -26,8 +26,8 @@ use async_trait::async_trait; use datafusion::assert_batches_eq; use datafusion::catalog::memory::DataSourceExec; use datafusion::catalog::{Session, TableProvider}; -use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNode}; -use datafusion::common::DFSchema; +use datafusion::common::tree_node::{Transformed, TransformedResult}; +use datafusion::common::{DFSchema, HashSet}; use datafusion::common::{Result, ScalarValue}; use datafusion::datasource::listing::PartitionedFile; use datafusion::datasource::physical_plan::{FileScanConfigBuilder, ParquetSource}; @@ -38,7 +38,7 @@ use datafusion::logical_expr::{Expr, TableProviderFilterPushDown, TableType}; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_expr::expressions::{CastExpr, Column, Literal}; -use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::{lit, SessionConfig}; use datafusion_physical_expr_adapter::{ @@ -308,11 +308,12 @@ impl PhysicalExprAdapter for DefaultValuePhysicalExprAdapter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom default value injection for missing columns let rewritten = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { self.inject_default_values( expr, &self.logical_file_schema, &self.physical_file_schema, + lambdas_params, ) }) .data()?; @@ -348,12 +349,15 @@ impl DefaultValuePhysicalExprAdapter { expr: Arc, logical_file_schema: &Schema, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(column) = expr.as_any().downcast_ref::() { let column_name = column.name(); // Check if this column exists in the physical schema - if physical_file_schema.index_of(column_name).is_err() { + if !lambdas_params.contains(column_name) + && physical_file_schema.index_of(column_name).is_err() + { // Column is missing from physical schema, check if logical schema has a default if let Ok(logical_field) = logical_file_schema.field_with_name(column_name) diff --git a/datafusion-examples/examples/expr_api.rs b/datafusion-examples/examples/expr_api.rs index 56f960870e58..29f074e2b400 100644 --- a/datafusion-examples/examples/expr_api.rs +++ b/datafusion-examples/examples/expr_api.rs @@ -23,7 +23,7 @@ use arrow::record_batch::RecordBatch; use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit}; use datafusion::common::stats::Precision; -use datafusion::common::tree_node::{Transformed, TreeNode}; +use datafusion::common::tree_node::Transformed; use datafusion::common::{ColumnStatistics, DFSchema}; use datafusion::common::{ScalarValue, ToDFSchema}; use datafusion::error::Result; @@ -556,7 +556,7 @@ fn type_coercion_demo() -> Result<()> { // 3. Type coercion with `TypeCoercionRewriter`. let coerced_expr = expr .clone() - .rewrite(&mut TypeCoercionRewriter::new(&df_schema))? + .rewrite_with_schema(&df_schema, &mut TypeCoercionRewriter::new(&df_schema))? .data; let physical_expr = datafusion::physical_expr::create_physical_expr( &coerced_expr, @@ -567,7 +567,7 @@ fn type_coercion_demo() -> Result<()> { // 4. Apply explicit type coercion by manually rewriting the expression let coerced_expr = expr - .transform(|e| { + .transform_with_schema(&df_schema, |e, df_schema| { // Only type coerces binary expressions. let Expr::BinaryExpr(e) = e else { return Ok(Transformed::no(e)); @@ -575,7 +575,7 @@ fn type_coercion_demo() -> Result<()> { if let Expr::Column(ref col_expr) = *e.left { let field = df_schema.field_with_name(None, col_expr.name())?; let cast_to_type = field.data_type(); - let coerced_right = e.right.cast_to(cast_to_type, &df_schema)?; + let coerced_right = e.right.cast_to(cast_to_type, df_schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( e.left, e.op, diff --git a/datafusion-examples/examples/json_shredding.rs b/datafusion-examples/examples/json_shredding.rs index 5ef8b59b6420..e97f27b818d8 100644 --- a/datafusion-examples/examples/json_shredding.rs +++ b/datafusion-examples/examples/json_shredding.rs @@ -22,10 +22,8 @@ use arrow::array::{RecordBatch, StringArray}; use arrow::datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}; use datafusion::assert_batches_eq; -use datafusion::common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; -use datafusion::common::{assert_contains, exec_datafusion_err, Result}; +use datafusion::common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; +use datafusion::common::{assert_contains, exec_datafusion_err, HashSet, Result}; use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, ListingTableUrl, }; @@ -36,8 +34,8 @@ use datafusion::logical_expr::{ }; use datafusion::parquet::arrow::ArrowWriter; use datafusion::parquet::file::properties::WriterProperties; -use datafusion::physical_expr::PhysicalExpr; use datafusion::physical_expr::{expressions, ScalarFunctionExpr}; +use datafusion::physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion::prelude::SessionConfig; use datafusion::scalar::ScalarValue; use datafusion_physical_expr_adapter::{ @@ -302,7 +300,9 @@ impl PhysicalExprAdapter for ShreddedJsonRewriter { fn rewrite(&self, expr: Arc) -> Result> { // First try our custom JSON shredding rewrite let rewritten = expr - .transform(|expr| self.rewrite_impl(expr, &self.physical_file_schema)) + .transform_with_lambdas_params(|expr, lambdas_params| { + self.rewrite_impl(expr, &self.physical_file_schema, lambdas_params) + }) .data()?; // Then apply the default adapter as a fallback to handle standard schema differences @@ -335,6 +335,7 @@ impl ShreddedJsonRewriter { &self, expr: Arc, physical_file_schema: &Schema, + lambdas_params: &HashSet, ) -> Result>> { if let Some(func) = expr.as_any().downcast_ref::() { if func.name() == "json_get_str" && func.args().len() == 2 { @@ -348,6 +349,7 @@ impl ShreddedJsonRewriter { if let Some(column) = func.args()[1] .as_any() .downcast_ref::() + .filter(|col| !lambdas_params.contains(col.name())) { let column_name = column.name(); // Check if there's a flat column with underscore prefix diff --git a/datafusion/catalog-listing/src/helpers.rs b/datafusion/catalog-listing/src/helpers.rs index 82cc36867939..444f505f4280 100644 --- a/datafusion/catalog-listing/src/helpers.rs +++ b/datafusion/catalog-listing/src/helpers.rs @@ -52,9 +52,9 @@ use object_store::{ObjectMeta, ObjectStore}; /// was performed pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { let mut is_applicable = true; - expr.apply(|expr| match expr { - Expr::Column(Column { ref name, .. }) => { - is_applicable &= col_names.contains(&name.as_str()); + expr.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(col) => { + is_applicable &= col_names.contains(&col.name()) || col.is_lambda_parameter(lambdas_params); if is_applicable { Ok(TreeNodeRecursion::Jump) } else { @@ -86,7 +86,8 @@ pub fn expr_applicable_for_cols(col_names: &[&str], expr: &Expr) -> bool { | Expr::InSubquery(_) | Expr::ScalarSubquery(_) | Expr::GroupingSet(_) - | Expr::Case(_) => Ok(TreeNodeRecursion::Continue), + | Expr::Case(_) + | Expr::Lambda(_) => Ok(TreeNodeRecursion::Continue), Expr::ScalarFunction(scalar_function) => { match scalar_function.func.signature().volatility { diff --git a/datafusion/common/src/column.rs b/datafusion/common/src/column.rs index c7f0b5a4f488..dd9b985e6485 100644 --- a/datafusion/common/src/column.rs +++ b/datafusion/common/src/column.rs @@ -22,6 +22,7 @@ use crate::utils::parse_identifiers_normalized; use crate::utils::quote_identifier; use crate::{DFSchema, Diagnostic, Result, SchemaError, Spans, TableReference}; use arrow::datatypes::{Field, FieldRef}; +use std::borrow::Borrow; use std::collections::HashSet; use std::fmt; @@ -325,6 +326,11 @@ impl Column { ..self.clone() } } + + pub fn is_lambda_parameter(&self, lambdas_params: &crate::HashSet + Eq + std::hash::Hash>) -> bool { + // currently, references to lambda parameters are always unqualified + self.relation.is_none() && lambdas_params.contains(self.name()) + } } impl From<&str> for Column { diff --git a/datafusion/common/src/cse.rs b/datafusion/common/src/cse.rs index 674d3386171f..a7ffde52c93b 100644 --- a/datafusion/common/src/cse.rs +++ b/datafusion/common/src/cse.rs @@ -178,6 +178,14 @@ pub trait CSEController { /// if all are always evaluated. fn conditional_children(node: &Self::Node) -> Option>; + // A helper method called on each node before is_ignored, during top-down traversal during the first, + // visiting traversal of CSE. + fn visit_f_down(&mut self, _node: &Self::Node) {} + + // A helper method called on each node after is_ignored, during bottom-up traversal during the first, + // visiting traversal of CSE. + fn visit_f_up(&mut self, _node: &Self::Node) {} + // Returns true if a node is valid. If a node is invalid then it can't be eliminated. // Validity is propagated up which means no subtree can be eliminated that contains // an invalid node. @@ -274,7 +282,7 @@ where /// thus can not be extracted as a common [`TreeNode`]. conditional: bool, - controller: &'a C, + controller: &'a mut C, } /// Record item that used when traversing a [`TreeNode`] tree. @@ -352,6 +360,7 @@ where self.visit_stack .push(VisitRecord::EnterMark(self.down_index)); self.down_index += 1; + self.controller.visit_f_down(node); // If a node can short-circuit then some of its children might not be executed so // count the occurrence either normal or conditional. @@ -414,6 +423,7 @@ where self.visit_stack .push(VisitRecord::NodeItem(node_id, is_valid)); self.up_index += 1; + self.controller.visit_f_up(node); Ok(TreeNodeRecursion::Continue) } @@ -532,7 +542,7 @@ where /// Add an identifier to `id_array` for every [`TreeNode`] in this tree. fn node_to_id_array<'n>( - &self, + &mut self, node: &'n N, node_stats: &mut NodeStats<'n, N>, id_array: &mut IdArray<'n, N>, @@ -546,7 +556,7 @@ where random_state: &self.random_state, found_common: false, conditional: false, - controller: &self.controller, + controller: &mut self.controller, }; node.visit(&mut visitor)?; @@ -561,7 +571,7 @@ where /// Each element is itself the result of [`CSE::node_to_id_array`] for that node /// (e.g. the identifiers for each node in the tree) fn to_arrays<'n>( - &self, + &mut self, nodes: &'n [N], node_stats: &mut NodeStats<'n, N>, ) -> Result<(bool, Vec>)> { @@ -761,7 +771,7 @@ mod test { #[test] fn id_array_visitor() -> Result<()> { let alias_generator = AliasGenerator::new(); - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::Normal, )); @@ -853,7 +863,7 @@ mod test { assert_eq!(expected, id_array); // include aggregates - let eliminator = CSE::new(TestTreeNodeCSEController::new( + let mut eliminator = CSE::new(TestTreeNodeCSEController::new( &alias_generator, TestTreeNodeMask::NormalAndAggregates, )); diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 24d152a7dba8..8a09d61292b2 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -314,8 +314,10 @@ impl DFSchema { return; } - let self_fields: HashSet<(Option<&TableReference>, &FieldRef)> = - self.iter().collect(); + let self_fields: HashSet<(Option<&TableReference>, &str)> = self + .iter() + .map(|(qualifier, field)| (qualifier, field.name().as_str())) + .collect(); let self_unqualified_names: HashSet<&str> = self .inner .fields @@ -328,7 +330,10 @@ impl DFSchema { for (qualifier, field) in other_schema.iter() { // skip duplicate columns let duplicated_field = match qualifier { - Some(q) => self_fields.contains(&(Some(q), field)), + Some(q) => { + self_fields.contains(&(Some(q), field.name().as_str())) + || self_fields.contains(&(None, field.name().as_str())) + } // for unqualified columns, check as unqualified name None => self_unqualified_names.contains(field.name().as_str()), }; @@ -867,6 +872,11 @@ impl DFSchema { &self.functional_dependencies } + /// Get functional dependencies + pub fn field_qualifiers(&self) -> &[Option] { + &self.field_qualifiers + } + /// Iterate over the qualifiers and fields in the DFSchema pub fn iter(&self) -> impl Iterator, &FieldRef)> { self.field_qualifiers diff --git a/datafusion/common/src/lib.rs b/datafusion/common/src/lib.rs index 76c7b46e3273..8923df683f89 100644 --- a/datafusion/common/src/lib.rs +++ b/datafusion/common/src/lib.rs @@ -117,6 +117,8 @@ pub mod hash_set { pub use hashbrown::hash_set::Entry; } +pub use hashbrown; + /// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is /// not possible. In normal usage of DataFusion the downcast should always succeed. /// diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 7b145ac3ae21..ec2dad505a56 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -22,15 +22,20 @@ pub mod memory; pub mod proxy; pub mod string_utils; -use crate::error::{_exec_datafusion_err, _internal_datafusion_err, _internal_err}; +use crate::error::{ + _exec_datafusion_err, _exec_err, _internal_datafusion_err, _internal_err, +}; use crate::{Result, ScalarValue}; use arrow::array::{ cast::AsArray, Array, ArrayRef, FixedSizeListArray, LargeListArray, ListArray, OffsetSizeTrait, }; +use arrow::array::{ArrowPrimitiveType, PrimitiveArray}; use arrow::buffer::OffsetBuffer; use arrow::compute::{partition, SortColumn, SortOptions}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; @@ -939,6 +944,124 @@ pub fn take_function_args( }) } +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +pub fn make_list_array_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = Vec::with_capacity( + offsets.last().unwrap().as_usize() - offsets.first().unwrap().as_usize(), + ); + + for (i, (&start, &end)) in std::iter::zip(&offsets[..], &offsets[1..]).enumerate() { + indices.extend(std::iter::repeat_n( + T::Native::usize_as(i), + end.as_usize() - start.as_usize(), + )); + } + + PrimitiveArray::new(indices.into(), None) +} + +/// [0, 2, 2, 5, 6] -> [0, 1, 0, 1, 2, 0] +pub fn make_list_element_indices( + offsets: &OffsetBuffer, +) -> PrimitiveArray { + let mut indices = vec![ + T::default_value(); + offsets.last().unwrap().as_usize() + - offsets.first().unwrap().as_usize() + ]; + + for (&start, &end) in std::iter::zip(&offsets[..], &offsets[1..]) { + for i in 0..end.as_usize() - start.as_usize() { + indices[start.as_usize() + i] = T::Native::usize_as(i); + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 0, 1, 1, 2, 2] +pub fn make_fsl_array_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = i as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +/// (3, 2) -> [0, 1, 0, 1, 0, 1] +pub fn make_fsl_element_indices( + list_size: i32, + array_len: usize, +) -> PrimitiveArray { + let mut indices = vec![0; list_size as usize * array_len]; + + for i in 0..array_len { + for j in 0..list_size as usize { + indices[i + j] = j as i32; + } + } + + PrimitiveArray::new(indices.into(), None) +} + +pub fn list_values(array: &dyn Array) -> Result<&ArrayRef> { + match array.data_type() { + DataType::List(_) => Ok(array.as_list::().values()), + DataType::LargeList(_) => Ok(array.as_list::().values()), + DataType::FixedSizeList(_, _) => Ok(array.as_fixed_size_list().values()), + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn list_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_array_indices::( + array.as_list().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_array_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + +pub fn elements_indices(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::LargeList(_) => Ok(Arc::new(make_list_element_indices::( + array.as_list::().offsets(), + ))), + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(make_fsl_element_indices( + fixed_size_list.value_length(), + fixed_size_list.len(), + ))) + } + other => _exec_err!("expected list, got {other}"), + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/core/src/execution/session_state.rs b/datafusion/core/src/execution/session_state.rs index c15b7eae0843..ad4ffb487ee1 100644 --- a/datafusion/core/src/execution/session_state.rs +++ b/datafusion/core/src/execution/session_state.rs @@ -41,7 +41,6 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::config::Dialect; use datafusion_common::config::{ConfigExtension, ConfigOptions, TableOptions}; use datafusion_common::display::{PlanType, StringifiedPlan, ToStringifiedPlan}; -use datafusion_common::tree_node::TreeNode; use datafusion_common::{ config_err, exec_err, plan_datafusion_err, DFSchema, DataFusionError, ResolvedTableReference, TableReference, @@ -701,7 +700,9 @@ impl SessionState { let config_options = self.config_options(); for rewrite in self.analyzer.function_rewrites() { expr = expr - .transform_up(|expr| rewrite.rewrite(expr, df_schema, config_options))? + .transform_up_with_schema(df_schema, |expr, df_schema| { + rewrite.rewrite(expr, df_schema, config_options) + })? .data; } create_physical_expr(&expr, df_schema, self.execution_props()) diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 097600e45ead..eea6085c02b9 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -516,7 +516,7 @@ fn make_uint_batches(start: u8, end: u8) -> RecordBatch { Field::new("u64", DataType::UInt64, true), ])); let v8: Vec = (start..end).collect(); - let v16: Vec = (start as _..end as _).collect(); + let v16: Vec = (start as u16..end as _).collect(); let v32: Vec = (start as _..end as _).collect(); let v64: Vec = (start as _..end as _).collect(); RecordBatch::try_new( diff --git a/datafusion/core/tests/parquet/schema_adapter.rs b/datafusion/core/tests/parquet/schema_adapter.rs index 40fc6176e212..dfa4c91ba5dd 100644 --- a/datafusion/core/tests/parquet/schema_adapter.rs +++ b/datafusion/core/tests/parquet/schema_adapter.rs @@ -27,7 +27,7 @@ use datafusion::datasource::listing::{ ListingTable, ListingTableConfig, ListingTableConfigExt, }; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::DataFusionError; use datafusion_common::{ColumnStatistics, ScalarValue}; use datafusion_datasource::file::FileSource; @@ -39,7 +39,7 @@ use datafusion_datasource::ListingTableUrl; use datafusion_datasource_parquet::source::ParquetSource; use datafusion_execution::object_store::ObjectStoreUrl; use datafusion_physical_expr::expressions::{self, Column}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt}; use datafusion_physical_expr_adapter::{ DefaultPhysicalExprAdapter, DefaultPhysicalExprAdapterFactory, PhysicalExprAdapter, PhysicalExprAdapterFactory, @@ -181,10 +181,10 @@ struct CustomPhysicalExprAdapter { impl PhysicalExprAdapter for CustomPhysicalExprAdapter { fn rewrite(&self, mut expr: Arc) -> Result> { expr = expr - .transform(|expr| { + .transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { let field_name = column.name(); - if self + if !lambdas_params.contains(field_name) && self .physical_file_schema .field_with_name(field_name) .ok() diff --git a/datafusion/datasource-parquet/src/row_filter.rs b/datafusion/datasource-parquet/src/row_filter.rs index 660b32f48612..45441ad71086 100644 --- a/datafusion/datasource-parquet/src/row_filter.rs +++ b/datafusion/datasource-parquet/src/row_filter.rs @@ -77,7 +77,7 @@ use datafusion_common::Result; use datafusion_datasource::schema_adapter::{SchemaAdapterFactory, SchemaMapper}; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::reassign_expr_columns; -use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr, PhysicalExprExt}; use datafusion_physical_plan::metrics; @@ -336,6 +336,20 @@ impl<'schema> PushdownChecker<'schema> { fn prevents_pushdown(&self) -> bool { self.non_primitive_columns || self.projected_columns } + + fn check(&mut self, node: Arc) -> Result { + node.apply_with_lambdas_params(|node, lamdas_params| { + if let Some(column) = node.as_any().downcast_ref::() { + if !lamdas_params.contains(column.name()) { + if let Some(recursion) = self.check_single_column(column.name()) { + return Ok(recursion); + } + } + } + + Ok(TreeNodeRecursion::Continue) + }) + } } impl TreeNodeVisitor<'_> for PushdownChecker<'_> { diff --git a/datafusion/expr/src/expr.rs b/datafusion/expr/src/expr.rs index 13160d573ab4..e2845ea5a7de 100644 --- a/datafusion/expr/src/expr.rs +++ b/datafusion/expr/src/expr.rs @@ -398,6 +398,10 @@ pub enum Expr { OuterReferenceColumn(FieldRef, Column), /// Unnest expression Unnest(Unnest), + /// Lambda expression, valid only as a scalar function argument + /// Note that it has it's own scoped schema, different from the plan schema, + /// that can be constructed with ScalarUDF::arguments_schemas and variants + Lambda(Lambda), } impl Default for Expr { @@ -1211,6 +1215,23 @@ impl GroupingSet { } } +/// Lambda expression. +#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] +pub struct Lambda { + pub params: Vec, + pub body: Box, +} + +impl Lambda { + /// Create a new lambda expression + pub fn new(params: Vec, body: Expr) -> Self { + Self { + params, + body: Box::new(body), + } + } +} + #[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)] #[cfg(not(feature = "sql"))] pub struct IlikeSelectItem { @@ -1525,6 +1546,7 @@ impl Expr { #[expect(deprecated)] Expr::Wildcard { .. } => "Wildcard", Expr::Unnest { .. } => "Unnest", + Expr::Lambda { .. } => "Lambda", } } @@ -1908,9 +1930,11 @@ impl Expr { /// /// See [`Self::column_refs`] for details pub fn add_column_refs<'a>(&'a self, set: &mut HashSet<&'a Column>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - set.insert(col); + if col.relation.is_some() || !lambdas_params.contains(col.name()) { + set.insert(col); + } } Ok(TreeNodeRecursion::Continue) }) @@ -1943,9 +1967,11 @@ impl Expr { /// /// See [`Self::column_refs_counts`] for details pub fn add_column_ref_counts<'a>(&'a self, map: &mut HashMap<&'a Column, usize>) { - self.apply(|expr| { + self.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(col) = expr { - *map.entry(col).or_default() += 1; + if !col.is_lambda_parameter(lambdas_params) { + *map.entry(col).or_default() += 1; + } } Ok(TreeNodeRecursion::Continue) }) @@ -1954,8 +1980,10 @@ impl Expr { /// Returns true if there are any column references in this Expr pub fn any_column_refs(&self) -> bool { - self.exists(|expr| Ok(matches!(expr, Expr::Column(_)))) - .expect("exists closure is infallible") + self.exists_with_lambdas_params(|expr, lambdas_params| { + Ok(matches!(expr, Expr::Column(c) if !c.is_lambda_parameter(lambdas_params))) + }) + .expect("exists closure is infallible") } /// Return true if the expression contains out reference(correlated) expressions. @@ -1995,7 +2023,7 @@ impl Expr { /// at least one placeholder. pub fn infer_placeholder_types(self, schema: &DFSchema) -> Result<(Expr, bool)> { let mut has_placeholder = false; - self.transform(|mut expr| { + self.transform_with_schema(schema, |mut expr, schema| { match &mut expr { // Default to assuming the arguments are the same type Expr::BinaryExpr(BinaryExpr { left, op: _, right }) => { @@ -2078,7 +2106,8 @@ impl Expr { | Expr::Wildcard { .. } | Expr::WindowFunction(..) | Expr::Literal(..) - | Expr::Placeholder(..) => false, + | Expr::Placeholder(..) + | Expr::Lambda { .. } => false, } } @@ -2674,6 +2703,12 @@ impl HashNode for Expr { column.hash(state); } Expr::Unnest(Unnest { expr: _expr }) => {} + Expr::Lambda(Lambda { + params, + body: _, + }) => { + params.hash(state); + } }; } } @@ -2987,6 +3022,12 @@ impl Display for SchemaDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", display_comma_separated(params)) + } } } } @@ -3167,6 +3208,12 @@ impl Display for SqlDisplay<'_> { } } } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {}", params.join(", "), SchemaDisplay(body)) + } _ => write!(f, "{}", self.0), } } @@ -3474,6 +3521,12 @@ impl Display for Expr { Expr::Unnest(Unnest { expr }) => { write!(f, "{UNNEST_COLUMN_PREFIX}({expr})") } + Expr::Lambda(Lambda { + params, + body, + }) => { + write!(f, "({}) -> {body}", params.join(", ")) + } } } } diff --git a/datafusion/expr/src/expr_rewriter/mod.rs b/datafusion/expr/src/expr_rewriter/mod.rs index 9c3c5df7007f..81ec6e7acbe3 100644 --- a/datafusion/expr/src/expr_rewriter/mod.rs +++ b/datafusion/expr/src/expr_rewriter/mod.rs @@ -62,11 +62,15 @@ pub trait FunctionRewrite: Debug { /// Recursively call `LogicalPlanBuilder::normalize` on all [`Column`] expressions /// in the `expr` expression tree. pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Expr::Column(c) = expr { - let col = LogicalPlanBuilder::normalize(plan, c)?; - Transformed::yes(Expr::Column(col)) + if c.relation.is_some() || !lambdas_params.contains(c.name()) { + let col = LogicalPlanBuilder::normalize(plan, c)?; + Transformed::yes(Expr::Column(col)) + } else { + Transformed::no(Expr::Column(c)) + } } else { Transformed::no(expr) } @@ -91,14 +95,21 @@ pub fn normalize_col_with_schemas_and_ambiguity_check( return Ok(Expr::Unnest(Unnest { expr: Box::new(e) })); } - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = expr { - let col = - c.normalize_with_schemas_and_ambiguity_check(schemas, using_columns)?; - Transformed::yes(Expr::Column(col)) - } else { - Transformed::no(expr) + match expr { + Expr::Column(c) => { + if c.relation.is_none() && lambdas_params.contains(c.name()) { + Transformed::no(Expr::Column(c)) + } else { + let col = c.normalize_with_schemas_and_ambiguity_check( + schemas, + using_columns, + )?; + Transformed::yes(Expr::Column(col)) + } + } + _ => Transformed::no(expr), } }) }) @@ -133,15 +144,18 @@ pub fn normalize_sorts( /// Recursively replace all [`Column`] expressions in a given expression tree with /// `Column` expressions provided by the hash map argument. pub fn replace_col(expr: Expr, replace_map: &HashMap<&Column, &Column>) -> Result { - expr.transform(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok({ - if let Expr::Column(c) = &expr { - match replace_map.get(c) { - Some(new_c) => Transformed::yes(Expr::Column((*new_c).to_owned())), - None => Transformed::no(expr), + match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(c) { + Some(new_c) => { + Transformed::yes(Expr::Column((*new_c).to_owned())) + } + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), } }) }) @@ -201,6 +215,7 @@ pub fn strip_outer_reference(expr: Expr) -> Expr { expr.transform(|expr| { Ok({ if let Expr::OuterReferenceColumn(_, col) = expr { + //todo: what if this col collides with a lambda parameter? Transformed::yes(Expr::Column(col)) } else { Transformed::no(expr) diff --git a/datafusion/expr/src/expr_rewriter/order_by.rs b/datafusion/expr/src/expr_rewriter/order_by.rs index 6db95555502d..b94c632ce74b 100644 --- a/datafusion/expr/src/expr_rewriter/order_by.rs +++ b/datafusion/expr/src/expr_rewriter/order_by.rs @@ -77,6 +77,10 @@ fn rewrite_in_terms_of_projection( // assumption is that each item in exprs, such as "b + c" is // available as an output column named "b + c" expr.transform(|expr| { + if matches!(expr, Expr::Lambda(_)) { + return Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)) + } + // search for unnormalized names first such as "c1" (such as aliases) if let Some(found) = proj_exprs.iter().find(|a| (**a) == expr) { let (qualifier, field_name) = found.qualified_name(); diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9e8d6080b82c..4a1efadccd0e 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -16,19 +16,24 @@ // under the License. use super::{Between, Expr, Like}; +use crate::expr::FieldMetadata; use crate::expr::{ AggregateFunction, AggregateFunctionParams, Alias, BinaryExpr, Cast, InList, - InSubquery, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, + InSubquery, Lambda, Placeholder, ScalarFunction, TryCast, Unnest, WindowFunction, WindowFunctionParams, }; use crate::type_coercion::functions::{ - data_types_with_scalar_udf, fields_with_aggregate_udf, fields_with_window_udf, + fields_with_aggregate_udf, fields_with_window_udf, +}; +use crate::{ + type_coercion::functions::data_types_with_scalar_udf, udf::ReturnFieldArgs, utils, + LogicalPlan, Projection, Subquery, WindowFunctionDefinition, +}; +use arrow::datatypes::FieldRef; +use arrow::{ + compute::can_cast_types, + datatypes::{DataType, Field}, }; -use crate::udf::ReturnFieldArgs; -use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; -use arrow::compute::can_cast_types; -use arrow::datatypes::{DataType, Field, FieldRef}; -use datafusion_common::metadata::FieldMetadata; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, Column, DataFusionError, ExprSchema, Result, Spans, TableReference, @@ -229,6 +234,7 @@ impl ExprSchemable for Expr { // Grouping sets do not really have a type and do not appear in projections Ok(DataType::Null) } + Expr::Lambda { .. } => Ok(DataType::Null), } } @@ -347,6 +353,7 @@ impl ExprSchemable for Expr { // in projections Ok(true) } + Expr::Lambda { .. } => Ok(false), } } @@ -535,14 +542,31 @@ impl ExprSchemable for Expr { func.return_field(&new_fields) } + // Expr::Lambda(Lambda { params, body}) => body.to_field(schema), Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, fields): (Vec, Vec>) = args + let fields = if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) { + let lambdas_schemas = func.arguments_expr_schema(args, schema)?; + + std::iter::zip(args, lambdas_schemas) + // .map(|(e, schema)| e.to_field(schema).map(|(_, f)| f)) + .map(|(e, schema)| match e { + Expr::Lambda(Lambda { params: _, body }) => { + body.to_field(&schema).map(|(_, f)| f) + } + _ => e.to_field(&schema).map(|(_, f)| f), + }) + .collect::>>()? + } else { + args.iter() + .map(|e| e.to_field(schema).map(|(_, f)| f)) + .collect::>>()? + }; + + let arg_types = fields .iter() - .map(|e| e.to_field(schema).map(|(_, f)| f)) - .collect::>>()? - .into_iter() - .map(|f| (f.data_type().clone(), f)) - .unzip(); + .map(|f| f.data_type().clone()) + .collect::>(); + // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) .map_err(|err| { @@ -573,9 +597,16 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| matches!(e, Expr::Lambda { .. })) + .collect::>(); + let args = ReturnFieldArgs { arg_fields: &new_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; func.return_field_from_args(args) @@ -600,7 +631,8 @@ impl ExprSchemable for Expr { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::Unnest(_) => Ok(Arc::new(Field::new( + | Expr::Unnest(_) + | Expr::Lambda(_) => Ok(Arc::new(Field::new( &schema_name, self.get_type(schema)?, self.nullable(schema)?, diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 2b7cc9d46ad3..46c7422814ac 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -117,7 +117,12 @@ pub use udaf::{ udaf_default_window_function_schema_name, AggregateUDF, AggregateUDFImpl, ReversedUDAF, SetMonotonicity, StatisticsArgs, }; -pub use udf::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl}; +pub use udf::{ + merge_captures_with_args, merge_captures_with_boxed_lazy_args, + merge_captures_with_lazy_args, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ScalarUDFImpl, ValueOrLambda, ValueOrLambdaField, + ValueOrLambdaParameter, +}; pub use udwf::{LimitEffect, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/tree_node.rs b/datafusion/expr/src/tree_node.rs index 81846b4f8060..63c535b43ee8 100644 --- a/datafusion/expr/src/tree_node.rs +++ b/datafusion/expr/src/tree_node.rs @@ -17,17 +17,20 @@ //! Tree node implementation for Logical Expressions -use crate::expr::{ - AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, Cast, - GroupingSet, InList, InSubquery, Like, Placeholder, ScalarFunction, TryCast, Unnest, - WindowFunction, WindowFunctionParams, +use crate::{ + expr::{ + AggregateFunction, AggregateFunctionParams, Alias, Between, BinaryExpr, Case, + Cast, GroupingSet, InList, InSubquery, Lambda, Like, Placeholder, ScalarFunction, + TryCast, Unnest, WindowFunction, WindowFunctionParams, + }, + Expr, }; -use crate::Expr; - -use datafusion_common::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, +use datafusion_common::{ + tree_node::{ + Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRefContainer, + }, + DFSchema, HashSet, Result, }; -use datafusion_common::Result; /// Implementation of the [`TreeNode`] trait /// @@ -106,6 +109,7 @@ impl TreeNode for Expr { Expr::InList(InList { expr, list, .. }) => { (expr, list).apply_ref_elements(f) } + Expr::Lambda (Lambda{ params: _, body}) => body.apply_elements(f) } } @@ -311,6 +315,686 @@ impl TreeNode for Expr { .update_data(|(new_expr, new_list)| { Expr::InList(InList::new(new_expr, new_list, negated)) }), + Expr::Lambda(Lambda { params, body }) => body + .map_elements(f)? + .update_data(|body| Expr::Lambda(Lambda { params, body })), + }) + } +} + +impl Expr { + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + pub fn rewrite_with_schema< + R: for<'a> TreeNodeRewriterWithPayload = &'a DFSchema>, + >( + self, + schema: &DFSchema, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, schema)? + .transform_children(|n| match &n { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut lambdas_schemas = func + .arguments_schema_from_logical_args(args, schema)? + .into_iter(); + + n.map_children(|n| { + n.rewrite_with_schema(&lambdas_schemas.next().unwrap(), rewriter) + }) + } + _ => n.map_children(|n| n.rewrite_with_schema(schema, rewriter)), + })? + .transform_parent(|n| rewriter.f_up(n, schema)) + } + + /// Similarly to [`Self::rewrite`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn rewrite_with_lambdas_params< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + rewriter: &mut R, + ) -> Result> { + self.rewrite_with_lambdas_params_impl(&HashSet::new(), rewriter) + } + + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn rewrite_with_lambdas_params_impl< + R: for<'a> TreeNodeRewriterWithPayload< + Node = Expr, + Payload<'a> = &'a HashSet, + >, + >( + self, + args: &HashSet, + rewriter: &mut R, + ) -> Result> { + rewriter + .f_down(self, args)? + .transform_children(|n| match n { + Expr::Lambda(Lambda { + ref params, + body: _, + }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + n.map_children(|n| { + n.rewrite_with_lambdas_params_impl(&args, rewriter) + }) + } + _ => { + n.map_children(|n| n.rewrite_with_lambdas_params_impl(args, rewriter)) + } + })? + .transform_parent(|n| rewriter.f_up(n, args)) + } + + /// Similarly to [`Self::map_children`], rewrites all lambdas that may + /// appear in expressions such as `array_transform([1, 2], v -> v*2)`. + /// + /// Returns the current node. + pub fn map_children_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + args: &HashSet, + mut f: F, + ) -> Result> { + match &self { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + self.map_children(|expr| f(expr, &args)) + } + _ => self.map_children(|expr| f(expr, args)), + } + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + node.map_children_with_lambdas_params(args, |node, args| { + transform_up_with_lambdas_params_impl(node, args, f) + })? + .transform_parent(|node| f(node, args)) + /*match &node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().cloned()); + + node.map_children(|n| { + transform_up_with_lambdas_params_impl(n, &args, f) + })? + .transform_parent(|n| f(n, &args)) + } + _ => node + .map_children(|n| transform_up_with_lambdas_params_impl(n, args, f))? + .transform_parent(|n| f(n, args)), + }*/ + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform_down`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_down_with_lambdas_params< + F: FnMut(Self, &HashSet) -> Result>, + >( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut(Expr, &HashSet) -> Result>, + >( + node: Expr, + args: &HashSet, + f: &mut F, + ) -> Result> { + f(node, args)?.transform_children(|node| { + node.map_children_with_lambdas_params(args, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + pub fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut(&'n Expr, &HashSet<&'n str>) -> Result, + >( + node: &'n Expr, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut args = args.clone(); + + args.extend(params.iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + /// Similarly to [`Self::transform`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + /// Similarly to [`Self::transform_up`], rewrites this expr and its inputs using `f`, + /// including lambdas that may appear in expressions such as `array_transform([1, 2], v -> v*2)`. + pub fn transform_up_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut(Expr, &DFSchema) -> Result>, + >( + node: Expr, + schema: &DFSchema, + f: &mut F, + ) -> Result> { + node.map_children_with_schema(schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + pub fn map_children_with_schema< + F: FnMut(Self, &DFSchema) -> Result>, + >( + self, + schema: &DFSchema, + mut f: F, + ) -> Result> { + match self { + Expr::ScalarFunction(ref fun) + if fun.args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + let mut args_schemas = fun + .func + .arguments_schema_from_logical_args(&fun.args, schema)? + .into_iter(); + + self.map_children(|expr| f(expr, &args_schemas.next().unwrap())) + } + _ => self.map_children(|expr| f(expr, schema)), + } + } + + pub fn exists_with_lambdas_params) -> Result>( + &self, + mut f: F, + ) -> Result { + let mut found = false; + + self.apply_with_lambdas_params(|n, lambdas_params| { + if f(n, lambdas_params)? { + found = true; + Ok(TreeNodeRecursion::Stop) + } else { + Ok(TreeNodeRecursion::Continue) + } + })?; + + Ok(found) + } +} + +pub trait ExprWithLambdasRewriter2: Sized { + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up(&mut self, node: Expr, _schema: &DFSchema) -> Result> { + Ok(Transformed::no(node)) + } +} +pub trait TreeNodeRewriterWithPayload: Sized { + type Node; + type Payload<'a>; + + /// Invoked while traversing down the tree before any children are rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_down<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } + + /// Invoked while traversing up the tree after all children have been rewritten. + /// Default implementation returns the node as is and continues recursion. + fn f_up<'a>( + &mut self, + node: Self::Node, + _payload: Self::Payload<'a>, + ) -> Result> { + Ok(Transformed::no(node)) + } +} + +/* +struct LambdaColumnNormalizer<'a> { + existing_qualifiers: HashSet<&'a str>, + alias_generator: AliasGenerator, + lambdas_columns: HashMap>, +} + +impl<'a> LambdaColumnNormalizer<'a> { + fn new(dfschema: &'a DFSchema, expr: &'a Expr) -> Self { + let mut existing_qualifiers: HashSet<&'a str> = dfschema + .field_qualifiers() + .iter() + .flatten() + .map(|tbl| tbl.table()) + .filter(|table| table.starts_with("lambda_")) + .collect(); + + expr.apply(|node| { + if let Expr::Lambda(lambda) = node { + if let Some(qualifier) = &lambda.qualifier { + existing_qualifiers.insert(qualifier); + } + } + + Ok(TreeNodeRecursion::Continue) }) + .unwrap(); + + Self { + existing_qualifiers, + alias_generator: AliasGenerator::new(), + lambdas_columns: HashMap::new(), + } + } +} + +impl TreeNodeRewriter for LambdaColumnNormalizer<'_> { + type Node = Expr; + + fn f_down(&mut self, node: Self::Node) -> Result> { + match node { + Expr::Lambda(mut lambda) => { + let tbl = lambda.qualifier.as_ref().map_or_else( + || loop { + let table = self.alias_generator.next("lambda"); + + if !self.existing_qualifiers.contains(table.as_str()) { + break TableReference::bare(table); + } + }, + |qualifier| TableReference::bare(qualifier.as_str()), + ); + + for param in &lambda.params { + self.lambdas_columns + .entry_ref(param) + .or_default() + .push(tbl.clone()); + } + + if lambda.qualifier.is_none() { + lambda.qualifier = Some(tbl.table().to_owned()); + + Ok(Transformed::yes(Expr::Lambda(lambda))) + } else { + Ok(Transformed::no(Expr::Lambda(lambda))) + } + } + Expr::Column(c) if c.relation.is_none() => { + if let Some(lambda_qualifier) = self.lambdas_columns.get(c.name()) { + Ok(Transformed::yes(Expr::Column( + c.with_relation(lambda_qualifier.last().unwrap().clone()), + ))) + } else { + Ok(Transformed::no(Expr::Column(c))) + } + } + _ => Ok(Transformed::no(node)) + } + } + + fn f_up(&mut self, node: Self::Node) -> Result> { + if let Expr::Lambda(lambda) = &node { + for param in &lambda.params { + match self.lambdas_columns.entry_ref(param) { + EntryRef::Occupied(mut entry) => { + let chain = entry.get_mut(); + + chain.pop(); + + if chain.is_empty() { + entry.remove(); + } + } + EntryRef::Vacant(_) => unreachable!(), + } + } + } + + Ok(Transformed::no(node)) + } +} +*/ + +// helpers used in udf.rs +#[cfg(test)] +pub(crate) mod tests { + use super::TreeNodeRewriterWithPayload; + use crate::{ + col, expr::Lambda, Expr, ScalarUDF, ScalarUDFImpl, ValueOrLambdaParameter, + }; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{ + tree_node::{Transformed, TreeNodeRecursion}, + DFSchema, HashSet, Result, + }; + use datafusion_expr_common::signature::{Signature, Volatility}; + + pub(crate) fn list_list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )])) + .unwrap() + } + + pub(crate) fn list_int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )])) + .unwrap() + } + + fn int() -> DFSchema { + DFSchema::try_from(Schema::new(vec![Field::new("v", DataType::Int32, false)])) + .unwrap() + } + + pub(crate) fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + pub(crate) fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Expr { + array_transform_udf().call(args()) + } + + #[derive(Debug, PartialEq, Eq, Hash)] + pub(crate) struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unreachable!() + }; + + let DataType::List(field) = value_field.data_type() else { + unreachable!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args( + &self, + _args: crate::ScalarFunctionArgs, + ) -> Result { + unimplemented!() + } + } + + #[test] + fn test_rewrite_with_schema() { + let schema = list_list_int(); + let array_transform = array_transform(); + + let mut rewriter = OkRewriter::default(); + + array_transform + .rewrite_with_schema(&schema, &mut rewriter) + .unwrap(); + + let expected = [ + ( + "f_down array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ("f_down v", list_list_int()), + ("f_up v", list_list_int()), + ("f_down (v) -> array_transform(v, (v) -> (- v))", list_int()), + ("f_down array_transform(v, (v) -> (- v))", list_int()), + ("f_down v", list_int()), + ("f_up v", list_int()), + ("f_down (v) -> (- v)", int()), + ("f_down (- v)", int()), + ("f_down v", int()), + ("f_up v", int()), + ("f_up (- v)", int()), + ("f_up (v) -> (- v)", int()), + ("f_up array_transform(v, (v) -> (- v))", list_int()), + ("f_up (v) -> array_transform(v, (v) -> (- v))", list_int()), + ( + "f_up array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + list_list_int(), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(rewriter.steps, expected) + } + + #[derive(Default)] + struct OkRewriter { + steps: Vec<(String, DFSchema)>, + } + + impl TreeNodeRewriterWithPayload for OkRewriter { + type Node = Expr; + type Payload<'a> = &'a DFSchema; + + fn f_down( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_down {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + + fn f_up( + &mut self, + node: Expr, + schema: &DFSchema, + ) -> Result> { + self.steps.push((format!("f_up {node}"), schema.clone())); + + Ok(Transformed::no(node)) + } + } + + #[test] + fn test_transform_up_with_lambdas_params() { + let mut steps = vec![]; + + array_transform() + .transform_up_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(Transformed::no(node)) + }) + .unwrap(); + + let lambdas_params = &HashSet::from([String::from("v")]); + + let expected = [ + ("v", lambdas_params), + ("v", lambdas_params), + ("v", lambdas_params), + ("(- v)", lambdas_params), + ("(v) -> (- v)", lambdas_params), + ("array_transform(v, (v) -> (- v))", lambdas_params), + ("(v) -> array_transform(v, (v) -> (- v))", lambdas_params), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + lambdas_params, + ), + ] + .map(|(a, b)| (String::from(a), b.clone())); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("v", HashSet::from(["v"])), + ("(- v)", HashSet::from(["v"])), + ("(v) -> (- v)", HashSet::from(["v"])), + ("array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ("(v) -> array_transform(v, (v) -> (- v))", HashSet::from(["v"])), + ( + "array_transform(v, (v) -> array_transform(v, (v) -> (- v)))", + HashSet::from(["v"]), + ), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); } } diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index fd54bb13a62f..74ac1b456ff0 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -18,21 +18,30 @@ //! [`ScalarUDF`]: Scalar User Defined Functions use crate::async_udf::AsyncScalarUDF; -use crate::expr::schema_name_from_exprs_comma_separated_without_space; +use crate::expr::{schema_name_from_exprs_comma_separated_without_space, Lambda}; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::udf_eq::UdfEq; -use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Field, FieldRef}; +use crate::{ColumnarValue, Documentation, Expr, ExprSchemable, Signature}; +use arrow::array::{ArrayRef, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Fields, Schema}; +use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; -use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{ + exec_err, not_impl_err, DFSchema, ExprSchema, Result, ScalarValue, +}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::interval_arithmetic::Interval; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use indexmap::IndexMap; use std::any::Any; +use std::borrow::Cow; use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::{Hash, Hasher}; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; /// Logical representation of a Scalar User Defined Function. /// @@ -343,6 +352,272 @@ impl ScalarUDF { pub fn as_async(&self) -> Option<&AsyncScalarUDF> { self.inner().as_any().downcast_ref::() } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead + pub(crate) fn arguments_expr_schema<'a>( + &self, + args: &[Expr], + schema: &'a dyn ExprSchema, + ) -> Result> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + ExtendableExprSchema::new(schema), + ) + } + + /// Variation of arguments_from_logical_args that works with arrow Schema's and ScalarFunctionArgMetadata instead, + pub fn arguments_arrow_schema<'a>( + &self, + args: &[ValueOrLambdaParameter], + schema: &'a Schema, + ) -> Result>> { + self.arguments_scope_with(args, Cow::Borrowed(schema)) + } + + pub fn arguments_schema_from_logical_args<'a>( + &self, + args: &[Expr], + schema: &'a DFSchema, + ) -> Result>> { + self.arguments_scope_with( + &lambda_parameters(args, schema)?, + Cow::Borrowed(schema), + ) + } + + /// Scalar function supports lambdas as arguments, which will be evaluated with + /// a different schema that of the function itself. This functions returns a vec + /// with the correspoding schema that each argument will run + /// + /// Return a vec with a value for each argument in args that, if it's a value, it's a clone of base_scope, + /// if it's a lambda, it's the return of merge called with the index and the fields from lambdas_parameters + /// updated with names from metadata + fn arguments_scope_with( + &self, + args: &[ValueOrLambdaParameter], + schema: T, + ) -> Result> { + let parameters = self.inner().lambdas_parameters(args)?; + + if parameters.len() != args.len() { + return exec_err!( + "lambdas_schemas: {} lambdas_parameters returned {} values instead of {}", + self.name(), + args.len(), + parameters.len() + ); + } + + std::iter::zip(args, parameters) + .enumerate() + .map(|(i, (arg, parameters))| match (arg, parameters) { + (ValueOrLambdaParameter::Value(_), None) => Ok(schema.clone()), + (ValueOrLambdaParameter::Value(_), Some(_)) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a value but lambdas_parameters result treat it as a lambda", self.name(), i), + (ValueOrLambdaParameter::Lambda(_, _), None) => exec_err!("lambdas_schemas: {} argument {} (0-indexed) is a lambda but lambdas_parameters result treat it as a value", self.name(), i), + (ValueOrLambdaParameter::Lambda(names, captures), Some(args)) => { + if names.len() > args.len() { + return exec_err!("lambdas_schemas: {} argument {} (0-indexed), a lambda, supports up to {} arguments, but got {}", self.name(), i, args.len(), names.len()) + } + + let fields = std::iter::zip(*names, args) + .map(|(name, arg)| arg.with_name(name)) + .collect::(); + + if *captures { + schema.extend(fields) + } else { + T::from_fields(fields) + } + } + }) + .collect() + } +} + +pub trait ExtendSchema: Sized { + fn from_fields(params: Fields) -> Result; + fn extend(&self, params: Fields) -> Result; +} + +impl ExtendSchema for DFSchema { + fn from_fields(params: Fields) -> Result { + DFSchema::from_unqualified_fields(params, Default::default()) + } + + fn extend(&self, params: Fields) -> Result { + let qualified_fields = self + .iter() + .map(|(qualifier, field)| { + if params.find(field.name().as_str()).is_none() { + return (qualifier.cloned(), Arc::clone(field)); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && !self.has_column_with_unqualified_name(&alias) + { + return ( + qualifier.cloned(), + Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )), + ); + } + } + }) + .collect(); + + let mut schema = DFSchema::new_with_metadata(qualified_fields, HashMap::new())?; + let fields_schema = DFSchema::from_unqualified_fields(params, HashMap::new())?; + + schema.merge(&fields_schema); + + assert_eq!( + schema.fields().len(), + self.fields().len() + fields_schema.fields().len() + ); + + Ok(schema) + } +} + +impl ExtendSchema for Schema { + fn from_fields(params: Fields) -> Result { + Ok(Schema::new(params)) + } + + fn extend(&self, params: Fields) -> Result { + let mut params2 = params.iter() + .map(|f| (f.name().as_str(), Some(Arc::clone(f)))) + .collect::>(); + + let mut fields = self.fields() + .iter() + .map(|field| { + match params2.get_mut(field.name().as_str()).and_then(|p| p.take()) { + Some(param) => param, + None => Arc::clone(field), + } + }) + .collect::>(); + + fields.extend(params2.into_values().flatten()); + + let fields = self + .fields() + .iter() + .map(|field| { + if params.find(field.name().as_str()).is_none() { + return Arc::clone(field); + } + + let alias_gen = AliasGenerator::new(); + + loop { + let alias = alias_gen.next(field.name().as_str()); + + if params.find(&alias).is_none() + && self.column_with_name(&alias).is_none() + { + return Arc::new(Field::new( + alias, + field.data_type().clone(), + field.is_nullable(), + )); + } + } + }) + .chain(params.iter().cloned()) + .collect::(); + + assert_eq!(fields.len(), self.fields().len() + params.len()); + + Ok(Schema::new_with_metadata(fields, self.metadata.clone())) + } +} + +impl ExtendSchema for Cow<'_, T> { + fn from_fields(params: Fields) -> Result { + Ok(Cow::Owned(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Cow::Owned(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for Arc { + fn from_fields(params: Fields) -> Result { + Ok(Arc::new(T::from_fields(params)?)) + } + + fn extend(&self, params: Fields) -> Result { + Ok(Arc::new(self.as_ref().extend(params)?)) + } +} + +impl ExtendSchema for ExtendableExprSchema<'_> { + fn from_fields(params: Fields) -> Result { + static EMPTY_DFSCHEMA: LazyLock = LazyLock::new(DFSchema::empty); + + Ok(ExtendableExprSchema { + fields_chain: vec![params], + outer_schema: &*EMPTY_DFSCHEMA, + }) + } + + fn extend(&self, params: Fields) -> Result { + Ok(ExtendableExprSchema { + fields_chain: std::iter::once(params) + .chain(self.fields_chain.iter().cloned()) + .collect(), + outer_schema: self.outer_schema, + }) + } +} + +/// A `&dyn ExprSchema` wrapper that supports adding the parameters of a lambda +#[derive(Clone, Debug)] +struct ExtendableExprSchema<'a> { + fields_chain: Vec, + outer_schema: &'a dyn ExprSchema, +} + +impl<'a> ExtendableExprSchema<'a> { + fn new(schema: &'a dyn ExprSchema) -> Self { + Self { + fields_chain: vec![], + outer_schema: schema, + } + } +} + +impl ExprSchema for ExtendableExprSchema<'_> { + fn field_from_column(&self, col: &datafusion_common::Column) -> Result<&Field> { + if col.relation.is_none() { + for fields in &self.fields_chain { + if let Some((_index, lambda_param)) = fields.find(&col.name) { + return Ok(lambda_param); + } + } + } + + self.outer_schema.field_from_column(col) + } +} + +#[derive(Clone, Debug)] +pub enum ValueOrLambdaParameter<'a> { + /// A columnar value with the given field + Value(FieldRef), + /// A lambda with the given parameters names and a flag indicating wheter it captures any columns + Lambda(&'a [String], bool), } impl From for ScalarUDF @@ -359,6 +634,7 @@ where #[derive(Debug, Clone)] pub struct ScalarFunctionArgs { /// The evaluated arguments to the function + /// If it's a lambda, will be `ColumnarValue::Scalar(ScalarValue::Null)` pub args: Vec, /// Field associated with each arg, if it exists pub arg_fields: Vec, @@ -370,6 +646,30 @@ pub struct ScalarFunctionArgs { pub return_field: FieldRef, /// The config options at execution time pub config_options: Arc, + /// The lambdas passed to the function + /// If it's not a lambda it will be `None` + pub lambdas: Option>>, +} + +/// A lambda argument to a ScalarFunction +#[derive(Clone, Debug)] +pub struct ScalarFunctionLambdaArg { + /// The parameters defined in this lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be `vec![Field::new("v", DataType::Int32, true)]` + pub params: Vec, + /// The body of the lambda + /// + /// For example, for `array_transform([2], v -> -v)`, + /// this will be the physical expression of `-v` + pub body: Arc, + /// A RecordBatch containing at least the captured columns inside this lambda body, if any + /// Note that it may contain additional, non-specified columns, but that's implementation detail + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + pub captures: Option, } impl ScalarFunctionArgs { @@ -378,6 +678,25 @@ impl ScalarFunctionArgs { pub fn return_type(&self) -> &DataType { self.return_field.data_type() } + + pub fn to_lambda_args(&self) -> Vec> { + match &self.lambdas { + Some(lambdas) => std::iter::zip(&self.args, lambdas) + .map(|(arg, lambda)| match lambda { + Some(lambda) => ValueOrLambda::Lambda(lambda), + None => ValueOrLambda::Value(arg), + }) + .collect(), + None => self.args.iter().map(ValueOrLambda::Value).collect(), + } + } +} + +// An argument to a ScalarUDF that supports lambdas +#[derive(Debug)] +pub enum ValueOrLambda<'a> { + Value(&'a ColumnarValue), + Lambda(&'a ScalarFunctionLambdaArg), } /// Information about arguments passed to the function @@ -390,6 +709,12 @@ impl ScalarFunctionArgs { #[derive(Debug)] pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function + /// + /// If argument `i` to the function is a lambda, it will be the field returned by the + /// lambda when executed with the arguments returned from `ScalarUDFImpl::lambdas_parameters` + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[Field::new("", DataType::List(DataType::Int32), false), Field::new("", DataType::Boolean, false)]` pub arg_fields: &'a [FieldRef], /// Is argument `i` to the function a scalar (constant)? /// @@ -398,6 +723,36 @@ pub struct ReturnFieldArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], + /// Is argument `i` to the function a lambda? + /// + /// For example, with `array_transform([1], v -> v == 5)` + /// this field will be `[false, true]` + pub lambdas: &'a [bool], +} + +/// A tagged Field indicating whether it correspond to a value or a lambda argument +#[derive(Debug)] +pub enum ValueOrLambdaField<'a> { + /// The Field of a ColumnarValue argument + Value(&'a FieldRef), + /// The Field of the return of the lambda body when evaluated with the parameters from ScalarUDF::lambda_parameters + Lambda(&'a FieldRef), +} + +impl<'a> ReturnFieldArgs<'a> { + /// Based on self.lambdas, encodes self.arg_fields to tagged enums + /// indicating whether it correspond to a value or a lambda argument + pub fn to_lambda_args(&self) -> Vec> { + std::iter::zip(self.arg_fields, self.lambdas) + .map(|(field, is_lambda)| { + if *is_lambda { + ValueOrLambdaField::Lambda(field) + } else { + ValueOrLambdaField::Value(field) + } + }) + .collect() + } } /// Trait for implementing user defined scalar functions. @@ -841,6 +1196,14 @@ pub trait ScalarUDFImpl: Debug + DynEq + DynHash + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// Returns the parameters that any lambda supports + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + Ok(vec![None; args.len()]) + } } /// ScalarUDF that adds an alias to the underlying function. It is better to @@ -959,6 +1322,118 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { fn documentation(&self) -> Option<&Documentation> { self.inner.documentation() } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + self.inner.lambdas_parameters(args) + } +} + +fn lambda_parameters<'a>( + args: &'a [Expr], + schema: &dyn ExprSchema, +) -> Result>> { + args.iter() + .map(|e| match e { + Expr::Lambda(Lambda { params, body: _ }) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| match expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(params.as_slice(), captures)) + } + _ => Ok(ValueOrLambdaParameter::Value(e.to_field(schema)?.1)), + }) + .collect() +} + +/// Merge the lambda body captured columns with it's arguments +/// Datafusion relies on an unspecified field ordering implemented in this function +/// As such, this is the only correct way to merge the captured values with the arguments +/// The number of args should not be lower than the number of params +/// +/// See also merge_captures_with_lazy_args and merge_captures_with_boxed_lazy_args that lazily +/// computes only the necessary arguments to match the number of params +pub fn merge_captures_with_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[ArrayRef], +) -> Result { + if args.len() < params.len() { + return exec_err!( + "merge_captures_with_args called with {} params but with {} args", + params.len(), + args.len() + ); + } + + // the order of the merged batch must be kept in sync with ScalarFunction::lambdas_schemas variants + let (fields, columns) = match captures { + Some(captures) => { + let fields = captures + .schema() + .fields() + .iter() + .chain(params) + .cloned() + .collect::>(); + + let columns = [captures.columns(), args].concat(); + + (fields, columns) + } + None => (params.to_vec(), args.to_vec()), + }; + + Ok(RecordBatch::try_new( + Arc::new(Schema::new(fields)), + columns, + )?) +} + +/// Lazy version of merge_captures_with_args that receives closures to compute the arguments, +/// and calls only the necessary to match the number of params +pub fn merge_captures_with_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[&dyn Fn() -> Result], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) +} + +/// Variation of merge_captures_with_lazy_args that take boxed closures +pub fn merge_captures_with_boxed_lazy_args( + captures: Option<&RecordBatch>, + params: &[FieldRef], + args: &[Box Result>], +) -> Result { + merge_captures_with_args( + captures, + params, + &args + .iter() + .take(params.len()) + .map(|arg| arg()) + .collect::>>()?, + ) } #[cfg(test)] @@ -1039,4 +1514,83 @@ mod tests { value.hash(hasher); hasher.finish() } + + use std::borrow::Cow; + + use arrow::datatypes::Fields; + + use crate::{ + tree_node::tests::{args, list_int, list_list_int, array_transform_udf}, + udf::{lambda_parameters, ExtendableExprSchema}, + }; + + #[test] + fn test_arguments_expr_schema() { + let args = args(); + let schema = list_list_int(); + + let schemas = array_transform_udf() + .arguments_expr_schema(&args, &schema) + .unwrap() + .into_iter() + .map(|s| format!("{s:?}")) + .collect::>(); + + let mut lambdas_parameters = array_transform_udf() + .inner() + .lambdas_parameters(&lambda_parameters(&args, &schema).unwrap()) + .unwrap(); + + assert_eq!( + schemas, + &[ + format!("{}", &list_list_int()), + format!( + "{:?}", + ExtendableExprSchema { + fields_chain: vec![Fields::from( + lambdas_parameters[0].take().unwrap() + )], + outer_schema: &list_list_int() + } + ), + ] + ) + } + + #[test] + fn test_arguments_arrow_schema() { + let list_int = list_int(); + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_arrow_schema( + &lambda_parameters(&args(), &list_list_int).unwrap(), + //&[HashSet::new(), HashSet::from([0])], + list_list_int.as_arrow(), + ) + .unwrap(); + + assert_eq!( + schemas, + &[ + Cow::Borrowed(list_list_int.as_arrow()), + Cow::Owned(list_int.as_arrow().clone()) + ] + ) + } + + #[test] + fn test_arguments_schema_from_logical_args() { + let list_list_int = list_list_int(); + + let schemas = array_transform_udf() + .arguments_schema_from_logical_args(&args(), &list_list_int) + .unwrap(); + + assert_eq!( + schemas, + &[Cow::Borrowed(&list_list_int), Cow::Owned(list_int())] + ) + } } diff --git a/datafusion/expr/src/utils.rs b/datafusion/expr/src/utils.rs index cd733e0a130a..93fcfaef882f 100644 --- a/datafusion/expr/src/utils.rs +++ b/datafusion/expr/src/utils.rs @@ -266,10 +266,12 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result> { /// Recursively walk an expression tree, collecting the unique set of columns /// referenced in the expression pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { match expr { Expr::Column(qc) => { - accum.insert(qc.clone()); + if qc.relation.is_some() || !lambdas_params.contains(qc.name()) { + accum.insert(qc.clone()); + } } // Use explicit pattern match instead of a default // implementation, so that in the future if someone adds @@ -307,7 +309,8 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet) -> Result<()> { | Expr::ScalarSubquery(_) | Expr::Wildcard { .. } | Expr::Placeholder(_) - | Expr::OuterReferenceColumn { .. } => {} + | Expr::OuterReferenceColumn { .. } + | Expr::Lambda { .. } => {} } Ok(TreeNodeRecursion::Continue) }) @@ -650,6 +653,7 @@ where /// Search an `Expr`, and all of its nested `Expr`'s, for any that pass the /// provided test. The returned `Expr`'s are deduplicated and returned in order /// of appearance (depth first). +/// todo: document about that columns may refer to a lambda parameter? fn find_exprs_in_expr(expr: &Expr, test_fn: &F) -> Vec where F: Fn(&Expr) -> bool, @@ -672,6 +676,7 @@ where } /// Recursively inspect an [`Expr`] and all its children. +/// todo: document about that columns may refer to a lambda parameter? pub fn inspect_expr_pre(expr: &Expr, mut f: F) -> Result<(), E> where F: FnMut(&Expr) -> Result<(), E>, @@ -743,13 +748,19 @@ pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result { _ => return Ok(e), }; let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect(); - e.transform_down(|node: Expr| match exprs_map.get(&node) { - Some(column) => Ok(Transformed::new( - Expr::Column(column.clone()), - true, - TreeNodeRecursion::Jump, - )), - None => Ok(Transformed::no(node)), + e.transform_down_with_lambdas_params(|node: Expr, lambdas_params| { + if matches!(&node, Expr::Column(c) if c.is_lambda_parameter(lambdas_params)) { + return Ok(Transformed::no(node)); + } + + match exprs_map.get(&node) { + Some(column) => Ok(Transformed::new( + Expr::Column(column.clone()), + true, + TreeNodeRecursion::Jump, + )), + None => Ok(Transformed::no(node)), + } }) .data() } @@ -766,9 +777,11 @@ pub fn find_column_exprs(exprs: &[Expr]) -> Vec { pub(crate) fn find_columns_referenced_by_expr(e: &Expr) -> Vec { let mut exprs = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - exprs.push(c.clone()) + if !c.is_lambda_parameter(lambdas_params) { + exprs.push(c.clone()) + } } Ok(TreeNodeRecursion::Continue) }) @@ -797,9 +810,9 @@ pub(crate) fn find_column_indexes_referenced_by_expr( schema: &DFSchemaRef, ) -> Vec { let mut indexes = vec![]; - e.apply(|expr| { + e.apply_with_lambdas_params(|expr, lambdas_params| { match expr { - Expr::Column(qc) => { + Expr::Column(qc) if !qc.is_lambda_parameter(lambdas_params) => { if let Ok(idx) = schema.index_of_column(qc) { indexes.push(idx); } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 5e59cfc5ecb0..400ad4469604 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -33,7 +33,7 @@ use arrow::{ }; use arrow_schema::FieldRef; use datafusion::config::ConfigOptions; -use datafusion::logical_expr::ReturnFieldArgs; +use datafusion::{common::exec_err, logical_expr::ReturnFieldArgs}; use datafusion::{ error::DataFusionError, logical_expr::type_coercion::functions::data_types_with_scalar_udf, @@ -210,6 +210,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( return_field, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = rresult_return!(udf @@ -382,10 +383,15 @@ impl ScalarUDFImpl for ForeignScalarUDF { arg_fields, number_rows, return_field, + lambdas, // TODO: pass config options: https://github.com/apache/datafusion/issues/17035 config_options: _config_options, } = invoke_args; + if lambdas.is_some_and(|lambdas| lambdas.iter().any(|l| l.is_some())) { + return exec_err!("ForeignScalarUDF doesn't support lambdas"); + } + let args = args .into_iter() .map(|v| v.to_array(number_rows)) diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index c437c9537be6..d5cbfff1d3a4 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -21,7 +21,7 @@ use abi_stable::{ }; use arrow_schema::FieldRef; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, + common::{exec_datafusion_err, exec_err}, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; @@ -42,6 +42,10 @@ impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; fn try_from(value: ReturnFieldArgs) -> Result { + if value.lambdas.iter().any(|l| *l) { + return exec_err!("FFI_ReturnFieldArgs doesn't support lambdas") + } + let arg_fields = vec_fieldref_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments @@ -77,6 +81,7 @@ pub struct ForeignReturnFieldArgsOwned { pub struct ForeignReturnFieldArgs<'a> { arg_fields: &'a [FieldRef], scalar_arguments: Vec>, + lambdas: Vec, // currently always false, used to return a reference in From<&Self> for ReturnFieldArgs } impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { @@ -116,6 +121,7 @@ impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { .iter() .map(|opt| opt.as_ref()) .collect(), + lambdas: vec![false; value.arg_fields.len()] } } } @@ -125,6 +131,7 @@ impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { ReturnFieldArgs { arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, + lambdas: &value.lambdas, } } } diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs new file mode 100644 index 000000000000..700fed477b4c --- /dev/null +++ b/datafusion/functions-nested/src/array_transform.rs @@ -0,0 +1,266 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`ScalarUDFImpl`] definitions for array_transform function. + +use arrow::{ + array::{Array, ArrayRef, AsArray, FixedSizeListArray, LargeListArray, ListArray}, + compute::take_record_batch, + datatypes::{DataType, Field}, +}; +use datafusion_common::{ + HashSet, Result, exec_err, internal_err, tree_node::{Transformed, TreeNode}, utils::{elements_indices, list_indices, list_values, take_function_args} +}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ScalarFunctionArgs, ScalarUDFImpl, Signature, ValueOrLambda, ValueOrLambdaField, ValueOrLambdaParameter, Volatility, expr::Lambda, merge_captures_with_lazy_args +}; +use datafusion_macros::user_doc; +use std::{any::Any, sync::Arc}; + +make_udf_expr_and_func!( + ArrayTransform, + array_transform, + array lambda, + "transforms the values of a array", + array_transform_udf +); + +#[user_doc( + doc_section(label = "Array Functions"), + description = "transforms the values of a array", + syntax_example = "array_transform(array, x -> x*2)", + sql_example = r#"```sql +> select array_transform([1, 2, 3, 4, 5], x -> x*2); ++-------------------------------------------+ +| array_transform([1, 2, 3, 4, 5], x -> x*2) | ++-------------------------------------------+ +| [2, 4, 6, 8, 10] | ++-------------------------------------------+ +```"#, + argument( + name = "array", + description = "Array expression. Can be a constant, column, or function, and any combination of array operators." + ), + argument(name = "lambda", description = "Lambda") +)] +#[derive(Debug, PartialEq, Eq, Hash)] +pub struct ArrayTransform { + signature: Signature, + aliases: Vec, +} + +impl Default for ArrayTransform { + fn default() -> Self { + Self::new() + } +} + +impl ArrayTransform { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + aliases: vec![String::from("list_transform")], + } + } +} + +impl ScalarUDFImpl for ArrayTransform { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn aliases(&self) -> &[String] { + &self.aliases + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + internal_err!("return_type called instead of return_field_from_args") + } + + fn return_field_from_args( + &self, + args: datafusion_expr::ReturnFieldArgs, + ) -> Result> { + let args = args.to_lambda_args(); + + let [ValueOrLambdaField::Value(list), ValueOrLambdaField::Lambda(lambda)] = + take_function_args(self.name(), &args)? + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + //TODO: should metadata be passed? If so, with the same keys or prefixed/suffixed? + + // lambda is the resulting field of executing the lambda body + // with the parameters returned in lambdas_parameters + let field = Arc::new(Field::new( + Field::LIST_FIELD_DEFAULT_NAME, + lambda.data_type().clone(), + lambda.is_nullable(), + )); + + let return_type = match list.data_type() { + DataType::List(_) => DataType::List(field), + DataType::LargeList(_) => DataType::LargeList(field), + DataType::FixedSizeList(_, size) => DataType::FixedSizeList(field, *size), + _ => unreachable!(), + }; + + Ok(Arc::new(Field::new("", return_type, list.is_nullable()))) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + // args.lambda_args allows the convenient match below, instead of inspecting both args.args and args.lambdas + let lambda_args = args.to_lambda_args(); + let [list_value, lambda] = take_function_args(self.name(), &lambda_args)?; + + let (ValueOrLambda::Value(list_value), ValueOrLambda::Lambda(lambda)) = + (list_value, lambda) + else { + return exec_err!( + "{} expects a value followed by a lambda, got {:?}", + self.name(), + &lambda_args + ); + }; + + let list_array = list_value.to_array(args.number_rows)?; + + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with mulitple values and removing values of empty lists + // list_indices is not cheap so is important to avoid it when no column is captured + let adjusted_captures = lambda + .captures + .as_ref() + .map(|captures| take_record_batch(captures, &list_indices(&list_array)?)) + .transpose()?; + + // use closures and merge_captures_with_lazy_args so that it calls only the needed ones based on the number of arguments + // avoiding unnecessary computations + let values_param = || Ok(Arc::clone(list_values(&list_array)?)); + let indices_param = || elements_indices(&list_array); + + // the order of the merged schema is an unspecified implementation detail that may change in the future, + // using this function is the correct way to merge as it return the correct ordering and will change in sync + // the implementation without the need for fixes. It also computes only the parameters requested + let lambda_batch = merge_captures_with_lazy_args( + adjusted_captures.as_ref(), + &lambda.params, // ScalarUDF already merged the fields returned in lambdas_parameters with the parameters names definied in the lambda, so we don't need to + &[&values_param, &indices_param], + )?; + + // call the transforming expression with the record batch composed of the list values merged with captured columns + let transformed_values = lambda + .body + .evaluate(&lambda_batch)? + .into_array(lambda_batch.num_rows())?; + + let field = match args.return_field.data_type() { + DataType::List(field) + | DataType::LargeList(field) + | DataType::FixedSizeList(field, _) => Arc::clone(field), + _ => { + return exec_err!( + "{} expected ScalarFunctionArgs.return_field to be a list, got {}", + self.name(), + args.return_field + ) + } + }; + + let transformed_list = match list_array.data_type() { + DataType::List(_) => { + let list = list_array.as_list(); + + Arc::new(ListArray::new( + field, + list.offsets().clone(), + transformed_values, + list.nulls().cloned(), + )) as ArrayRef + } + DataType::LargeList(_) => { + let large_list = list_array.as_list(); + + Arc::new(LargeListArray::new( + field, + large_list.offsets().clone(), + transformed_values, + large_list.nulls().cloned(), + )) + } + DataType::FixedSizeList(_, value_length) => { + Arc::new(FixedSizeListArray::new( + field, + *value_length, + transformed_values, + list_array.as_fixed_size_list().nulls().cloned(), + )) + } + other => exec_err!("expected list, got {other}")?, + }; + + Ok(ColumnarValue::Array(transformed_list)) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let [ValueOrLambdaParameter::Value(list), ValueOrLambdaParameter::Lambda(_, _)] = + args + else { + return exec_err!( + "{} expects a value follewed by a lambda, got {:?}", + self.name(), + args + ); + }; + + let (field, index_type) = match list.data_type() { + DataType::List(field) => (field, DataType::Int32), + DataType::LargeList(field) => (field, DataType::Int64), + DataType::FixedSizeList(field, _) => (field, DataType::Int32), + _ => return exec_err!("expected list, got {list}"), + }; + + // we don't need to omit the index in the case the lambda don't specify, e.g. array_transform([], v -> v*2), + // nor check whether the lambda contains more than two parameters, e.g. array_transform([], (v, i, j) -> v+i+j), + // as datafusion will do that for us + let value = Field::new("value", field.data_type().clone(), field.is_nullable()) + .with_metadata(field.metadata().clone()); + let index = Field::new("index", index_type, false); + + Ok(vec![None, Some(vec![value, index])]) + } + + fn documentation(&self) -> Option<&Documentation> { + self.doc() + } +} diff --git a/datafusion/optimizer/src/analyzer/function_rewrite.rs b/datafusion/optimizer/src/analyzer/function_rewrite.rs index c6bf14ebce2e..0e5e602f8238 100644 --- a/datafusion/optimizer/src/analyzer/function_rewrite.rs +++ b/datafusion/optimizer/src/analyzer/function_rewrite.rs @@ -19,7 +19,7 @@ use super::AnalyzerRule; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{DFSchema, Result}; use crate::utils::NamePreserver; @@ -64,15 +64,16 @@ impl ApplyFunctionRewrites { let original_name = name_preserver.save(&expr); // recursively transform the expression, applying the rewrites at each step - let transformed_expr = expr.transform_up(|expr| { - let mut result = Transformed::no(expr); - for rewriter in self.function_rewrites.iter() { - result = result.transform_data(|expr| { - rewriter.rewrite(expr, &schema, options) - })?; - } - Ok(result) - })?; + let transformed_expr = + expr.transform_up_with_schema(&schema, |expr, schema| { + let mut result = Transformed::no(expr); + for rewriter in self.function_rewrites.iter() { + result = result.transform_data(|expr| { + rewriter.rewrite(expr, schema, options) + })?; + } + Ok(result) + })?; Ok(transformed_expr.update_data(|expr| original_name.restore(expr))) }) diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs index 4fb0f8553b4b..1b82182e8600 100644 --- a/datafusion/optimizer/src/analyzer/type_coercion.rs +++ b/datafusion/optimizer/src/analyzer/type_coercion.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use datafusion_expr::binary::BinaryTypeCoercer; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use itertools::{izip, Itertools as _}; use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; @@ -27,7 +28,7 @@ use arrow::datatypes::{DataType, Field, IntervalUnit, Schema}; use crate::analyzer::AnalyzerRule; use crate::utils::NamePreserver; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRewriter}; +use datafusion_common::tree_node::Transformed; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue, @@ -140,7 +141,7 @@ fn analyze_internal( // apply coercion rewrite all expressions in the plan individually plan.map_expressions(|expr| { let original_name = name_preserver.save(&expr); - expr.rewrite(&mut expr_rewrite) + expr.rewrite_with_schema(&schema, &mut expr_rewrite) .map(|transformed| transformed.update_data(|e| original_name.restore(e))) })? // some plans need extra coercion after their expressions are coerced @@ -304,10 +305,11 @@ impl<'a> TypeCoercionRewriter<'a> { } } -impl TreeNodeRewriter for TypeCoercionRewriter<'_> { +impl TreeNodeRewriterWithPayload for TypeCoercionRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, schema: &DFSchema) -> Result> { match expr { Expr::Unnest(_) => not_impl_err!( "Unnest should be rewritten to LogicalPlan::Unnest before type coercion" @@ -318,7 +320,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans, }) => { let new_plan = - analyze_internal(self.schema, Arc::unwrap_or_clone(subquery))?.data; + analyze_internal(schema, Arc::unwrap_or_clone(subquery))?.data; Ok(Transformed::yes(Expr::ScalarSubquery(Subquery { subquery: Arc::new(new_plan), outer_ref_columns, @@ -327,7 +329,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::Exists(Exists { subquery, negated }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; @@ -346,11 +348,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { negated, }) => { let new_plan = analyze_internal( - self.schema, + schema, Arc::unwrap_or_clone(subquery.subquery), )? .data; - let expr_type = expr.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; let subquery_type = new_plan.schema().field(0).data_type(); let common_type = comparison_coercion(&expr_type, subquery_type).ok_or( plan_datafusion_err!( @@ -363,32 +365,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { spans: subquery.spans, }; Ok(Transformed::yes(Expr::InSubquery(InSubquery::new( - Box::new(expr.cast_to(&common_type, self.schema)?), + Box::new(expr.cast_to(&common_type, schema)?), cast_subquery(new_subquery, &common_type)?, negated, )))) } Expr::Not(expr) => Ok(Transformed::yes(not(get_casted_expr_for_bool_op( *expr, - self.schema, + schema, )?))), Expr::IsTrue(expr) => Ok(Transformed::yes(is_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsFalse(expr) => Ok(Transformed::yes(is_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown( - get_casted_expr_for_bool_op(*expr, self.schema)?, + get_casted_expr_for_bool_op(*expr, schema)?, ))), Expr::Like(Like { negated, @@ -397,8 +399,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { escape_char, case_insensitive, }) => { - let left_type = expr.get_type(self.schema)?; - let right_type = pattern.get_type(self.schema)?; + let left_type = expr.get_type(schema)?; + let right_type = pattern.get_type(schema)?; let coerced_type = like_coercion(&left_type, &right_type).ok_or_else(|| { let op_name = if case_insensitive { "ILIKE" @@ -411,9 +413,9 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { })?; let expr = match left_type { DataType::Dictionary(_, inner) if *inner == DataType::Utf8 => expr, - _ => Box::new(expr.cast_to(&coerced_type, self.schema)?), + _ => Box::new(expr.cast_to(&coerced_type, schema)?), }; - let pattern = Box::new(pattern.cast_to(&coerced_type, self.schema)?); + let pattern = Box::new(pattern.cast_to(&coerced_type, schema)?); Ok(Transformed::yes(Expr::Like(Like::new( negated, expr, @@ -424,7 +426,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } Expr::BinaryExpr(BinaryExpr { left, op, right }) => { let (left, right) = - self.coerce_binary_op(*left, self.schema, op, *right, self.schema)?; + self.coerce_binary_op(*left, schema, op, *right, schema)?; Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new( Box::new(left), op, @@ -437,15 +439,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { low, high, }) => { - let expr_type = expr.get_type(self.schema)?; - let low_type = low.get_type(self.schema)?; + let expr_type = expr.get_type(schema)?; + let low_type = low.get_type(schema)?; let low_coerced_type = comparison_coercion(&expr_type, &low_type) .ok_or_else(|| { internal_datafusion_err!( "Failed to coerce types {expr_type} and {low_type} in BETWEEN expression" ) })?; - let high_type = high.get_type(self.schema)?; + let high_type = high.get_type(schema)?; let high_coerced_type = comparison_coercion(&expr_type, &high_type) .ok_or_else(|| { internal_datafusion_err!( @@ -460,10 +462,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ) })?; Ok(Transformed::yes(Expr::Between(Between::new( - Box::new(expr.cast_to(&coercion_type, self.schema)?), + Box::new(expr.cast_to(&coercion_type, schema)?), negated, - Box::new(low.cast_to(&coercion_type, self.schema)?), - Box::new(high.cast_to(&coercion_type, self.schema)?), + Box::new(low.cast_to(&coercion_type, schema)?), + Box::new(high.cast_to(&coercion_type, schema)?), )))) } Expr::InList(InList { @@ -471,10 +473,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { list, negated, }) => { - let expr_data_type = expr.get_type(self.schema)?; + let expr_data_type = expr.get_type(schema)?; let list_data_types = list .iter() - .map(|list_expr| list_expr.get_type(self.schema)) + .map(|list_expr| list_expr.get_type(schema)) .collect::>>()?; let result_type = get_coerce_type_for_list(&expr_data_type, &list_data_types); @@ -484,11 +486,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { ), Some(coerced_type) => { // find the coerced type - let cast_expr = expr.cast_to(&coerced_type, self.schema)?; + let cast_expr = expr.cast_to(&coerced_type, schema)?; let cast_list_expr = list .into_iter() .map(|list_expr| { - list_expr.cast_to(&coerced_type, self.schema) + list_expr.cast_to(&coerced_type, schema) }) .collect::>>()?; Ok(Transformed::yes(Expr::InList(InList ::new( @@ -500,13 +502,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { } } Expr::Case(case) => { - let case = coerce_case_expression(case, self.schema)?; + let case = coerce_case_expression(case, schema)?; Ok(Transformed::yes(Expr::Case(case))) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let new_expr = coerce_arguments_for_signature_with_scalar_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::ScalarFunction( @@ -526,7 +528,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }) => { let new_expr = coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, &func, )?; Ok(Transformed::yes(Expr::AggregateFunction( @@ -555,13 +557,13 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { }, } = *window_fun; let window_frame = - coerce_window_frame(window_frame, self.schema, &order_by)?; + coerce_window_frame(window_frame, schema, &order_by)?; let args = match &fun { expr::WindowFunctionDefinition::AggregateUDF(udf) => { coerce_arguments_for_signature_with_aggregate_udf( args, - self.schema, + schema, udf, )? } @@ -597,7 +599,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> { | Expr::Wildcard { .. } | Expr::GroupingSet(_) | Expr::Placeholder(_) - | Expr::OuterReferenceColumn(_, _) => Ok(Transformed::no(expr)), + | Expr::OuterReferenceColumn(_, _) + | Expr::Lambda { .. } => Ok(Transformed::no(expr)), } } } @@ -793,9 +796,11 @@ fn coerce_arguments_for_signature_with_scalar_udf( return Ok(expressions); } - let current_types = expressions - .iter() - .map(|e| e.get_type(schema)) + let current_types = expressions.iter() + .map(|e| match e { + Expr::Lambda { .. } => Ok(DataType::Null), + _ => e.get_type(schema), + }) .collect::>>()?; let new_types = data_types_with_scalar_udf(¤t_types, func)?; @@ -803,7 +808,10 @@ fn coerce_arguments_for_signature_with_scalar_udf( expressions .into_iter() .enumerate() - .map(|(i, expr)| expr.cast_to(&new_types[i], schema)) + .map(|(i, expr)| match expr { + lambda @ Expr::Lambda { .. } => Ok(lambda), + _ => expr.cast_to(&new_types[i], schema), + }) .collect() } @@ -1125,7 +1133,7 @@ mod test { use crate::analyzer::Analyzer; use crate::assert_analyzed_plan_with_config_eq_snapshot; use datafusion_common::config::ConfigOptions; - use datafusion_common::tree_node::{TransformedResult, TreeNode}; + use datafusion_common::tree_node::{TransformedResult}; use datafusion_common::{DFSchema, DFSchemaRef, Result, ScalarValue, Spans}; use datafusion_expr::expr::{self, InSubquery, Like, ScalarFunction}; use datafusion_expr::logical_plan::{EmptyRelation, Projection, Sort}; @@ -2076,7 +2084,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).gt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).gt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // eq @@ -2087,7 +2095,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).eq(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).eq(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); // lt @@ -2098,7 +2106,7 @@ mod test { let mut rewriter = TypeCoercionRewriter { schema: &schema }; let expr = is_true(lit(12i32).lt(lit(13i64))); let expected = is_true(cast(lit(12i32), DataType::Int64).lt(lit(13i64))); - let result = expr.rewrite(&mut rewriter).data()?; + let result = expr.rewrite_with_schema(&schema, &mut rewriter).data()?; assert_eq!(expected, result); Ok(()) diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 251006849459..e06ed6e547eb 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -29,7 +29,7 @@ use datafusion_common::alias::AliasGenerator; use datafusion_common::cse::{CSEController, FoundCommonNodes, CSE}; use datafusion_common::tree_node::{Transformed, TreeNode}; -use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, Result}; +use datafusion_common::{qualified_name, Column, DFSchema, DFSchemaRef, HashSet, Result}; use datafusion_expr::expr::{Alias, ScalarFunction}; use datafusion_expr::logical_plan::{ Aggregate, Filter, LogicalPlan, Projection, Sort, Window, @@ -632,6 +632,7 @@ struct ExprCSEController<'a> { // how many aliases have we seen so far alias_counter: usize, + lambdas_params: HashSet, } impl<'a> ExprCSEController<'a> { @@ -640,6 +641,7 @@ impl<'a> ExprCSEController<'a> { alias_generator, mask, alias_counter: 0, + lambdas_params: HashSet::new(), } } } @@ -693,11 +695,30 @@ impl CSEController for ExprCSEController<'_> { } } + fn visit_f_down(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + self.lambdas_params + .extend(lambda.params.iter().cloned()); + } + } + + fn visit_f_up(&mut self, node: &Expr) { + if let Expr::Lambda(lambda) = node { + for param in &lambda.params { + self.lambdas_params.remove(param); + } + } + } + fn is_valid(node: &Expr) -> bool { !node.is_volatile_node() } fn is_ignored(&self, node: &Expr) -> bool { + if matches!(node, Expr::Column(c) if c.is_lambda_parameter(&self.lambdas_params)) { + return true + } + // TODO: remove the next line after `Expr::Wildcard` is removed #[expect(deprecated)] let is_normal_minus_aggregates = matches!( diff --git a/datafusion/optimizer/src/decorrelate.rs b/datafusion/optimizer/src/decorrelate.rs index 63236787743a..0f4374183400 100644 --- a/datafusion/optimizer/src/decorrelate.rs +++ b/datafusion/optimizer/src/decorrelate.rs @@ -527,18 +527,17 @@ fn proj_exprs_evaluation_result_on_empty_batch( for expr in proj_expr.iter() { let result_expr = expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { if let Some(result_expr) = - input_expr_result_map_for_count_bug.get(name) + input_expr_result_map_for_count_bug.get(col.name()) { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; @@ -570,16 +569,17 @@ fn filter_exprs_evaluation_result_on_empty_batch( ) -> Result> { let result_expr = filter_expr .clone() - .transform_up(|expr| { - if let Expr::Column(Column { name, .. }) = &expr { - if let Some(result_expr) = input_expr_result_map_for_count_bug.get(name) { + .transform_up_with_lambdas_params(|expr, lambdas_params| match &expr { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { + if let Some(result_expr) = + input_expr_result_map_for_count_bug.get(col.name()) + { Ok(Transformed::yes(result_expr.clone())) } else { Ok(Transformed::no(expr)) } - } else { - Ok(Transformed::no(expr)) } + _ => Ok(Transformed::no(expr)), }) .data()?; diff --git a/datafusion/optimizer/src/optimize_projections/mod.rs b/datafusion/optimizer/src/optimize_projections/mod.rs index 5db71417bc8f..f0187b618ccc 100644 --- a/datafusion/optimizer/src/optimize_projections/mod.rs +++ b/datafusion/optimizer/src/optimize_projections/mod.rs @@ -639,7 +639,7 @@ fn is_expr_trivial(expr: &Expr) -> bool { /// --Source(a, b) /// ``` fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { - expr.transform_up(|expr| { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { match expr { // remove any intermediate aliases if they do not carry metadata Expr::Alias(alias) => { @@ -653,7 +653,7 @@ fn rewrite_expr(expr: Expr, input: &Projection) -> Result> { false => Ok(Transformed::no(Expr::Alias(alias))), } } - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { // Find index of column: let idx = input.schema.index_of_column(&col)?; // get the corresponding unaliased input expression diff --git a/datafusion/optimizer/src/push_down_filter.rs b/datafusion/optimizer/src/push_down_filter.rs index 1c0790b3e3ac..54cb02654327 100644 --- a/datafusion/optimizer/src/push_down_filter.rs +++ b/datafusion/optimizer/src/push_down_filter.rs @@ -293,7 +293,8 @@ fn can_evaluate_as_join_condition(predicate: &Expr) -> Result { Expr::AggregateFunction(_) | Expr::WindowFunction(_) | Expr::Wildcard { .. } - | Expr::GroupingSet(_) => internal_err!("Unsupported predicate type"), + | Expr::GroupingSet(_) + | Expr::Lambda { .. } => internal_err!("Unsupported predicate type"), })?; Ok(is_evaluate) } @@ -1389,14 +1390,15 @@ pub fn replace_cols_by_name( e: Expr, replace_map: &HashMap, ) -> Result { - e.transform_up(|expr| { - Ok(if let Expr::Column(c) = &expr { - match replace_map.get(&c.flat_name()) { - Some(new_c) => Transformed::yes(new_c.clone()), - None => Transformed::no(expr), + e.transform_up_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match replace_map.get(&c.flat_name()) { + Some(new_c) => Transformed::yes(new_c.clone()), + None => Transformed::no(expr), + } } - } else { - Transformed::no(expr) + _ => Transformed::no(expr), }) }) .data() @@ -1405,17 +1407,18 @@ pub fn replace_cols_by_name( /// check whether the expression uses the columns in `check_map`. fn contain(e: &Expr, check_map: &HashMap) -> bool { let mut is_contain = false; - e.apply(|expr| { - Ok(if let Expr::Column(c) = &expr { - match check_map.get(&c.flat_name()) { - Some(_) => { - is_contain = true; - TreeNodeRecursion::Stop + e.apply_with_lambdas_params(|expr, lambdas_params| { + Ok(match &expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { + match check_map.get(&c.flat_name()) { + Some(_) => { + is_contain = true; + TreeNodeRecursion::Stop + } + None => TreeNodeRecursion::Continue, } - None => TreeNodeRecursion::Continue, } - } else { - TreeNodeRecursion::Continue + _ => TreeNodeRecursion::Continue, }) }) .unwrap(); diff --git a/datafusion/optimizer/src/scalar_subquery_to_join.rs b/datafusion/optimizer/src/scalar_subquery_to_join.rs index 48d118252701..f1e619750f9c 100644 --- a/datafusion/optimizer/src/scalar_subquery_to_join.rs +++ b/datafusion/optimizer/src/scalar_subquery_to_join.rs @@ -106,17 +106,22 @@ impl OptimizerRule for ScalarSubqueryToJoin { { if !expr_check_map.is_empty() { rewrite_expr = rewrite_expr - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = expr - .try_as_col() - .and_then(|col| expr_check_map.get(&col.name)) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter(lambdas_params) + }) + .and_then(|col| expr_check_map.get(&col.name)) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; } cur_input = optimized_subquery; @@ -171,18 +176,26 @@ impl OptimizerRule for ScalarSubqueryToJoin { { let new_expr = rewrite_expr .clone() - .transform_up(|expr| { - // replace column references with entry in map, if it exists - if let Some(map_expr) = - expr.try_as_col().and_then(|col| { - expr_check_map.get(&col.name) - }) - { - Ok(Transformed::yes(map_expr.clone())) - } else { - Ok(Transformed::no(expr)) - } - }) + .transform_up_with_lambdas_params( + |expr, lambdas_params| { + // replace column references with entry in map, if it exists + if let Some(map_expr) = expr + .try_as_col() + .filter(|c| { + !c.is_lambda_parameter( + lambdas_params, + ) + }) + .and_then(|col| { + expr_check_map.get(&col.name) + }) + { + Ok(Transformed::yes(map_expr.clone())) + } else { + Ok(Transformed::no(expr)) + } + }, + ) .data()?; expr_to_rewrite_expr_map.insert(expr, new_expr); } @@ -396,8 +409,12 @@ fn build_join( let mut expr_rewrite = TypeCoercionRewriter { schema: new_plan.schema(), }; - computation_project_expr - .insert(name, computer_expr.rewrite(&mut expr_rewrite).data()?); + computation_project_expr.insert( + name, + computer_expr + .rewrite_with_schema(new_plan.schema(), &mut expr_rewrite) + .data()?, + ); } } diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs index 05b8c28fadd6..a824f6b7be49 100644 --- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs +++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs @@ -17,27 +17,30 @@ //! Expression simplification API +use std::collections::HashSet; +use std::ops::Not; +use std::{borrow::Cow, sync::Arc}; + use arrow::{ array::{new_null_array, AsArray}, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; -use std::borrow::Cow; -use std::collections::HashSet; -use std::ops::Not; -use std::sync::Arc; +use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ cast::{as_large_list_array, as_list_array}, metadata::FieldMetadata, - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{ + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + }, }; use datafusion_common::{ exec_datafusion_err, internal_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::{ - and, binary::BinaryTypeCoercer, lit, or, BinaryExpr, Case, ColumnarValue, Expr, Like, - Operator, Volatility, + and, binary::BinaryTypeCoercer, lit, or, simplify::SimplifyContext, BinaryExpr, Case, + ColumnarValue, Expr, Like, Operator, Volatility, }; use datafusion_expr::{expr::ScalarFunction, interval_arithmetic::NullableInterval}; use datafusion_expr::{ @@ -267,7 +270,7 @@ impl ExprSimplifier { /// documentation for more details on type coercion pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } /// Input guarantees about the values of columns. @@ -649,7 +652,8 @@ impl<'a> ConstEvaluator<'a> { | Expr::WindowFunction { .. } | Expr::GroupingSet(_) | Expr::Wildcard { .. } - | Expr::Placeholder(_) => false, + | Expr::Placeholder(_) + | Expr::Lambda { .. } => false, Expr::ScalarFunction(ScalarFunction { func, .. }) => { Self::volatility_ok(func.signature().volatility) } @@ -754,6 +758,89 @@ impl<'a, S> Simplifier<'a, S> { impl TreeNodeRewriter for Simplifier<'_, S> { type Node = Expr; + fn f_down(&mut self, expr: Self::Node) -> Result> { + match expr { + Expr::ScalarFunction(ScalarFunction { func, args }) + if args.iter().any(|arg| matches!(arg, Expr::Lambda(_))) => + { + // there's currently no way to adapt a generic SimplifyInfo with lambda parameters, + // so, if the scalar function has any lambda, we materialize a DFSchema using all the + // columns references in every arguments. Than we can call lambdas_schemas_from_args, + // and for each argument, we create a new SimplifyContext with the scoped schema, and + // simplify the argument using this 'sub-context'. Finally, we set Transformed.tnr to + // Jump so the parent context doesn't try to simplify the argument again, without the + // parameters info + + // get all columns references + let mut columns_refs = HashSet::new(); + + for arg in &args { + arg.add_column_refs(&mut columns_refs); + } + + // materialize columns references into qualified fields + let qualified_fields = columns_refs + .into_iter() + .map(|captured_column| { + let expr = Expr::Column(captured_column.clone()); + + Ok(( + captured_column.relation.clone(), + Arc::new(Field::new( + captured_column.name(), + self.info.get_data_type(&expr)?, + self.info.nullable(&expr)?, + )), + )) + }) + .collect::>()?; + + // create a schema using the materialized fields + let dfschema = + DFSchema::new_with_metadata(qualified_fields, Default::default())?; + + let mut scoped_schemas = func + .arguments_schema_from_logical_args(&args, &dfschema)? + .into_iter(); + + let transformed_args = args + .map_elements(|arg| { + let scoped_schema = scoped_schemas.next().unwrap(); + + // create a sub-context, using the scoped schema, that includes information about the lambda parameters + let simplify_context = + SimplifyContext::new(self.info.execution_props()) + .with_schema(Arc::new(scoped_schema.into_owned())); + + let mut simplifier = Simplifier::new(&simplify_context); + + // simplify the argument using it's context + arg.rewrite(&mut simplifier) + })? + .update_data(|args| { + Expr::ScalarFunction(ScalarFunction { func, args }) + }); + + Ok(Transformed::new( + transformed_args.data, + transformed_args.transformed, + // return at least Jump so the parent contex doesn't try again to simplify the arguments + // (and fail because it doesn't contain info about lambdas paramters) + match transformed_args.tnr { + TreeNodeRecursion::Continue | TreeNodeRecursion::Jump => { + TreeNodeRecursion::Jump + } + TreeNodeRecursion::Stop => TreeNodeRecursion::Stop, + }, + )) + + // Ok(transformed_args.update_data(|args| Expr::ScalarFunction(ScalarFunction { func, args}))) + } + // Expr::Lambda(_) => Ok(Transformed::new(expr, false, TreeNodeRecursion::Jump)), + _ => Ok(Transformed::no(expr)), + } + } + /// rewrite the expression simplifying any constant expressions fn f_up(&mut self, expr: Expr) -> Result> { use datafusion_expr::Operator::{ diff --git a/datafusion/optimizer/src/utils.rs b/datafusion/optimizer/src/utils.rs index 81763fa0552f..d0ae4932628f 100644 --- a/datafusion/optimizer/src/utils.rs +++ b/datafusion/optimizer/src/utils.rs @@ -23,7 +23,7 @@ use crate::analyzer::type_coercion::TypeCoercionRewriter; use arrow::array::{new_null_array, Array, RecordBatch}; use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; -use datafusion_common::tree_node::{TransformedResult, TreeNode}; +use datafusion_common::tree_node::TransformedResult; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr_rewriter::replace_col; @@ -148,7 +148,7 @@ fn evaluate_expr_with_null_column<'a>( fn coerce(expr: Expr, schema: &DFSchema) -> Result { let mut expr_rewrite = TypeCoercionRewriter { schema }; - expr.rewrite(&mut expr_rewrite).data() + expr.rewrite_with_schema(schema, &mut expr_rewrite).data() } #[cfg(test)] diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 61cc97dae300..4a81a5c99ac7 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -21,12 +21,14 @@ use std::sync::Arc; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, FieldRef, Schema, SchemaRef}; +use datafusion_common::HashSet; use datafusion_common::{ exec_err, - tree_node::{Transformed, TransformedResult, TreeNode}, + tree_node::{Transformed, TransformedResult}, Result, ScalarValue, }; use datafusion_functions::core::getfield::GetFieldFunc; +use datafusion_physical_expr::PhysicalExprExt; use datafusion_physical_expr::{ expressions::{self, CastExpr, Column}, ScalarFunctionExpr, @@ -217,8 +219,10 @@ impl PhysicalExprAdapter for DefaultPhysicalExprAdapter { physical_file_schema: &self.physical_file_schema, partition_fields: &self.partition_values, }; - expr.transform(|expr| rewriter.rewrite_expr(Arc::clone(&expr))) - .data() + expr.transform_with_lambdas_params(|expr, lambdas_params| { + rewriter.rewrite_expr(Arc::clone(&expr), lambdas_params) + }) + .data() } fn with_partition_values( @@ -242,13 +246,18 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn rewrite_expr( &self, expr: Arc, + lambdas_params: &HashSet, ) -> Result>> { - if let Some(transformed) = self.try_rewrite_struct_field_access(&expr)? { + if let Some(transformed) = + self.try_rewrite_struct_field_access(&expr, lambdas_params)? + { return Ok(Transformed::yes(transformed)); } if let Some(column) = expr.as_any().downcast_ref::() { - return self.rewrite_column(Arc::clone(&expr), column); + if !lambdas_params.contains(column.name()) { + return self.rewrite_column(Arc::clone(&expr), column); + } } Ok(Transformed::no(expr)) @@ -260,6 +269,7 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { fn try_rewrite_struct_field_access( &self, expr: &Arc, + lambdas_params: &HashSet, ) -> Result>> { let get_field_expr = match ScalarFunctionExpr::try_downcast_func::(expr.as_ref()) { @@ -291,8 +301,8 @@ impl<'a> DefaultPhysicalExprAdapterRewriter<'a> { }; let column = match source_expr.as_any().downcast_ref::() { - Some(column) => column, - None => return Ok(None), + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => return Ok(None), }; let physical_field = @@ -446,6 +456,7 @@ mod tests { use super::*; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; + use datafusion_common::hashbrown::HashSet; use datafusion_common::{assert_contains, record_batch, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{col, lit, CastExpr, Column, Literal}; @@ -852,7 +863,9 @@ mod tests { // Test that when a field exists in physical schema, it returns None let column = Arc::new(Column::new("struct_col", 0)) as Arc; - let result = rewriter.try_rewrite_struct_field_access(&column).unwrap(); + let result = rewriter + .try_rewrite_struct_field_access(&column, &HashSet::new()) + .unwrap(); assert!(result.is_none()); // The actual test for the get_field expression would require creating a proper ScalarFunctionExpr diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index b7654a0f6f60..d4c0e1cbe6eb 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -37,6 +37,9 @@ workspace = true [lib] name = "datafusion_physical_expr" +[features] +recursive_protection = ["dep:recursive"] + [dependencies] ahash = { workspace = true } arrow = { workspace = true } @@ -52,6 +55,7 @@ itertools = { workspace = true, features = ["use_std"] } parking_lot = { workspace = true } paste = "^1.0" petgraph = "0.8.3" +recursive = { workspace = true, optional = true } [dev-dependencies] arrow = { workspace = true, features = ["test_utils"] } diff --git a/datafusion/physical-expr/src/async_scalar_function.rs b/datafusion/physical-expr/src/async_scalar_function.rs index b434694a20cc..a34d3cda4768 100644 --- a/datafusion/physical-expr/src/async_scalar_function.rs +++ b/datafusion/physical-expr/src/async_scalar_function.rs @@ -168,6 +168,7 @@ impl AsyncFuncExpr { number_rows: current_batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); @@ -187,6 +188,7 @@ impl AsyncFuncExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .await?, ); diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 9ca464b30430..c55f42ae333b 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -22,12 +22,13 @@ use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; use arrow::datatypes::FieldRef; use arrow::{ datatypes::{DataType, Schema, SchemaRef}, record_batch::RecordBatch, }; -use datafusion_common::tree_node::{Transformed, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_err, plan_err, Result}; use datafusion_expr::ColumnarValue; @@ -67,7 +68,8 @@ use datafusion_expr::ColumnarValue; pub struct Column { /// The name of the column (used for debugging and display purposes) name: String, - /// The index of the column in its schema + /// The index of the column in its schema. + /// Within a lambda body, this refer to the lambda scoped schema, not the plan schema. index: usize, } @@ -178,9 +180,9 @@ pub fn with_new_schema( expr: Arc, schema: &SchemaRef, ) -> Result> { - Ok(expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { + expr.transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let idx = col.index(); let Some(field) = schema.fields().get(idx) else { return plan_err!( @@ -188,12 +190,13 @@ pub fn with_new_schema( ); }; let new_col = Column::new(field.name(), idx); + Ok(Transformed::yes(Arc::new(new_col) as _)) - } else { - Ok(Transformed::no(expr)) } - })? - .data) + _ => Ok(Transformed::no(expr)), + } + }) + .data() } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs new file mode 100644 index 000000000000..55110fdf5bf6 --- /dev/null +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Physical column reference: [`Column`] + +use std::hash::Hash; +use std::sync::Arc; +use std::{any::Any, sync::OnceLock}; + +use crate::expressions::Column; +use crate::physical_expr::PhysicalExpr; +use crate::PhysicalExprExt; +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion_common::tree_node::TreeNodeRecursion; +use datafusion_common::{internal_err, HashSet, Result}; +use datafusion_expr::ColumnarValue; + +/// Represents a lambda with the given parameters name and body +#[derive(Debug, Eq, Clone)] +pub struct LambdaExpr { + params: Vec, + body: Arc, + captures: OnceLock>, +} + +// Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] +impl PartialEq for LambdaExpr { + fn eq(&self, other: &Self) -> bool { + self.params.eq(&other.params) && self.body.eq(&other.body) + } +} + +impl Hash for LambdaExpr { + fn hash(&self, state: &mut H) { + self.params.hash(state); + self.body.hash(state); + } +} + +impl LambdaExpr { + /// Create a new lambda expression with the given parameters and body + pub fn new(params: Vec, body: Arc) -> Self { + Self { + params, + body, + captures: OnceLock::new(), + } + } + + /// Get the lambda's params names + pub fn params(&self) -> &[String] { + &self.params + } + + /// Get the lambda's body + pub fn body(&self) -> &Arc { + &self.body + } + + pub fn captures(&self) -> &HashSet { + self.captures.get_or_init(|| { + let mut indices = HashSet::new(); + + self.body + .apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = expr.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + indices.insert(column.index()); + } + } + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + indices + }) + } +} + +impl std::fmt::Display for LambdaExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} + +impl PhysicalExpr for LambdaExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Null) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("Lambda::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.body] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(Self { + params: self.params.clone(), + body: Arc::clone(&children[0]), + captures: OnceLock::new(), + })) + } + + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "({}) -> {}", self.params.join(", "), self.body) + } +} diff --git a/datafusion/physical-expr/src/expressions/mod.rs b/datafusion/physical-expr/src/expressions/mod.rs index 59d675753d98..e87941da5ef4 100644 --- a/datafusion/physical-expr/src/expressions/mod.rs +++ b/datafusion/physical-expr/src/expressions/mod.rs @@ -27,6 +27,7 @@ mod dynamic_filters; mod in_list; mod is_not_null; mod is_null; +mod lambda; mod like; mod literal; mod negative; @@ -49,6 +50,7 @@ pub use dynamic_filters::DynamicFilterPhysicalExpr; pub use in_list::{in_list, InListExpr}; pub use is_not_null::{is_not_null, IsNotNullExpr}; pub use is_null::{is_null, IsNullExpr}; +pub use lambda::LambdaExpr; pub use like::{like, LikeExpr}; pub use literal::{lit, Literal}; pub use negative::{negative, NegativeExpr}; diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index aa8c9e50fd71..873205f28bef 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -70,6 +70,8 @@ pub use datafusion_physical_expr_common::sort_expr::{ pub use planner::{create_physical_expr, create_physical_exprs}; pub use scalar_function::ScalarFunctionExpr; +pub use scalar_function::PhysicalExprExt; + pub use simplifier::PhysicalExprSimplifier; pub use utils::{conjunction, conjunction_opt, split_conjunction}; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index c658a8eddc23..2584fc22885c 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -18,11 +18,11 @@ use std::sync::Arc; use crate::expressions::{self, Column}; -use crate::{create_physical_expr, LexOrdering, PhysicalSortExpr}; +use crate::{create_physical_expr, LexOrdering, PhysicalExprExt, PhysicalSortExpr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{plan_err, Result}; use datafusion_common::{DFSchema, HashMap}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,14 +38,14 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { - Some(col) => { + expr.transform_down_with_lambdas_params(|e, lambdas_params| match e.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); }; Ok(Transformed::yes(Arc::new(Column::new(col.name(), idx)))) } - None => Ok(Transformed::no(e)), + _ => Ok(Transformed::no(e)), }) .data() } diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 7790380dffd5..0119c81b8ed9 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -17,6 +17,7 @@ use std::sync::Arc; +use crate::expressions::LambdaExpr; use crate::ScalarFunctionExpr; use crate::{ expressions::{self, binary, like, similar_to, Column, Literal}, @@ -30,7 +31,7 @@ use datafusion_common::{ exec_err, not_impl_err, plan_err, DFSchema, Result, ScalarValue, ToDFSchema, }; use datafusion_expr::execution_props::ExecutionProps; -use datafusion_expr::expr::{Alias, Cast, InList, Placeholder, ScalarFunction}; +use datafusion_expr::expr::{Alias, Cast, InList, Lambda, Placeholder, ScalarFunction}; use datafusion_expr::var_provider::is_system_variables; use datafusion_expr::var_provider::VarType; use datafusion_expr::{ @@ -104,7 +105,8 @@ use datafusion_expr::{ /// /// * `e` - The logical expression /// * `input_dfschema` - The DataFusion schema for the input, used to resolve `Column` references -/// to qualified or unqualified fields by name. +/// to qualified or unqualified fields by name. Note that for creating a lambda, this must be +/// scoped lambda schema, and not the outer schema pub fn create_physical_expr( e: &Expr, input_dfschema: &DFSchema, @@ -314,9 +316,28 @@ pub fn create_physical_expr( input_dfschema, execution_props, )?), + Expr::Lambda { .. } => { + exec_err!("Expr::Lambda should be handled by Expr::ScalarFunction, as it can only exist within it") + } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let physical_args = - create_physical_exprs(args, input_dfschema, execution_props)?; + let lambdas_schemas = + func.arguments_schema_from_logical_args(args, input_dfschema)?; + + let physical_args = std::iter::zip(args, lambdas_schemas) + .map(|(expr, schema)| match expr { + Expr::Lambda(Lambda { params, body }) => { + Ok(Arc::new(LambdaExpr::new( + params.clone(), + create_physical_expr(body, &schema, execution_props)?, + )) as Arc) + } + expr => create_physical_expr(expr, &schema, execution_props), + }) + .collect::>>()?; + + //let physical_args = + // create_physical_exprs(args, input_dfschema, execution_props)?; + let config_options = match execution_props.config_options.as_ref() { Some(config_options) => Arc::clone(config_options), None => Arc::new(ConfigOptions::default()), diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index a120ab427e1d..70be717a8436 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -20,11 +20,11 @@ use std::sync::Arc; use crate::expressions::Column; use crate::utils::collect_columns; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::{internal_datafusion_err, internal_err, plan_err, Result}; use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; @@ -499,13 +499,16 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr| { + .transform_up_with_lambdas_params(|expr, lambdas_params| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } - let Some(column) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); + let column = match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => column, + _ => { + return Ok(Transformed::no(expr)); + } }; if sync_with_child { state = RewriteState::RewrittenValid; @@ -616,14 +619,14 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down_with_schema(input_schema, |e, schema| match e.as_any().downcast_ref::() { Some(col) => { - // Sometimes, an expression and its name in the input_schema + // Sometimes, an expression and its name in the schema // doesn't match. This can cause problems, so we make sure - // that the expression name matches with the name in `input_schema`. + // that the expression name matches with the name in `schema`. // Conceptually, `source_expr` and `expression` should be the same. let idx = col.index(); - let matching_field = input_schema.field(idx); + let matching_field = schema.field(idx); let matching_name = matching_field.name(); if col.name() != matching_name { return internal_err!( @@ -737,21 +740,25 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = Arc::clone(expr).transform_up(|expr| { - let Some(col) = expr.as_any().downcast_ref::() else { - return Ok(Transformed::no(expr)); - }; + let transformed = + Arc::clone(expr).transform_up_with_lambdas_params(|expr, lambdas_params| { + let col = match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => col, + _ => { + return Ok(Transformed::no(expr)); + } + }; - let name = col.name(); - if let Some((idx, _)) = schema.column_with_name(name) { - // Compute the new column expression (with correct index) after projection: - Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) - } else { - // Cannot find expression in the projected_schema, - // signal this using an Err result - plan_err!("") - } - }); + let name = col.name(); + if let Some((idx, _)) = schema.column_with_name(name) { + // Compute the new column expression (with correct index) after projection: + Ok(Transformed::yes(Arc::new(Column::new(name, idx)))) + } else { + // Cannot find expression in the projected_schema, + // signal this using an Err result + plan_err!("") + } + }); match transformed { Ok(transformed) => { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 743d5b99cde9..22fa300f05df 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,23 +30,25 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; +use std::borrow::Cow; use std::fmt::{self, Debug, Formatter}; use std::hash::{Hash, Hasher}; use std::sync::Arc; -use crate::expressions::Literal; +use crate::expressions::{Column, LambdaExpr, Literal}; use crate::PhysicalExpr; -use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, FieldRef, Schema}; +use arrow::array::{Array, NullArray, RecordBatch}; +use arrow::datatypes::{DataType, Field, FieldRef, Schema}; use datafusion_common::config::{ConfigEntry, ConfigOptions}; -use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::{internal_err, HashSet, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - Volatility, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, + ScalarFunctionLambdaArg, ScalarUDF, ValueOrLambdaParameter, Volatility, }; /// Physical expression of a scalar function @@ -94,10 +96,16 @@ impl ScalarFunctionExpr { schema: &Schema, config_options: Arc, ) -> Result { - let name = fun.name().to_string(); - let arg_fields = args - .iter() - .map(|e| e.return_field(schema)) + let lambdas_schemas = lambdas_schemas_from_args(&fun, &args, schema)?; + + let arg_fields = std::iter::zip(&args, lambdas_schemas) + .map(|(e, schema)| { + if let Some(lambda) = e.as_any().downcast_ref::() { + lambda.body().return_field(&schema) + } else { + e.return_field(&schema) + } + }) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -105,6 +113,7 @@ impl ScalarFunctionExpr { .iter() .map(|f| f.data_type().clone()) .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args @@ -115,11 +124,21 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); + + let lambdas = args + .iter() + .map(|e| e.as_any().is::()) + .collect::>(); + let ret_args = ReturnFieldArgs { arg_fields: &arg_fields, scalar_arguments: &arguments, + lambdas: &lambdas, }; + let return_field = fun.return_field_from_args(ret_args)?; + let name = fun.name().to_string(); + Ok(Self { fun, name, @@ -260,7 +279,10 @@ impl PhysicalExpr for ScalarFunctionExpr { let args = self .args .iter() - .map(|e| e.evaluate(batch)) + .map(|e| match e.as_any().downcast_ref::() { + Some(_) => Ok(ColumnarValue::Scalar(ScalarValue::Null)), + None => Ok(e.evaluate(batch)?), + }) .collect::>>()?; let arg_fields = self @@ -274,6 +296,111 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .all(|arg| matches!(arg, ColumnarValue::Scalar(_))); + let lambdas = if self.args.iter().any(|arg| arg.as_any().is::()) { + let args_metadata = std::iter::zip(&self.args, &arg_fields) + .map( + |(expr, field)| match expr.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + ValueOrLambdaParameter::Lambda(lambda.params(), captures) + } + None => ValueOrLambdaParameter::Value(Arc::clone(field)), + }, + ) + .collect::>(); + + let params = self.fun().inner().lambdas_parameters(&args_metadata)?; + + let lambdas = std::iter::zip(&self.args, params) + .map(|(arg, lambda_params)| { + arg.as_any() + .downcast_ref::() + .map(|lambda| { + let mut indices = HashSet::new(); + + arg.apply_with_lambdas_params(|expr, lambdas_params| { + if let Some(column) = + expr.as_any().downcast_ref::() + { + if !lambdas_params.contains(column.name()) { + indices.insert( + column.index(), //batch + // .schema_ref() + // .index_of(column.name())?, + ); + } + } + + Ok(TreeNodeRecursion::Continue) + })?; + + //let mut indices = indices.into_iter().collect::>(); + + //indices.sort_unstable(); + + let params = + std::iter::zip(lambda.params(), lambda_params.unwrap()) + .map(|(name, param)| Arc::new(param.with_name(name))) + .collect(); + + let captures = if !indices.is_empty() { + let (fields, columns): (Vec<_>, _) = std::iter::zip( + batch.schema_ref().fields(), + batch.columns(), + ) + .enumerate() + .map(|(column_index, (field, column))| { + if indices.contains(&column_index) { + (Arc::clone(field), Arc::clone(column)) + } else { + ( + Arc::new(Field::new( + field.name(), + DataType::Null, + false, + )), + Arc::new(NullArray::new(column.len())) as _, + ) + } + }) + .unzip(); + + let schema = Arc::new(Schema::new(fields)); + + Some(RecordBatch::try_new(schema, columns)?) + //Some(batch.project(&indices)?) + } else { + None + }; + + Ok(ScalarFunctionLambdaArg { + params, + body: Arc::clone(lambda.body()), + captures, + }) + }) + .transpose() + }) + .collect::>>()?; + + Some(lambdas) + } else { + None + }; + // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, @@ -281,6 +408,7 @@ impl PhysicalExpr for ScalarFunctionExpr { number_rows: batch.num_rows(), return_field: Arc::clone(&self.return_field), config_options: Arc::clone(&self.config_options), + lambdas, })?; if let ColumnarValue::Array(array) = &output { @@ -365,14 +493,377 @@ impl PhysicalExpr for ScalarFunctionExpr { } } +pub fn lambdas_schemas_from_args<'a>( + fun: &ScalarUDF, + args: &[Arc], + schema: &'a Schema, +) -> Result>> { + let args_metadata = args + .iter() + .map(|e| match e.as_any().downcast_ref::() { + Some(lambda) => { + let mut captures = false; + + e.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + captures = true; + + Ok(TreeNodeRecursion::Stop) + } + _ => Ok(TreeNodeRecursion::Continue), + } + }) + .unwrap(); + + Ok(ValueOrLambdaParameter::Lambda(lambda.params(), captures)) + } + None => Ok(ValueOrLambdaParameter::Value(e.return_field(schema)?)), + }) + .collect::>>()?; + + /*let captures = args + .iter() + .map(|arg| { + if arg.as_any().is::() { + let mut columns = HashSet::new(); + + arg.apply_with_lambdas_params(|n, lambdas_params| { + if let Some(column) = n.as_any().downcast_ref::() { + if !lambdas_params.contains(column.name()) { + columns.insert(schema.index_of(column.name())?); + } + // columns.insert(column.index()); + } + + Ok(TreeNodeRecursion::Continue) + })?; + + Ok(columns) + } else { + Ok(HashSet::new()) + } + }) + .collect::>>()?; */ + + fun.arguments_arrow_schema(&args_metadata, schema) +} + +pub trait PhysicalExprExt: Sized { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + f: F, + ) -> Result; + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + f: F, + ) -> Result; + + fn transform_down_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result>; + + fn transform_with_schema Result>>( + self, + schema: &Schema, + f: F, + ) -> Result> { + self.transform_up_with_schema(schema, f) + } + + fn transform_down_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_up_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result>; + + fn transform_with_lambdas_params( + self, + f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + self.transform_up_with_lambdas_params(f) + } +} + +impl PhysicalExprExt for Arc { + fn apply_with_lambdas_params< + 'n, + F: FnMut(&'n Self, &HashSet<&'n str>) -> Result, + >( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_params_impl< + 'n, + F: FnMut( + &'n Arc, + &HashSet<&'n str>, + ) -> Result, + >( + node: &'n Arc, + args: &HashSet<&'n str>, + f: &mut F, + ) -> Result { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut args = args.clone(); + + args.extend(lambda.params().iter().map(|v| v.as_str())); + + f(node, &args)?.visit_children(|| { + node.apply_children(|c| { + apply_with_lambdas_params_impl(c, &args, f) + }) + }) + } + _ => f(node, args)?.visit_children(|| { + node.apply_children(|c| apply_with_lambdas_params_impl(c, args, f)) + }), + } + } + + apply_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn apply_with_schema<'n, F: FnMut(&'n Self, &Schema) -> Result>( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_with_lambdas_impl< + 'n, + F: FnMut(&'n Arc, &Schema) -> Result, + >( + node: &'n Arc, + schema: &Schema, + f: &mut F, + ) -> Result { + f(node, schema)?.visit_children(|| { + node.apply_children_with_schema(schema, |c, schema| { + apply_with_lambdas_impl(c, schema, f) + }) + }) + } + + apply_with_lambdas_impl(self, schema, &mut f) + } + + fn apply_children_with_schema< + 'n, + F: FnMut(&'n Self, &Schema) -> Result, + >( + &'n self, + schema: &Schema, + mut f: F, + ) -> Result { + match self.as_any().downcast_ref::() { + Some(scalar_function) + if scalar_function + .args() + .iter() + .any(|arg| arg.as_any().is::()) => + { + let mut lambdas_schemas = lambdas_schemas_from_args( + scalar_function.fun(), + scalar_function.args(), + schema, + )? + .into_iter(); + + self.apply_children(|expr| f(expr, &lambdas_schemas.next().unwrap())) + } + _ => self.apply_children(|e| f(e, schema)), + } + } + + fn transform_down_with_schema< + F: FnMut(Self, &Schema) -> Result>, + >( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + f(node, schema)?.transform_children(|node| { + map_children_with_schema(node, schema, |n, schema| { + transform_down_with_schema_impl(n, schema, f) + }) + }) + } + + transform_down_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_schema Result>>( + self, + schema: &Schema, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_schema_impl< + F: FnMut( + Arc, + &Schema, + ) -> Result>>, + >( + node: Arc, + schema: &Schema, + f: &mut F, + ) -> Result>> { + map_children_with_schema(node, schema, |n, schema| { + transform_up_with_schema_impl(n, schema, f) + })? + .transform_parent(|n| f(n, schema)) + } + + transform_up_with_schema_impl(self, schema, &mut f) + } + + fn transform_up_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + map_children_with_lambdas_params(node, params, |n, params| { + transform_up_with_lambdas_params_impl(n, params, f) + })? + .transform_parent(|n| f(n, params)) + } + + transform_up_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } + + fn transform_down_with_lambdas_params( + self, + mut f: impl FnMut(Self, &HashSet) -> Result>, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_with_lambdas_params_impl< + F: FnMut( + Arc, + &HashSet, + ) -> Result>>, + >( + node: Arc, + params: &HashSet, + f: &mut F, + ) -> Result>> { + f(node, params)?.transform_children(|node| { + map_children_with_lambdas_params(node, params, |node, args| { + transform_down_with_lambdas_params_impl(node, args, f) + }) + }) + } + + transform_down_with_lambdas_params_impl(self, &HashSet::new(), &mut f) + } +} + +fn map_children_with_schema( + node: Arc, + schema: &Schema, + mut f: impl FnMut( + Arc, + &Schema, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(fun) if fun.args().iter().any(|arg| arg.as_any().is::()) => { + let mut args_schemas = + lambdas_schemas_from_args(fun.fun(), fun.args(), schema)?.into_iter(); + + node.map_children(|node| f(node, &args_schemas.next().unwrap())) + } + _ => node.map_children(|node| f(node, schema)), + } +} + +fn map_children_with_lambdas_params( + node: Arc, + params: &HashSet, + mut f: impl FnMut( + Arc, + &HashSet, + ) -> Result>>, +) -> Result>> { + match node.as_any().downcast_ref::() { + Some(lambda) => { + let mut params = params.clone(); + + params.extend(lambda.params().iter().cloned()); + + node.map_children(|node| f(node, ¶ms)) + } + None => node.map_children(|node| f(node, params)), + } +} + #[cfg(test)] mod tests { + use std::any::Any; + use std::{borrow::Cow, sync::Arc}; + use super::*; + use super::{lambdas_schemas_from_args, PhysicalExprExt}; use crate::expressions::Column; + use crate::{create_physical_expr, ScalarFunctionExpr}; use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_common::{tree_node::TreeNodeRecursion, DFSchema, HashSet, Result}; + use datafusion_expr::{ + col, expr::Lambda, Expr, ScalarFunctionArgs, ValueOrLambdaParameter, Volatility, + }; use datafusion_expr::{ScalarUDF, ScalarUDFImpl, Signature}; + use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_physical_expr_common::physical_expr::is_volatile; - use std::any::Any; + use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// Test helper to create a mock UDF with a specific volatility #[derive(Debug, PartialEq, Eq, Hash)] @@ -444,4 +935,190 @@ mod tests { let stable_arc: Arc = Arc::new(stable_expr); assert!(!is_volatile(&stable_arc)); } + + fn list_list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::new_list(DataType::Int32, false), false), + false, + )]) + } + + fn list_int() -> Schema { + Schema::new(vec![Field::new( + "v", + DataType::new_list(DataType::Int32, false), + false, + )]) + } + + fn int() -> Schema { + Schema::new(vec![Field::new("v", DataType::Int32, false)]) + } + + fn array_transform_udf() -> ScalarUDF { + ScalarUDF::new_from_impl(ArrayTransformFunc::new()) + } + + fn args() -> Vec { + vec![ + col("v"), + Expr::Lambda(Lambda::new( + vec!["v".into()], + array_transform_udf().call(vec![ + col("v"), + Expr::Lambda(Lambda::new(vec!["v".into()], -col("v"))), + ]), + )), + ] + } + + // array_transform(v, |v| -> array_transform(v, |v| -> -v)) + fn array_transform() -> Arc { + let e = array_transform_udf().call(args()); + + create_physical_expr( + &e, + &DFSchema::try_from(list_list_int()).unwrap(), + &Default::default(), + ) + .unwrap() + } + + #[derive(Debug, PartialEq, Eq, Hash)] + struct ArrayTransformFunc { + signature: Signature, + } + + impl ArrayTransformFunc { + pub fn new() -> Self { + Self { + signature: Signature::any(2, Volatility::Immutable), + } + } + } + + impl ScalarUDFImpl for ArrayTransformFunc { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn name(&self) -> &str { + "array_transform" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + Ok(arg_types[0].clone()) + } + + fn lambdas_parameters( + &self, + args: &[ValueOrLambdaParameter], + ) -> Result>>> { + let ValueOrLambdaParameter::Value(value_field) = &args[0] else { + unimplemented!() + }; + let DataType::List(field) = value_field.data_type() else { + unimplemented!() + }; + + Ok(vec![ + None, + Some(vec![Field::new( + "", + field.data_type().clone(), + field.is_nullable(), + )]), + ]) + } + + fn invoke_with_args(&self, _args: ScalarFunctionArgs) -> Result { + unimplemented!() + } + } + + #[test] + fn test_lambdas_schemas_from_args() { + let schema = list_list_int(); + let expr = array_transform(); + + let args = expr + .as_any() + .downcast_ref::() + .unwrap() + .args(); + + let schemas = + lambdas_schemas_from_args(&array_transform_udf(), args, &schema).unwrap(); + + assert_eq!(schemas, &[Cow::Borrowed(&schema), Cow::Owned(list_int())]); + } + + #[test] + fn test_apply_with_schema() { + let mut steps = vec![]; + + array_transform() + .apply_with_schema(&list_list_int(), |node, schema| { + steps.push((node.to_string(), schema.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + list_list_int(), + ), + ("(v) -> array_transform(v@0, (v) -> (- v@0))", list_int()), + ("array_transform(v@0, (v) -> (- v@0))", list_int()), + ("(v) -> (- v@0)", int()), + ("(- v@0)", int()), + ("v@0", int()), + ("v@0", int()), + ("v@0", int()), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } + + #[test] + fn test_apply_with_lambdas_params() { + let array_transform = array_transform(); + let mut steps = vec![]; + + array_transform + .apply_with_lambdas_params(|node, params| { + steps.push((node.to_string(), params.clone())); + + Ok(TreeNodeRecursion::Continue) + }) + .unwrap(); + + let expected = [ + ( + "array_transform(v@0, (v) -> array_transform(v@0, (v) -> (- v@0)))", + HashSet::from(["v"]), + ), + ( + "(v) -> array_transform(v@0, (v) -> (- v@0))", + HashSet::from(["v"]), + ), + ("array_transform(v@0, (v) -> (- v@0))", HashSet::from(["v"])), + ("(v) -> (- v@0)", HashSet::from(["v"])), + ("(- v@0)", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ("v@0", HashSet::from(["v"])), + ] + .map(|(a, b)| (String::from(a), b)); + + assert_eq!(steps, expected); + } } diff --git a/datafusion/physical-expr/src/simplifier/mod.rs b/datafusion/physical-expr/src/simplifier/mod.rs index 80d6ee0a7b91..dd7e6e314672 100644 --- a/datafusion/physical-expr/src/simplifier/mod.rs +++ b/datafusion/physical-expr/src/simplifier/mod.rs @@ -19,12 +19,12 @@ use arrow::datatypes::Schema; use datafusion_common::{ - tree_node::{Transformed, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, Result, }; use std::sync::Arc; -use crate::PhysicalExpr; +use crate::{PhysicalExpr, PhysicalExprExt}; pub mod unwrap_cast; @@ -48,6 +48,22 @@ impl<'a> PhysicalExprSimplifier<'a> { &mut self, expr: Arc, ) -> Result> { + return expr + .transform_up_with_schema(self.schema, |node, schema| { + // Apply unwrap cast optimization + #[cfg(test)] + let original_type = node.data_type(schema).unwrap(); + let unwrapped = unwrap_cast::unwrap_cast_in_comparison(node, schema)?; + #[cfg(test)] + assert_eq!( + unwrapped.data.data_type(schema).unwrap(), + original_type, + "Simplified expression should have the same data type as the original" + ); + Ok(unwrapped) + }) + .data(); + Ok(expr.rewrite(self)?.data) } } diff --git a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs index d409ce9cb5bf..1ccfc1cfe84d 100644 --- a/datafusion/physical-expr/src/simplifier/unwrap_cast.rs +++ b/datafusion/physical-expr/src/simplifier/unwrap_cast.rs @@ -34,22 +34,22 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Schema}; -use datafusion_common::{ - tree_node::{Transformed, TreeNode}, - Result, ScalarValue, -}; +use datafusion_common::{tree_node::Transformed, Result, ScalarValue}; use datafusion_expr::Operator; use datafusion_expr_common::casts::try_cast_literal_to_type; -use crate::expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}; use crate::PhysicalExpr; +use crate::{ + expressions::{lit, BinaryExpr, CastExpr, Literal, TryCastExpr}, + PhysicalExprExt, +}; /// Attempts to unwrap casts in comparison expressions. pub(crate) fn unwrap_cast_in_comparison( expr: Arc, schema: &Schema, ) -> Result>> { - expr.transform_down(|e| { + expr.transform_down_with_schema(schema, |e, schema| { if let Some(binary) = e.as_any().downcast_ref::() { if let Some(unwrapped) = try_unwrap_cast_binary(binary, schema)? { return Ok(Transformed::yes(unwrapped)); diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 745ae855efee..92ecbb7176dc 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -22,6 +22,7 @@ use std::borrow::Borrow; use std::sync::Arc; use crate::expressions::{BinaryExpr, Column}; +use crate::scalar_function::PhysicalExprExt; use crate::tree_node::ExprContext; use crate::PhysicalExpr; use crate::PhysicalSortExpr; @@ -227,9 +228,11 @@ where /// Recursively extract referenced [`Column`]s within a [`PhysicalExpr`]. pub fn collect_columns(expr: &Arc) -> HashSet { let mut columns = HashSet::::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - columns.get_or_insert_owned(column); + if !lambdas_params.contains(column.name()) { + columns.get_or_insert_owned(column); + } } Ok(TreeNodeRecursion::Continue) }) @@ -251,14 +254,16 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down(|expr| { + expr.transform_down_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - let index = schema.index_of(column.name())?; + if !lambdas_params.contains(column.name()) { + let index = schema.index_of(column.name())?; - return Ok(Transformed::yes(Arc::new(Column::new( - column.name(), - index, - )))); + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + index, + )))); + } } Ok(Transformed::no(expr)) }) diff --git a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs index 6e4e78486612..d87e00194641 100644 --- a/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs +++ b/datafusion/physical-optimizer/src/enforce_sorting/sort_pushdown.rs @@ -29,7 +29,7 @@ use datafusion_expr::JoinType; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::utils::collect_columns; use datafusion_physical_expr::{ - add_offset_to_physical_sort_exprs, EquivalenceProperties, + add_offset_to_physical_sort_exprs, EquivalenceProperties, PhysicalExprExt, }; use datafusion_physical_expr_common::sort_expr::{ LexOrdering, LexRequirement, OrderingRequirements, PhysicalSortExpr, @@ -661,20 +661,21 @@ fn handle_custom_pushdown( .into_iter() .map(|req| { let child_schema = plan.children()[maintained_child_idx].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let new_index = col.index() - sub_offset; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(new_index).name(), - new_index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let new_index = col.index() - sub_offset; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(new_index).name(), + new_index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; @@ -742,20 +743,21 @@ fn handle_hash_join( .into_iter() .map(|req| { let child_schema = plan.children()[1].schema(); - let updated_columns = req - .expr - .transform_up(|expr| { - if let Some(col) = expr.as_any().downcast_ref::() { - let index = projected_indices[col.index()].index; - Ok(Transformed::yes(Arc::new(Column::new( - child_schema.field(index).name(), - index, - )))) - } else { - Ok(Transformed::no(expr)) - } - })? - .data; + let updated_columns = + req.expr + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } + _ => Ok(Transformed::no(expr)), + } + })? + .data; Ok(PhysicalSortRequirement::new(updated_columns, req.options)) }) .collect::>>()?; diff --git a/datafusion/physical-optimizer/src/projection_pushdown.rs b/datafusion/physical-optimizer/src/projection_pushdown.rs index 987e3cb6f713..8ed81d3874d6 100644 --- a/datafusion/physical-optimizer/src/projection_pushdown.rs +++ b/datafusion/physical-optimizer/src/projection_pushdown.rs @@ -23,6 +23,7 @@ use crate::PhysicalOptimizerRule; use arrow::datatypes::{Fields, Schema, SchemaRef}; use datafusion_common::alias::AliasGenerator; +use datafusion_physical_expr::PhysicalExprExt; use std::collections::HashSet; use std::sync::Arc; @@ -243,9 +244,11 @@ fn minimize_join_filter( rhs_schema: &Schema, ) -> JoinFilter { let mut used_columns = HashSet::new(); - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Some(col) = expr.as_any().downcast_ref::() { - used_columns.insert(col.index()); + if !lambdas_params.contains(col.name()) { + used_columns.insert(col.index()); + } } Ok(TreeNodeRecursion::Continue) }) @@ -267,17 +270,19 @@ fn minimize_join_filter( .collect::(); let final_expr = expr - .transform_up(|expr| match expr.as_any().downcast_ref::() { - None => Ok(Transformed::no(expr)), - Some(column) => { - let new_idx = used_columns - .iter() - .filter(|idx| **idx < column.index()) - .count(); - let new_column = Column::new(column.name(), new_idx); - Ok(Transformed::yes( - Arc::new(new_column) as Arc - )) + .transform_up_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_idx = used_columns + .iter() + .filter(|idx| **idx < column.index()) + .count(); + let new_column = Column::new(column.name(), new_idx); + Ok(Transformed::yes( + Arc::new(new_column) as Arc + )) + } + _ => Ok(Transformed::no(expr)), } }) .expect("Closure cannot fail"); @@ -380,10 +385,9 @@ impl<'a> JoinFilterRewriter<'a> { // First, add a new projection. The expression must be rewritten, as it is no longer // executed against the filter schema. let new_idx = self.join_side_projections.len(); - let rewritten_expr = expr.transform_up(|expr| { + let rewritten_expr = expr.transform_up_with_lambdas_params(|expr, lambdas_params| { Ok(match expr.as_any().downcast_ref::() { - None => Transformed::no(expr), - Some(column) => { + Some(column) if !lambdas_params.contains(column.name()) => { let intermediate_column = &self.intermediate_column_indices[column.index()]; assert_eq!(intermediate_column.side, self.join_side); @@ -393,6 +397,7 @@ impl<'a> JoinFilterRewriter<'a> { let new_column = Column::new(field.name(), join_side_index); Transformed::yes(Arc::new(new_column) as Arc) } + _ => Transformed::no(expr), }) })?; self.join_side_projections.push((rewritten_expr.data, name)); @@ -415,15 +420,17 @@ impl<'a> JoinFilterRewriter<'a> { join_side: JoinSide, ) -> Result { let mut result = false; - expr.apply(|expr| match expr.as_any().downcast_ref::() { - None => Ok(TreeNodeRecursion::Continue), - Some(c) => { - let column_index = &self.intermediate_column_indices[c.index()]; - if column_index.side == join_side { - result = true; - return Ok(TreeNodeRecursion::Stop); + expr.apply_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(c) if !lambdas_params.contains(c.name()) => { + let column_index = &self.intermediate_column_indices[c.index()]; + if column_index.side == join_side { + result = true; + return Ok(TreeNodeRecursion::Stop); + } + Ok(TreeNodeRecursion::Continue) } - Ok(TreeNodeRecursion::Continue) + _ => Ok(TreeNodeRecursion::Continue), } })?; diff --git a/datafusion/physical-plan/src/async_func.rs b/datafusion/physical-plan/src/async_func.rs index 54a76e0ebb97..be72a6af2b50 100644 --- a/datafusion/physical-plan/src/async_func.rs +++ b/datafusion/physical-plan/src/async_func.rs @@ -22,13 +22,13 @@ use crate::{ }; use arrow::array::RecordBatch; use arrow_schema::{Fields, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{Transformed, TreeNodeRecursion}; use datafusion_common::{internal_err, Result}; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; use datafusion_physical_expr::async_scalar_function::AsyncFuncExpr; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::expressions::Column; -use datafusion_physical_expr::ScalarFunctionExpr; +use datafusion_physical_expr::{PhysicalExprExt, ScalarFunctionExpr}; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; use futures::stream::StreamExt; use log::trace; @@ -249,7 +249,7 @@ impl AsyncMapper { schema: &Schema, ) -> Result<()> { // recursively look for references to async functions - physical_expr.apply(|expr| { + physical_expr.apply_with_schema(schema, |expr, schema| { if let Some(scalar_func_expr) = expr.as_any().downcast_ref::() { diff --git a/datafusion/physical-plan/src/joins/stream_join_utils.rs b/datafusion/physical-plan/src/joins/stream_join_utils.rs index 80221a77992c..b70a8f60508a 100644 --- a/datafusion/physical-plan/src/joins/stream_join_utils.rs +++ b/datafusion/physical-plan/src/joins/stream_join_utils.rs @@ -35,7 +35,7 @@ use arrow::array::{ }; use arrow::compute::concat_batches; use arrow::datatypes::{ArrowNativeType, Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{Transformed, TransformedResult}; use datafusion_common::utils::memory::estimate_memory_size; use datafusion_common::{ arrow_datafusion_err, DataFusionError, HashSet, JoinSide, Result, ScalarValue, @@ -44,7 +44,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr::intervals::cp_solver::ExprIntervalGraph; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr}; +use datafusion_physical_expr::{PhysicalExpr, PhysicalExprExt, PhysicalSortExpr}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use hashbrown::HashTable; @@ -312,13 +312,13 @@ pub fn convert_sort_expr_with_filter_schema( // Since we are sure that one to one column mapping includes all columns, we convert // the sort expression into a filter expression. let converted_filter_expr = expr - .transform_up(|p| { - convert_filter_columns(p.as_ref(), &column_map).map(|transformed| { - match transformed { + .transform_up_with_lambdas_params(|p, lambdas_params| { + convert_filter_columns(p.as_ref(), &column_map, lambdas_params).map( + |transformed| match transformed { Some(transformed) => Transformed::yes(transformed), None => Transformed::no(p), - } - }) + }, + ) }) .data()?; // Search the converted `PhysicalExpr` in filter expression; if an exact @@ -361,14 +361,17 @@ pub fn build_filter_input_order( fn convert_filter_columns( input: &dyn PhysicalExpr, column_map: &HashMap, + lambdas_params: &HashSet, ) -> Result>> { // Attempt to downcast the input expression to a Column type. - Ok(if let Some(col) = input.as_any().downcast_ref::() { - // If the downcast is successful, retrieve the corresponding filter column. - column_map.get(col).map(|c| Arc::new(c.clone()) as _) - } else { - // If the downcast fails, return the input expression as is. - None + Ok(match input.as_any().downcast_ref::() { + Some(col) if !lambdas_params.contains(col.name()) => { + column_map.get(col).map(|c| Arc::new(c.clone()) as _) + } + _ => { + // If the downcast fails, return the input expression as is. + None + } }) } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index ead2196860cd..ab654e4eee1d 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -42,14 +42,13 @@ use std::task::{Context, Poll}; use arrow::datatypes::SchemaRef; use arrow::record_batch::{RecordBatch, RecordBatchOptions}; use datafusion_common::config::ConfigOptions; -use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNodeRecursion}; use datafusion_common::{internal_err, JoinSide, Result}; use datafusion_execution::TaskContext; use datafusion_physical_expr::equivalence::ProjectionMapping; use datafusion_physical_expr::utils::collect_columns; -use datafusion_physical_expr_common::physical_expr::{fmt_sql, PhysicalExprRef}; +use datafusion_physical_expr::{PhysicalExprExt, PhysicalExprRef}; +use datafusion_physical_expr_common::physical_expr::fmt_sql; use datafusion_physical_expr_common::sort_expr::{LexOrdering, LexRequirement}; // Re-exported from datafusion-physical-expr for backwards compatibility // We recommend updating your imports to use datafusion-physical-expr directly @@ -866,10 +865,12 @@ fn try_unifying_projections( projection.expr().iter().for_each(|proj_expr| { proj_expr .expr - .apply(|expr| { + .apply_with_lambdas_params(|expr, lambdas_params| { Ok({ if let Some(column) = expr.as_any().downcast_ref::() { - *column_ref_map.entry(column.clone()).or_default() += 1; + if !lambdas_params.contains(column.name()) { + *column_ref_map.entry(column.clone()).or_default() += 1; + } } TreeNodeRecursion::Continue }) @@ -957,31 +958,31 @@ fn new_columns_for_join_on( .filter_map(|on| { // Rewrite all columns in `on` Arc::clone(*on) - .transform(|expr| { - if let Some(column) = expr.as_any().downcast_ref::() { - // Find the column in the projection expressions - let new_column = projection_exprs - .iter() - .enumerate() - .find(|(_, (proj_column, _))| { - column.name() == proj_column.name() - && column.index() + column_index_offset - == proj_column.index() - }) - .map(|(index, (_, alias))| Column::new(alias, index)); - if let Some(new_column) = new_column { - Ok(Transformed::yes(Arc::new(new_column))) - } else { - // If the column is not found in the projection expressions, - // it means that the column is not projected. In this case, - // we cannot push the projection down. - internal_err!( - "Column {:?} not found in projection expressions", - column - ) + .transform_with_lambdas_params(|expr, lambdas_params| { + match expr.as_any().downcast_ref::() { + Some(column) if !lambdas_params.contains(column.name()) => { + let new_column = projection_exprs + .iter() + .enumerate() + .find(|(_, (proj_column, _))| { + column.name() == proj_column.name() + && column.index() + column_index_offset + == proj_column.index() + }) + .map(|(index, (_, alias))| Column::new(alias, index)); + if let Some(new_column) = new_column { + Ok(Transformed::yes(Arc::new(new_column))) + } else { + // If the column is not found in the projection expressions, + // it means that the column is not projected. In this case, + // we cannot push the projection down. + internal_err!( + "Column {:?} not found in projection expressions", + column + ) + } } - } else { - Ok(Transformed::no(expr)) + _ => Ok(Transformed::no(expr)), } }) .data() diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2774b5b6ba7c..b87a50b3f528 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -622,6 +622,11 @@ pub fn serialize_expr( .unwrap_or(HashMap::new()), })), }, + Expr::Lambda { .. } => { + return Err(Error::General( + "Proto serialization error: Lambda not supported".to_string(), + )) + } }; Ok(expr_node) diff --git a/datafusion/pruning/src/pruning_predicate.rs b/datafusion/pruning/src/pruning_predicate.rs index 380ada10df6e..c9df93f8b693 100644 --- a/datafusion/pruning/src/pruning_predicate.rs +++ b/datafusion/pruning/src/pruning_predicate.rs @@ -38,13 +38,14 @@ use datafusion_common::error::Result; use datafusion_common::tree_node::TransformedResult; use datafusion_common::{ internal_datafusion_err, internal_err, plan_datafusion_err, plan_err, - tree_node::{Transformed, TreeNode}, - ScalarValue, + tree_node::Transformed, ScalarValue, }; use datafusion_common::{Column, DFSchema}; use datafusion_expr_common::operator::Operator; use datafusion_physical_expr::utils::{collect_columns, Guarantee, LiteralGuarantee}; -use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; +use datafusion_physical_expr::{ + expressions as phys_expr, PhysicalExprExt, PhysicalExprRef, +}; use datafusion_physical_expr_common::physical_expr::snapshot_physical_expr; use datafusion_physical_plan::{ColumnarValue, PhysicalExpr}; @@ -1204,9 +1205,9 @@ fn rewrite_column_expr( column_old: &phys_expr::Column, column_new: &phys_expr::Column, ) -> Result> { - e.transform(|expr| { + e.transform_with_lambdas_params(|expr, lambdas_params| { if let Some(column) = expr.as_any().downcast_ref::() { - if column == column_old { + if !lambdas_params.contains(column.name()) && column == column_old { return Ok(Transformed::yes(Arc::new(column_new.clone()))); } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 50e479af3620..c13fd33104eb 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -22,10 +22,10 @@ use datafusion_common::{ internal_datafusion_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Diagnostic, Result, Span, }; -use datafusion_expr::expr::{ - NullTreatment, ScalarFunction, Unnest, WildcardOptions, WindowFunction, -}; -use datafusion_expr::planner::{PlannerResult, RawAggregateExpr, RawWindowExpr}; +use datafusion_expr::expr::{Lambda, ScalarFunction, Unnest}; +use datafusion_expr::expr::{NullTreatment, WildcardOptions, WindowFunction}; +use datafusion_expr::planner::PlannerResult; +use datafusion_expr::planner::{RawAggregateExpr, RawWindowExpr}; use datafusion_expr::{expr, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition}; use sqlparser::ast::{ DuplicateTreatment, Expr as SQLExpr, Function as SQLFunction, FunctionArg, @@ -724,6 +724,26 @@ impl SqlToRel<'_, S> { let arg_name = crate::utils::normalize_ident(name); Ok((expr, Some(arg_name))) } + FunctionArg::Unnamed(FunctionArgExpr::Expr(SQLExpr::Lambda( + sqlparser::ast::LambdaFunction { params, body }, + ))) => { + let params = params + .into_iter() + .map(|v| v.to_string()) + .collect::>(); + + Ok(( + Expr::Lambda(Lambda { + params: params.clone(), + body: Box::new(self.sql_expr_to_logical_expr( + *body, + schema, + &mut planner_context.clone().with_lambda_parameters(params), + )?), + }), + None, + )) + } FunctionArg::Unnamed(FunctionArgExpr::Expr(arg)) => { let expr = self.sql_expr_to_logical_expr(arg, schema, planner_context)?; Ok((expr, None)) diff --git a/datafusion/sql/src/expr/identifier.rs b/datafusion/sql/src/expr/identifier.rs index 3c57d195ade6..dc39cb4de055 100644 --- a/datafusion/sql/src/expr/identifier.rs +++ b/datafusion/sql/src/expr/identifier.rs @@ -53,6 +53,19 @@ impl SqlToRel<'_, S> { // identifier. (e.g. it is "foo.bar" not foo.bar) let normalize_ident = self.ident_normalizer.normalize(id); + if planner_context + .lambdas_parameters() + .contains(&normalize_ident) + { + let mut column = Column::new_unqualified(normalize_ident); + if self.options.collect_spans { + if let Some(span) = Span::try_from_sqlparser_span(id_span) { + column.spans_mut().add_span(span); + } + } + return Ok(Expr::Column(column)); + } + // Check for qualified field with unqualified name if let Ok((qualifier, _)) = schema.qualified_field_with_unqualified_name(normalize_ident.as_str()) diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 7bac0337672d..2992378fd1d6 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -28,13 +28,11 @@ use datafusion_common::datatype::{DataTypeExt, FieldExt}; use datafusion_common::error::add_possible_columns_to_diag; use datafusion_common::TableReference; use datafusion_common::{ - field_not_found, internal_err, plan_datafusion_err, DFSchemaRef, Diagnostic, - SchemaError, + field_not_found, plan_datafusion_err, DFSchemaRef, Diagnostic, HashSet, SchemaError, }; use datafusion_common::{not_impl_err, plan_err, DFSchema, DataFusionError, Result}; use datafusion_expr::logical_plan::{LogicalPlan, LogicalPlanBuilder}; pub use datafusion_expr::planner::ContextProvider; -use datafusion_expr::utils::find_column_exprs; use datafusion_expr::{col, Expr}; use sqlparser::ast::{ArrayElemTypeDef, ExactNumberInfo, TimezoneInfo}; use sqlparser::ast::{ColumnDef as SQLColumnDef, ColumnOption}; @@ -267,6 +265,8 @@ pub struct PlannerContext { outer_from_schema: Option, /// The query schema defined by the table create_table_schema: Option, + /// The lambda introduced columns names + lambdas_parameters: HashSet, } impl Default for PlannerContext { @@ -284,6 +284,7 @@ impl PlannerContext { outer_query_schema: None, outer_from_schema: None, create_table_schema: None, + lambdas_parameters: HashSet::new(), } } @@ -370,6 +371,19 @@ impl PlannerContext { self.ctes.get(cte_name).map(|cte| cte.as_ref()) } + pub fn lambdas_parameters(&self) -> &HashSet { + &self.lambdas_parameters + } + + pub fn with_lambda_parameters( + mut self, + arguments: impl IntoIterator, + ) -> Self { + self.lambdas_parameters.extend(arguments); + + self + } + /// Remove the plan of CTE / Subquery for the specified name pub(super) fn remove_cte(&mut self, cte_name: &str) { self.ctes.remove(cte_name); @@ -531,10 +545,11 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { schema: &DFSchema, exprs: &[Expr], ) -> Result<()> { - find_column_exprs(exprs) + exprs .iter() - .try_for_each(|col| match col { - Expr::Column(col) => match &col.relation { + .flat_map(|expr| expr.column_refs()) + .try_for_each(|col| { + match &col.relation { Some(r) => schema.field_with_qualified_name(r, &col.name).map(|_| ()), None => { if !schema.fields_with_unqualified_name(&col.name).is_empty() { @@ -584,8 +599,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { err.with_diagnostic(diagnostic) } _ => err, - }), - _ => internal_err!("Not a column"), + }) }) } diff --git a/datafusion/sql/src/select.rs b/datafusion/sql/src/select.rs index 42013a76a865..0e7490d2c780 100644 --- a/datafusion/sql/src/select.rs +++ b/datafusion/sql/src/select.rs @@ -540,9 +540,11 @@ impl SqlToRel<'_, S> { None => { let mut columns = HashSet::new(); for expr in &aggr_expr { - expr.apply(|expr| { + expr.apply_with_lambdas_params(|expr, lambdas_params| { if let Expr::Column(c) = expr { - columns.insert(Expr::Column(c.clone())); + if !c.is_lambda_parameter(lambdas_params) { + columns.insert(Expr::Column(c.clone())); + } } Ok(TreeNodeRecursion::Continue) }) diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 97f2b58bf840..67ca92bb1c1f 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -15,13 +15,14 @@ // specific language governing permissions and limitations // under the License. -use datafusion_expr::expr::{AggregateFunctionParams, Unnest, WindowFunctionParams}; +use datafusion_expr::expr::{AggregateFunctionParams, WindowFunctionParams}; +use datafusion_expr::expr::{Lambda, Unnest}; use sqlparser::ast::Value::SingleQuotedString; use sqlparser::ast::{ - self, Array, BinaryOperator, CaseWhen, DuplicateTreatment, Expr as AstExpr, Function, - Ident, Interval, ObjectName, OrderByOptions, Subscript, TimezoneInfo, UnaryOperator, - ValueWithSpan, + self, Array, BinaryOperator, Expr as AstExpr, Function, Ident, Interval, + LambdaFunction, ObjectName, Subscript, TimezoneInfo, UnaryOperator, }; +use sqlparser::ast::{CaseWhen, DuplicateTreatment, OrderByOptions, ValueWithSpan}; use std::sync::Arc; use std::vec; @@ -527,6 +528,14 @@ impl Unparser<'_> { } Expr::OuterReferenceColumn(_, col) => self.col_to_sql(col), Expr::Unnest(unnest) => self.unnest_to_sql(unnest), + Expr::Lambda(Lambda { params, body }) => { + Ok(ast::Expr::Lambda(LambdaFunction { + params: ast::OneOrManyWithParens::Many( + params.iter().map(|param| param.as_str().into()).collect(), + ), + body: Box::new(self.expr_to_sql_inner(body)?), + })) + } } } diff --git a/datafusion/sql/src/unparser/plan.rs b/datafusion/sql/src/unparser/plan.rs index e7535338b767..c218ce547b31 100644 --- a/datafusion/sql/src/unparser/plan.rs +++ b/datafusion/sql/src/unparser/plan.rs @@ -40,7 +40,7 @@ use crate::unparser::{ast::UnnestRelationBuilder, rewrite::rewrite_qualify}; use crate::utils::UNNEST_PLACEHOLDER; use datafusion_common::{ internal_err, not_impl_err, - tree_node::{TransformedResult, TreeNode}, + tree_node::TransformedResult, Column, DataFusionError, Result, ScalarValue, TableReference, }; use datafusion_expr::expr::OUTER_REFERENCE_COLUMN_PREFIX; @@ -1131,7 +1131,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } @@ -1197,7 +1197,7 @@ impl Unparser<'_> { .cloned() .map(|expr| { if let Some(ref mut rewriter) = alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/unparser/rewrite.rs b/datafusion/sql/src/unparser/rewrite.rs index c961f1d6f1f0..58f443509551 100644 --- a/datafusion/sql/src/unparser/rewrite.rs +++ b/datafusion/sql/src/unparser/rewrite.rs @@ -20,10 +20,11 @@ use std::{collections::HashSet, sync::Arc}; use arrow::datatypes::Schema; use datafusion_common::tree_node::TreeNodeContainer; use datafusion_common::{ - tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter}, + tree_node::{Transformed, TransformedResult, TreeNode}, Column, HashMap, Result, TableReference, }; use datafusion_expr::expr::{Alias, UNNEST_COLUMN_PREFIX}; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::{Expr, LogicalPlan, Projection, Sort, SortExpr}; use sqlparser::ast::Ident; @@ -466,12 +467,17 @@ pub struct TableAliasRewriter<'a> { pub alias_name: TableReference, } -impl TreeNodeRewriter for TableAliasRewriter<'_> { +impl TreeNodeRewriterWithPayload for TableAliasRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a datafusion_common::HashSet; - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down( + &mut self, + expr: Expr, + lambdas_params: &datafusion_common::HashSet, + ) -> Result> { match expr { - Expr::Column(column) => { + Expr::Column(column) if !column.is_lambda_parameter(lambdas_params) => { if let Ok(field) = self.table_schema.field_with_name(&column.name) { let new_column = Column::new(Some(self.alias_name.clone()), field.name().clone()); diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 8b3791017a8a..f785f640dbce 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -161,11 +161,11 @@ pub(crate) fn find_window_nodes_within_select<'a>( /// For example, if expr contains the column expr "__unnest_placeholder(make_array(Int64(1),Int64(2),Int64(2),Int64(5),NULL),depth=1)" /// it will be transformed into an actual unnest expression UNNEST([1, 2, 2, 5, NULL]) pub(crate) fn unproject_unnest_expr(expr: Expr, unnest: &Unnest) -> Result { - expr.transform(|sub_expr| { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { if let Expr::Column(col_ref) = &sub_expr { // Check if the column is among the columns to run unnest on. // Currently, only List/Array columns (defined in `list_type_columns`) are supported for unnesting. - if unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { + if !col_ref.is_lambda_parameter(lambdas_params) && unnest.list_type_columns.iter().any(|e| e.1.output_column.name == col_ref.name) { if let Ok(idx) = unnest.schema.index_of_column(col_ref) { if let LogicalPlan::Projection(Projection { expr, .. }) = unnest.input.as_ref() { if let Some(unprojected_expr) = expr.get(idx) { @@ -195,22 +195,21 @@ pub(crate) fn unproject_agg_exprs( agg: &Aggregate, windows: Option<&[&Window]>, ) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { - if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { - Ok(Transformed::yes(unprojected_expr.clone())) - } else if let Some(unprojected_expr) = - windows.and_then(|w| find_window_expr(w, &c.name).cloned()) - { - // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected - Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) - } else { - internal_err!( - "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name - ) - } - } else { - Ok(Transformed::no(sub_expr)) + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| { + match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => if let Some(unprojected_expr) = find_agg_expr(agg, &c)? { + Ok(Transformed::yes(unprojected_expr.clone())) + } else if let Some(unprojected_expr) = + windows.and_then(|w| find_window_expr(w, &c.name).cloned()) + { + // Window function can contain an aggregation columns, e.g., 'avg(sum(ss_sales_price)) over ...' that needs to be unprojected + Ok(Transformed::yes(unproject_agg_exprs(unprojected_expr, agg, None)?)) + } else { + internal_err!( + "Tried to unproject agg expr for column '{}' that was not found in the provided Aggregate!", &c.name + ) + }, + _ => Ok(Transformed::no(sub_expr)), } }) .map(|e| e.data) @@ -222,16 +221,15 @@ pub(crate) fn unproject_agg_exprs( /// For example, if expr contains the column expr "COUNT(*) PARTITION BY id" it will be transformed /// into an actual window expression as identified in the window node. pub(crate) fn unproject_window_exprs(expr: Expr, windows: &[&Window]) -> Result { - expr.transform(|sub_expr| { - if let Expr::Column(c) = sub_expr { + expr.transform_up_with_lambdas_params(|sub_expr, lambdas_params| match sub_expr { + Expr::Column(c) if !c.is_lambda_parameter(lambdas_params) => { if let Some(unproj) = find_window_expr(windows, &c.name) { Ok(Transformed::yes(unproj.clone())) } else { Ok(Transformed::no(Expr::Column(c))) } - } else { - Ok(Transformed::no(sub_expr)) } + _ => Ok(Transformed::no(sub_expr)), }) .map(|e| e.data) } @@ -376,7 +374,7 @@ pub(crate) fn try_transform_to_simple_table_scan_with_filters( .cloned() .map(|expr| { if let Some(ref mut rewriter) = filter_alias_rewriter { - expr.rewrite(rewriter).data() + expr.rewrite_with_lambdas_params(rewriter).data() } else { Ok(expr) } diff --git a/datafusion/sql/src/utils.rs b/datafusion/sql/src/utils.rs index 3c86d2d04905..6380412e3b5e 100644 --- a/datafusion/sql/src/utils.rs +++ b/datafusion/sql/src/utils.rs @@ -23,16 +23,16 @@ use arrow::datatypes::{ DataType, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, DECIMAL_DEFAULT_SCALE, }; use datafusion_common::tree_node::{ - Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeRewriter, + Transformed, TransformedResult, TreeNode, TreeNodeRecursion, }; use datafusion_common::{ - exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchemaRef, - Diagnostic, HashMap, Result, ScalarValue, + exec_datafusion_err, exec_err, internal_err, plan_err, Column, DFSchema, Diagnostic, HashMap, Result, ScalarValue }; use datafusion_expr::builder::get_struct_unnested_columns; use datafusion_expr::expr::{ Alias, GroupingSet, Unnest, WindowFunction, WindowFunctionParams, }; +use datafusion_expr::tree_node::TreeNodeRewriterWithPayload; use datafusion_expr::utils::{expr_as_column_expr, find_column_exprs}; use datafusion_expr::{ col, expr_vec_fmt, ColumnUnnestList, Expr, ExprSchemable, LogicalPlan, @@ -44,9 +44,9 @@ use sqlparser::ast::{Ident, Value}; /// Make a best-effort attempt at resolving all columns in the expression tree pub(crate) fn resolve_columns(expr: &Expr, plan: &LogicalPlan) -> Result { expr.clone() - .transform_up(|nested_expr| { + .transform_up_with_lambdas_params(|nested_expr, lambdas_params| { match nested_expr { - Expr::Column(col) => { + Expr::Column(col) if !col.is_lambda_parameter(lambdas_params) => { let (qualifier, field) = plan.schema().qualified_field_from_column(&col)?; Ok(Transformed::yes(Expr::Column(Column::from(( @@ -81,6 +81,7 @@ pub(crate) fn rebase_expr( base_exprs: &[Expr], plan: &LogicalPlan, ) -> Result { + //todo user transform_down_with_lambdas_params expr.clone() .transform_down(|nested_expr| { if base_exprs.contains(&nested_expr) { @@ -231,8 +232,8 @@ pub(crate) fn resolve_aliases_to_exprs( expr: Expr, aliases: &HashMap, ) -> Result { - expr.transform_up(|nested_expr| match nested_expr { - Expr::Column(c) if c.relation.is_none() => { + expr.transform_up_with_lambdas_params(|nested_expr, lambdas_params| match nested_expr { + Expr::Column(c) if c.relation.is_none() && !c.is_lambda_parameter(lambdas_params) => { if let Some(aliased_expr) = aliases.get(&c.name) { Ok(Transformed::yes(aliased_expr.clone())) } else { @@ -371,7 +372,6 @@ This is only usedful when used with transform down up A full example of how the transformation works: */ struct RecursiveUnnestRewriter<'a> { - input_schema: &'a DFSchemaRef, root_expr: &'a Expr, // Useful to detect which child expr is a part of/ not a part of unnest operation top_most_unnest: Option, @@ -405,6 +405,7 @@ impl RecursiveUnnestRewriter<'_> { alias_name: String, expr_in_unnest: &Expr, struct_allowed: bool, + input_schema: &DFSchema, ) -> Result> { let inner_expr_name = expr_in_unnest.schema_name().to_string(); @@ -418,7 +419,7 @@ impl RecursiveUnnestRewriter<'_> { // column name as is, to comply with group by and order by let placeholder_column = Column::from_name(placeholder_name.clone()); - let (data_type, _) = expr_in_unnest.data_type_and_nullable(self.input_schema)?; + let (data_type, _) = expr_in_unnest.data_type_and_nullable(input_schema)?; match data_type { DataType::Struct(inner_fields) => { @@ -468,17 +469,18 @@ impl RecursiveUnnestRewriter<'_> { } } -impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { +impl TreeNodeRewriterWithPayload for RecursiveUnnestRewriter<'_> { type Node = Expr; + type Payload<'a> = &'a DFSchema; /// This downward traversal needs to keep track of: /// - Whether or not some unnest expr has been visited from the top util the current node /// - If some unnest expr has been visited, maintain a stack of such information, this /// is used to detect if some recursive unnest expr exists (e.g **unnest(unnest(unnest(3d column))))** - fn f_down(&mut self, expr: Expr) -> Result> { + fn f_down(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref unnest_expr) = expr { let (data_type, _) = - unnest_expr.expr.data_type_and_nullable(self.input_schema)?; + unnest_expr.expr.data_type_and_nullable(input_schema)?; self.consecutive_unnest.push(Some(unnest_expr.clone())); // if expr inside unnest is a struct, do not consider // the next unnest as consecutive unnest (if any) @@ -532,7 +534,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { /// column2 /// ``` /// - fn f_up(&mut self, expr: Expr) -> Result> { + fn f_up(&mut self, expr: Expr, input_schema: &DFSchema) -> Result> { if let Expr::Unnest(ref traversing_unnest) = expr { if traversing_unnest == self.top_most_unnest.as_ref().unwrap() { self.top_most_unnest = None; @@ -568,6 +570,7 @@ impl TreeNodeRewriter for RecursiveUnnestRewriter<'_> { expr.schema_name().to_string(), inner_expr, struct_allowed, + input_schema, )?; if struct_allowed { self.transformed_root_exprs = Some(transformed_exprs.clone()); @@ -619,7 +622,6 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( original_expr: &Expr, ) -> Result> { let mut rewriter = RecursiveUnnestRewriter { - input_schema: input.schema(), root_expr: original_expr, top_most_unnest: None, consecutive_unnest: vec![], @@ -641,7 +643,7 @@ pub(crate) fn rewrite_recursive_unnest_bottom_up( data: transformed_expr, transformed, tnr: _, - } = original_expr.clone().rewrite(&mut rewriter)?; + } = original_expr.clone().rewrite_with_schema(input.schema(), &mut rewriter)?; if !transformed { // TODO: remove the next line after `Expr::Wildcard` is removed diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 00629c392df4..29ea8cb78607 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5866,10 +5866,10 @@ select array_ndims(arrow_cast([null], 'List(List(List(Int64)))')); 3 # array_ndims scalar function #2 -query II -select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ----- -3 21 +#query II +#select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); +#---- +#3 21 # array_ndims scalar function #3 query II diff --git a/datafusion/sqllogictest/test_files/lambda.slt b/datafusion/sqllogictest/test_files/lambda.slt new file mode 100644 index 000000000000..0043eae17a60 --- /dev/null +++ b/datafusion/sqllogictest/test_files/lambda.slt @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +############# +## Array Expressions Tests +############# + +statement ok +set datafusion.sql_parser.dialect = databricks; + +statement ok +CREATE TABLE tt +AS VALUES +([1, 50], 10), +([4, 50], 40); + +statement ok +CREATE TABLE t AS SELECT 1 as f, [ [ [2, 3], [2] ], [ [1] ], [ [] ] ] as v, 1 as n; + +query I? +SELECT t.n, array_transform([], e1 -> t.n) from t; +---- +1 [] + +query ? +SELECT array_transform([1], e1 -> (select n from t)); +---- +[1] + +query ? +SELECT array_transform(v, (v1, i) -> array_transform(v1, (v2, j) -> array_transform(v2, v3 -> j)) ) from t; +---- +[[[0, 0], [1]], [[0]], [[]]] + +query I? +SELECT t.n, array_transform([1, 2], (e) -> n) from t; +---- +1 [1, 1] + +# selection pushdown not working yet +query ? +SELECT array_transform([1, 2], (e) -> n) from t; +---- +[1, 1] + +query ? +SELECT array_transform([1, 2], (e, i) -> i) from t; +---- +[0, 1] + +# type coercion +query ? +SELECT array_transform([1, 2], (e, i) -> e+i) from t; +---- +[1, 3] + +query TT +EXPLAIN SELECT array_transform([1, 2], (e, i) -> e+i); +---- +logical_plan +01)Projection: array_transform(List([1, 2]), (e, i) -> e + CAST(i AS Int64)) AS array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2], (e, i) -> e@0 + CAST(i@1 AS Int64)) as array_transform(make_array(Int64(1),Int64(2)),(e, i) -> e + i)] +02)--PlaceholderRowExec + +#cse +query TT +explain select n + 1, array_transform([1], v -> v + n + 1) from t; +---- +logical_plan +01)Projection: t.n + Int64(1), array_transform(List([1]), (v) -> v + t.n + Int64(1)) AS array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1)) +02)--TableScan: t projection=[n] +physical_plan +01)ProjectionExec: expr=[n@0 + 1 as t.n + Int64(1), array_transform([1], (v) -> v@1 + n@0 + 1) as array_transform(make_array(Int64(1)),(v) -> v + t.n + Int64(1))] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + +query ? +SELECT array_transform([1,2,3,4,5], v -> 2); +---- +[2, 2, 2, 2, 2] + +query ? +SELECT array_transform([[1,2],[3,4,5]], v -> array_transform(v, v -> v*2)); +---- +[[2, 4], [6, 8, 10]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> repeat("a", v)); +---- +[a, aa, aaa, aaaa, aaaaa] + +query ? +SELECT array_transform([1,2,3,4,5], v -> list_repeat("a", v)); +---- +[[a], [a, a], [a, a, a], [a, a, a, a], [a, a, a, a, a]] + +query TT +EXPLAIN SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +logical_plan +01)Projection: array_transform(List([1, 2, 3, 4, 5]), (v) -> v * Int64(2)) AS array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2)) +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[array_transform([1, 2, 3, 4, 5], (v) -> v@0 * 2) as array_transform(make_array(Int64(1),Int64(2),Int64(3),Int64(4),Int64(5)),(v) -> v * Int64(2))] +02)--PlaceholderRowExec + + +query I?? +SELECT t.n, t.v, array_transform(t.v, (v, i) -> array_transform(v, (v, j) -> n) ) from t; +---- +1 [[[2, 3], [2]], [[1]], [[]]] [[1, 1], [1], [1]] + +query ? +SELECT array_transform([1,2,3,4,5], v -> v*2); +---- +[2, 4, 6, 8, 10] + + +# expr simplifier +query TT +EXPLAIN SELECT v = v, array_transform([1], v -> v = v) from t; +---- +logical_plan +01)Projection: Boolean(true) AS t.v = t.v, array_transform(List([1]), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(Int64(1)),(v) -> v = v) +02)--TableScan: t projection=[] +physical_plan +01)ProjectionExec: expr=[true as t.v = t.v, array_transform([1], (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(Int64(1)),(v) -> v = v)] +02)--DataSourceExec: partitions=1, partition_sizes=[1] + + +query error +select array_transform(); +---- +DataFusion error: Error during planning: 'array_transform' does not support zero arguments No function matches the given name and argument types 'array_transform()'. You might need to add explicit type casts. + Candidate functions: + array_transform(Any, Any) + + +query error DataFusion error: Execution error: expected list, got Field \{ name: "Int64\(1\)", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: \{\} \} +select array_transform(1, v -> v*2); + +query error DataFusion error: Execution error: array_transform expects a value follewed by a lambda, got \[Lambda\(\["v"\], false\), Value\(Field \{ name: "make_array\(Int64\(1\),Int64\(2\)\)", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\)\] +select array_transform(v -> v*2, [1, 2]); + +query error DataFusion error: Execution error: lambdas_schemas: array_transform argument 1 \(0\-indexed\), a lambda, supports up to 2 arguments, but got 3 +SELECT array_transform([1, 2], (e, i, j) -> i) from t; diff --git a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs index f4e43fd58677..103d593cafbc 100644 --- a/datafusion/substrait/src/logical_plan/producer/expr/mod.rs +++ b/datafusion/substrait/src/logical_plan/producer/expr/mod.rs @@ -152,6 +152,7 @@ pub fn to_substrait_rex( not_impl_err!("Cannot convert {expr:?} to Substrait") } Expr::Unnest(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), + Expr::Lambda(expr) => not_impl_err!("Cannot convert {expr:?} to Substrait"), // https://github.com/substrait-io/substrait/issues/349 } } From fa4a8fbebe21207225077f41183c2e9016a24fbd Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Sat, 22 Nov 2025 15:34:04 -0300 Subject: [PATCH 2/2] add lambdas: None to existing ScalarFunctionArgs in tests/benches --- datafusion/functions-nested/benches/map.rs | 1 + datafusion/functions-nested/src/array_has.rs | 2 ++ datafusion/functions-nested/src/lib.rs | 3 +++ datafusion/functions-nested/src/map_values.rs | 1 + datafusion/functions-nested/src/set_ops.rs | 1 + datafusion/functions/benches/ascii.rs | 4 ++++ .../functions/benches/character_length.rs | 4 ++++ datafusion/functions/benches/chr.rs | 1 + datafusion/functions/benches/concat.rs | 1 + datafusion/functions/benches/cot.rs | 2 ++ datafusion/functions/benches/date_bin.rs | 1 + datafusion/functions/benches/date_trunc.rs | 1 + datafusion/functions/benches/encoding.rs | 4 ++++ datafusion/functions/benches/find_in_set.rs | 4 ++++ datafusion/functions/benches/gcd.rs | 3 +++ datafusion/functions/benches/initcap.rs | 3 +++ datafusion/functions/benches/isnan.rs | 2 ++ datafusion/functions/benches/iszero.rs | 2 ++ datafusion/functions/benches/lower.rs | 6 ++++++ datafusion/functions/benches/ltrim.rs | 1 + datafusion/functions/benches/make_date.rs | 4 ++++ datafusion/functions/benches/nullif.rs | 1 + datafusion/functions/benches/pad.rs | 1 + datafusion/functions/benches/random.rs | 2 ++ datafusion/functions/benches/repeat.rs | 1 + datafusion/functions/benches/reverse.rs | 4 ++++ datafusion/functions/benches/signum.rs | 2 ++ datafusion/functions/benches/strpos.rs | 4 ++++ datafusion/functions/benches/substr.rs | 1 + datafusion/functions/benches/substr_index.rs | 1 + datafusion/functions/benches/to_char.rs | 6 ++++++ datafusion/functions/benches/to_hex.rs | 2 ++ datafusion/functions/benches/to_timestamp.rs | 6 ++++++ datafusion/functions/benches/trunc.rs | 2 ++ datafusion/functions/benches/upper.rs | 1 + datafusion/functions/benches/uuid.rs | 1 + datafusion/functions/src/core/union_extract.rs | 4 ++++ datafusion/functions/src/core/union_tag.rs | 8 ++++++-- datafusion/functions/src/core/version.rs | 1 + datafusion/functions/src/datetime/date_bin.rs | 1 + .../functions/src/datetime/date_trunc.rs | 2 ++ .../functions/src/datetime/from_unixtime.rs | 2 ++ datafusion/functions/src/datetime/make_date.rs | 1 + datafusion/functions/src/datetime/now.rs | 2 ++ datafusion/functions/src/datetime/to_char.rs | 7 +++++++ datafusion/functions/src/datetime/to_date.rs | 1 + .../functions/src/datetime/to_local_time.rs | 2 ++ .../functions/src/datetime/to_timestamp.rs | 2 ++ datafusion/functions/src/math/log.rs | 18 ++++++++++++++++++ datafusion/functions/src/math/power.rs | 2 ++ datafusion/functions/src/math/signum.rs | 2 ++ datafusion/functions/src/regex/regexpcount.rs | 1 + datafusion/functions/src/regex/regexpinstr.rs | 1 + datafusion/functions/src/string/concat.rs | 1 + datafusion/functions/src/string/concat_ws.rs | 2 ++ datafusion/functions/src/string/contains.rs | 1 + datafusion/functions/src/string/lower.rs | 1 + datafusion/functions/src/string/upper.rs | 1 + .../functions/src/unicode/find_in_set.rs | 1 + datafusion/functions/src/unicode/strpos.rs | 1 + datafusion/functions/src/utils.rs | 3 +++ datafusion/spark/benches/char.rs | 1 + .../spark/src/function/bitmap/bitmap_count.rs | 1 + .../src/function/datetime/make_dt_interval.rs | 1 + .../src/function/datetime/make_interval.rs | 1 + datafusion/spark/src/function/string/concat.rs | 2 ++ datafusion/spark/src/function/utils.rs | 5 ++++- 67 files changed, 162 insertions(+), 3 deletions(-) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 3197cc55cc95..3075d2e573e4 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -117,6 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions-nested/src/array_has.rs b/datafusion/functions-nested/src/array_has.rs index 080b2f16d92f..d6a333c0a0ef 100644 --- a/datafusion/functions-nested/src/array_has.rs +++ b/datafusion/functions-nested/src/array_has.rs @@ -819,6 +819,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; @@ -847,6 +848,7 @@ mod tests { number_rows: 1, return_field, config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; let output = result.into_array(1)?; diff --git a/datafusion/functions-nested/src/lib.rs b/datafusion/functions-nested/src/lib.rs index 3a66e6569476..55acf24ba465 100644 --- a/datafusion/functions-nested/src/lib.rs +++ b/datafusion/functions-nested/src/lib.rs @@ -37,6 +37,7 @@ pub mod macros; pub mod array_has; +pub mod array_transform; pub mod cardinality; pub mod concat; pub mod dimension; @@ -78,6 +79,7 @@ pub mod expr_fn { pub use super::array_has::array_has; pub use super::array_has::array_has_all; pub use super::array_has::array_has_any; + pub use super::array_transform::array_transform; pub use super::cardinality::cardinality; pub use super::concat::array_append; pub use super::concat::array_concat; @@ -145,6 +147,7 @@ pub fn all_default_nested_functions() -> Vec> { array_has::array_has_udf(), array_has::array_has_all_udf(), array_has::array_has_any_udf(), + array_transform::array_transform_udf(), empty::array_empty_udf(), length::array_length_udf(), distance::array_distance_udf(), diff --git a/datafusion/functions-nested/src/map_values.rs b/datafusion/functions-nested/src/map_values.rs index 6ae8a278063d..ac21ff8acd3f 100644 --- a/datafusion/functions-nested/src/map_values.rs +++ b/datafusion/functions-nested/src/map_values.rs @@ -204,6 +204,7 @@ mod tests { let args = datafusion_expr::ReturnFieldArgs { arg_fields: &[field], scalar_arguments: &[None::<&ScalarValue>], + lambdas: &[false], }; func.return_field_from_args(args).unwrap() diff --git a/datafusion/functions-nested/src/set_ops.rs b/datafusion/functions-nested/src/set_ops.rs index 53642bf1622b..f26fc173d8a9 100644 --- a/datafusion/functions-nested/src/set_ops.rs +++ b/datafusion/functions-nested/src/set_ops.rs @@ -596,6 +596,7 @@ mod tests { number_rows: 1, return_field: input_field.clone().into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_eq!( diff --git a/datafusion/functions/benches/ascii.rs b/datafusion/functions/benches/ascii.rs index 03d25e9c3d4f..97e6ab20ed45 100644 --- a/datafusion/functions/benches/ascii.rs +++ b/datafusion/functions/benches/ascii.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -81,6 +82,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -108,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -129,6 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 4a1a63d62765..f98e8a8b1a68 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -103,6 +105,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -127,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8356cf7c3172..d51cda4566d6 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -69,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 09200139a244..637832853782 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -60,6 +60,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 97f21ccd6d55..56f50522acc5 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 74390491d538..1c3713723738 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -66,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 498a3e63ef29..b757535fb03c 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -71,6 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 98faee91e191..72b033cf5d9e 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -45,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -63,6 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(); @@ -101,6 +104,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index a928f5655806..6fe498a58d84 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -168,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -186,6 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -208,6 +210,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })) }) }); @@ -228,6 +231,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index 19e196d9a3ea..2bfec91e290d 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -79,6 +80,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 0, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 50aee8dbb916..37d98596deb8 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -70,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -86,6 +87,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -100,6 +102,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 4a90d45d6622..dcce59e46ce4 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -53,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -77,6 +78,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 961cba7200ce..574539fbb642 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -82,6 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 6a5178b87fdc..e741afd0d8e0 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -145,6 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -167,6 +168,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -191,6 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -225,6 +228,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -240,6 +244,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); @@ -256,6 +261,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 4458af614396..9b344cc6b143 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -153,6 +153,7 @@ fn run_with_string_type( number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 15a895468db9..2a681ddedcbe 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -81,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -111,6 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -141,6 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) @@ -168,6 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index d649697cc518..15914cd7ee6c 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f92a69bbf4f9..c7d46da3d26c 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -116,6 +116,7 @@ fn invoke_pad_with_args( number_rows, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }; if left_pad { diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 88efb2d1b5b9..293587668580 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -43,6 +43,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 8192, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); @@ -64,6 +65,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 128, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 80ffa8ee38f1..9a7c63ed4f30 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -76,6 +76,7 @@ fn invoke_repeat_with_args( number_rows: repeat_times as usize, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index b1eca654fb25..a8af40cd8cc1 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -58,6 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -80,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -107,6 +109,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -131,6 +134,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: N_ROWS, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 24b8861e4d28..805b62c83da6 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -83,6 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 18a99e44bf48..708ebb551872 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -128,6 +128,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -146,6 +147,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); @@ -165,6 +167,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, @@ -185,6 +188,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: n_rows, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 771413458c1f..58fda73defd2 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -116,6 +116,7 @@ fn invoke_substr_with_args( number_rows, return_field: Field::new("f", DataType::Utf8View, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) } diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index d0941d9baedd..a77b961657c5 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -110,6 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 945508aec740..61990b4cb8b9 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -149,6 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -176,6 +177,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -203,6 +205,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -229,6 +232,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -256,6 +260,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) @@ -288,6 +293,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a75ed9258791..baa2de80c466 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -44,6 +44,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -62,6 +63,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index a8f5c5816d4d..e510a7c3fad4 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -130,6 +130,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -150,6 +151,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -170,6 +172,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -203,6 +206,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -244,6 +248,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) @@ -286,6 +291,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: batch_len, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 6e225e0e7038..0b08791f9ae5 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -49,6 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) @@ -68,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::clone(&return_field), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 7328b32574a4..e9f0941032d8 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -50,6 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 1368e2f2af5d..8ad79b2866ea 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -37,6 +37,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: 1024, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&config_options), + lambdas: None, })) }) }); diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index a71e2e87388d..ac542866f7e4 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -209,6 +209,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -232,6 +233,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -248,12 +250,14 @@ mod tests { .iter() .map(|arg| Field::new("a", arg.data_type().clone(), true).into()) .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { args, arg_fields, number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/union_tag.rs b/datafusion/functions/src/core/union_tag.rs index aeadb8292ba1..ecdebf66e004 100644 --- a/datafusion/functions/src/core/union_tag.rs +++ b/datafusion/functions/src/core/union_tag.rs @@ -173,6 +173,7 @@ mod tests { fields, UnionMode::Dense, ); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -180,10 +181,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); @@ -196,6 +198,7 @@ mod tests { #[test] fn union_scalar_empty() { let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense); + let arg_fields = vec![Field::new("a", scalar.data_type(), false).into()]; let return_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); @@ -203,10 +206,11 @@ mod tests { let result = UnionTagFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(scalar)], + arg_fields, number_rows: 1, return_field: Field::new("res", return_type, true).into(), - arg_fields: vec![], config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index ef3c5aafa480..390111028c8f 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -112,6 +112,7 @@ mod test { number_rows: 0, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 92af123dbafa..546612931464 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -530,6 +530,7 @@ mod tests { number_rows, return_field: Arc::clone(return_field), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; DateBinFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 913e6217af82..5736c221cae8 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -892,6 +892,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -1080,6 +1081,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 5d6adfb6f119..be44be094e5b 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -179,6 +179,7 @@ mod test { number_rows: 1, return_field: Field::new("f", DataType::Timestamp(Second, None), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -212,6 +213,7 @@ mod test { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 0fe5d156a838..afa4ef132147 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -250,6 +250,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; MakeDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index 4723548a4558..f18e72a107e2 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -163,6 +163,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("legacy now() return field"); @@ -170,6 +171,7 @@ mod tests { .return_field_from_args(ReturnFieldArgs { arg_fields: &empty_fields, scalar_arguments: &empty_scalars, + lambdas: &[], }) .expect("configured now() return field"); diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 7d9b2bc241e1..5d69ce233f64 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -375,6 +375,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::clone(&Arc::new(ConfigOptions::default())), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -480,6 +481,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -574,6 +576,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -738,6 +741,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -766,6 +770,7 @@ mod tests { number_rows: batch_len, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -791,6 +796,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -812,6 +818,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 3840c8d8bbb9..f6b313e6a28b 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -186,6 +186,7 @@ mod tests { number_rows, return_field: Field::new("f", DataType::Date32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; ToDateFunc::new().invoke_with_args(args) } diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 6e0a150b0a35..4d50a70d3723 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -549,6 +549,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", expected.data_type(), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) .unwrap(); match res { @@ -620,6 +621,7 @@ mod tests { ) .into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 0a0700097770..f35e17007303 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1033,6 +1033,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) @@ -1083,6 +1084,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", rt, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index f66f6fcfc1f8..1a73ed8436a6 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -370,6 +370,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -390,6 +391,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); @@ -407,6 +409,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -437,6 +440,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -471,6 +475,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -505,6 +510,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -537,6 +543,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -572,6 +579,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -613,6 +621,7 @@ mod tests { number_rows: 5, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -655,6 +664,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -836,6 +846,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Decimal128(38, 0), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -869,6 +880,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -903,6 +915,7 @@ mod tests { number_rows: 6, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -947,6 +960,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -987,6 +1001,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1037,6 +1052,7 @@ mod tests { number_rows: 7, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new() .invoke_with_args(args) @@ -1078,6 +1094,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); @@ -1101,6 +1118,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = LogFunc::new().invoke_with_args(args); assert!(result.is_err()); diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index ad2e795d086e..21a777abb329 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -222,6 +222,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) @@ -258,6 +259,7 @@ mod tests { number_rows: 4, return_field: Field::new("f", DataType::Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index bbe6178f39b7..d1d49b1bf6f9 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -173,6 +173,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float32, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) @@ -220,6 +221,7 @@ mod test { number_rows: array.len(), return_field: Field::new("f", DataType::Float64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8bad506217aa..ee6f412bb9a1 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -628,6 +628,7 @@ mod tests { number_rows: args.len(), return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/regex/regexpinstr.rs b/datafusion/functions/src/regex/regexpinstr.rs index 851c182a90dd..1e64f7087ea7 100644 --- a/datafusion/functions/src/regex/regexpinstr.rs +++ b/datafusion/functions/src/regex/regexpinstr.rs @@ -494,6 +494,7 @@ mod tests { number_rows: args.len(), return_field: Arc::new(Field::new("f", Int64, true)), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }) } diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a93e70e714e8..661bcfe4e0fd 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -487,6 +487,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index cdd30ac8755a..85704d6b2f46 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -495,6 +495,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -532,6 +533,7 @@ mod tests { number_rows: 3, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 7e50676933c8..1edab4c6bf33 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -177,6 +177,7 @@ mod test { number_rows: 2, return_field: Field::new("f", DataType::Boolean, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index ee56a6a54985..099a3ffd44cc 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -113,6 +113,7 @@ mod tests { arg_fields, return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 8bb2ec1d511c..d7d2bde94b0a 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -112,6 +112,7 @@ mod tests { arg_fields: vec![arg_field], return_field: Field::new("f", Utf8, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index fa68e539600b..219bd6eaa762 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -485,6 +485,7 @@ mod tests { number_rows: cardinality, return_field: Field::new("f", return_type, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 4f238b2644bd..a3734b0c0de4 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -336,6 +336,7 @@ mod tests { Field::new("f2", DataType::Utf8, substring_nullable).into(), ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], + lambdas: &[false; 2], }; strpos.return_field_from_args(args).unwrap().is_nullable() diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 932d61e8007c..d6d56b32722d 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -234,6 +234,7 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; scalar_arguments_refs.len()], }); let arg_fields = $ARGS.iter() .enumerate() @@ -252,6 +253,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS }); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -274,6 +276,7 @@ pub mod test { arg_fields, number_rows: cardinality, return_field, + lambdas: None, config_options: $CONFIG_OPTIONS, }) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/spark/benches/char.rs b/datafusion/spark/benches/char.rs index 02eab7630d07..501bfd2a0186 100644 --- a/datafusion/spark/benches/char.rs +++ b/datafusion/spark/benches/char.rs @@ -68,6 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) { number_rows: size, return_field: Arc::new(Field::new("f", DataType::Utf8, true)), config_options: Arc::clone(&config_options), + lambdas: None, }) .unwrap(), ) diff --git a/datafusion/spark/src/function/bitmap/bitmap_count.rs b/datafusion/spark/src/function/bitmap/bitmap_count.rs index 56a9c5edb812..e4c12ebe1966 100644 --- a/datafusion/spark/src/function/bitmap/bitmap_count.rs +++ b/datafusion/spark/src/function/bitmap/bitmap_count.rs @@ -217,6 +217,7 @@ mod tests { number_rows: 1, return_field: Field::new("f", Int64, true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; let udf = BitmapCount::new(); let actual = udf.invoke_with_args(args)?; diff --git a/datafusion/spark/src/function/datetime/make_dt_interval.rs b/datafusion/spark/src/function/datetime/make_dt_interval.rs index bbfba4486134..aaff5400d0c0 100644 --- a/datafusion/spark/src/function/datetime/make_dt_interval.rs +++ b/datafusion/spark/src/function/datetime/make_dt_interval.rs @@ -317,6 +317,7 @@ mod tests { number_rows, return_field: Field::new("f", Duration(Microsecond), true).into(), config_options: Arc::new(Default::default()), + lambdas: None, }; SparkMakeDtInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/datetime/make_interval.rs b/datafusion/spark/src/function/datetime/make_interval.rs index 8e3169556b95..9f98c4b5ce9f 100644 --- a/datafusion/spark/src/function/datetime/make_interval.rs +++ b/datafusion/spark/src/function/datetime/make_interval.rs @@ -516,6 +516,7 @@ mod tests { number_rows, return_field: Field::new("f", Interval(MonthDayNano), true).into(), config_options: Arc::new(ConfigOptions::default()), + lambdas: None, }; SparkMakeInterval::new().invoke_with_args(args) } diff --git a/datafusion/spark/src/function/string/concat.rs b/datafusion/spark/src/function/string/concat.rs index 0dcc58d5bb8e..e2cd8d977fe2 100644 --- a/datafusion/spark/src/function/string/concat.rs +++ b/datafusion/spark/src/function/string/concat.rs @@ -105,6 +105,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, } = args; // Handle zero-argument case: return empty string @@ -130,6 +131,7 @@ fn spark_concat(args: ScalarFunctionArgs) -> Result { number_rows, return_field, config_options, + lambdas, }; let result = concat_func.invoke_with_args(func_args)?; diff --git a/datafusion/spark/src/function/utils.rs b/datafusion/spark/src/function/utils.rs index e272d91d8a70..1064acc34291 100644 --- a/datafusion/spark/src/function/utils.rs +++ b/datafusion/spark/src/function/utils.rs @@ -60,7 +60,8 @@ pub mod test { let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments_refs + scalar_arguments: &scalar_arguments_refs, + lambdas: &vec![false; arg_fields.len()], }); match expected { @@ -74,6 +75,7 @@ pub mod test { return_field, arg_fields: arg_fields.clone(), config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(col_value) => { match col_value.to_array(cardinality) { @@ -117,6 +119,7 @@ pub mod test { return_field: value, arg_fields, config_options: $CONFIG_OPTIONS, + lambdas: None, }) { Ok(_) => assert!(false, "expected error"), Err(error) => {