From 9afcd371cf2b950941e4151759bd9142483d2e9a Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:29:20 -0300 Subject: [PATCH 1/4] add lambda column capture support --- datafusion/common/src/utils/mod.rs | 157 +++++++++++++++++- datafusion/expr/src/execution_props.rs | 22 +++ datafusion/expr/src/higher_order_function.rs | 89 ++++++++-- .../functions-nested/src/array_any_match.rs | 10 +- .../functions-nested/src/array_transform.rs | 12 +- .../physical-expr/src/expressions/case.rs | 53 ++++-- .../physical-expr/src/expressions/lambda.rs | 76 ++++++++- .../src/higher_order_function.rs | 15 +- datafusion/physical-expr/src/planner.rs | 69 +++++--- .../test_files/array/array_transform.slt | 118 +++++++++++-- 10 files changed, 546 insertions(+), 75 deletions(-) diff --git a/datafusion/common/src/utils/mod.rs b/datafusion/common/src/utils/mod.rs index 8c88be03fd5c8..acee7b7a84b02 100644 --- a/datafusion/common/src/utils/mod.rs +++ b/datafusion/common/src/utils/mod.rs @@ -31,19 +31,23 @@ use arrow::array::{ cast::AsArray, }; use arrow::array::{ - Datum, GenericListArray, Int32Array, Int64Array, MutableArrayData, make_array, + ArrowPrimitiveType, Datum, GenericListArray, Int32Array, Int64Array, + MutableArrayData, PrimitiveArray, make_array, }; use arrow::array::{LargeListViewArray, ListViewArray}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; use arrow::compute::kernels::cmp::neq; use arrow::compute::kernels::length::length; use arrow::compute::{SortColumn, SortOptions, partition}; -use arrow::datatypes::{DataType, Field, SchemaRef}; +use arrow::datatypes::{ + ArrowNativeType, DataType, Field, Int32Type, Int64Type, SchemaRef, +}; #[cfg(feature = "sql")] use sqlparser::{ast::Ident, dialect::GenericDialect, parser::Parser}; use std::borrow::{Borrow, Cow}; use std::cmp::{Ordering, min}; use std::collections::HashSet; +use std::iter::repeat_n; use std::num::NonZero; use std::ops::Range; use std::sync::{Arc, LazyLock}; @@ -1181,6 +1185,74 @@ fn truncate_list_nulls( Ok(list.clone()) } +/// If `array` is a list or a map, returns a new array of the same length as it's inner values +/// where each value is the 1-based index of the sublist it's contained. Example: +/// +/// `[[1], [2, 3], [4, 5, 6]] => [1, 2, 2, 3, 3, 3]` +/// +/// Otherwise returns an error +pub fn list_values_row_number(array: &dyn Array) -> Result { + match array.data_type() { + DataType::List(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list().offsets()))), + DataType::LargeList(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int64Type, + >(array.as_list().offsets()))), + DataType::ListView(_) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_list_view().offsets()))), + DataType::LargeListView(_) => { + Ok(Arc::new(variable_size_list_values_row_number::( + array.as_list_view().offsets(), + ))) + } + DataType::FixedSizeList(_, _) => { + let fixed_size_list = array.as_fixed_size_list(); + + Ok(Arc::new(fsl_values_row_number( + fixed_size_list.value_length(), + fixed_size_list.len(), + )?)) + } + DataType::Map(_, _) => Ok(Arc::new(variable_size_list_values_row_number::< + Int32Type, + >(array.as_map().offsets()))), + other => _exec_err!("expected list, got {other}"), + } +} + +/// [0, 2, 2, 5, 6] -> [0, 0, 2, 2, 2, 3] +fn variable_size_list_values_row_number( + offsets: &[T::Native], +) -> PrimitiveArray { + let mut rows_number = Vec::with_capacity( + offsets[offsets.len() - 1].to_usize().unwrap() - offsets[0].to_usize().unwrap(), + ); + + for (i, w) in offsets.windows(2).enumerate() { + let len = w[1].as_usize() - w[0].as_usize(); + rows_number.extend(repeat_n(T::Native::usize_as(i), len)); + } + + PrimitiveArray::new(rows_number.into(), None) +} + +/// (2, 3) -> [0, 0, 1, 1, 2, 2] +fn fsl_values_row_number(list_size: i32, array_len: usize) -> Result { + let list_size = list_size.to_usize().ok_or_else(|| { + _exec_datafusion_err!("fsl_values_index: invalid list_size {list_size}") + })?; + + let mut rows_number = Vec::with_capacity(list_size * array_len); + + for i in 0..array_len { + rows_number.extend(repeat_n(i as i32, list_size)); + } + + Ok(PrimitiveArray::new(rows_number.into(), None)) +} + #[cfg(test)] mod tests { use std::sync::Arc; @@ -1617,4 +1689,85 @@ mod tests { assert_eq!(res.values(), expected.values()); assert_eq!(res.offsets(), expected.offsets()); } + + #[test] + fn test_list_array_values_row_number() { + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1, 3, 0, 2,]) + ), + Int32Array::from(vec![0, 1, 1, 1, 3, 3]) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([0, 0]) + ), + Int32Array::new_null(0) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([1]) + ), + Int32Array::from(vec![0]) + ); + + assert_eq!( + variable_size_list_values_row_number::( + &OffsetBuffer::from_lengths([2]) + ), + Int32Array::from(vec![0, 0]) + ); + } + + #[test] + fn test_fsl_values_row_number() { + assert_eq!( + fsl_values_row_number(2, 3).unwrap(), + Int32Array::from(vec![0, 0, 1, 1, 2, 2]) + ); + + assert_eq!( + fsl_values_row_number(1, 3).unwrap(), + Int32Array::from(vec![0, 1, 2]) + ); + + assert_eq!( + fsl_values_row_number(2, 1).unwrap(), + Int32Array::from(vec![0, 0]) + ); + + assert_eq!( + fsl_values_row_number(2, 0).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 2).unwrap(), + Int32Array::new_null(0), + ); + + assert_eq!( + fsl_values_row_number(0, 0).unwrap(), + Int32Array::new_null(0), + ); + + fsl_values_row_number(-1, 2).unwrap_err(); + fsl_values_row_number(-1, 0).unwrap_err(); + } } diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 24d0f333a6e56..4e6d9c8e7aa65 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -19,6 +19,7 @@ use crate::var_provider::{VarProvider, VarType}; use chrono::{DateTime, Utc}; use datafusion_common::HashMap; use datafusion_common::ScalarValue; +use datafusion_common::TableReference; use datafusion_common::alias::AliasGenerator; use datafusion_common::config::ConfigOptions; use datafusion_common::{Result, internal_err}; @@ -69,6 +70,10 @@ pub struct ExecutionProps { /// Shared results container for uncorrelated scalar subquery values. /// Populated at execution time by `ScalarSubqueryExec`. pub subquery_results: ScalarSubqueryResults, + /// Maps each lambda variable name to its lambda qualifier generated + /// during physical planning. Populated by the physical planner for + /// each lambda before calling `create_physical_expr`. + pub lambda_variable_qualifier: HashMap, } impl Default for ExecutionProps { @@ -87,6 +92,7 @@ impl ExecutionProps { var_providers: None, subquery_indexes: HashMap::new(), subquery_results: ScalarSubqueryResults::default(), + lambda_variable_qualifier: HashMap::new(), } } @@ -145,6 +151,22 @@ impl ExecutionProps { pub fn config_options(&self) -> Option<&Arc> { self.config_options.as_ref() } + + /// Adds a mapping for each variable to the given qualifier. Existing + /// variables with conflicting names get's shadowed + pub fn with_qualified_lambda_variables( + mut self, + qualifier: &TableReference, + variables: &[String], + ) -> Self { + for var in variables { + self.lambda_variable_qualifier + .entry_ref(var) + .insert(qualifier.clone()); + } + + self + } } /// Index of a scalar subquery within a [`ScalarSubqueryResults`] container. diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs index 0e238ffc65f1e..64f56f4d9b4d9 100644 --- a/datafusion/expr/src/higher_order_function.rs +++ b/datafusion/expr/src/higher_order_function.rs @@ -23,7 +23,7 @@ use arrow::array::{ArrayRef, RecordBatch}; use arrow::datatypes::{DataType, FieldRef, Schema}; use arrow_schema::SchemaRef; use datafusion_common::config::ConfigOptions; -use datafusion_common::{Result, ScalarValue, not_impl_err}; +use datafusion_common::{Result, ScalarValue, exec_err, not_impl_err}; use datafusion_expr_common::dyn_eq::{DynEq, DynHash}; use datafusion_expr_common::signature::Volatility; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -224,35 +224,104 @@ pub struct LambdaArgument { /// per outer sublist), avoiding the per-call `Schema::new` build that /// includes constructing the internal name -> index map. schema: SchemaRef, + /// A RecordBatch containing the captured columns inside this lambda body, if any + /// + /// For example, for `array_transform([2], v -> v + a + b)`, + /// this will be a `RecordBatch` with two columns, `a` and `b` + captures: Option, } impl LambdaArgument { - pub fn new(params: Vec, body: Arc) -> Self { - let schema = Arc::new(Schema::new(params.clone())); + pub fn new( + params: Vec, + body: Arc, + captures: Option, + ) -> Self { + let fields = match &captures { + Some(batch) => batch + .schema_ref() + .fields() + .iter() + .cloned() + .chain(params.clone()) + .collect(), + None => params.clone(), + }; + + let schema = Arc::new(Schema::new(fields)); + Self { params, body, schema, + captures, } } /// Evaluate this lambda /// `args` should evaluate to the value of each parameter /// of the correspondent lambda returned in [HigherOrderUDF::lambda_parameters]. + /// + /// `adjust` should adjust the captured columns of this + /// lambda, if any, relative to it's parameters pub fn evaluate( &self, args: &[&dyn Fn() -> Result], + adjust: impl FnOnce(&[ArrayRef]) -> Result>, ) -> Result { - let columns = args + let adjusted_captures = self + .captures + .as_ref() + .map(|captures| { + let adjusted_columns = adjust(captures.columns())?; + + RecordBatch::try_new(captures.schema(), adjusted_columns) + }) + .transpose()?; + + let merged = merge_captures_with_variables( + adjusted_captures.as_ref(), + Arc::clone(&self.schema), + &self.params, + args, + )?; + + self.body.evaluate(&merged) + } +} + +fn merge_captures_with_variables( + captures: Option<&RecordBatch>, + schema: SchemaRef, + params: &[FieldRef], + variables: &[&dyn Fn() -> Result], +) -> Result { + if variables.len() < params.len() { + return exec_err!( + "expected at least {} lambda arguments to merge with captures, got {}", + params.len(), + variables.len() + ); + } + + let columns = match captures { + Some(captures) => { + let mut columns = captures.columns().to_vec(); + + for arg in &variables[..params.len()] { + columns.push(arg()?); + } + + columns + } + None => variables .iter() - .take(self.params.len()) + .take(params.len()) .map(|arg| arg()) - .collect::>()?; - - let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?; + .collect::>()?, + }; - self.body.evaluate(&batch) - } + Ok(RecordBatch::try_new(schema, columns)?) } /// Information about arguments passed to the function diff --git a/datafusion/functions-nested/src/array_any_match.rs b/datafusion/functions-nested/src/array_any_match.rs index e0a56e0f3c117..dce06bb2f2d3b 100644 --- a/datafusion/functions-nested/src/array_any_match.rs +++ b/datafusion/functions-nested/src/array_any_match.rs @@ -20,11 +20,14 @@ use arrow::{ array::{Array, AsArray, BooleanArray, BooleanBuilder, new_null_array}, buffer::NullBuffer, + compute::take_arrays, datatypes::{ArrowNativeType, DataType, Field, FieldRef}, }; use datafusion_common::{ Result, exec_datafusion_err, exec_err, plan_err, - utils::{adjust_offsets_for_slice, list_values, take_function_args}, + utils::{ + adjust_offsets_for_slice, list_values, list_values_row_number, take_function_args, + }, }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, @@ -196,7 +199,10 @@ impl HigherOrderUDF for ArrayAnyMatch { let values_param = || Ok(Arc::clone(&list_values)); let predicate_results = lambda - .evaluate(&[&values_param])? + .evaluate(&[&values_param], |arrays| { + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })? .into_array(list_values.len())?; let predicate_bool = predicate_results diff --git a/datafusion/functions-nested/src/array_transform.rs b/datafusion/functions-nested/src/array_transform.rs index 0dcc7a7613f1e..2542d4ab1fe48 100644 --- a/datafusion/functions-nested/src/array_transform.rs +++ b/datafusion/functions-nested/src/array_transform.rs @@ -19,11 +19,14 @@ use arrow::{ array::{Array, ArrayRef, AsArray, LargeListArray, ListArray}, + compute::take_arrays, datatypes::{DataType, Field, FieldRef}, }; use datafusion_common::{ Result, ScalarValue, exec_err, plan_err, - utils::{adjust_offsets_for_slice, list_values, take_function_args}, + utils::{ + adjust_offsets_for_slice, list_values, list_values_row_number, take_function_args, + }, }; use datafusion_expr::{ ColumnarValue, Documentation, HigherOrderFunctionArgs, HigherOrderReturnFieldArgs, @@ -203,7 +206,12 @@ impl HigherOrderUDF for ArrayTransform { // call the transforming lambda let transformed_values = lambda - .evaluate(&[&values_param])? + .evaluate(&[&values_param], |arrays| { + // if any column got captured, we need to adjust it to the values arrays, + // duplicating values of list with multitple values and removing values of empty lists + let indices = list_values_row_number(&list_array)?; + Ok(take_arrays(arrays, &indices, None)?) + })? .into_array(list_values.len())?; let field = match args.return_field.data_type() { diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index bf95f8e6acf93..edea4697315cf 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -19,7 +19,7 @@ mod literal_lookup_table; use super::{Column, Literal}; use crate::PhysicalExpr; -use crate::expressions::{LambdaExpr, LambdaVariable, lit, try_cast}; +use crate::expressions::{LambdaVariable, lit, try_cast}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{ @@ -137,9 +137,6 @@ impl CaseBody { expr.downcast_ref::() { used_column_indices.insert(lambda_variable.index()); - } else if expr.is::() { - //todo: remove this branch when lambda supports column capture - return Ok(TreeNodeRecursion::Jump); } Ok(TreeNodeRecursion::Continue) }) @@ -189,9 +186,6 @@ impl CaseBody { Arc::clone(lambda_variable.field()), )))); } - } else if e.is::() { - //todo: remove this branch when lambda supports column capture - return Ok(Transformed::new(e, false, TreeNodeRecursion::Jump)); } Ok(Transformed::no(e)) }) @@ -1038,8 +1032,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_with_expr(&projected_batch, &return_type) @@ -1061,8 +1062,15 @@ impl CaseExpr { projected: &ProjectedCaseBody, ) -> Result { let return_type = self.data_type(&batch.schema())?; - if projected.projection.len() < batch.num_columns() { - let projected_batch = batch.project(&projected.projection)?; + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + let projected_batch = batch.project(&projection)?; projected .body .case_when_no_expr(&projected_batch, &return_type) @@ -1180,14 +1188,23 @@ impl CaseExpr { )?)) } } - } else if projected.projection.len() < batch.num_columns() { - // The case expressions do not use all the columns of the input batch. - // Project first to reduce time spent filtering. - let projected_batch = batch.project(&projected.projection)?; - projected.body.expr_or_expr(&projected_batch, when_value) } else { - // All columns are used in the case expressions, so there is no need to project. - self.body.expr_or_expr(batch, when_value) + // projected.projection may include indexes of lambda variables not available on this batch + let projection = projected + .projection + .iter() + .copied() + .filter(|index| *index < batch.num_columns()) + .collect::>(); + if projection.len() < batch.num_columns() { + // The case expressions do not use all the columns of the input batch. + // Project first to reduce time spent filtering. + let projected_batch = batch.project(&projection)?; + projected.body.expr_or_expr(&projected_batch, when_value) + } else { + // All columns are used in the case expressions, so there is no need to project. + self.body.expr_or_expr(batch, when_value) + } } } diff --git a/datafusion/physical-expr/src/expressions/lambda.rs b/datafusion/physical-expr/src/expressions/lambda.rs index 267f5f605f8ff..5e6dca1a62667 100644 --- a/datafusion/physical-expr/src/expressions/lambda.rs +++ b/datafusion/physical-expr/src/expressions/lambda.rs @@ -20,12 +20,18 @@ use std::hash::Hash; use std::sync::Arc; -use crate::physical_expr::PhysicalExpr; +use crate::{ + expressions::{Column, LambdaVariable}, + physical_expr::PhysicalExpr, +}; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; -use datafusion_common::plan_err; +use datafusion_common::{ + HashMap, plan_err, + tree_node::{Transformed, TreeNode, TreeNodeRecursion}, +}; use datafusion_common::{HashSet, Result, internal_err}; use datafusion_expr::ColumnarValue; @@ -34,6 +40,8 @@ use datafusion_expr::ColumnarValue; pub struct LambdaExpr { params: Vec, body: Arc, + projected_body: Arc, + projection: Vec, } // Manually derive PartialEq and Hash to work around https://github.com/rust-lang/rust/issues/78808 [https://github.com/apache/datafusion/issues/13196] @@ -61,7 +69,61 @@ impl LambdaExpr { } fn new(params: Vec, body: Arc) -> Self { - Self { params, body } + let mut used_column_indices = HashSet::new(); + + body.apply(|node| { + if let Some(col) = node.downcast_ref::() { + used_column_indices.insert(col.index()); + } else if let Some(var) = node.downcast_ref::() { + used_column_indices.insert(var.index()); + } + + Ok(TreeNodeRecursion::Continue) + }) + .expect("closure should be infallible"); + + let mut projection = used_column_indices.into_iter().collect::>(); + + projection.sort(); + + let column_index_map = projection + .iter() + .enumerate() + .map(|(projected, original)| (*original, projected)) + .collect::>(); + + let projected_body = Arc::clone(&body) + .transform_down(|e| { + if let Some(column) = e.downcast_ref::() { + let original = column.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(Column::new( + column.name(), + projected, + )))); + } + } else if let Some(lambda_variable) = e.downcast_ref::() { + let original = lambda_variable.index(); + let projected = *column_index_map.get(&original).unwrap(); + if projected != original { + return Ok(Transformed::yes(Arc::new(LambdaVariable::new( + projected, + Arc::clone(lambda_variable.field()), + )))); + } + } + Ok(Transformed::no(e)) + }) + .expect("closure should be infallible") + .data; + + Self { + params, + body, + projected_body, + projection, + } } /// Get the lambda's params names @@ -73,6 +135,14 @@ impl LambdaExpr { pub fn body(&self) -> &Arc { &self.body } + + pub(crate) fn projection(&self) -> &[usize] { + &self.projection + } + + pub(crate) fn projected_body(&self) -> &Arc { + &self.projected_body + } } impl std::fmt::Display for LambdaExpr { diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index 3b2002a150a02..9a5606f97d326 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -337,9 +337,22 @@ impl PhysicalExpr for HigherOrderFunctionExpr { .map(|(name, param)| param.renamed(name.as_str())) .collect(); + // lambda.projection may include indexes of nested lambda variables not present on this batch + let projection = lambda + .projection() + .iter() + .copied() + .filter(|i| *i < batch.num_columns()) + .collect::>(); + Ok(ValueOrLambda::Lambda(LambdaArgument::new( params, - Arc::clone(lambda.body()), + Arc::clone(lambda.projected_body()), + if projection.is_empty() { + None + } else { + Some(batch.project(&projection)?) + }, ))) } ArgSlot::Value => { diff --git a/datafusion/physical-expr/src/planner.rs b/datafusion/physical-expr/src/planner.rs index 9cb20de252aa0..d0d0508a106a5 100644 --- a/datafusion/physical-expr/src/planner.rs +++ b/datafusion/physical-expr/src/planner.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::sync::Arc; use crate::scalar_subquery::ScalarSubqueryExpr; @@ -30,8 +29,8 @@ use datafusion_common::config::ConfigOptions; use datafusion_common::datatype::FieldExt; use datafusion_common::metadata::{FieldMetadata, format_type_and_metadata}; use datafusion_common::{ - DFSchema, Result, ScalarValue, ToDFSchema, exec_err, internal_datafusion_err, - not_impl_err, plan_datafusion_err, plan_err, + DFSchema, Result, ScalarValue, TableReference, ToDFSchema, exec_err, + internal_datafusion_err, not_impl_err, plan_datafusion_err, plan_err, }; use datafusion_expr::execution_props::ExecutionProps; use datafusion_expr::expr::{ @@ -452,6 +451,18 @@ pub fn create_physical_expr( ); } + let lambda_qualifier = 1 + input_dfschema + .iter() + .filter_map(|(qualifier, _field)| { + qualifier.and_then(|tbl| { + tbl.table().strip_prefix("lambda_")?.parse::().ok() + }) + }) + .max() + .unwrap_or_default(); + + let qualifier = TableReference::bare(format!("lambda_{lambda_qualifier}")); + let physical_args = args .iter() .map(|arg| match arg { @@ -465,15 +476,26 @@ pub fn create_physical_expr( })? .into_iter() .zip(&lambda.params) - .map(|(field, name)| field.renamed(name.as_str())) + .map(|(field, name)| { + (Some(qualifier.clone()), field.renamed(name.as_str())) + }); + + let new_fields = input_dfschema + .iter() + .map(|(tbl, field)| (tbl.cloned(), Arc::clone(field))) + .chain(lambda_parameters) .collect(); - let lambda_schema = DFSchema::from_unqualified_fields( - lambda_parameters, - HashMap::new(), + let lambda_schema = DFSchema::new_with_metadata( + new_fields, + input_dfschema.metadata().clone(), )?; - create_physical_expr(arg, &lambda_schema, execution_props) + let execution_props = execution_props + .clone() + .with_qualified_lambda_variables(&qualifier, &lambda.params); + + create_physical_expr(arg, &lambda_schema, &execution_props) } _ => create_physical_expr(arg, input_dfschema, execution_props), }) @@ -491,17 +513,10 @@ pub fn create_physical_expr( config_options, )?)) } - Expr::Lambda(Lambda { params, body }) => { - // tracked at https://github.com/apache/datafusion/issues/21172 - if body.any_column_refs() { - return plan_err!("lambda doesn't support column capture"); - } - - expressions::lambda( - params, - create_physical_expr(body, input_dfschema, execution_props)?, - ) - } + Expr::Lambda(Lambda { params, body }) => expressions::lambda( + params, + create_physical_expr(body, input_dfschema, execution_props)?, + ), Expr::LambdaVariable(LambdaVariable { name, field, @@ -511,7 +526,21 @@ pub fn create_physical_expr( plan_datafusion_err!("unresolved LambdaVariable {name}") })?; - let index = input_dfschema.inner().index_of(name)?; + let qualifier = execution_props + .lambda_variable_qualifier + .get(name) + .ok_or_else(|| { + plan_datafusion_err!("qualifier for lambda variable {name} not found") + })?; + + let index = input_dfschema + .index_of_column_by_name(Some(qualifier), name) + .ok_or_else(|| { + plan_datafusion_err!( + "lambda variable {qualifier}.{name} not found in planning schema" + ) + })?; + let schema_field = input_dfschema.field(index); // LambdaVariable.field will be made optional as in Expr::Placeholder diff --git a/datafusion/sqllogictest/test_files/array/array_transform.slt b/datafusion/sqllogictest/test_files/array/array_transform.slt index 235abf5b229c4..aaf77a8706c38 100644 --- a/datafusion/sqllogictest/test_files/array/array_transform.slt +++ b/datafusion/sqllogictest/test_files/array/array_transform.slt @@ -23,11 +23,11 @@ statement ok set datafusion.sql_parser.dialect = databricks; statement ok -CREATE TABLE t (list array, number int) +CREATE TABLE t (text varchar, list array, number int) AS VALUES -([1, 50], 10), -([4, 50], 40), -([7, 50], 60); +('a', [1, 50], 10), +('b', [4, 50], 40), +('c', [7, 50], 60); statement ok CREATE TABLE with_null_list (list array) @@ -99,7 +99,7 @@ logical_plan 01)Projection: array_transform(make_array(t.number), (v) -> CAST(v AS Float64) + Float64(3)) 02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@0 AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> CAST(v@1 AS Float64) + 3) as array_transform(make_array(t.number),(v) -> v + Float64(3))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] #cse should not eliminate subtrees containing lambdas @@ -122,7 +122,7 @@ logical_plan 02)--Projection: make_array(t.number) AS __common_expr_1 03)----TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@0 AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@0 AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] +01)ProjectionExec: expr=[array_transform(__common_expr_1@0, (v) -> CAST(v@1 AS Int64) * 2) as array_transform(make_array(t.number),(v) -> v * Int64(2)), array_transform(__common_expr_1@0, (v) -> CAST(v@1 AS Int64) * 2 - 1) as array_transform(make_array(t.number),(v) -> v * Int64(2) - Int64(1))] 02)--ProjectionExec: expr=[make_array(number@0) as __common_expr_1] 03)----DataSourceExec: partitions=1, partition_sizes=[1] @@ -142,7 +142,7 @@ logical_plan 01)Projection: array_transform(make_array(t.number), (v) -> v IS NOT NULL OR Boolean(NULL)) AS array_transform(make_array(t.number),(v) -> v = v) 02)--TableScan: t projection=[number] physical_plan -01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@0 IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] +01)ProjectionExec: expr=[array_transform(make_array(number@0), (v) -> v@1 IS NOT NULL OR NULL) as array_transform(make_array(t.number),(v) -> v = v)] 02)--DataSourceExec: partitions=1, partition_sizes=[1] @@ -154,7 +154,7 @@ logical_plan 01)Projection: array_transform(CAST(CAST(t.list AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a AS Int64) + Int64(1)) AS array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1)) 02)--TableScan: t projection=[list] physical_plan -01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@0 AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] +01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS ListView(Int32)) AS List(Int32)), (a) -> CAST(a@1 AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("ListView(Int32)")),(a) -> a + Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query ? @@ -166,9 +166,9 @@ select array_transform(arrow_cast(t.list, 'ListView(Int32)'), a -> a+1) from t; # higher order function with inner case using lambda variables only query ? -select array_transform([3, 5, 0], v -> case when v > 1 then 2 when v > 4 then 6 else 8 end); +select array_transform([1, 5, 9], v -> case when v = 1 then 2 when v = 5 then 6 else 8 end); ---- -[2, 2, 8] +[2, 6, 8] #case with inner higher order function query I?? @@ -186,6 +186,42 @@ order by t.number; 40 [4, 50] [5, 51] 60 [7, 50] [8, 51] +# case with inner nested higher order function +query T?I? +select + t.text, + t.list, + t.number, + case + when t.number > 30 then array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + v + list[1] + ) + ) + ) + else array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + list[1] + ) + ) + ) + end +from t +order by t.number; +---- +a [1, 50] 10 [[[11, 11]]] +b [4, 50] 40 [[[48, 94]]] +c [7, 50] 60 [[[74, 117]]] + + # array_transform coercion rules query TT explain select array_transform(arrow_cast(t.list, 'FixedSizeList(2, Int32)'), a -> a+1) from t; @@ -194,7 +230,7 @@ logical_plan 01)Projection: array_transform(CAST(CAST(t.list AS FixedSizeList(2 x Int32)) AS List(Int32)), (a) -> CAST(a AS Int64) + Int64(1)) AS array_transform(arrow_cast(t.list,Utf8("FixedSizeList(2, Int32)")),(a) -> a + Int64(1)) 02)--TableScan: t projection=[list] physical_plan -01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS FixedSizeList(2 x Int32)) AS List(Int32)), (a) -> CAST(a@0 AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("FixedSizeList(2, Int32)")),(a) -> a + Int64(1))] +01)ProjectionExec: expr=[array_transform(CAST(CAST(list@0 AS FixedSizeList(2 x Int32)) AS List(Int32)), (a) -> CAST(a@1 AS Int64) + 1) as array_transform(arrow_cast(t.list,Utf8("FixedSizeList(2, Int32)")),(a) -> a + Int64(1))] 02)--DataSourceExec: partitions=1, partition_sizes=[1] query ? @@ -235,13 +271,61 @@ select array_transform(list, v -> v+1) from fully_null_list; NULL NULL -# higher order function with inner case using lambda variables and captured column(capture not supported yet) -query error DataFusion error: Error during planning: lambda doesn't support column capture -select array_transform([3, 5, 9], v -> case when v > 1 then t.number when v > 4 then 6 else 8 end) from t; +# higher order function with inner case using lambda variables and captured column +query ? +select array_transform([1, 5, 9], v -> case when v = 1 then t.number when v = 5 then 6 else 8 end) from t; +---- +[10, 6, 8] +[40, 6, 8] +[60, 6, 8] + +# higher order function with inner case using captured column only +query ? +select array_transform([3, 5, 9], v -> case when t.number = 10 then 2 when t.number = 40 then 6 else 8 end) from t; +---- +[2, 2, 2] +[6, 6, 6] +[8, 8, 8] -# higher order function with inner case using captured column only(capture not supported yet) -query error DataFusion error: Error during planning: lambda doesn't support column capture -select array_transform([3, 5, 9], v -> case when t.number > 1 then 2 when t.number > 4 then 6 else 8 end) from t; +# explain case with inner nested higher order function +query TT +explain select + t.text, + t.list, + t.number, + case + when t.number > 30 then array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + v + list[1] + ) + ) + ) + else array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + list[1] + ) + ) + ) + end +from t +order by t.number; +---- +logical_plan +01)Sort: t.number ASC NULLS LAST +02)--Projection: t.text, t.list, t.number, CASE WHEN t.number > Int32(30) THEN array_transform(make_array(make_array(t.list)), (list) -> array_transform(list, (list) -> array_transform(list, (v) -> t.number + v + array_element(list, Int64(1))))) ELSE array_transform(make_array(make_array(t.list)), (list) -> array_transform(list, (list) -> array_transform(list, (v) -> t.number + array_element(list, Int64(1))))) END AS CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END +03)----TableScan: t projection=[text, list, number] +physical_plan +01)SortExec: expr=[number@2 ASC NULLS LAST], preserve_partitioning=[false] +02)--ProjectionExec: expr=[text@0 as text, list@1 as list, number@2 as number, CASE WHEN number@2 > 30 THEN array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + v@5 + array_element(list@4, 1)))) ELSE array_transform(make_array(make_array(list@1)), (list) -> array_transform(list@3, (list) -> array_transform(list@4, (v) -> number@2 + array_element(list@4, 1)))) END as CASE WHEN t.number > Int64(30) THEN array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + v + list[Int64(1)]))) ELSE array_transform(make_array(make_array(t.list)),(list) -> array_transform(list,(list) -> array_transform(list,(v) -> t.number + list[Int64(1)]))) END] +03)----DataSourceExec: partitions=1, partition_sizes=[1] query error select array_transform(); From c07e1688dd32f8419b89d17413721cfb1c2d483d Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 30 Apr 2026 17:47:18 -0300 Subject: [PATCH 2/4] fix: add new argument to lambda test --- datafusion/physical-expr/src/higher_order_function.rs | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index 9a5606f97d326..f7500df31c235 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -566,9 +566,10 @@ mod tests { args: HigherOrderFunctionArgs, ) -> Result { match &args.args[0] { - ValueOrLambda::Lambda(lambda) => { - lambda.evaluate(&[&|| Ok(Arc::new(NullArray::new(args.number_rows)))]) - } + ValueOrLambda::Lambda(lambda) => lambda.evaluate( + &[&|| Ok(Arc::new(NullArray::new(args.number_rows)))], + |_| unreachable!(), + ), ValueOrLambda::Value(value) => Ok(value.clone()), } } From 17e50e0bda9adb02f380b001b34bc068d7c60e2f Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Thu, 30 Apr 2026 18:24:00 -0300 Subject: [PATCH 3/4] fix: add lambda_qualifiers to ExecutionProps debug test --- datafusion/expr/src/execution_props.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/expr/src/execution_props.rs b/datafusion/expr/src/execution_props.rs index 4e6d9c8e7aa65..649f74ed3997c 100644 --- a/datafusion/expr/src/execution_props.rs +++ b/datafusion/expr/src/execution_props.rs @@ -274,7 +274,7 @@ mod test { fn debug() { let props = ExecutionProps::new(); assert_eq!( - "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [] }", + "ExecutionProps { query_execution_start_time: None, alias_generator: AliasGenerator { next_id: 1 }, config_options: None, var_providers: None, subquery_indexes: {}, subquery_results: [], lambda_variable_qualifier: {} }", format!("{props:?}") ); } From 26bc71371532bb67be450314193c500cc354bfb6 Mon Sep 17 00:00:00 2001 From: gstvg <28798827+gstvg@users.noreply.github.com> Date: Mon, 4 May 2026 08:20:27 +0000 Subject: [PATCH 4/4] address pr review --- datafusion/expr/src/higher_order_function.rs | 22 ++++++++---- .../physical-expr/src/expressions/case.rs | 6 ++-- .../physical-expr/src/expressions/column.rs | 1 + .../src/expressions/lambda_variable.rs | 3 +- .../src/higher_order_function.rs | 2 +- .../test_files/array/array_transform.slt | 34 +++++++++++++++++++ 6 files changed, 57 insertions(+), 11 deletions(-) diff --git a/datafusion/expr/src/higher_order_function.rs b/datafusion/expr/src/higher_order_function.rs index 64f56f4d9b4d9..d027fa084055e 100644 --- a/datafusion/expr/src/higher_order_function.rs +++ b/datafusion/expr/src/higher_order_function.rs @@ -262,25 +262,33 @@ impl LambdaArgument { /// `args` should evaluate to the value of each parameter /// of the correspondent lambda returned in [HigherOrderUDF::lambda_parameters]. /// - /// `adjust` should adjust the captured columns of this - /// lambda, if any, relative to it's parameters + /// `spread_captures` is responsible for transforming the captured column arrays + /// so they align with the evaluation batch. Captures are snapshotted from the + /// outer batch at construction time, giving one value per outer row, but the + /// function may evaluate the lambda body over a batch with a different number + /// of rows. It is the function responsibility to provide the appropriate `spread_captures` + /// closure to expand (or otherwise reshape) the captures to match. + /// Function working on lists, for example `array_transform(arr, v -> v + 1)` + /// flattens all list elements into a single batch, duplicating captured + /// values for rows with multiple elements and dropping them for empty lists. + /// If the lambda has no captures, `spread_captures` is never called. pub fn evaluate( &self, args: &[&dyn Fn() -> Result], - adjust: impl FnOnce(&[ArrayRef]) -> Result>, + spread_captures: impl FnOnce(&[ArrayRef]) -> Result>, ) -> Result { - let adjusted_captures = self + let spread_captures = self .captures .as_ref() .map(|captures| { - let adjusted_columns = adjust(captures.columns())?; + let spread_columns = spread_captures(captures.columns())?; - RecordBatch::try_new(captures.schema(), adjusted_columns) + RecordBatch::try_new(captures.schema(), spread_columns) }) .transpose()?; let merged = merge_captures_with_variables( - adjusted_captures.as_ref(), + spread_captures.as_ref(), Arc::clone(&self.schema), &self.params, args, diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index bf7ea2684ba28..20d0a9e97e833 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -33,8 +33,9 @@ use datafusion_common::{ internal_datafusion_err, internal_err, }; use datafusion_expr::ColumnarValue; -use indexmap::{IndexMap, IndexSet}; +use indexmap::IndexMap; use std::borrow::Cow; +use std::collections::BTreeSet; use std::hash::Hash; use std::sync::Arc; @@ -128,7 +129,8 @@ impl CaseBody { /// Derives a [ProjectedCaseBody] from this [CaseBody]. fn project(&self) -> Result { // Determine the set of columns that are used in all the expressions of the case body. - let mut used_column_indices = IndexSet::::new(); + // Use an ordered set so lambda variables continue to be positioned after columns + let mut used_column_indices = BTreeSet::::new(); let mut collect_column_indices = |expr: &Arc| { expr.apply(|expr| { if let Some(column) = expr.downcast_ref::() { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ba8cd5e6360a1..7d4b0e7e2f396 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -124,6 +124,7 @@ impl PhysicalExpr for Column { } fn return_field(&self, input_schema: &Schema) -> Result { + self.bounds_check(input_schema)?; Ok(input_schema.field(self.index).clone().into()) } diff --git a/datafusion/physical-expr/src/expressions/lambda_variable.rs b/datafusion/physical-expr/src/expressions/lambda_variable.rs index fb9657e897550..1c130ab12e9bb 100644 --- a/datafusion/physical-expr/src/expressions/lambda_variable.rs +++ b/datafusion/physical-expr/src/expressions/lambda_variable.rs @@ -107,7 +107,8 @@ impl PhysicalExpr for LambdaVariable { if self.field.as_ref() != batch.schema_ref().field(self.index) { return exec_err!( - "Physical LambdaVariable field doesn't match batch field during evaluation {} != {}", + "Field of physical LambdaVariable with index {} doesn't match batch field during evaluation {} != {}", + self.index, self.field, batch.schema_ref().field(self.index) ); diff --git a/datafusion/physical-expr/src/higher_order_function.rs b/datafusion/physical-expr/src/higher_order_function.rs index f7500df31c235..801e69ea8fb69 100644 --- a/datafusion/physical-expr/src/higher_order_function.rs +++ b/datafusion/physical-expr/src/higher_order_function.rs @@ -568,7 +568,7 @@ mod tests { match &args.args[0] { ValueOrLambda::Lambda(lambda) => lambda.evaluate( &[&|| Ok(Arc::new(NullArray::new(args.number_rows)))], - |_| unreachable!(), + |arrays| Ok(arrays.to_vec()), ), ValueOrLambda::Value(value) => Ok(value.clone()), } diff --git a/datafusion/sqllogictest/test_files/array/array_transform.slt b/datafusion/sqllogictest/test_files/array/array_transform.slt index aaf77a8706c38..cba1aca4f71f3 100644 --- a/datafusion/sqllogictest/test_files/array/array_transform.slt +++ b/datafusion/sqllogictest/test_files/array/array_transform.slt @@ -221,6 +221,40 @@ a [1, 50] 10 [[[11, 11]]] b [4, 50] 40 [[[48, 94]]] c [7, 50] 60 [[[74, 117]]] +# case with inner nested higher order function where the condition is a lambda function +query T?I? +select + t.text, + t.list, + t.number, + case + when array_any_match(t.list, v -> v = 7) then array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + v + list[1] + ) + ) + ) + else array_transform( + [[t.list]], + list -> array_transform( + list, + list -> array_transform( + list, + v -> number + list[1] + ) + ) + ) + end +from t +order by t.number; +---- +a [1, 50] 10 [[[11, 11]]] +b [4, 50] 40 [[[44, 44]]] +c [7, 50] 60 [[[74, 117]]] # array_transform coercion rules query TT