From 79c78e6f732e036f2bc95bfa0da1f14c44f5116a Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 16:00:39 +0300 Subject: [PATCH 01/17] add test that column in different scope should not be used for case when --- datafusion/common/src/tree_node.rs | 416 +++++++++++++++++- .../physical-expr/src/expressions/case.rs | 330 +++++++++++++- 2 files changed, 727 insertions(+), 19 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 1e7c02e424256..25c465d1bd658 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -31,6 +31,15 @@ macro_rules! handle_transform_recursion { }}; } +/// These macros are used to determine continuation during transforming traversals. +macro_rules! handle_transform_recursion_in_scope { + ($F_DOWN:expr, $F_CHILD:expr, $F_UP:expr) => {{ + $F_DOWN? + .transform_children(|n| n.map_children_in_scope($F_CHILD))? + .transform_parent($F_UP) + }}; +} + /// API for inspecting and rewriting tree data structures. /// /// The `TreeNode` API is used to express algorithms separately from traversing @@ -435,6 +444,274 @@ pub trait TreeNode: Sized { ) -> Result>; } +/// API for inspecting and rewriting tree data structures. +/// +/// See [`TreeNode`] for more details. +/// +/// This add the notion of scopes to [`TreeNode`] and allow you to operate in that. +/// +/// Scope is left for implementators to define, for `PhysicalExpr` child is defined in scope if it have the same input schema as current `PhysicalExpr`. +pub trait ScopedTreeNode: TreeNode { + /// Visit the tree node with a [`TreeNodeVisitor`], performing a + /// depth-first walk of the node and its children. + /// + /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before + /// children are visited), [`TreeNodeVisitor::f_up()`] is called in + /// bottom-up order (after children are visited). + /// + /// # Return Value + /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details. + /// + /// # See Also: + /// * [`Self::apply`] for inspecting nodes with a closure + /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// + /// # Example + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeVisitor::f_down(ParentNode) + /// TreeNodeVisitor::f_down(ChildNode1) + /// TreeNodeVisitor::f_up(ChildNode1) + /// TreeNodeVisitor::f_down(ChildNode2) + /// TreeNodeVisitor::f_up(ChildNode2) + /// TreeNodeVisitor::f_up(ParentNode) + /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn visit_in_scope<'n, V: TreeNodeVisitor<'n, Node = Self>>( + &'n self, + visitor: &mut V, + ) -> Result { + visitor + .f_down(self)? + .visit_children(|| self.apply_children_in_scope(|c| c.visit(visitor)))? + .visit_parent(|| visitor.f_up(self)) + } + + /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a + /// depth-first walk of the node and its children. + /// + /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before + /// children are visited), [`TreeNodeRewriter::f_up()`] is called in + /// bottom-up order (after children are visited). + /// + /// Note: If using the default [`TreeNodeRewriter::f_up`] or + /// [`TreeNodeRewriter::f_down`] that do nothing, consider using + /// [`Self::transform_down`] instead. + /// + /// # Return Value + /// The returns value specifies how the tree walk should proceed. See + /// [`TreeNodeRecursion`] for details. If an [`Err`] is returned, the + /// recursion stops immediately. + /// + /// # See Also + /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s + /// * [Self::transform_down_up] for a top-down (pre-order) traversal. + /// * [Self::transform_down] for a top-down (pre-order) traversal. + /// * [`Self::transform_up`] for a bottom-up (post-order) traversal. + /// + /// # Example + /// Consider the following tree structure: + /// ```text + /// ParentNode + /// left: ChildNode1 + /// right: ChildNode2 + /// ``` + /// + /// Here, the nodes would be visited using the following order: + /// ```text + /// TreeNodeRewriter::f_down(ParentNode) + /// TreeNodeRewriter::f_down(ChildNode1) + /// TreeNodeRewriter::f_up(ChildNode1) + /// TreeNodeRewriter::f_down(ChildNode2) + /// TreeNodeRewriter::f_up(ChildNode2) + /// TreeNodeRewriter::f_up(ParentNode) + /// ``` + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn rewrite_in_scope>( + self, + rewriter: &mut R, + ) -> Result> { + handle_transform_recursion_in_scope!( + rewriter.f_down(self), + |c| c.rewrite_in_scope(rewriter), + |n| { rewriter.f_up(n) } + ) + } + + /// Applies `f` to the node then each of its children, recursively (a + /// top-down, pre-order traversal). + /// + /// The return [`TreeNodeRecursion`] controls the recursion and can cause + /// an early return. + /// + /// # See Also + /// * [`Self::transform_down`] for the equivalent transformation API. + /// * [`Self::visit`] for both top-down and bottom up traversal. + fn apply_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + mut f: F, + ) -> Result { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn apply_impl< + 'n, + N: ScopedTreeNode, + F: FnMut(&'n N) -> Result, + >( + node: &'n N, + f: &mut F, + ) -> Result { + f(node)?.visit_children(|| node.apply_children_in_scope(|c| apply_impl(c, f))) + } + + apply_impl(self, &mut f) + } + + /// Recursively rewrite the node's children and then the node using `f` + /// (a bottom-up post-order traversal). + /// + /// A synonym of [`Self::transform_up_in_scope`]. + fn transform_in_scope Result>>( + self, + f: F, + ) -> Result> { + self.transform_up_in_scope(f) + } + + /// Recursively rewrite the tree using `f` in a top-down (pre-order) + /// fashion. + /// + /// `f` is applied to the node first, and then its children. + /// + /// # See Also + /// * [`Self::transform_down`] for the same transformation but in all children ignoring scope + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. + /// * [Self::transform_down_up_in_scope] for a combined traversal with closures + /// * [`Self::rewrite`] for a combined traversal with a visitor + fn transform_down_in_scope Result>>( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_impl< + N: ScopedTreeNode, + F: FnMut(N) -> Result>, + >( + node: N, + f: &mut F, + ) -> Result> { + f(node)?.transform_children(|n| { + n.map_children_in_scope(|c| transform_down_impl(c, f)) + }) + } + + transform_down_impl(self, &mut f) + } + + /// Recursively rewrite the node using `f` in a bottom-up (post-order) + /// fashion. + /// + /// `f` is applied to the node's children first, and then to the node itself. + /// + /// # See Also + /// * [`Self::transform_down`] top-down (pre-order) traversal. + /// * [Self::transform_down_up] for a combined traversal with closures + /// * [`Self::rewrite`] for a combined traversal with a visitor + fn transform_up_in_scope Result>>( + self, + mut f: F, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_up_impl Result>>( + node: N, + f: &mut F, + ) -> Result> { + node.map_children_in_scope(|c| transform_up_impl(c, f))? + .transform_parent(f) + } + + transform_up_impl(self, &mut f) + } + + /// Same as [`Self::transform_down_up`] but limited to the same scope + fn transform_down_up_in_scope< + FD: FnMut(Self) -> Result>, + FU: FnMut(Self) -> Result>, + >( + self, + mut f_down: FD, + mut f_up: FU, + ) -> Result> { + #[cfg_attr(feature = "recursive_protection", recursive::recursive)] + fn transform_down_up_impl< + N: ScopedTreeNode, + FD: FnMut(N) -> Result>, + FU: FnMut(N) -> Result>, + >( + node: N, + f_down: &mut FD, + f_up: &mut FU, + ) -> Result> { + handle_transform_recursion_in_scope!( + f_down(node), + |c| transform_down_up_impl(c, f_down, f_up), + f_up + ) + } + + transform_down_up_impl(self, &mut f_down, &mut f_up) + } + + /// Returns true if `f` returns true for any node in the tree. + /// + /// Stops recursion as soon as a matching node is found + fn exists_in_scope Result>(&self, mut f: F) -> Result { + let mut found = false; + self.apply_in_scope(|n| { + Ok(if f(n)? { + found = true; + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + }) + .map(|_| found) + } + + /// Low-level API used to implement other APIs. + /// + /// If you want to implement the [`TreeNode`] trait for your own type, you + /// should implement this method and [`Self::map_children`]. + /// + /// Users should use one of the higher level APIs described on [`Self`]. + /// + /// Description: Apply `f` to inspect node's children that are in the same scope as this node (but not the node + /// itself), scope is defined by the node. + fn apply_children_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + f: F, + ) -> Result; + + /// Low-level API used to implement other APIs. + /// + /// If you want to implement the [`TreeNode`] trait for your own type, you + /// should implement this method and [`Self::apply_children`]. + /// + /// Users should use one of the higher level APIs described on [`Self`]. + /// + /// Description: Apply `f` to rewrite the node's children (but not the node itself). + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result>; +} + /// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively /// inspecting [`TreeNode`]s via [`TreeNode::visit`]. /// @@ -1293,6 +1570,57 @@ impl TreeNode for Arc { } } +/// Helper trait for implementing [`ScopedTreeNode`] that have children stored as +/// `Arc`s. If some trait object, such as `dyn T`, implements this trait, +/// its related `Arc` will automatically implement [`ScopedTreeNode`]. +pub trait DynScopedTreeNode: DynTreeNode { + /// Returns all children of the specified `ScopedTreeNode`. + fn arc_children_in_scope(&self) -> Vec<&Arc>; + + /// Constructs a new node with the specified children in scope. + fn with_new_arc_children_in_scope( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result>; +} + +/// Blanket implementation for any `Arc` where `T` implements [`DynScopedTreeNode`] +/// (such as [`Arc`]). +impl ScopedTreeNode for Arc { + fn apply_children_in_scope<'n, F: FnMut(&'n Self) -> Result>( + &'n self, + f: F, + ) -> Result { + self.arc_children_in_scope().into_iter().apply_until_stop(f) + } + + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result> { + let children_in_scope = self.arc_children_in_scope(); + if !children_in_scope.is_empty() { + let new_children_in_scope = children_in_scope + .into_iter() + .cloned() + .map_until_stop_and_collect(f)?; + // Propagate up `new_children_in_scope.transformed` and `new_children_in_scope.tnr` + // along with the node containing transformed children. + if new_children_in_scope.transformed { + let arc_self = Arc::clone(&self); + new_children_in_scope.map_data(|new_children_in_scope| { + self.with_new_arc_children_in_scope(arc_self, new_children_in_scope) + }) + } else { + Ok(Transformed::new(self, false, new_children_in_scope.tnr)) + } + } else { + Ok(Transformed::no(self)) + } + } +} + /// Instead of implementing [`TreeNode`], it's recommended to implement a [`ConcreteTreeNode`] for /// trees that contain nodes with payloads. This approach ensures safe execution of algorithms /// involving payloads, by enforcing rules for detaching and reattaching child nodes. @@ -1338,24 +1666,53 @@ pub(crate) mod tests { use crate::Result; use crate::tree_node::{ - Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, TreeNodeRewriter, - TreeNodeVisitor, + ScopedTreeNode, Transformed, TreeNode, TreeNodeContainer, TreeNodeRecursion, + TreeNodeRewriter, TreeNodeVisitor, }; #[derive(Debug, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { pub(crate) children: Vec>, + pub(crate) children_in_scope: Vec>, pub(crate) data: T, } impl TestTreeNode { pub(crate) fn new(children: Vec>, data: T) -> Self { - Self { children, data } + Self { + children, + children_in_scope: vec![], + data, + } + } + + pub(crate) fn new_scoped( + children_in_scope: Vec>, + data: T, + ) -> Self { + Self { + children: vec![], + children_in_scope, + data, + } + } + + pub(crate) fn new_mixed( + children: Vec>, + children_in_scope: Vec>, + data: T, + ) -> Self { + Self { + children, + children_in_scope, + data, + } } pub(crate) fn new_leaf(data: T) -> Self { Self { children: vec![], + children_in_scope: vec![], data, } } @@ -1387,6 +1744,31 @@ pub(crate) mod tests { } } + impl ScopedTreeNode for TestTreeNode { + fn apply_children_in_scope< + 'n, + F: FnMut(&'n Self) -> Result, + >( + &'n self, + f: F, + ) -> Result { + // TODO - should call apply elements + self.children_in_scope.apply_elements(f) + } + + fn map_children_in_scope Result>>( + self, + f: F, + ) -> Result> { + Ok(self.children_in_scope.map_elements(f)?.update_data( + |new_children_in_scope| Self { + children_in_scope: new_children_in_scope, + ..self + }, + )) + } + } + impl<'a, T: 'a> TreeNodeContainer<'a, Self> for TestTreeNode { fn apply_elements Result>( &'a self, @@ -1403,28 +1785,28 @@ pub(crate) mod tests { } } - // J - // | - // I - // | - // F - // / \ - // E G - // | | - // C H + // J + // | + // I + // | + // F (mixed) + // / \ + // E (scoped) G + // | | + // C (mixed) H // / \ - // B D + // B D (scoped) // | // A fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_d = TestTreeNode::new_scoped(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new_mixed(vec![node_b], vec![node_d], "c".to_string()); + let node_e = TestTreeNode::new_scoped(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); - let node_f = TestTreeNode::new(vec![node_e, node_g], "f".to_string()); + let node_f = TestTreeNode::new_mixed(vec![node_g], vec![node_e], "f".to_string()); let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); TestTreeNode::new(vec![node_i], "j".to_string()) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 4ac40df2201e5..dd316d1cdc5ec 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1425,9 +1425,9 @@ mod tests { use crate::expressions; use crate::expressions::{BinaryExpr, binary, cast, col, is_not_null}; - use arrow::buffer::Buffer; + use arrow::buffer::{BooleanBuffer, Buffer}; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::Field; + use arrow::datatypes::{ArrowNativeType, Field, Fields}; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; @@ -1731,6 +1731,332 @@ mod tests { Ok(()) } + #[test] + fn case_with_expression_that_have_different_scope() -> Result<()> { + /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function + /// + /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution + #[derive(Debug, Hash, Clone)] + pub struct AllListElementMatchMiniLambda { + child: Arc, + predicate_on_list_elements: Arc, + } + impl PartialEq for AllListElementMatchMiniLambda { + fn eq(&self, other: &Self) -> bool { + self.child.as_ref() == other.child.as_ref() && self.predicate_on_list_elements.as_ref() == other.predicate_on_list_elements.as_ref() + } + } + + impl Eq for AllListElementMatchMiniLambda {} + + impl AllListElementMatchMiniLambda { + pub fn new( + child: Arc, + predicate_on_list_element: Arc, + ) -> Arc { + Arc::new(Self { + child, + predicate_on_list_elements: predicate_on_list_element + }) + } + } + + impl std::fmt::Display for AllListElementMatchMiniLambda { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "all_match({:?}, {:?})", self.child, self.predicate_on_list_elements) + } + } + + impl PhysicalExpr for AllListElementMatchMiniLambda { + fn as_any(&self) -> &dyn Any { + self + } + + fn return_field(&self, input_schema: &Schema) -> Result { + let is_child_nullable = self.child.nullable(input_schema)?; + Ok(Arc::new(Field::new("match", DataType::Boolean, is_child_nullable))) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let child = self.child.evaluate(batch)?; + let DataType::List(child_list_field) = self.child.data_type(batch.schema_ref())? else { + unreachable!() + }; + + let child = child.to_array_of_size(batch.num_rows())?; + let list = child.as_list::(); + + let lambda_schema = Arc::new(Schema::new(Fields::from(vec![ + Field::new("index", DataType::UInt32, false), + child_list_field.as_ref().clone() + ]))); + + assert_eq!(list.value_offsets()[0].as_usize(), 0, "this is mock implementation, it does not support sliced list"); + assert_eq!(list.value_offsets().last().unwrap().as_usize(), list.values().len(), "this is mock implementation, it does not support sliced list"); + + let list_values = list.values(); + + let new_batch = RecordBatch::try_new(Arc::clone(&lambda_schema), vec![ + Arc::new(list.offsets().lengths().flat_map(|list_len| 0..list_len as u32).collect::()), + Arc::clone(list_values), + ])?; + + let any_match = self.predicate_on_list_elements.evaluate(&new_batch)?; + let any_match = any_match.to_array_of_size(list_values.len())?; + let any_match = any_match.as_boolean(); + + let all_match_per_list = list.offsets().windows(2).map(|start_and_end| { + let length = start_and_end[1] - start_and_end[0]; + let list_matches = any_match.slice(start_and_end[0] as usize, length as usize); + + list_matches.true_count() == list_matches.len() as usize + }).collect::(); + + let result = Arc::new(BooleanArray::new(all_match_per_list, list.nulls().cloned())); + + Ok(ColumnarValue::Array(result)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.predicate_on_list_elements] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + child: iter.next().unwrap(), + predicate_on_list_elements: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + let mut iter = children_in_scope.into_iter(); + Ok(Arc::new(Self { + child: iter.next().unwrap(), + // TODO - but what if child has changed to not be list or the data type has changed?? + predicate_on_list_elements: Arc::clone(&self.predicate_on_list_elements), + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + let input_schema = Arc::new(Schema::new(vec![ + Arc::new(Field::new("list", DataType::new_list(DataType::UInt32, true), true)) + ])); + + let input_list = ListArray::from_iter_primitive::(vec![ + // all even place numbers are even + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Not all even place are even but all odd place are odd + Some(vec![Some(0), Some(1), Some(1)]), + + // Not odd and not even in corresponding places + Some(vec![Some(1), Some(2)]), + ]); + + let batch = RecordBatch::try_new( + input_schema, + vec![ + Arc::new(input_list) + ] + ).unwrap(); + let schema = batch.schema(); + + // case + // WHEN have_all(list, item -> idx % 2 == 0 && item % 2 == 0) THEN "all even values" + // WHEN have_all(list, item -> idx % 2 == 1 && item % 2 == 1) THEN "all odd values are odd" + fn create_when_expr(is_even: bool) -> Arc { + let idx_col: Arc = Arc::new(Column::new("idx", 0)); + let item_col: Arc = Arc::new(Column::new("item", 1)); + AllListElementMatchMiniLambda::new( + Arc::new(Column::new("list", 0)), + create_both_odd_or_even(&idx_col, &item_col, is_even) + ) + } + + fn create_both_odd_or_even(idx_column: &Arc, list_item_column: &Arc, is_even: bool) -> Arc { + let equal_value = if is_even { 0 } else {1}; + let idx_equal = module_2_equal_value(idx_column, equal_value); + let item_equal = module_2_equal_value(list_item_column, equal_value); + + case( + None, + vec![(idx_equal, item_equal)], + // if idx not equal than true + Some(lit(true)), + ).unwrap() + } + + fn module_2_equal_value(left: &Arc, equal_value: u32) -> Arc { + let modulo2 = BinaryExpr::new(Arc::clone(&left), Operator::Modulo, lit(2u32)); + let equal_value = BinaryExpr::new(Arc::new(modulo2), Operator::Eq, lit(equal_value)); + + Arc::new(equal_value) + } + + let expr = generate_case_when_with_type_coercion( + None, + vec![(create_when_expr(true), lit("both even")), (create_when_expr(false), lit("both odd"))], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + + let expected = &StringArray::from(vec![Some("both even"), None, Some("both odd"), None]); + + assert_eq!(expected, result.as_string::()); + + Ok(()) + } + + #[test] + fn case_without_expr_and_with_custom_column_impl() -> Result<()> { + /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function + /// + /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution + #[derive(Debug, Hash, PartialEq, Eq, Clone)] + pub struct CustomColumn { + /// The name of the column (used for debugging and display purposes) + name: String, + /// The index of the column in its schema + index: usize, + data_type: DataType, + nullable: bool, + } + + impl CustomColumn { + pub fn new_with_schema( + name: &str, + schema: &Schema, + ) -> Result> { + let index = schema.index_of(name)?; + let field = schema.field(index); + Ok(Arc::new(CustomColumn { + name: name.to_string(), + index, + data_type: field.data_type().clone(), + nullable: field.is_nullable(), + })) + } + } + + impl std::fmt::Display for CustomColumn { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!(f, "{}@{}", self.name, self.index) + } + } + + impl PhysicalExpr for CustomColumn { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(self.nullable) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.bounds_check(batch.schema().as_ref())?; + Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } + + fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { + unimplemented!() + } + } + + impl CustomColumn { + fn bounds_check(&self, input_schema: &Schema) -> Result<()> { + if self.index < input_schema.fields.len() { + Ok(()) + } else { + internal_err!( + "PhysicalExpr BoundLambdaColumn references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}", + self.name, + self.index, + input_schema.fields.len(), + input_schema + .fields() + .iter() + .map(|f| f.name()) + .collect::>() + ) + } + } + } + + let batch = case_test_batch()?; + let schema = batch.schema(); + + // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END + let when1 = binary( + CustomColumn::new_with_schema("a", &schema)?, + Operator::Eq, + lit("foo"), + &batch.schema(), + )?; + let then1 = lit(123i32); + let when2 = binary( + CustomColumn::new_with_schema("a", &schema)?, + Operator::Eq, + lit("bar"), + &batch.schema(), + )?; + let then2 = lit(456i32); + + let expr = generate_case_when_with_type_coercion( + None, + vec![(when1, then1), (when2, then2)], + None, + schema.as_ref(), + )?; + let result = expr + .evaluate(&batch)? + .into_array(batch.num_rows()) + .expect("Failed to convert to array"); + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); + + assert_eq!(expected, result); + + Ok(()) + } + #[test] fn case_with_expr_when_null() -> Result<()> { let batch = case_test_batch()?; From 18f5cd67fb673239c560c056ee4f807f39f2ead1 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 16:01:07 +0300 Subject: [PATCH 02/17] add tree scope --- .../src/schema_rewriter.rs | 3 +- .../physical-expr-common/src/physical_expr.rs | 57 ++++++- .../physical-expr-common/src/tree_node.rs | 18 ++- .../physical-expr/src/expressions/case.rs | 141 +----------------- datafusion/physical-expr/src/physical_expr.rs | 4 +- datafusion/physical-expr/src/projection.rs | 8 +- datafusion/physical-expr/src/utils/mod.rs | 3 +- .../physical-plan/src/filter_pushdown.rs | 4 +- 8 files changed, 88 insertions(+), 150 deletions(-) diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 3a255ae05f76f..1ec889546065d 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -26,6 +26,7 @@ use std::sync::Arc; use arrow::array::RecordBatch; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; +use datafusion_common::tree_node::ScopedTreeNode; use datafusion_common::{ DataFusionError, Result, ScalarValue, exec_err, metadata::FieldMetadata, @@ -69,7 +70,7 @@ where K: Borrow + Eq + Hash, V: Borrow, { - expr.transform_down(|expr| { + expr.transform_down_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() && let Some(replacement_value) = replacements.get(column.name()) { diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 7107b0a9004d3..b4e3175979844 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -165,12 +165,37 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { /// Get a list of child PhysicalExpr that provide the input for this expr. fn children(&self) -> Vec<&Arc>; + /// Get a list of child PhysicalExpr that provide the input for this expr that are in the same scope as this expression. + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`] + /// + /// To know if specific child is considered in the same scope you can answer this simple question: + /// If that child is a `Column` would that column can be evaluated with the same input schema + /// Expressions like `plus`, `sum`, etc have all children in scope. + /// Lambda expressions like `array_filter(list, value -> value + 1)`, have the `list` in the same scope and the lambda function in different scope + fn children_in_scope(&self) -> Vec<&Arc> { + self.children() + } + /// Returns a new PhysicalExpr where all children were replaced by new exprs. fn with_new_children( self: Arc, children: Vec>, ) -> Result>; + /// Returns a new PhysicalExpr where all scoped children were replaced by new exprs. + /// + /// See [`Self::children_in_scope`] for definition of what child considered a scope + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`] + /// + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + self.with_new_children(children_in_scope) + } + /// Computes the output interval for the expression, given the input /// intervals. /// @@ -476,16 +501,40 @@ pub fn with_new_children_if_necessary( ); if children.is_empty() - || children - .iter() - .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + || children + .iter() + .zip(old_children.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) { Ok(expr.with_new_children(children)?) } else { Ok(expr) } } +/// Returns a copy of this expr if we change any child according to the pointer comparison. +/// The size of `children_in_scope` must be equal to the size of [`PhysicalExpr::children_in_scope()`]. +pub fn with_new_children_in_scope_if_necessary( + expr: Arc, + children_in_scope: Vec>, +) -> Result> { + let old_children_in_scope = expr.children_in_scope(); + assert_eq_or_internal_err!( + children_in_scope.len(), + old_children_in_scope.len(), + "PhysicalExpr: Wrong number of children in scope" + ); + + if children_in_scope.is_empty() + || children_in_scope + .iter() + .zip(old_children_in_scope.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + { + Ok(expr.with_new_children_in_scope(children_in_scope)?) + } else { + Ok(expr) + } +} /// Returns [`Display`] able a list of [`PhysicalExpr`] /// diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs index 6c7d04a22535f..1f9e855762be7 100644 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ b/datafusion/physical-expr-common/src/tree_node.rs @@ -20,10 +20,10 @@ use std::fmt::{self, Display, Formatter}; use std::sync::Arc; -use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary}; +use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary}; use datafusion_common::Result; -use datafusion_common::tree_node::{ConcreteTreeNode, DynTreeNode}; +use datafusion_common::tree_node::{ConcreteTreeNode, DynScopedTreeNode, DynTreeNode}; impl DynTreeNode for dyn PhysicalExpr { fn arc_children(&self) -> Vec<&Arc> { @@ -39,6 +39,20 @@ impl DynTreeNode for dyn PhysicalExpr { } } +impl DynScopedTreeNode for dyn PhysicalExpr { + fn arc_children_in_scope(&self) -> Vec<&Arc> { + self.children_in_scope() + } + + fn with_new_arc_children_in_scope( + &self, + arc_self: Arc, + new_children: Vec>, + ) -> Result> { + with_new_children_in_scope_if_necessary(arc_self, new_children) + } +} + /// A node object encapsulating a [`PhysicalExpr`] node with a payload. Since there are /// two ways to access child plans—directly from the plan and through child nodes—it's /// recommended to perform mutable operations via [`Self::update_expr_from_children`]. diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index dd316d1cdc5ec..2c4504c7c13a0 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -40,7 +40,9 @@ use std::{any::Any, sync::Arc}; use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; -use datafusion_common::tree_node::{Transformed, TreeNode, TreeNodeRecursion}; +use datafusion_common::tree_node::{ + ScopedTreeNode, Transformed, TreeNode, TreeNodeRecursion, +}; use datafusion_physical_expr_common::datum::compare_with_eq; use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; @@ -130,7 +132,7 @@ impl CaseBody { // Determine the set of columns that are used in all the expressions of the case body. let mut used_column_indices = IndexSet::::new(); let mut collect_column_indices = |expr: &Arc| { - expr.apply(|expr| { + expr.apply_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { used_column_indices.insert(column.index()); } @@ -161,7 +163,7 @@ impl CaseBody { // using the column index mapping. let project = |expr: &Arc| -> Result> { Arc::clone(expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { if let Some(column) = e.as_any().downcast_ref::() { let original = column.index(); let projected = *column_index_map.get(&original).unwrap(); @@ -1397,7 +1399,7 @@ fn replace_with_null( input_schema: &Schema, ) -> Result, DataFusionError> { let with_null = Arc::clone(expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { if e.as_ref().dyn_eq(expr_to_replace) { let data_type = e.data_type(input_schema)?; let null_literal = lit(ScalarValue::try_new_null(&data_type)?); @@ -1928,135 +1930,6 @@ mod tests { Ok(()) } - #[test] - fn case_without_expr_and_with_custom_column_impl() -> Result<()> { - /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function - /// - /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution - #[derive(Debug, Hash, PartialEq, Eq, Clone)] - pub struct CustomColumn { - /// The name of the column (used for debugging and display purposes) - name: String, - /// The index of the column in its schema - index: usize, - data_type: DataType, - nullable: bool, - } - - impl CustomColumn { - pub fn new_with_schema( - name: &str, - schema: &Schema, - ) -> Result> { - let index = schema.index_of(name)?; - let field = schema.field(index); - Ok(Arc::new(CustomColumn { - name: name.to_string(), - index, - data_type: field.data_type().clone(), - nullable: field.is_nullable(), - })) - } - } - - impl std::fmt::Display for CustomColumn { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "{}@{}", self.name, self.index) - } - } - - impl PhysicalExpr for CustomColumn { - fn as_any(&self) -> &dyn Any { - self - } - - fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.data_type.clone()) - } - - fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) - } - - fn evaluate(&self, batch: &RecordBatch) -> Result { - self.bounds_check(batch.schema().as_ref())?; - Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) - } - - fn children(&self) -> Vec<&Arc> { - vec![] - } - - fn with_new_children( - self: Arc, - _children: Vec>, - ) -> Result> { - Ok(self) - } - - fn fmt_sql(&self, _: &mut Formatter<'_>) -> std::fmt::Result { - unimplemented!() - } - } - - impl CustomColumn { - fn bounds_check(&self, input_schema: &Schema) -> Result<()> { - if self.index < input_schema.fields.len() { - Ok(()) - } else { - internal_err!( - "PhysicalExpr BoundLambdaColumn references column '{}' at index {} (zero-based) but input schema only has {} columns: {:?}", - self.name, - self.index, - input_schema.fields.len(), - input_schema - .fields() - .iter() - .map(|f| f.name()) - .collect::>() - ) - } - } - } - - let batch = case_test_batch()?; - let schema = batch.schema(); - - // CASE WHEN a = 'foo' THEN 123 WHEN a = 'bar' THEN 456 END - let when1 = binary( - CustomColumn::new_with_schema("a", &schema)?, - Operator::Eq, - lit("foo"), - &batch.schema(), - )?; - let then1 = lit(123i32); - let when2 = binary( - CustomColumn::new_with_schema("a", &schema)?, - Operator::Eq, - lit("bar"), - &batch.schema(), - )?; - let then2 = lit(456i32); - - let expr = generate_case_when_with_type_coercion( - None, - vec![(when1, then1), (when2, then2)], - None, - schema.as_ref(), - )?; - let result = expr - .evaluate(&batch)? - .into_array(batch.num_rows()) - .expect("Failed to convert to array"); - let result = as_int32_array(&result)?; - - let expected = &Int32Array::from(vec![Some(123), None, None, Some(456)]); - - assert_eq!(expected, result); - - Ok(()) - } - #[test] fn case_with_expr_when_null() -> Result<()> { let batch = case_test_batch()?; @@ -2552,7 +2425,7 @@ mod tests { .unwrap(); let expr3 = Arc::clone(&expr) - .transform_down(|e| { + .transform_down_in_scope(|e| { let transformed = match e.as_any().downcast_ref::() { Some(lit_value) => match lit_value.value() { ScalarValue::Utf8(Some(str_value)) => { diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index a03b58e0b594d..37bcddb0a92c3 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -22,7 +22,7 @@ use crate::{LexOrdering, PhysicalSortExpr, create_physical_expr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode}; use datafusion_common::{DFSchema, HashMap}; use datafusion_common::{Result, plan_err}; use datafusion_expr::execution_props::ExecutionProps; @@ -38,7 +38,7 @@ pub fn add_offset_to_expr( expr: Arc, offset: isize, ) -> Result> { - expr.transform_down(|e| match e.as_any().downcast_ref::() { + expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::() { Some(col) => { let Some(idx) = col.index().checked_add_signed(offset) else { return plan_err!("Column index overflow"); diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index e133e5a849cd8..53b2ae279e5bb 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -27,7 +27,7 @@ use crate::utils::collect_columns; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode}; use datafusion_common::{ Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, plan_err, @@ -920,7 +920,7 @@ pub fn update_expr( let mut state = RewriteState::Unchanged; let new_expr = Arc::clone(expr) - .transform_up(|expr| { + .transform_up_in_scope(|expr| { if state == RewriteState::RewrittenInvalid { return Ok(Transformed::no(expr)); } @@ -1043,7 +1043,7 @@ impl ProjectionMapping { let mut map = IndexMap::<_, ProjectionTargets>::new(); for (expr_idx, (expr, name)) in expr.into_iter().enumerate() { let target_expr = Arc::new(Column::new(&name, expr_idx)) as _; - let source_expr = expr.transform_down(|e| match e.as_any().downcast_ref::() { + let source_expr = expr.transform_down_in_scope(|e| match e.as_any().downcast_ref::() { Some(col) => { // Sometimes, an expression and its name in the input_schema // doesn't match. This can cause problems, so we make sure @@ -1162,7 +1162,7 @@ pub fn project_ordering( ) -> Option { let mut projected_exprs = vec![]; for PhysicalSortExpr { expr, options } in ordering.iter() { - let transformed = Arc::clone(expr).transform_up(|expr| { + let transformed = Arc::clone(expr).transform_up_in_scope(|expr| { let Some(col) = expr.as_any().downcast_ref::() else { return Ok(Transformed::no(expr)); }; diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 6a8b49ac52523..c97d5ef723098 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -16,6 +16,7 @@ // under the License. mod guarantee; +use datafusion_common::tree_node::ScopedTreeNode; pub use guarantee::{Guarantee, LiteralGuarantee}; use std::borrow::Borrow; @@ -312,7 +313,7 @@ pub fn reassign_expr_columns( expr: Arc, schema: &Schema, ) -> Result> { - expr.transform_down(|expr| { + expr.transform_down_in_scope(|expr| { if let Some(column) = expr.as_any().downcast_ref::() { let index = schema.index_of(column.name())?; diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 7e82b9e8239e0..8910d26913776 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -40,7 +40,7 @@ use std::sync::Arc; use arrow_schema::SchemaRef; use datafusion_common::{ Result, - tree_node::{Transformed, TreeNode}, + tree_node::{ScopedTreeNode, Transformed, TreeNode}, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -364,7 +364,7 @@ impl FilterRemapper { filter: &Arc, ) -> Result>> { let mut all_valid = true; - let transformed = Arc::clone(filter).transform_down(|expr| { + let transformed = Arc::clone(filter).transform_down_in_scope(|expr| { if let Some(col) = expr.as_any().downcast_ref::() { if self.allowed_indices.contains(&col.index()) && let Ok(new_index) = self.child_schema.index_of(col.name()) From 94fb0c3498bec7ad65215491e4ee041208302fc8 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 16:09:45 +0300 Subject: [PATCH 03/17] make the test actually fail on main --- datafusion/physical-expr/src/expressions/case.rs | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 2c4504c7c13a0..a3409202d51c2 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1858,7 +1858,10 @@ mod tests { } let input_schema = Arc::new(Schema::new(vec![ - Arc::new(Field::new("list", DataType::new_list(DataType::UInt32, true), true)) + Arc::new(Field::new("col_1", DataType::Utf8, true)), + Arc::new(Field::new("col_2", DataType::Utf8, true)), + Arc::new(Field::new("col_3", DataType::Utf8, true)), + Arc::new(Field::new("list", DataType::new_list(DataType::UInt32, true), true)), ])); let input_list = ListArray::from_iter_primitive::(vec![ @@ -1875,7 +1878,10 @@ mod tests { let batch = RecordBatch::try_new( input_schema, vec![ - Arc::new(input_list) + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + Arc::new(input_list), ] ).unwrap(); let schema = batch.schema(); @@ -1887,7 +1893,7 @@ mod tests { let idx_col: Arc = Arc::new(Column::new("idx", 0)); let item_col: Arc = Arc::new(Column::new("item", 1)); AllListElementMatchMiniLambda::new( - Arc::new(Column::new("list", 0)), + Arc::new(Column::new("list", 3)), create_both_odd_or_even(&idx_col, &item_col, is_even) ) } From c8d264862568618fa2be4cfa9888eb42841f383d Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 17:27:27 +0300 Subject: [PATCH 04/17] add tests --- .../src/schema_rewriter.rs | 127 ++++++++++++ datafusion/physical-expr/src/projection.rs | 160 +++++++++++++- datafusion/physical-expr/src/utils/mod.rs | 127 ++++++++++++ .../physical-plan/src/filter_pushdown.rs | 196 ++++++++++++++++++ 4 files changed, 609 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 1ec889546065d..9eee7a5599269 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -1734,4 +1734,131 @@ mod tests { assert_eq!(cast_expr.input_field().data_type(), &DataType::Int32); assert_eq!(cast_expr.target_field().data_type(), &DataType::Int64); } + + /// A mock expression with an in-scope child and an out-of-scope child. + /// Used to verify that scoped traversal does not modify out-of-scope children. + #[derive(Debug, Hash, Clone)] + struct ScopedExprMock { + in_scope_child: Arc, + out_of_scope_child: Arc, + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() + == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl std::fmt::Display for ScopedExprMock { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field( + &self, + input_schema: &Schema, + ) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate( + &self, + _batch: &RecordBatch, + ) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + #[test] + fn test_replace_columns_with_literals_does_not_modify_out_of_scope_children() { + // The in-scope child references column "a" which should be replaced + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child also references a column "a" but should NOT be replaced + let out_of_scope_child: Arc = + Arc::new(Column::new("a", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let mut replacements = HashMap::new(); + replacements.insert("a", ScalarValue::Int32(Some(42))); + + let result = replace_columns_with_literals(expr, &replacements).unwrap(); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a" should be replaced with literal 42 + let in_scope_lit = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be replaced with Literal"); + assert_eq!(in_scope_lit.value(), &ScalarValue::Int32(Some(42))); + + // The out-of-scope child "a@0" should be UNCHANGED (still a Column) + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "a"); + assert_eq!(out_of_scope_col.index(), 0); + } } diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 53b2ae279e5bb..58b4c0b7cb8f5 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -1200,7 +1200,7 @@ pub(crate) mod tests { use super::*; use crate::equivalence::{EquivalenceProperties, convert_to_orderings}; use crate::expressions::{BinaryExpr, col}; - use crate::utils::tests::TestScalarUDF; + use crate::utils::tests::{ScopedExprMock, TestScalarUDF}; use crate::{PhysicalExprRef, ScalarFunctionExpr}; use arrow::compute::SortOptions; @@ -3038,4 +3038,162 @@ pub(crate) mod tests { Ok(()) } + + #[test] + fn test_update_expr_does_not_modify_out_of_scope_children() -> Result<()> { + // Outer schema: [a, b, c] + // Expression: ScopedExprMock(in_scope=a@0, out_of_scope=x@0) + // Projection: [c@2 as c_new, a@0 as a_new, b@1 as b_new] + // After unproject: in_scope should become c@2, out_of_scope should stay x@0 + let in_scope_child: Arc = Arc::new(Column::new("a_new", 1)); + let out_of_scope_child: Arc = + Arc::new(Column::new("x", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let projected_exprs = vec![ + ProjectionExpr { + expr: Arc::new(Column::new("c", 2)), + alias: "c_new".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("a", 0)), + alias: "a_new".to_string(), + }, + ProjectionExpr { + expr: Arc::new(Column::new("b", 1)), + alias: "b_new".to_string(), + }, + ]; + + let result = + update_expr(&expr, &projected_exprs, true)?.expect("Should be valid"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a_new@1" should be unprojected to "a@0" + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + + Ok(()) + } + + #[test] + fn test_project_ordering_does_not_modify_out_of_scope_children() { + // Schema: [a: Int32, b: Int32] + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = + Arc::new(Column::new("x", 0)); + + let scoped_expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let ordering = + LexOrdering::new(vec![PhysicalSortExpr::new(scoped_expr, SortOptions::new(false, false))]) + .unwrap(); + + let result = project_ordering(&ordering, &schema).expect("Should project"); + + let projected_expr = &result.first().expr; + let mock = projected_expr + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope column "a" should be reindexed (stays at 0 in this case) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + } + + #[test] + fn test_projection_mapping_does_not_modify_out_of_scope_children() -> Result<()> { + // Input schema: [a: Int32, b: Int32] + let input_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = + Arc::new(Column::new("x", 0)); + + let scoped_expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + // Project: [ScopedExprMock as "result"] + let projection_exprs = vec![(scoped_expr, "result".to_string())]; + + let mapping = ProjectionMapping::try_new(projection_exprs, &input_schema)?; + + // The source expression in the mapping should have its in-scope column + // validated but the out-of-scope column left untouched + let (source_expr, _targets) = mapping.iter().next().unwrap(); + let mock = source_expr + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // In-scope child: "a@0" should still be "a@0" (name matches schema) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // Out-of-scope child: "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + + Ok(()) + } } diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index c97d5ef723098..6ed986722fd86 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -343,8 +343,90 @@ pub(crate) mod tests { ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; + use arrow::array::RecordBatch; use petgraph::visit::Bfs; + /// A mock expression that has two children but only one is "in scope". + /// This simulates a lambda-like expression where the `in_scope_child` + /// references columns in the outer schema and the `out_of_scope_child` + /// references columns in a different (lambda) schema. + #[derive(Debug, Hash, Clone)] + pub(crate) struct ScopedExprMock { + pub in_scope_child: Arc, + pub out_of_scope_child: Arc, + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl Display for ScopedExprMock { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field( + &self, + input_schema: &Schema, + ) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate(&self, _batch: &RecordBatch) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + Display::fmt(&self, f) + } + } + #[derive(Debug, PartialEq, Eq, Hash)] pub struct TestScalarUDF { pub(crate) signature: Signature, @@ -648,4 +730,49 @@ pub(crate) mod tests { Ok(()) } + + #[test] + fn test_reassign_expr_columns_does_not_modify_out_of_scope_children() { + // Outer schema: [a: Int32, b: Int32] + // Lambda schema: [x: Int32] (different scope, should not be touched) + let outer_schema = Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ]); + + // The in-scope child references "b" at index 5 (wrong index for outer_schema) + let in_scope_child: Arc = Arc::new(Column::new("b", 5)); + // The out-of-scope child references "x" at index 0 in the lambda schema + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); + + let expr: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = reassign_expr_columns(expr, &outer_schema).unwrap(); + + // The in-scope "b" column should be reassigned to index 1 (its position in outer_schema) + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "b"); + assert_eq!(in_scope_col.index(), 1); // reassigned to correct index + + // The out-of-scope "x" column should be UNCHANGED (still index 0) + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); // not modified + } } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index 8910d26913776..c742dc9796498 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -553,3 +553,199 @@ impl FilterDescription { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::RecordBatch; + use arrow::datatypes::{DataType, Field, Schema}; + use datafusion_expr::ColumnarValue; + use datafusion_physical_expr::expressions::Column; + + /// A mock expression with an in-scope child and an out-of-scope child. + /// Used to verify that scoped traversal does not modify out-of-scope children. + #[derive(Debug, Hash, Clone)] + struct ScopedExprMock { + in_scope_child: Arc, + out_of_scope_child: Arc, + } + + impl PartialEq for ScopedExprMock { + fn eq(&self, other: &Self) -> bool { + self.in_scope_child.as_ref() == other.in_scope_child.as_ref() + && self.out_of_scope_child.as_ref() + == other.out_of_scope_child.as_ref() + } + } + + impl Eq for ScopedExprMock {} + + impl std::fmt::Display for ScopedExprMock { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "scoped_mock({}, {})", + self.in_scope_child, self.out_of_scope_child + ) + } + } + + impl PhysicalExpr for ScopedExprMock { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn return_field( + &self, + input_schema: &Schema, + ) -> Result> { + self.in_scope_child.return_field(input_schema) + } + + fn evaluate( + &self, + _batch: &RecordBatch, + ) -> Result { + unimplemented!("ScopedExprMock does not support evaluation") + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.in_scope_child, &self.out_of_scope_child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert_eq!(children.len(), 2); + let mut iter = children.into_iter(); + Ok(Arc::new(Self { + in_scope_child: iter.next().unwrap(), + out_of_scope_child: iter.next().unwrap(), + })) + } + + fn children_in_scope(&self) -> Vec<&Arc> { + vec![&self.in_scope_child] + } + + fn with_new_children_in_scope( + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + Ok(Arc::new(Self { + in_scope_child: children_in_scope.into_iter().next().unwrap(), + out_of_scope_child: Arc::clone(&self.out_of_scope_child), + })) + } + + fn fmt_sql( + &self, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } + } + + #[test] + fn test_filter_remapper_does_not_modify_out_of_scope_children() { + // Child schema: [a: Int32, b: Int32] + let child_schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Int32, false), + ])); + + let remapper = FilterRemapper::new(Arc::clone(&child_schema)); + + // The in-scope child references column "a@0" which exists in child schema + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child also references "a@0" but should NOT be remapped + let out_of_scope_child: Arc = + Arc::new(Column::new("a", 0)); + + let filter: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = remapper + .try_remap(&filter) + .unwrap() + .expect("Should remap successfully"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a@0" should be remapped (stays "a@0" since schema matches) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 0); + + // The out-of-scope child "a@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "a"); + assert_eq!(out_of_scope_col.index(), 0); + } + + #[test] + fn test_filter_remapper_remaps_in_scope_but_not_out_of_scope() { + // Parent schema: [a: Int32, b: Int32, c: Int32] + // Child schema: [b: Int32, a: Int32] (different order, so "a" remaps from 0 -> 1) + let child_schema = Arc::new(Schema::new(vec![ + Field::new("b", DataType::Int32, false), + Field::new("a", DataType::Int32, false), + ])); + + let remapper = FilterRemapper::new(Arc::clone(&child_schema)); + + // The in-scope child references "a@0" - should be remapped to "a@1" in child schema + let in_scope_child: Arc = Arc::new(Column::new("a", 0)); + // The out-of-scope child references "x@0" in the lambda schema - should NOT be touched + let out_of_scope_child: Arc = + Arc::new(Column::new("x", 0)); + + let filter: Arc = Arc::new(ScopedExprMock { + in_scope_child, + out_of_scope_child, + }); + + let result = remapper + .try_remap(&filter) + .unwrap() + .expect("Should remap successfully"); + + let mock = result + .as_any() + .downcast_ref::() + .expect("Should still be ScopedExprMock"); + + // The in-scope child "a@0" should be remapped to "a@1" (position in child schema) + let in_scope_col = mock + .in_scope_child + .as_any() + .downcast_ref::() + .expect("in_scope_child should be Column"); + assert_eq!(in_scope_col.name(), "a"); + assert_eq!(in_scope_col.index(), 1); + + // The out-of-scope child "x@0" should be UNCHANGED + let out_of_scope_col = mock + .out_of_scope_child + .as_any() + .downcast_ref::() + .expect("out_of_scope_child should still be Column"); + assert_eq!(out_of_scope_col.name(), "x"); + assert_eq!(out_of_scope_col.index(), 0); + } +} From b6383cc727f56095338a6f1ee903aeff69e212fd Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 17:38:51 +0300 Subject: [PATCH 05/17] remove outdated column --- datafusion/physical-expr/src/expressions/case.rs | 3 --- 1 file changed, 3 deletions(-) diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index a3409202d51c2..d113460899881 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1886,9 +1886,6 @@ mod tests { ).unwrap(); let schema = batch.schema(); - // case - // WHEN have_all(list, item -> idx % 2 == 0 && item % 2 == 0) THEN "all even values" - // WHEN have_all(list, item -> idx % 2 == 1 && item % 2 == 1) THEN "all odd values are odd" fn create_when_expr(is_even: bool) -> Arc { let idx_col: Arc = Arc::new(Column::new("idx", 0)); let item_col: Arc = Arc::new(Column::new("item", 1)); From d981a15c9f3242ee557c53371f481699f333aa3b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 18:15:39 +0300 Subject: [PATCH 06/17] add tests --- datafusion/common/src/tree_node.rs | 912 ++++++++++++++++++++++++++++- 1 file changed, 897 insertions(+), 15 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 25c465d1bd658..da092afcaf198 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -490,7 +490,7 @@ pub trait ScopedTreeNode: TreeNode { ) -> Result { visitor .f_down(self)? - .visit_children(|| self.apply_children_in_scope(|c| c.visit(visitor)))? + .visit_children(|| self.apply_children_in_scope(|c| c.visit_in_scope(visitor)))? .visit_parent(|| visitor.f_up(self)) } @@ -1677,7 +1677,7 @@ pub(crate) mod tests { pub(crate) data: T, } - impl TestTreeNode { + impl TestTreeNode { pub(crate) fn new(children: Vec>, data: T) -> Self { Self { children, @@ -1686,24 +1686,28 @@ pub(crate) mod tests { } } + /// Creates a node where all children are in scope. + /// Automatically sets `children` = clone of `children_in_scope`. pub(crate) fn new_scoped( children_in_scope: Vec>, data: T, ) -> Self { Self { - children: vec![], + children: children_in_scope.clone(), children_in_scope, data, } } + /// Creates a node with explicit `children` (all, in order) and + /// `children_in_scope` (the scoped subset). pub(crate) fn new_mixed( - children: Vec>, + all_children: Vec>, children_in_scope: Vec>, data: T, ) -> Self { Self { - children, + children: all_children, children_in_scope, data, } @@ -1720,6 +1724,20 @@ pub(crate) mod tests { pub(crate) fn is_leaf(&self) -> bool { self.children.is_empty() } + + /// Strip children_in_scope recursively - used to compare trees + /// in TreeNode tests where children_in_scope is not relevant. + fn strip_scope(self) -> Self { + Self { + children: self + .children + .into_iter() + .map(|c| c.strip_scope()) + .collect(), + children_in_scope: vec![], + data: self.data, + } + } } impl TreeNode for TestTreeNode { @@ -1798,17 +1816,31 @@ pub(crate) mod tests { // B D (scoped) // | // A + // + // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, D, A, G, H + // (new_scoped/new_mixed auto-add children_in_scope to children) + // + // ScopedTreeNode (children_in_scope) traversal visits: J, I, F, E, C, D, A + // (skips B, G, H which are only in children, not children_in_scope) fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); let node_b = TestTreeNode::new_leaf("b".to_string()); let node_d = TestTreeNode::new_scoped(vec![node_a], "d".to_string()); - let node_c = TestTreeNode::new_mixed(vec![node_b], vec![node_d], "c".to_string()); + let node_c = TestTreeNode::new_mixed( + vec![node_b, node_d.clone()], + vec![node_d], + "c".to_string(), + ); let node_e = TestTreeNode::new_scoped(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); - let node_f = TestTreeNode::new_mixed(vec![node_g], vec![node_e], "f".to_string()); - let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); - TestTreeNode::new(vec![node_i], "j".to_string()) + let node_f = TestTreeNode::new_mixed( + vec![node_e.clone(), node_g], + vec![node_e], + "f".to_string(), + ); + let node_i = TestTreeNode::new_scoped(vec![node_f], "i".to_string()); + TestTreeNode::new_scoped(vec![node_i], "j".to_string()) } // Continue on all nodes @@ -2352,7 +2384,7 @@ pub(crate) mod tests { } } - fn transform_yes>( + fn transform_yes + Clone>( transformation_name: N, ) -> impl FnMut(TestTreeNode) -> Result>> { move |node| { @@ -2365,7 +2397,7 @@ pub(crate) mod tests { fn transform_and_event_on< N: Display, - T: PartialEq + Display + From, + T: PartialEq + Display + From + Clone, D: Into, >( transformation_name: N, @@ -2392,7 +2424,9 @@ pub(crate) mod tests { fn $NAME() -> Result<()> { let tree = test_tree(); let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); - assert_eq!(tree.rewrite(&mut rewriter)?, $EXPECTED_TREE); + let actual = tree.rewrite(&mut rewriter)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, $EXPECTED_TREE); Ok(()) } @@ -2404,7 +2438,9 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_down_up($F_DOWN, $F_UP,)?, $EXPECTED_TREE); + let actual = tree.transform_down_up($F_DOWN, $F_UP)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, $EXPECTED_TREE); Ok(()) } @@ -2416,7 +2452,9 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_down($F)?, $EXPECTED_TREE); + let actual = tree.transform_down($F)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, $EXPECTED_TREE); Ok(()) } @@ -2428,7 +2466,9 @@ pub(crate) mod tests { #[test] fn $NAME() -> Result<()> { let tree = test_tree(); - assert_eq!(tree.transform_up($F)?, $EXPECTED_TREE); + let actual = tree.transform_up($F)?; + let actual_stripped = actual.update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, $EXPECTED_TREE); Ok(()) } @@ -2813,4 +2853,846 @@ pub(crate) mod tests { item.visit(&mut visitor).unwrap(); } + + // ===================================================================== + // ScopedTreeNode tests + // + // Models nested scope boundaries, like expressions with nested lambdas. + // Each scope has multiple nodes (>=3) to test sibling traversal within + // a scope, plus some scopes with only 1 in-scope child. + // + // OUTER SCOPE (root traversal visits these): + // root → and → [plus, gt] + // plus → [col_a, col_b] + // gt → [col_c, Lambda1] ← Lambda1 is out of scope + // + // Scoped traversal: root, and, plus, col_a, col_b, gt, col_c + // (7 nodes, Lambda1 not entered) + // + // LAMBDA1 SCOPE (Lambda1's own traversal): + // lambda1 → [list_col, cmp] + // cmp → [idx_col, Lambda2] ← Lambda2 is out of scope + // + // Scoped traversal: lambda1, list_col, cmp, idx_col + // (4 nodes, Lambda2 not entered) + // + // LAMBDA2 SCOPE (Lambda2's own traversal — only 1 in-scope child): + // lambda2 → [inner_col] + // + // Scoped traversal: lambda2, inner_col + // (2 nodes) + // + // root + // | (scoped) + // and (scoped: [plus, gt]) + // / \ + // plus gt (mixed: in_scope=[col_c], out=[Lambda1]) + // / \ / \ + // col_a col_b col_c Lambda1 ← scope boundary + // | + // (in_scope=[list_col, cmp], out=[]) + // / \ + // list_col cmp (mixed: in_scope=[idx_col], out=[Lambda2]) + // / \ + // idx_col Lambda2 ← nested scope boundary + // | + // (in_scope=[inner_col]) + // | + // inner_col + // ===================================================================== + + fn scoped_test_tree() -> TestTreeNode { + // Leaves for outer scope + let col_a = TestTreeNode::new_leaf("col_a".to_string()); + let col_b = TestTreeNode::new_leaf("col_b".to_string()); + let col_c = TestTreeNode::new_leaf("col_c".to_string()); + + let plus = TestTreeNode::new_scoped(vec![col_a, col_b], "plus".to_string()); + + // --- Innermost lambda scope: Lambda2 → inner_col --- + let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); + let lambda2 = TestTreeNode::new_scoped(vec![inner_col], "lambda2".to_string()); + + // --- Middle lambda scope: Lambda1 → list_col, cmp --- + let list_col = TestTreeNode::new_leaf("list_col".to_string()); + let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); + // cmp has idx_col in scope, Lambda2 out of scope + let cmp = TestTreeNode::new_mixed( + vec![idx_col.clone(), lambda2.clone()], + vec![idx_col], + "cmp".to_string(), + ); + let lambda1 = TestTreeNode::new_scoped( + vec![list_col, cmp], + "lambda1".to_string(), + ); + + // --- Outer scope --- + // gt has col_c in scope, Lambda1 out of scope + let gt = TestTreeNode::new_mixed( + vec![col_c.clone(), lambda1.clone()], + vec![col_c], + "gt".to_string(), + ); + let and = TestTreeNode::new_scoped(vec![plus, gt], "and".to_string()); + + TestTreeNode::new_scoped(vec![and], "root".to_string()) + } + + /// Collect all data reachable via children_in_scope (scoped DFS). + fn collect_scoped_data(node: &TestTreeNode) -> Vec { + let mut result = vec![node.data.clone()]; + for child in &node.children_in_scope { + result.extend(collect_scoped_data(child)); + } + result + } + + /// Collect all data reachable via children (non-scoped DFS). + fn collect_children_data(node: &TestTreeNode) -> Vec { + let mut result = vec![node.data.clone()]; + for child in &node.children { + result.extend(collect_children_data(child)); + } + result + } + + // Scoped transform helpers that preserve both children and children_in_scope + fn transform_yes_in_scope>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(Transformed::yes(TestTreeNode { + children: node.children, + children_in_scope: node.children_in_scope, + data: format!("{}({})", transformation_name, node.data).into(), + })) + } + } + + fn transform_and_event_on_in_scope< + N: Display, + T: PartialEq + Display + From, + D: Into, + >( + transformation_name: N, + data: D, + event: TreeNodeRecursion, + ) -> impl FnMut(TestTreeNode) -> Result>> { + let d = data.into(); + move |node| { + let new_node = TestTreeNode { + children: node.children, + children_in_scope: node.children_in_scope, + data: format!("{}({})", transformation_name, node.data).into(), + }; + Ok(if node.data == d { + Transformed::new(new_node, true, event) + } else { + Transformed::yes(new_node) + }) + } + } + + fn s(v: &[&str]) -> Vec { + v.iter().map(|s| s.to_string()).collect() + } + + // === visit_in_scope === + + #[test] + fn test_visit_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_continue), + ); + tree.visit_in_scope(&mut visitor)?; + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", + "f_down(plus)", "f_down(col_a)", "f_up(col_a)", + "f_down(col_b)", "f_up(col_b)", "f_up(plus)", + "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", + "f_up(and)", "f_up(root)", + ])); + Ok(()) + } + + #[test] + fn test_visit_in_scope_does_not_enter_lambda() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_continue), + ); + tree.visit_in_scope(&mut visitor)?; + let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + for v in &visitor.visits { + for name in &out_of_scope { + assert!(!v.contains(name), "should not enter other scope ({name}): {v}"); + } + } + Ok(()) + } + + #[test] + fn test_visit_in_scope_f_down_jump_on_plus() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_event_on("plus", TreeNodeRecursion::Jump)), + Box::new(visit_continue), + ); + tree.visit_in_scope(&mut visitor)?; + // Jump on Plus: skip Plus's children, continue with sibling gt + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", "f_down(plus)", + "f_up(plus)", + "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", + "f_up(and)", "f_up(root)", + ])); + Ok(()) + } + + #[test] + fn test_visit_in_scope_f_down_stop_on_col_a() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_event_on("col_a", TreeNodeRecursion::Stop)), + Box::new(visit_continue), + ); + tree.visit_in_scope(&mut visitor)?; + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", "f_down(plus)", "f_down(col_a)", + ])); + Ok(()) + } + + #[test] + fn test_visit_in_scope_f_up_jump_on_col_a() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_event_on("col_a", TreeNodeRecursion::Jump)), + ); + tree.visit_in_scope(&mut visitor)?; + // Jump after f_up(col_a): continue with sibling col_b. + // col_b returns Continue, resetting the tnr, so plus's f_up IS called. + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", + "f_down(plus)", "f_down(col_a)", "f_up(col_a)", + "f_down(col_b)", "f_up(col_b)", + "f_up(plus)", + "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", + "f_up(and)", "f_up(root)", + ])); + Ok(()) + } + + #[test] + fn test_visit_in_scope_f_up_stop_on_col_a() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_event_on("col_a", TreeNodeRecursion::Stop)), + ); + tree.visit_in_scope(&mut visitor)?; + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", "f_down(plus)", + "f_down(col_a)", "f_up(col_a)", + ])); + Ok(()) + } + + #[test] + fn test_visit_in_scope_f_up_jump_on_plus() -> Result<()> { + let tree = scoped_test_tree(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_event_on("plus", TreeNodeRecursion::Jump)), + ); + tree.visit_in_scope(&mut visitor)?; + // Jump after f_up(plus): continue with sibling gt, skip f_up(and) + assert_eq!(visitor.visits, s(&[ + "f_down(root)", "f_down(and)", + "f_down(plus)", "f_down(col_a)", "f_up(col_a)", + "f_down(col_b)", "f_up(col_b)", "f_up(plus)", + // gt is sibling of plus, so it gets visited + "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", + // and's f_up is skipped (Jump from plus), but root f_up happens + // because gt returned Continue, resetting the last tnr + "f_up(and)", "f_up(root)", + ])); + Ok(()) + } + + // === apply_in_scope === + + #[test] + fn test_apply_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let mut visits = vec![]; + tree.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visits, s(&["root", "and", "plus", "col_a", "col_b", "gt", "col_c"])); + Ok(()) + } + + #[test] + fn test_apply_in_scope_does_not_enter_lambda() -> Result<()> { + let tree = scoped_test_tree(); + let mut visits = vec![]; + tree.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + for name in &out_of_scope { + assert!(!visits.contains(&name.to_string()), "{name} should not be visited"); + } + Ok(()) + } + + #[test] + fn test_apply_in_scope_jump_on_plus() -> Result<()> { + let tree = scoped_test_tree(); + let mut visits = vec![]; + tree.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(if n.data == "plus" { + TreeNodeRecursion::Jump + } else { + TreeNodeRecursion::Continue + }) + })?; + // Jump on Plus skips its children, continues to sibling gt + assert_eq!(visits, s(&["root", "and", "plus", "gt", "col_c"])); + Ok(()) + } + + #[test] + fn test_apply_in_scope_stop_on_col_a() -> Result<()> { + let tree = scoped_test_tree(); + let mut visits = vec![]; + tree.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(if n.data == "col_a" { + TreeNodeRecursion::Stop + } else { + TreeNodeRecursion::Continue + }) + })?; + assert_eq!(visits, s(&["root", "and", "plus", "col_a"])); + Ok(()) + } + + // === transform_down_in_scope === + + #[test] + fn test_transform_down_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_in_scope( + transform_yes_in_scope("f_down"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_down(root)", "f_down(and)", "f_down(plus)", + "f_down(col_a)", "f_down(col_b)", + "f_down(gt)", "f_down(col_c)", + ])); + // Lambda internals untouched in children path + let children = collect_children_data(&result.data); + assert_eq!(children[0], "f_down(root)"); + assert!(children.contains(&"lambda1".to_string())); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"list_col".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_transform_down_in_scope_does_not_transform_lambda_internals() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_in_scope( + transform_yes_in_scope("f_down"), + )?; + let scoped = collect_scoped_data(&result.data); + let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + for v in &scoped { + for name in &out_of_scope { + assert!(!v.contains(name), "{name} should not be in scoped data: {v}"); + } + } + Ok(()) + } + + #[test] + fn test_transform_down_in_scope_jump_on_plus() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_in_scope( + transform_and_event_on_in_scope("f_down", "plus", TreeNodeRecursion::Jump), + )?; + // Plus is transformed but children skipped, gt still visited + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_down(root)", "f_down(and)", "f_down(plus)", + "col_a", "col_b", + "f_down(gt)", "f_down(col_c)", + ])); + Ok(()) + } + + // === transform_up_in_scope === + + #[test] + fn test_transform_up_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_up_in_scope( + transform_yes_in_scope("f_up"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(root)", "f_up(and)", "f_up(plus)", + "f_up(col_a)", "f_up(col_b)", + "f_up(gt)", "f_up(col_c)", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda1".to_string())); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_transform_up_in_scope_stop_on_col_a() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_up_in_scope( + transform_and_event_on_in_scope("f_up", "col_a", TreeNodeRecursion::Stop), + )?; + // Stop on col_a: only col_a transformed, everything else untouched + assert_eq!(collect_scoped_data(&result.data), s(&[ + "root", "and", "plus", "f_up(col_a)", "col_b", + "gt", "col_c", + ])); + Ok(()) + } + + // === transform_down_up_in_scope === + + #[test] + fn test_transform_down_up_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_up_in_scope( + transform_yes_in_scope("f_down"), + transform_yes_in_scope("f_up"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", + "f_up(f_down(col_a))", "f_up(f_down(col_b))", + "f_up(f_down(gt))", "f_up(f_down(col_c))", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda1".to_string())); + assert!(children.contains(&"lambda2".to_string())); + Ok(()) + } + + // === exists_in_scope === + + #[test] + fn test_exists_in_scope_found_in_scope() -> Result<()> { + let tree = scoped_test_tree(); + assert!(tree.exists_in_scope(|n| Ok(n.data == "root"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "and"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "plus"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "col_a"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "col_b"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "gt"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "col_c"))?); + Ok(()) + } + + #[test] + fn test_exists_in_scope_not_found_lambda1_scope() -> Result<()> { + let tree = scoped_test_tree(); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "lambda1"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "list_col"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "cmp"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "idx_col"))?); + Ok(()) + } + + #[test] + fn test_exists_in_scope_not_found_nested_lambda2_scope() -> Result<()> { + let tree = scoped_test_tree(); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "lambda2"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "inner_col"))?); + Ok(()) + } + + #[test] + fn test_exists_in_scope_not_found_nonexistent() -> Result<()> { + let tree = scoped_test_tree(); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "zzz"))?); + Ok(()) + } + + // === rewrite_in_scope === + + #[test] + fn test_rewrite_in_scope_continue() -> Result<()> { + let tree = scoped_test_tree(); + let mut rewriter = TestRewriter::new( + Box::new(transform_yes_in_scope("f_down")), + Box::new(transform_yes_in_scope("f_up")), + ); + let result = tree.rewrite_in_scope(&mut rewriter)?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", + "f_up(f_down(col_a))", "f_up(f_down(col_b))", + "f_up(f_down(gt))", "f_up(f_down(col_c))", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda1".to_string())); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_rewrite_in_scope_f_down_jump_on_plus() -> Result<()> { + let tree = scoped_test_tree(); + let mut rewriter = TestRewriter::new( + Box::new(transform_and_event_on_in_scope( + "f_down", "plus", TreeNodeRecursion::Jump, + )), + Box::new(transform_yes_in_scope("f_up")), + ); + let result = tree.rewrite_in_scope(&mut rewriter)?; + // Jump on Plus: children skipped, but sibling gt is visited + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", + "col_a", "col_b", + "f_up(f_down(gt))", "f_up(f_down(col_c))", + ])); + Ok(()) + } + + // === Nested scope isolation tests === + // Each scope is tested with multiple traversal methods to ensure + // no scope crossing occurs in any direction. + + /// Build the Lambda2 subtree (innermost scope: lambda2 → inner_col) + fn build_lambda2() -> TestTreeNode { + let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); + TestTreeNode::new_scoped(vec![inner_col], "lambda2".to_string()) + } + + /// Build the Lambda1 subtree (middle scope: lambda1 → [list_col, cmp → idx_col]) + fn build_lambda1() -> TestTreeNode { + let lambda2 = build_lambda2(); + let list_col = TestTreeNode::new_leaf("list_col".to_string()); + let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); + let cmp = TestTreeNode::new_mixed( + vec![idx_col.clone(), lambda2.clone()], + vec![idx_col], + "cmp".to_string(), + ); + TestTreeNode::new_scoped(vec![list_col, cmp], "lambda1".to_string()) + } + + // --- Lambda1 scope (middle): 4 in-scope nodes --- + + #[test] + fn test_lambda1_scope_apply() -> Result<()> { + let lambda1 = build_lambda1(); + let mut visits = vec![]; + lambda1.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visits, s(&["lambda1", "list_col", "cmp", "idx_col"])); + assert!(!visits.contains(&"lambda2".to_string())); + assert!(!visits.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_lambda1_scope_visit() -> Result<()> { + let lambda1 = build_lambda1(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_continue), + ); + lambda1.visit_in_scope(&mut visitor)?; + assert_eq!(visitor.visits, s(&[ + "f_down(lambda1)", "f_down(list_col)", "f_up(list_col)", + "f_down(cmp)", "f_down(idx_col)", "f_up(idx_col)", "f_up(cmp)", + "f_up(lambda1)", + ])); + // Must not enter lambda2 scope + for v in &visitor.visits { + assert!(!v.contains("lambda2"), "lambda1 visit must not enter lambda2: {v}"); + assert!(!v.contains("inner_col"), "lambda1 visit must not enter lambda2: {v}"); + } + Ok(()) + } + + #[test] + fn test_lambda1_scope_transform_down() -> Result<()> { + let lambda1 = build_lambda1(); + let result = lambda1.transform_down_in_scope( + transform_yes_in_scope("TX"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)", + ])); + // Lambda2 untouched in children path + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_lambda1_scope_transform_up() -> Result<()> { + let lambda1 = build_lambda1(); + let result = lambda1.transform_up_in_scope( + transform_yes_in_scope("TX"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_lambda1_scope_exists() -> Result<()> { + let lambda1 = build_lambda1(); + // In scope + assert!(lambda1.exists_in_scope(|n| Ok(n.data == "lambda1"))?); + assert!(lambda1.exists_in_scope(|n| Ok(n.data == "list_col"))?); + assert!(lambda1.exists_in_scope(|n| Ok(n.data == "cmp"))?); + assert!(lambda1.exists_in_scope(|n| Ok(n.data == "idx_col"))?); + // Out of scope (lambda2's scope) + assert!(!lambda1.exists_in_scope(|n| Ok(n.data == "lambda2"))?); + assert!(!lambda1.exists_in_scope(|n| Ok(n.data == "inner_col"))?); + Ok(()) + } + + #[test] + fn test_lambda1_scope_rewrite() -> Result<()> { + let lambda1 = build_lambda1(); + let mut rewriter = TestRewriter::new( + Box::new(transform_yes_in_scope("f_down")), + Box::new(transform_yes_in_scope("f_up")), + ); + let result = lambda1.rewrite_in_scope(&mut rewriter)?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(lambda1))", "f_up(f_down(list_col))", + "f_up(f_down(cmp))", "f_up(f_down(idx_col))", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_lambda1_scope_transform_down_up() -> Result<()> { + let lambda1 = build_lambda1(); + let result = lambda1.transform_down_up_in_scope( + transform_yes_in_scope("f_down"), + transform_yes_in_scope("f_up"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(lambda1))", "f_up(f_down(list_col))", + "f_up(f_down(cmp))", "f_up(f_down(idx_col))", + ])); + let children = collect_children_data(&result.data); + assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"inner_col".to_string())); + Ok(()) + } + + // --- Lambda2 scope (innermost): 2 in-scope nodes --- + + #[test] + fn test_lambda2_scope_apply() -> Result<()> { + let lambda2 = build_lambda2(); + let mut visits = vec![]; + lambda2.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visits, s(&["lambda2", "inner_col"])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_visit() -> Result<()> { + let lambda2 = build_lambda2(); + let mut visitor = TestVisitor::new( + Box::new(visit_continue), + Box::new(visit_continue), + ); + lambda2.visit_in_scope(&mut visitor)?; + assert_eq!(visitor.visits, s(&[ + "f_down(lambda2)", "f_down(inner_col)", "f_up(inner_col)", "f_up(lambda2)", + ])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_transform_down() -> Result<()> { + let lambda2 = build_lambda2(); + let result = lambda2.transform_down_in_scope( + transform_yes_in_scope("TX"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "TX(lambda2)", "TX(inner_col)", + ])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_transform_up() -> Result<()> { + let lambda2 = build_lambda2(); + let result = lambda2.transform_up_in_scope( + transform_yes_in_scope("TX"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "TX(lambda2)", "TX(inner_col)", + ])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_transform_down_up() -> Result<()> { + let lambda2 = build_lambda2(); + let result = lambda2.transform_down_up_in_scope( + transform_yes_in_scope("f_down"), + transform_yes_in_scope("f_up"), + )?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(lambda2))", "f_up(f_down(inner_col))", + ])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_rewrite() -> Result<()> { + let lambda2 = build_lambda2(); + let mut rewriter = TestRewriter::new( + Box::new(transform_yes_in_scope("f_down")), + Box::new(transform_yes_in_scope("f_up")), + ); + let result = lambda2.rewrite_in_scope(&mut rewriter)?; + assert_eq!(collect_scoped_data(&result.data), s(&[ + "f_up(f_down(lambda2))", "f_up(f_down(inner_col))", + ])); + Ok(()) + } + + #[test] + fn test_lambda2_scope_exists() -> Result<()> { + let lambda2 = build_lambda2(); + assert!(lambda2.exists_in_scope(|n| Ok(n.data == "lambda2"))?); + assert!(lambda2.exists_in_scope(|n| Ok(n.data == "inner_col"))?); + // Not in any scope from lambda2's perspective + assert!(!lambda2.exists_in_scope(|n| Ok(n.data == "lambda1"))?); + assert!(!lambda2.exists_in_scope(|n| Ok(n.data == "col_a"))?); + Ok(()) + } + + // --- Outer scope: transform must not affect any inner scope --- + + #[test] + fn test_outer_scope_transform_does_not_affect_lambda1() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_in_scope( + transform_yes_in_scope("TX"), + )?; + let all_data = collect_children_data(&result.data); + assert_eq!(all_data[0], "TX(root)"); + // Lambda1 scope completely untouched + assert!(all_data.contains(&"lambda1".to_string())); + assert!(all_data.contains(&"list_col".to_string())); + assert!(all_data.contains(&"cmp".to_string())); + assert!(all_data.contains(&"idx_col".to_string())); + // Lambda2 scope completely untouched + assert!(all_data.contains(&"lambda2".to_string())); + assert!(all_data.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_outer_scope_transform_down_up_does_not_affect_inner_scopes() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_down_up_in_scope( + transform_yes_in_scope("f_down"), + transform_yes_in_scope("f_up"), + )?; + let all_data = collect_children_data(&result.data); + assert!(all_data.contains(&"lambda1".to_string())); + assert!(all_data.contains(&"list_col".to_string())); + assert!(all_data.contains(&"cmp".to_string())); + assert!(all_data.contains(&"idx_col".to_string())); + assert!(all_data.contains(&"lambda2".to_string())); + assert!(all_data.contains(&"inner_col".to_string())); + Ok(()) + } + + #[test] + fn test_outer_scope_transform_up_does_not_affect_inner_scopes() -> Result<()> { + let tree = scoped_test_tree(); + let result = tree.transform_up_in_scope( + transform_yes_in_scope("TX"), + )?; + let all_data = collect_children_data(&result.data); + // Lambda1 + Lambda2 scopes completely untouched + assert!(all_data.contains(&"lambda1".to_string())); + assert!(all_data.contains(&"list_col".to_string())); + assert!(all_data.contains(&"cmp".to_string())); + assert!(all_data.contains(&"idx_col".to_string())); + assert!(all_data.contains(&"lambda2".to_string())); + assert!(all_data.contains(&"inner_col".to_string())); + Ok(()) + } + + // === Edge cases === + + #[test] + fn test_leaf_node_scoped_traversal() -> Result<()> { + let leaf = TestTreeNode::::new_leaf("leaf".to_string()); + let mut visits = vec![]; + leaf.apply_in_scope(|n| { + visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(visits, vec!["leaf"]); + Ok(()) + } + + #[test] + fn test_node_with_only_children_no_scope() -> Result<()> { + let child = TestTreeNode::new_leaf("child".to_string()); + let parent = TestTreeNode::new(vec![child], "parent".to_string()); + + let mut scoped_visits = vec![]; + parent.apply_in_scope(|n| { + scoped_visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(scoped_visits, vec!["parent"]); + + let mut all_visits = vec![]; + parent.apply(|n| { + all_visits.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) + })?; + assert_eq!(all_visits, vec!["parent", "child"]); + Ok(()) + } } From 25fcc4a728d61b182d6a236f259116b9999ead66 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 18:25:09 +0300 Subject: [PATCH 07/17] lint and format --- datafusion/common/src/tree_node.rs | 501 +++++++++++------- .../src/schema_rewriter.rs | 26 +- .../physical-expr-common/src/physical_expr.rs | 30 +- .../physical-expr-common/src/tree_node.rs | 4 +- .../physical-expr/src/expressions/case.rs | 191 ++++--- datafusion/physical-expr/src/physical_expr.rs | 2 +- datafusion/physical-expr/src/projection.rs | 19 +- datafusion/physical-expr/src/utils/mod.rs | 15 +- .../physical-plan/src/filter_pushdown.rs | 36 +- 9 files changed, 516 insertions(+), 308 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index da092afcaf198..2247dd87d7c49 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -445,12 +445,12 @@ pub trait TreeNode: Sized { } /// API for inspecting and rewriting tree data structures. -/// +/// /// See [`TreeNode`] for more details. -/// +/// /// This add the notion of scopes to [`TreeNode`] and allow you to operate in that. -/// -/// Scope is left for implementators to define, for `PhysicalExpr` child is defined in scope if it have the same input schema as current `PhysicalExpr`. +/// +/// Scope is left for implementers to define, for `PhysicalExpr` child is defined in scope if it have the same input schema as current `PhysicalExpr`. pub trait ScopedTreeNode: TreeNode { /// Visit the tree node with a [`TreeNodeVisitor`], performing a /// depth-first walk of the node and its children. @@ -490,7 +490,9 @@ pub trait ScopedTreeNode: TreeNode { ) -> Result { visitor .f_down(self)? - .visit_children(|| self.apply_children_in_scope(|c| c.visit_in_scope(visitor)))? + .visit_children(|| { + self.apply_children_in_scope(|c| c.visit_in_scope(visitor)) + })? .visit_parent(|| visitor.f_up(self)) } @@ -1729,11 +1731,7 @@ pub(crate) mod tests { /// in TreeNode tests where children_in_scope is not relevant. fn strip_scope(self) -> Self { Self { - children: self - .children - .into_iter() - .map(|c| c.strip_scope()) - .collect(), + children: self.children.into_iter().map(|c| c.strip_scope()).collect(), children_in_scope: vec![], data: self.data, } @@ -2922,10 +2920,8 @@ pub(crate) mod tests { vec![idx_col], "cmp".to_string(), ); - let lambda1 = TestTreeNode::new_scoped( - vec![list_col, cmp], - "lambda1".to_string(), - ); + let lambda1 = + TestTreeNode::new_scoped(vec![list_col, cmp], "lambda1".to_string()); // --- Outer scope --- // gt has col_c in scope, Lambda1 out of scope @@ -3003,33 +2999,51 @@ pub(crate) mod tests { #[test] fn test_visit_in_scope_continue() -> Result<()> { let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_continue), - ); + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); tree.visit_in_scope(&mut visitor)?; - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", - "f_down(plus)", "f_down(col_a)", "f_up(col_a)", - "f_down(col_b)", "f_up(col_b)", "f_up(plus)", - "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", - "f_up(and)", "f_up(root)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + "f_up(col_a)", + "f_down(col_b)", + "f_up(col_b)", + "f_up(plus)", + "f_down(gt)", + "f_down(col_c)", + "f_up(col_c)", + "f_up(gt)", + "f_up(and)", + "f_up(root)", + ]) + ); Ok(()) } #[test] fn test_visit_in_scope_does_not_enter_lambda() -> Result<()> { let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_continue), - ); + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); tree.visit_in_scope(&mut visitor)?; - let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + let out_of_scope = [ + "lambda1", + "lambda2", + "list_col", + "cmp", + "idx_col", + "inner_col", + ]; for v in &visitor.visits { for name in &out_of_scope { - assert!(!v.contains(name), "should not enter other scope ({name}): {v}"); + assert!( + !v.contains(name), + "should not enter other scope ({name}): {v}" + ); } } Ok(()) @@ -3044,12 +3058,21 @@ pub(crate) mod tests { ); tree.visit_in_scope(&mut visitor)?; // Jump on Plus: skip Plus's children, continue with sibling gt - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", "f_down(plus)", - "f_up(plus)", - "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", - "f_up(and)", "f_up(root)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_up(plus)", + "f_down(gt)", + "f_down(col_c)", + "f_up(col_c)", + "f_up(gt)", + "f_up(and)", + "f_up(root)", + ]) + ); Ok(()) } @@ -3061,9 +3084,15 @@ pub(crate) mod tests { Box::new(visit_continue), ); tree.visit_in_scope(&mut visitor)?; - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", "f_down(plus)", "f_down(col_a)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + ]) + ); Ok(()) } @@ -3077,14 +3106,25 @@ pub(crate) mod tests { tree.visit_in_scope(&mut visitor)?; // Jump after f_up(col_a): continue with sibling col_b. // col_b returns Continue, resetting the tnr, so plus's f_up IS called. - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", - "f_down(plus)", "f_down(col_a)", "f_up(col_a)", - "f_down(col_b)", "f_up(col_b)", - "f_up(plus)", - "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", - "f_up(and)", "f_up(root)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + "f_up(col_a)", + "f_down(col_b)", + "f_up(col_b)", + "f_up(plus)", + "f_down(gt)", + "f_down(col_c)", + "f_up(col_c)", + "f_up(gt)", + "f_up(and)", + "f_up(root)", + ]) + ); Ok(()) } @@ -3096,10 +3136,16 @@ pub(crate) mod tests { Box::new(visit_event_on("col_a", TreeNodeRecursion::Stop)), ); tree.visit_in_scope(&mut visitor)?; - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", "f_down(plus)", - "f_down(col_a)", "f_up(col_a)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + "f_up(col_a)", + ]) + ); Ok(()) } @@ -3112,16 +3158,28 @@ pub(crate) mod tests { ); tree.visit_in_scope(&mut visitor)?; // Jump after f_up(plus): continue with sibling gt, skip f_up(and) - assert_eq!(visitor.visits, s(&[ - "f_down(root)", "f_down(and)", - "f_down(plus)", "f_down(col_a)", "f_up(col_a)", - "f_down(col_b)", "f_up(col_b)", "f_up(plus)", - // gt is sibling of plus, so it gets visited - "f_down(gt)", "f_down(col_c)", "f_up(col_c)", "f_up(gt)", - // and's f_up is skipped (Jump from plus), but root f_up happens - // because gt returned Continue, resetting the last tnr - "f_up(and)", "f_up(root)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + "f_up(col_a)", + "f_down(col_b)", + "f_up(col_b)", + "f_up(plus)", + // gt is sibling of plus, so it gets visited + "f_down(gt)", + "f_down(col_c)", + "f_up(col_c)", + "f_up(gt)", + // and's f_up is skipped (Jump from plus), but root f_up happens + // because gt returned Continue, resetting the last tnr + "f_up(and)", + "f_up(root)", + ]) + ); Ok(()) } @@ -3135,7 +3193,10 @@ pub(crate) mod tests { visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - assert_eq!(visits, s(&["root", "and", "plus", "col_a", "col_b", "gt", "col_c"])); + assert_eq!( + visits, + s(&["root", "and", "plus", "col_a", "col_b", "gt", "col_c"]) + ); Ok(()) } @@ -3147,9 +3208,19 @@ pub(crate) mod tests { visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + let out_of_scope = [ + "lambda1", + "lambda2", + "list_col", + "cmp", + "idx_col", + "inner_col", + ]; for name in &out_of_scope { - assert!(!visits.contains(&name.to_string()), "{name} should not be visited"); + assert!( + !visits.contains(&name.to_string()), + "{name} should not be visited" + ); } Ok(()) } @@ -3192,14 +3263,19 @@ pub(crate) mod tests { #[test] fn test_transform_down_in_scope_continue() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope( - transform_yes_in_scope("f_down"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_down(root)", "f_down(and)", "f_down(plus)", - "f_down(col_a)", "f_down(col_b)", - "f_down(gt)", "f_down(col_c)", - ])); + let result = tree.transform_down_in_scope(transform_yes_in_scope("f_down"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "f_down(col_a)", + "f_down(col_b)", + "f_down(gt)", + "f_down(col_c)", + ]) + ); // Lambda internals untouched in children path let children = collect_children_data(&result.data); assert_eq!(children[0], "f_down(root)"); @@ -3213,14 +3289,22 @@ pub(crate) mod tests { #[test] fn test_transform_down_in_scope_does_not_transform_lambda_internals() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope( - transform_yes_in_scope("f_down"), - )?; + let result = tree.transform_down_in_scope(transform_yes_in_scope("f_down"))?; let scoped = collect_scoped_data(&result.data); - let out_of_scope = ["lambda1", "lambda2", "list_col", "cmp", "idx_col", "inner_col"]; + let out_of_scope = [ + "lambda1", + "lambda2", + "list_col", + "cmp", + "idx_col", + "inner_col", + ]; for v in &scoped { for name in &out_of_scope { - assert!(!v.contains(name), "{name} should not be in scoped data: {v}"); + assert!( + !v.contains(name), + "{name} should not be in scoped data: {v}" + ); } } Ok(()) @@ -3229,15 +3313,24 @@ pub(crate) mod tests { #[test] fn test_transform_down_in_scope_jump_on_plus() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope( - transform_and_event_on_in_scope("f_down", "plus", TreeNodeRecursion::Jump), - )?; + let result = tree.transform_down_in_scope(transform_and_event_on_in_scope( + "f_down", + "plus", + TreeNodeRecursion::Jump, + ))?; // Plus is transformed but children skipped, gt still visited - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_down(root)", "f_down(and)", "f_down(plus)", - "col_a", "col_b", - "f_down(gt)", "f_down(col_c)", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_down(root)", + "f_down(and)", + "f_down(plus)", + "col_a", + "col_b", + "f_down(gt)", + "f_down(col_c)", + ]) + ); Ok(()) } @@ -3246,14 +3339,19 @@ pub(crate) mod tests { #[test] fn test_transform_up_in_scope_continue() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope( - transform_yes_in_scope("f_up"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(root)", "f_up(and)", "f_up(plus)", - "f_up(col_a)", "f_up(col_b)", - "f_up(gt)", "f_up(col_c)", - ])); + let result = tree.transform_up_in_scope(transform_yes_in_scope("f_up"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(root)", + "f_up(and)", + "f_up(plus)", + "f_up(col_a)", + "f_up(col_b)", + "f_up(gt)", + "f_up(col_c)", + ]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda1".to_string())); assert!(children.contains(&"lambda2".to_string())); @@ -3264,14 +3362,16 @@ pub(crate) mod tests { #[test] fn test_transform_up_in_scope_stop_on_col_a() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope( - transform_and_event_on_in_scope("f_up", "col_a", TreeNodeRecursion::Stop), - )?; + let result = tree.transform_up_in_scope(transform_and_event_on_in_scope( + "f_up", + "col_a", + TreeNodeRecursion::Stop, + ))?; // Stop on col_a: only col_a transformed, everything else untouched - assert_eq!(collect_scoped_data(&result.data), s(&[ - "root", "and", "plus", "f_up(col_a)", "col_b", - "gt", "col_c", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&["root", "and", "plus", "f_up(col_a)", "col_b", "gt", "col_c",]) + ); Ok(()) } @@ -3284,11 +3384,18 @@ pub(crate) mod tests { transform_yes_in_scope("f_down"), transform_yes_in_scope("f_up"), )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", - "f_up(f_down(col_a))", "f_up(f_down(col_b))", - "f_up(f_down(gt))", "f_up(f_down(col_c))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(f_down(root))", + "f_up(f_down(and))", + "f_up(f_down(plus))", + "f_up(f_down(col_a))", + "f_up(f_down(col_b))", + "f_up(f_down(gt))", + "f_up(f_down(col_c))", + ]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda1".to_string())); assert!(children.contains(&"lambda2".to_string())); @@ -3345,11 +3452,18 @@ pub(crate) mod tests { Box::new(transform_yes_in_scope("f_up")), ); let result = tree.rewrite_in_scope(&mut rewriter)?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", - "f_up(f_down(col_a))", "f_up(f_down(col_b))", - "f_up(f_down(gt))", "f_up(f_down(col_c))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(f_down(root))", + "f_up(f_down(and))", + "f_up(f_down(plus))", + "f_up(f_down(col_a))", + "f_up(f_down(col_b))", + "f_up(f_down(gt))", + "f_up(f_down(col_c))", + ]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda1".to_string())); assert!(children.contains(&"lambda2".to_string())); @@ -3362,17 +3476,26 @@ pub(crate) mod tests { let tree = scoped_test_tree(); let mut rewriter = TestRewriter::new( Box::new(transform_and_event_on_in_scope( - "f_down", "plus", TreeNodeRecursion::Jump, + "f_down", + "plus", + TreeNodeRecursion::Jump, )), Box::new(transform_yes_in_scope("f_up")), ); let result = tree.rewrite_in_scope(&mut rewriter)?; // Jump on Plus: children skipped, but sibling gt is visited - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", "f_up(f_down(and))", "f_up(f_down(plus))", - "col_a", "col_b", - "f_up(f_down(gt))", "f_up(f_down(col_c))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(f_down(root))", + "f_up(f_down(and))", + "f_up(f_down(plus))", + "col_a", + "col_b", + "f_up(f_down(gt))", + "f_up(f_down(col_c))", + ]) + ); Ok(()) } @@ -3418,20 +3541,32 @@ pub(crate) mod tests { #[test] fn test_lambda1_scope_visit() -> Result<()> { let lambda1 = build_lambda1(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_continue), - ); + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); lambda1.visit_in_scope(&mut visitor)?; - assert_eq!(visitor.visits, s(&[ - "f_down(lambda1)", "f_down(list_col)", "f_up(list_col)", - "f_down(cmp)", "f_down(idx_col)", "f_up(idx_col)", "f_up(cmp)", - "f_up(lambda1)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(lambda1)", + "f_down(list_col)", + "f_up(list_col)", + "f_down(cmp)", + "f_down(idx_col)", + "f_up(idx_col)", + "f_up(cmp)", + "f_up(lambda1)", + ]) + ); // Must not enter lambda2 scope for v in &visitor.visits { - assert!(!v.contains("lambda2"), "lambda1 visit must not enter lambda2: {v}"); - assert!(!v.contains("inner_col"), "lambda1 visit must not enter lambda2: {v}"); + assert!( + !v.contains("lambda2"), + "lambda1 visit must not enter lambda2: {v}" + ); + assert!( + !v.contains("inner_col"), + "lambda1 visit must not enter lambda2: {v}" + ); } Ok(()) } @@ -3439,12 +3574,11 @@ pub(crate) mod tests { #[test] fn test_lambda1_scope_transform_down() -> Result<()> { let lambda1 = build_lambda1(); - let result = lambda1.transform_down_in_scope( - transform_yes_in_scope("TX"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)", - ])); + let result = lambda1.transform_down_in_scope(transform_yes_in_scope("TX"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&["TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)",]) + ); // Lambda2 untouched in children path let children = collect_children_data(&result.data); assert!(children.contains(&"lambda2".to_string())); @@ -3455,12 +3589,11 @@ pub(crate) mod tests { #[test] fn test_lambda1_scope_transform_up() -> Result<()> { let lambda1 = build_lambda1(); - let result = lambda1.transform_up_in_scope( - transform_yes_in_scope("TX"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)", - ])); + let result = lambda1.transform_up_in_scope(transform_yes_in_scope("TX"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&["TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)",]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda2".to_string())); assert!(children.contains(&"inner_col".to_string())); @@ -3489,10 +3622,15 @@ pub(crate) mod tests { Box::new(transform_yes_in_scope("f_up")), ); let result = lambda1.rewrite_in_scope(&mut rewriter)?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda1))", "f_up(f_down(list_col))", - "f_up(f_down(cmp))", "f_up(f_down(idx_col))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(f_down(lambda1))", + "f_up(f_down(list_col))", + "f_up(f_down(cmp))", + "f_up(f_down(idx_col))", + ]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda2".to_string())); assert!(children.contains(&"inner_col".to_string())); @@ -3506,10 +3644,15 @@ pub(crate) mod tests { transform_yes_in_scope("f_down"), transform_yes_in_scope("f_up"), )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda1))", "f_up(f_down(list_col))", - "f_up(f_down(cmp))", "f_up(f_down(idx_col))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&[ + "f_up(f_down(lambda1))", + "f_up(f_down(list_col))", + "f_up(f_down(cmp))", + "f_up(f_down(idx_col))", + ]) + ); let children = collect_children_data(&result.data); assert!(children.contains(&"lambda2".to_string())); assert!(children.contains(&"inner_col".to_string())); @@ -3533,38 +3676,40 @@ pub(crate) mod tests { #[test] fn test_lambda2_scope_visit() -> Result<()> { let lambda2 = build_lambda2(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_continue), - ); + let mut visitor = + TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); lambda2.visit_in_scope(&mut visitor)?; - assert_eq!(visitor.visits, s(&[ - "f_down(lambda2)", "f_down(inner_col)", "f_up(inner_col)", "f_up(lambda2)", - ])); + assert_eq!( + visitor.visits, + s(&[ + "f_down(lambda2)", + "f_down(inner_col)", + "f_up(inner_col)", + "f_up(lambda2)", + ]) + ); Ok(()) } #[test] fn test_lambda2_scope_transform_down() -> Result<()> { let lambda2 = build_lambda2(); - let result = lambda2.transform_down_in_scope( - transform_yes_in_scope("TX"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "TX(lambda2)", "TX(inner_col)", - ])); + let result = lambda2.transform_down_in_scope(transform_yes_in_scope("TX"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&["TX(lambda2)", "TX(inner_col)",]) + ); Ok(()) } #[test] fn test_lambda2_scope_transform_up() -> Result<()> { let lambda2 = build_lambda2(); - let result = lambda2.transform_up_in_scope( - transform_yes_in_scope("TX"), - )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "TX(lambda2)", "TX(inner_col)", - ])); + let result = lambda2.transform_up_in_scope(transform_yes_in_scope("TX"))?; + assert_eq!( + collect_scoped_data(&result.data), + s(&["TX(lambda2)", "TX(inner_col)",]) + ); Ok(()) } @@ -3575,9 +3720,10 @@ pub(crate) mod tests { transform_yes_in_scope("f_down"), transform_yes_in_scope("f_up"), )?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda2))", "f_up(f_down(inner_col))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&["f_up(f_down(lambda2))", "f_up(f_down(inner_col))",]) + ); Ok(()) } @@ -3589,9 +3735,10 @@ pub(crate) mod tests { Box::new(transform_yes_in_scope("f_up")), ); let result = lambda2.rewrite_in_scope(&mut rewriter)?; - assert_eq!(collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda2))", "f_up(f_down(inner_col))", - ])); + assert_eq!( + collect_scoped_data(&result.data), + s(&["f_up(f_down(lambda2))", "f_up(f_down(inner_col))",]) + ); Ok(()) } @@ -3611,9 +3758,7 @@ pub(crate) mod tests { #[test] fn test_outer_scope_transform_does_not_affect_lambda1() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope( - transform_yes_in_scope("TX"), - )?; + let result = tree.transform_down_in_scope(transform_yes_in_scope("TX"))?; let all_data = collect_children_data(&result.data); assert_eq!(all_data[0], "TX(root)"); // Lambda1 scope completely untouched @@ -3647,9 +3792,7 @@ pub(crate) mod tests { #[test] fn test_outer_scope_transform_up_does_not_affect_inner_scopes() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope( - transform_yes_in_scope("TX"), - )?; + let result = tree.transform_up_in_scope(transform_yes_in_scope("TX"))?; let all_data = collect_children_data(&result.data); // Lambda1 + Lambda2 scopes completely untouched assert!(all_data.contains(&"lambda1".to_string())); diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index 9eee7a5599269..aac450af766f0 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -657,6 +657,7 @@ mod tests { use datafusion_common::{assert_contains, record_batch}; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{Column, Literal, col, lit}; + use std::hash::Hash; fn create_test_schema() -> (Schema, Schema) { let physical_schema = Schema::new(vec![ @@ -1737,17 +1738,23 @@ mod tests { /// A mock expression with an in-scope child and an out-of-scope child. /// Used to verify that scoped traversal does not modify out-of-scope children. - #[derive(Debug, Hash, Clone)] + #[derive(Debug, Clone)] struct ScopedExprMock { in_scope_child: Arc, out_of_scope_child: Arc, } + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + impl PartialEq for ScopedExprMock { fn eq(&self, other: &Self) -> bool { self.in_scope_child.as_ref() == other.in_scope_child.as_ref() - && self.out_of_scope_child.as_ref() - == other.out_of_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() } } @@ -1768,10 +1775,7 @@ mod tests { self } - fn return_field( - &self, - input_schema: &Schema, - ) -> Result> { + fn return_field(&self, input_schema: &Schema) -> Result> { self.in_scope_child.return_field(input_schema) } @@ -1813,10 +1817,7 @@ mod tests { })) } - fn fmt_sql( - &self, - f: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self, f) } } @@ -1826,8 +1827,7 @@ mod tests { // The in-scope child references column "a" which should be replaced let in_scope_child: Arc = Arc::new(Column::new("a", 0)); // The out-of-scope child also references a column "a" but should NOT be replaced - let out_of_scope_child: Arc = - Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("a", 0)); let expr: Arc = Arc::new(ScopedExprMock { in_scope_child, diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index b4e3175979844..0c3b98e16ef2e 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -166,9 +166,9 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { fn children(&self) -> Vec<&Arc>; /// Get a list of child PhysicalExpr that provide the input for this expr that are in the same scope as this expression. - /// - /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`] - /// + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::children`] + /// /// To know if specific child is considered in the same scope you can answer this simple question: /// If that child is a `Column` would that column can be evaluated with the same input schema /// Expressions like `plus`, `sum`, etc have all children in scope. @@ -184,11 +184,11 @@ pub trait PhysicalExpr: Any + Send + Sync + Display + Debug + DynEq + DynHash { ) -> Result>; /// Returns a new PhysicalExpr where all scoped children were replaced by new exprs. - /// + /// /// See [`Self::children_in_scope`] for definition of what child considered a scope - /// - /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`] - /// + /// + /// Due to the majority of expressions being in the same scope the default implementation is to call to [`Self::with_new_children`] + /// fn with_new_children_in_scope( self: Arc, children_in_scope: Vec>, @@ -501,10 +501,10 @@ pub fn with_new_children_if_necessary( ); if children.is_empty() - || children - .iter() - .zip(old_children.iter()) - .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + || children + .iter() + .zip(old_children.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) { Ok(expr.with_new_children(children)?) } else { @@ -525,10 +525,10 @@ pub fn with_new_children_in_scope_if_necessary( ); if children_in_scope.is_empty() - || children_in_scope - .iter() - .zip(old_children_in_scope.iter()) - .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) + || children_in_scope + .iter() + .zip(old_children_in_scope.iter()) + .any(|(c1, c2)| !Arc::ptr_eq(c1, c2)) { Ok(expr.with_new_children_in_scope(children_in_scope)?) } else { diff --git a/datafusion/physical-expr-common/src/tree_node.rs b/datafusion/physical-expr-common/src/tree_node.rs index 1f9e855762be7..c61ce91e4024c 100644 --- a/datafusion/physical-expr-common/src/tree_node.rs +++ b/datafusion/physical-expr-common/src/tree_node.rs @@ -20,7 +20,9 @@ use std::fmt::{self, Display, Formatter}; use std::sync::Arc; -use crate::physical_expr::{PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary}; +use crate::physical_expr::{ + PhysicalExpr, with_new_children_if_necessary, with_new_children_in_scope_if_necessary, +}; use datafusion_common::Result; use datafusion_common::tree_node::{ConcreteTreeNode, DynScopedTreeNode, DynTreeNode}; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index d113460899881..6354f1ce4e1ee 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -40,9 +40,7 @@ use std::{any::Any, sync::Arc}; use crate::expressions::case::literal_lookup_table::LiteralLookupTable; use arrow::compute::kernels::merge::{MergeIndex, merge, merge_n}; -use datafusion_common::tree_node::{ - ScopedTreeNode, Transformed, TreeNode, TreeNodeRecursion, -}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TreeNodeRecursion}; use datafusion_physical_expr_common::datum::compare_with_eq; use datafusion_physical_expr_common::utils::scatter; use itertools::Itertools; @@ -1738,14 +1736,23 @@ mod tests { /// Represents the column at a given index in a RecordBatch that is inside a Spark lambda function /// /// This is the same as the datafusion [`datafusion::physical_expr::expressions::Column`] except that it store the entire info so that it can be used in lambda execution - #[derive(Debug, Hash, Clone)] + #[derive(Debug, Clone)] pub struct AllListElementMatchMiniLambda { child: Arc, predicate_on_list_elements: Arc, } + + impl Hash for AllListElementMatchMiniLambda { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.predicate_on_list_elements.hash(state); + } + } impl PartialEq for AllListElementMatchMiniLambda { fn eq(&self, other: &Self) -> bool { - self.child.as_ref() == other.child.as_ref() && self.predicate_on_list_elements.as_ref() == other.predicate_on_list_elements.as_ref() + self.child.as_ref() == other.child.as_ref() + && self.predicate_on_list_elements.as_ref() + == other.predicate_on_list_elements.as_ref() } } @@ -1755,17 +1762,21 @@ mod tests { pub fn new( child: Arc, predicate_on_list_element: Arc, - ) -> Arc { - Arc::new(Self { + ) -> Self { + Self { child, - predicate_on_list_elements: predicate_on_list_element - }) + predicate_on_list_elements: predicate_on_list_element, + } } } impl std::fmt::Display for AllListElementMatchMiniLambda { fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "all_match({:?}, {:?})", self.child, self.predicate_on_list_elements) + write!( + f, + "all_match({:?}, {:?})", + self.child, self.predicate_on_list_elements + ) } } @@ -1774,47 +1785,80 @@ mod tests { self } - fn return_field(&self, input_schema: &Schema) -> Result { + fn return_field( + &self, + input_schema: &Schema, + ) -> Result { let is_child_nullable = self.child.nullable(input_schema)?; - Ok(Arc::new(Field::new("match", DataType::Boolean, is_child_nullable))) + Ok(Arc::new(Field::new( + "match", + DataType::Boolean, + is_child_nullable, + ))) } fn evaluate(&self, batch: &RecordBatch) -> Result { let child = self.child.evaluate(batch)?; - let DataType::List(child_list_field) = self.child.data_type(batch.schema_ref())? else { + let DataType::List(child_list_field) = + self.child.data_type(batch.schema_ref())? + else { unreachable!() - }; + }; let child = child.to_array_of_size(batch.num_rows())?; let list = child.as_list::(); let lambda_schema = Arc::new(Schema::new(Fields::from(vec![ Field::new("index", DataType::UInt32, false), - child_list_field.as_ref().clone() + child_list_field.as_ref().clone(), ]))); - assert_eq!(list.value_offsets()[0].as_usize(), 0, "this is mock implementation, it does not support sliced list"); - assert_eq!(list.value_offsets().last().unwrap().as_usize(), list.values().len(), "this is mock implementation, it does not support sliced list"); + assert_eq!( + list.value_offsets()[0].as_usize(), + 0, + "this is mock implementation, it does not support sliced list" + ); + assert_eq!( + list.value_offsets().last().unwrap().as_usize(), + list.values().len(), + "this is mock implementation, it does not support sliced list" + ); let list_values = list.values(); - let new_batch = RecordBatch::try_new(Arc::clone(&lambda_schema), vec![ - Arc::new(list.offsets().lengths().flat_map(|list_len| 0..list_len as u32).collect::()), - Arc::clone(list_values), - ])?; + let new_batch = RecordBatch::try_new( + Arc::clone(&lambda_schema), + vec![ + Arc::new( + list.offsets() + .lengths() + .flat_map(|list_len| 0..list_len as u32) + .collect::(), + ), + Arc::clone(list_values), + ], + )?; let any_match = self.predicate_on_list_elements.evaluate(&new_batch)?; let any_match = any_match.to_array_of_size(list_values.len())?; let any_match = any_match.as_boolean(); - let all_match_per_list = list.offsets().windows(2).map(|start_and_end| { - let length = start_and_end[1] - start_and_end[0]; - let list_matches = any_match.slice(start_and_end[0] as usize, length as usize); + let all_match_per_list = list + .offsets() + .windows(2) + .map(|start_and_end| { + let length = start_and_end[1] - start_and_end[0]; + let list_matches = + any_match.slice(start_and_end[0] as usize, length as usize); - list_matches.true_count() == list_matches.len() as usize - }).collect::(); + list_matches.true_count() == list_matches.len() as usize + }) + .collect::(); - let result = Arc::new(BooleanArray::new(all_match_per_list, list.nulls().cloned())); + let result = Arc::new(BooleanArray::new( + all_match_per_list, + list.nulls().cloned(), + )); Ok(ColumnarValue::Array(result)) } @@ -1840,16 +1884,18 @@ mod tests { } fn with_new_children_in_scope( - self: Arc, - children_in_scope: Vec>, - ) -> Result> { - assert_eq!(children_in_scope.len(), 1); - let mut iter = children_in_scope.into_iter(); - Ok(Arc::new(Self { - child: iter.next().unwrap(), - // TODO - but what if child has changed to not be list or the data type has changed?? - predicate_on_list_elements: Arc::clone(&self.predicate_on_list_elements), - })) + self: Arc, + children_in_scope: Vec>, + ) -> Result> { + assert_eq!(children_in_scope.len(), 1); + let mut iter = children_in_scope.into_iter(); + Ok(Arc::new(Self { + child: iter.next().unwrap(), + // TODO - but what if child has changed to not be list or the data type has changed?? + predicate_on_list_elements: Arc::clone( + &self.predicate_on_list_elements, + ), + })) } fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { @@ -1861,42 +1907,50 @@ mod tests { Arc::new(Field::new("col_1", DataType::Utf8, true)), Arc::new(Field::new("col_2", DataType::Utf8, true)), Arc::new(Field::new("col_3", DataType::Utf8, true)), - Arc::new(Field::new("list", DataType::new_list(DataType::UInt32, true), true)), + Arc::new(Field::new( + "list", + DataType::new_list(DataType::UInt32, true), + true, + )), ])); let input_list = ListArray::from_iter_primitive::(vec![ - // all even place numbers are even - Some(vec![Some(0), Some(1), Some(2)]), - None, - // Not all even place are even but all odd place are odd - Some(vec![Some(0), Some(1), Some(1)]), - - // Not odd and not even in corresponding places - Some(vec![Some(1), Some(2)]), - ]); + // all even place numbers are even + Some(vec![Some(0), Some(1), Some(2)]), + None, + // Not all even place are even but all odd place are odd + Some(vec![Some(0), Some(1), Some(1)]), + // Not odd and not even in corresponding places + Some(vec![Some(1), Some(2)]), + ]); let batch = RecordBatch::try_new( input_schema, vec![ - new_null_array(&DataType::Utf8, input_list.len()), - new_null_array(&DataType::Utf8, input_list.len()), - new_null_array(&DataType::Utf8, input_list.len()), - Arc::new(input_list), - ] - ).unwrap(); + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + new_null_array(&DataType::Utf8, input_list.len()), + Arc::new(input_list), + ], + ) + .unwrap(); let schema = batch.schema(); fn create_when_expr(is_even: bool) -> Arc { let idx_col: Arc = Arc::new(Column::new("idx", 0)); let item_col: Arc = Arc::new(Column::new("item", 1)); - AllListElementMatchMiniLambda::new( + Arc::new(AllListElementMatchMiniLambda::new( Arc::new(Column::new("list", 3)), - create_both_odd_or_even(&idx_col, &item_col, is_even) - ) + create_both_odd_or_even(&idx_col, &item_col, is_even), + )) } - fn create_both_odd_or_even(idx_column: &Arc, list_item_column: &Arc, is_even: bool) -> Arc { - let equal_value = if is_even { 0 } else {1}; + fn create_both_odd_or_even( + idx_column: &Arc, + list_item_column: &Arc, + is_even: bool, + ) -> Arc { + let equal_value = if is_even { 0 } else { 1 }; let idx_equal = module_2_equal_value(idx_column, equal_value); let item_equal = module_2_equal_value(list_item_column, equal_value); @@ -1905,19 +1959,27 @@ mod tests { vec![(idx_equal, item_equal)], // if idx not equal than true Some(lit(true)), - ).unwrap() + ) + .unwrap() } - fn module_2_equal_value(left: &Arc, equal_value: u32) -> Arc { - let modulo2 = BinaryExpr::new(Arc::clone(&left), Operator::Modulo, lit(2u32)); - let equal_value = BinaryExpr::new(Arc::new(modulo2), Operator::Eq, lit(equal_value)); + fn module_2_equal_value( + left: &Arc, + equal_value: u32, + ) -> Arc { + let modulo2 = BinaryExpr::new(Arc::clone(left), Operator::Modulo, lit(2u32)); + let equal_value = + BinaryExpr::new(Arc::new(modulo2), Operator::Eq, lit(equal_value)); Arc::new(equal_value) } let expr = generate_case_when_with_type_coercion( None, - vec![(create_when_expr(true), lit("both even")), (create_when_expr(false), lit("both odd"))], + vec![ + (create_when_expr(true), lit("both even")), + (create_when_expr(false), lit("both odd")), + ], None, schema.as_ref(), )?; @@ -1926,7 +1988,8 @@ mod tests { .into_array(batch.num_rows()) .expect("Failed to convert to array"); - let expected = &StringArray::from(vec![Some("both even"), None, Some("both odd"), None]); + let expected = + &StringArray::from(vec![Some("both even"), None, Some("both odd"), None]); assert_eq!(expected, result.as_string::()); diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 37bcddb0a92c3..555eca7c3bf81 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -22,7 +22,7 @@ use crate::{LexOrdering, PhysicalSortExpr, create_physical_expr}; use arrow::compute::SortOptions; use arrow::datatypes::{Schema, SchemaRef}; -use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult}; use datafusion_common::{DFSchema, HashMap}; use datafusion_common::{Result, plan_err}; use datafusion_expr::execution_props::ExecutionProps; diff --git a/datafusion/physical-expr/src/projection.rs b/datafusion/physical-expr/src/projection.rs index 58b4c0b7cb8f5..3c5d4e547a61c 100644 --- a/datafusion/physical-expr/src/projection.rs +++ b/datafusion/physical-expr/src/projection.rs @@ -27,7 +27,7 @@ use crate::utils::collect_columns; use arrow::array::{RecordBatch, RecordBatchOptions}; use arrow::datatypes::{Field, Schema, SchemaRef}; use datafusion_common::stats::{ColumnStatistics, Precision}; -use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult, TreeNode}; +use datafusion_common::tree_node::{ScopedTreeNode, Transformed, TransformedResult}; use datafusion_common::{ Result, ScalarValue, Statistics, assert_or_internal_err, internal_datafusion_err, plan_err, @@ -3046,8 +3046,7 @@ pub(crate) mod tests { // Projection: [c@2 as c_new, a@0 as a_new, b@1 as b_new] // After unproject: in_scope should become c@2, out_of_scope should stay x@0 let in_scope_child: Arc = Arc::new(Column::new("a_new", 1)); - let out_of_scope_child: Arc = - Arc::new(Column::new("x", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); let expr: Arc = Arc::new(ScopedExprMock { in_scope_child, @@ -3107,17 +3106,18 @@ pub(crate) mod tests { ])); let in_scope_child: Arc = Arc::new(Column::new("a", 0)); - let out_of_scope_child: Arc = - Arc::new(Column::new("x", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); let scoped_expr: Arc = Arc::new(ScopedExprMock { in_scope_child, out_of_scope_child, }); - let ordering = - LexOrdering::new(vec![PhysicalSortExpr::new(scoped_expr, SortOptions::new(false, false))]) - .unwrap(); + let ordering = LexOrdering::new(vec![PhysicalSortExpr::new( + scoped_expr, + SortOptions::new(false, false), + )]) + .unwrap(); let result = project_ordering(&ordering, &schema).expect("Should project"); @@ -3155,8 +3155,7 @@ pub(crate) mod tests { ])); let in_scope_child: Arc = Arc::new(Column::new("a", 0)); - let out_of_scope_child: Arc = - Arc::new(Column::new("x", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); let scoped_expr: Arc = Arc::new(ScopedExprMock { in_scope_child, diff --git a/datafusion/physical-expr/src/utils/mod.rs b/datafusion/physical-expr/src/utils/mod.rs index 6ed986722fd86..f5a2508d68249 100644 --- a/datafusion/physical-expr/src/utils/mod.rs +++ b/datafusion/physical-expr/src/utils/mod.rs @@ -334,6 +334,7 @@ pub(crate) mod tests { use super::*; use crate::expressions::{Literal, binary, cast, col, in_list, lit}; + use std::hash::Hash; use arrow::array::{ArrayRef, Float32Array, Float64Array}; use arrow::datatypes::{DataType, Field}; @@ -350,12 +351,19 @@ pub(crate) mod tests { /// This simulates a lambda-like expression where the `in_scope_child` /// references columns in the outer schema and the `out_of_scope_child` /// references columns in a different (lambda) schema. - #[derive(Debug, Hash, Clone)] + #[derive(Debug, Clone)] pub(crate) struct ScopedExprMock { pub in_scope_child: Arc, pub out_of_scope_child: Arc, } + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + impl PartialEq for ScopedExprMock { fn eq(&self, other: &Self) -> bool { self.in_scope_child.as_ref() == other.in_scope_child.as_ref() @@ -380,10 +388,7 @@ pub(crate) mod tests { self } - fn return_field( - &self, - input_schema: &Schema, - ) -> Result> { + fn return_field(&self, input_schema: &Schema) -> Result> { self.in_scope_child.return_field(input_schema) } diff --git a/datafusion/physical-plan/src/filter_pushdown.rs b/datafusion/physical-plan/src/filter_pushdown.rs index c742dc9796498..7e60c9a937d9d 100644 --- a/datafusion/physical-plan/src/filter_pushdown.rs +++ b/datafusion/physical-plan/src/filter_pushdown.rs @@ -40,7 +40,7 @@ use std::sync::Arc; use arrow_schema::SchemaRef; use datafusion_common::{ Result, - tree_node::{ScopedTreeNode, Transformed, TreeNode}, + tree_node::{ScopedTreeNode, Transformed}, }; use datafusion_physical_expr::expressions::Column; use datafusion_physical_expr_common::physical_expr::PhysicalExpr; @@ -561,20 +561,27 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr::expressions::Column; + use std::hash::Hash; /// A mock expression with an in-scope child and an out-of-scope child. /// Used to verify that scoped traversal does not modify out-of-scope children. - #[derive(Debug, Hash, Clone)] + #[derive(Debug, Clone)] struct ScopedExprMock { in_scope_child: Arc, out_of_scope_child: Arc, } + impl Hash for ScopedExprMock { + fn hash(&self, state: &mut H) { + self.in_scope_child.hash(state); + self.out_of_scope_child.hash(state); + } + } + impl PartialEq for ScopedExprMock { fn eq(&self, other: &Self) -> bool { self.in_scope_child.as_ref() == other.in_scope_child.as_ref() - && self.out_of_scope_child.as_ref() - == other.out_of_scope_child.as_ref() + && self.out_of_scope_child.as_ref() == other.out_of_scope_child.as_ref() } } @@ -595,17 +602,11 @@ mod tests { self } - fn return_field( - &self, - input_schema: &Schema, - ) -> Result> { + fn return_field(&self, input_schema: &Schema) -> Result> { self.in_scope_child.return_field(input_schema) } - fn evaluate( - &self, - _batch: &RecordBatch, - ) -> Result { + fn evaluate(&self, _batch: &RecordBatch) -> Result { unimplemented!("ScopedExprMock does not support evaluation") } @@ -640,10 +641,7 @@ mod tests { })) } - fn fmt_sql( - &self, - f: &mut std::fmt::Formatter<'_>, - ) -> std::fmt::Result { + fn fmt_sql(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self, f) } } @@ -661,8 +659,7 @@ mod tests { // The in-scope child references column "a@0" which exists in child schema let in_scope_child: Arc = Arc::new(Column::new("a", 0)); // The out-of-scope child also references "a@0" but should NOT be remapped - let out_of_scope_child: Arc = - Arc::new(Column::new("a", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("a", 0)); let filter: Arc = Arc::new(ScopedExprMock { in_scope_child, @@ -712,8 +709,7 @@ mod tests { // The in-scope child references "a@0" - should be remapped to "a@1" in child schema let in_scope_child: Arc = Arc::new(Column::new("a", 0)); // The out-of-scope child references "x@0" in the lambda schema - should NOT be touched - let out_of_scope_child: Arc = - Arc::new(Column::new("x", 0)); + let out_of_scope_child: Arc = Arc::new(Column::new("x", 0)); let filter: Arc = Arc::new(ScopedExprMock { in_scope_child, From 984076bff76c3f0f4d4618e89a355f36ebd8c827 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 18:30:23 +0300 Subject: [PATCH 08/17] update comments --- datafusion/common/src/tree_node.rs | 66 +++++++++++++++++------------- 1 file changed, 38 insertions(+), 28 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 2247dd87d7c49..5d61500a137b6 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -453,7 +453,7 @@ pub trait TreeNode: Sized { /// Scope is left for implementers to define, for `PhysicalExpr` child is defined in scope if it have the same input schema as current `PhysicalExpr`. pub trait ScopedTreeNode: TreeNode { /// Visit the tree node with a [`TreeNodeVisitor`], performing a - /// depth-first walk of the node and its children. + /// depth-first walk of the node and its children that are in the same scope. /// /// [`TreeNodeVisitor::f_down()`] is called in top-down order (before /// children are visited), [`TreeNodeVisitor::f_up()`] is called in @@ -463,8 +463,8 @@ pub trait ScopedTreeNode: TreeNode { /// Specifies how the tree walk ended. See [`TreeNodeRecursion`] for details. /// /// # See Also: - /// * [`Self::apply`] for inspecting nodes with a closure - /// * [`Self::rewrite`] to rewrite owned `TreeNode`s + /// * [`Self::apply_in_scope`] for inspecting nodes with a closure + /// * [`Self::rewrite_in_scope`] to rewrite owned `ScopedTreeNode`s /// /// # Example /// Consider the following tree structure: @@ -497,7 +497,7 @@ pub trait ScopedTreeNode: TreeNode { } /// Rewrite the tree node with a [`TreeNodeRewriter`], performing a - /// depth-first walk of the node and its children. + /// depth-first walk of the node and its children that are in the same scope. /// /// [`TreeNodeRewriter::f_down()`] is called in top-down order (before /// children are visited), [`TreeNodeRewriter::f_up()`] is called in @@ -505,7 +505,7 @@ pub trait ScopedTreeNode: TreeNode { /// /// Note: If using the default [`TreeNodeRewriter::f_up`] or /// [`TreeNodeRewriter::f_down`] that do nothing, consider using - /// [`Self::transform_down`] instead. + /// [`Self::transform_down_in_scope`] instead. /// /// # Return Value /// The returns value specifies how the tree walk should proceed. See @@ -513,10 +513,10 @@ pub trait ScopedTreeNode: TreeNode { /// recursion stops immediately. /// /// # See Also - /// * [`Self::visit`] for inspecting (without modification) `TreeNode`s - /// * [Self::transform_down_up] for a top-down (pre-order) traversal. - /// * [Self::transform_down] for a top-down (pre-order) traversal. - /// * [`Self::transform_up`] for a bottom-up (post-order) traversal. + /// * [`Self::visit_in_scope`] for inspecting (without modification) `ScopedTreeNode`s + /// * [Self::transform_down_up_in_scope] for a top-down (pre-order) traversal. + /// * [Self::transform_down_in_scope] for a top-down (pre-order) traversal. + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. /// /// # Example /// Consider the following tree structure: @@ -547,15 +547,15 @@ pub trait ScopedTreeNode: TreeNode { ) } - /// Applies `f` to the node then each of its children, recursively (a - /// top-down, pre-order traversal). + /// Applies `f` to the node then each of its children that are in the + /// same scope, recursively (a top-down, pre-order traversal). /// /// The return [`TreeNodeRecursion`] controls the recursion and can cause /// an early return. /// /// # See Also - /// * [`Self::transform_down`] for the equivalent transformation API. - /// * [`Self::visit`] for both top-down and bottom up traversal. + /// * [`Self::transform_down_in_scope`] for the equivalent transformation API. + /// * [`Self::visit_in_scope`] for both top-down and bottom up traversal. fn apply_in_scope<'n, F: FnMut(&'n Self) -> Result>( &'n self, mut f: F, @@ -587,15 +587,15 @@ pub trait ScopedTreeNode: TreeNode { } /// Recursively rewrite the tree using `f` in a top-down (pre-order) - /// fashion. + /// fashion, limited to children in the same scope. /// - /// `f` is applied to the node first, and then its children. + /// `f` is applied to the node first, and then its children in scope. /// /// # See Also /// * [`Self::transform_down`] for the same transformation but in all children ignoring scope /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. /// * [Self::transform_down_up_in_scope] for a combined traversal with closures - /// * [`Self::rewrite`] for a combined traversal with a visitor + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor fn transform_down_in_scope Result>>( self, mut f: F, @@ -617,14 +617,14 @@ pub trait ScopedTreeNode: TreeNode { } /// Recursively rewrite the node using `f` in a bottom-up (post-order) - /// fashion. + /// fashion, limited to children in the same scope. /// - /// `f` is applied to the node's children first, and then to the node itself. + /// `f` is applied to the node's children in scope first, and then to the node itself. /// /// # See Also - /// * [`Self::transform_down`] top-down (pre-order) traversal. - /// * [Self::transform_down_up] for a combined traversal with closures - /// * [`Self::rewrite`] for a combined traversal with a visitor + /// * [`Self::transform_down_in_scope`] top-down (pre-order) traversal. + /// * [Self::transform_down_up_in_scope] for a combined traversal with closures + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor fn transform_up_in_scope Result>>( self, mut f: F, @@ -641,7 +641,16 @@ pub trait ScopedTreeNode: TreeNode { transform_up_impl(self, &mut f) } - /// Same as [`Self::transform_down_up`] but limited to the same scope + /// Transforms the node using `f_down` while traversing the tree top-down + /// (pre-order), and using `f_up` while traversing the tree bottom-up + /// (post-order), limited to children in the same scope. + /// + /// Same as [`Self::transform_down_up`] but limited to the same scope. + /// + /// # See Also + /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. + /// * [Self::transform_down_in_scope] for a top-down (pre-order) traversal. + /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor fn transform_down_up_in_scope< FD: FnMut(Self) -> Result>, FU: FnMut(Self) -> Result>, @@ -670,7 +679,8 @@ pub trait ScopedTreeNode: TreeNode { transform_down_up_impl(self, &mut f_down, &mut f_up) } - /// Returns true if `f` returns true for any node in the tree. + /// Returns true if `f` returns true for any node in the tree + /// that is in the same scope. /// /// Stops recursion as soon as a matching node is found fn exists_in_scope Result>(&self, mut f: F) -> Result { @@ -688,8 +698,8 @@ pub trait ScopedTreeNode: TreeNode { /// Low-level API used to implement other APIs. /// - /// If you want to implement the [`TreeNode`] trait for your own type, you - /// should implement this method and [`Self::map_children`]. + /// If you want to implement the [`ScopedTreeNode`] trait for your own type, you + /// should implement this method and [`Self::map_children_in_scope`]. /// /// Users should use one of the higher level APIs described on [`Self`]. /// @@ -702,12 +712,12 @@ pub trait ScopedTreeNode: TreeNode { /// Low-level API used to implement other APIs. /// - /// If you want to implement the [`TreeNode`] trait for your own type, you - /// should implement this method and [`Self::apply_children`]. + /// If you want to implement the [`ScopedTreeNode`] trait for your own type, you + /// should implement this method and [`Self::apply_children_in_scope`]. /// /// Users should use one of the higher level APIs described on [`Self`]. /// - /// Description: Apply `f` to rewrite the node's children (but not the node itself). + /// Description: Apply `f` to rewrite the node's children in scope (but not the node itself). fn map_children_in_scope Result>>( self, f: F, From 864cec3b1ab13c43f8e85ae1c8cd1324fe928c83 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 18:32:46 +0300 Subject: [PATCH 09/17] update comment --- datafusion/common/src/tree_node.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 5d61500a137b6..3f4448e5ac2a2 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -514,7 +514,7 @@ pub trait ScopedTreeNode: TreeNode { /// /// # See Also /// * [`Self::visit_in_scope`] for inspecting (without modification) `ScopedTreeNode`s - /// * [Self::transform_down_up_in_scope] for a top-down (pre-order) traversal. + /// * [Self::transform_down_up_in_scope] for a combined top-down and bottom-up traversal. /// * [Self::transform_down_in_scope] for a top-down (pre-order) traversal. /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. /// @@ -575,8 +575,8 @@ pub trait ScopedTreeNode: TreeNode { apply_impl(self, &mut f) } - /// Recursively rewrite the node's children and then the node using `f` - /// (a bottom-up post-order traversal). + /// Recursively rewrite the node's children in scope and then the node + /// using `f` (a bottom-up post-order traversal). /// /// A synonym of [`Self::transform_up_in_scope`]. fn transform_in_scope Result>>( @@ -622,7 +622,7 @@ pub trait ScopedTreeNode: TreeNode { /// `f` is applied to the node's children in scope first, and then to the node itself. /// /// # See Also - /// * [`Self::transform_down_in_scope`] top-down (pre-order) traversal. + /// * [`Self::transform_down_in_scope`] for a top-down (pre-order) traversal. /// * [Self::transform_down_up_in_scope] for a combined traversal with closures /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor fn transform_up_in_scope Result>>( From fdb1f9d7cd0aaed2f376b013d4f646c469ea7e3b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 18:41:57 +0300 Subject: [PATCH 10/17] Update tree_node.rs --- datafusion/common/src/tree_node.rs | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 3f4448e5ac2a2..f1e09a10b0a68 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -1778,7 +1778,6 @@ pub(crate) mod tests { &'n self, f: F, ) -> Result { - // TODO - should call apply elements self.children_in_scope.apply_elements(f) } @@ -3167,7 +3166,8 @@ pub(crate) mod tests { Box::new(visit_event_on("plus", TreeNodeRecursion::Jump)), ); tree.visit_in_scope(&mut visitor)?; - // Jump after f_up(plus): continue with sibling gt, skip f_up(and) + // Jump after f_up(plus): skip plus's parent f_up, continue with sibling gt. + // gt returns Continue which resets the tnr, so f_up(and) and f_up(root) are called. assert_eq!( visitor.visits, s(&[ @@ -3184,8 +3184,7 @@ pub(crate) mod tests { "f_down(col_c)", "f_up(col_c)", "f_up(gt)", - // and's f_up is skipped (Jump from plus), but root f_up happens - // because gt returned Continue, resetting the last tnr + // gt returned Continue, resetting the tnr, so f_up(and) is called "f_up(and)", "f_up(root)", ]) From 9221b1f55021e66eb462b0a0578cd599eef6baec Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 19:28:10 +0300 Subject: [PATCH 11/17] fix clippy and update comment --- datafusion/common/src/tree_node.rs | 4 ++-- datafusion/physical-expr/src/expressions/case.rs | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index f1e09a10b0a68..f1d23a17081de 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -592,7 +592,7 @@ pub trait ScopedTreeNode: TreeNode { /// `f` is applied to the node first, and then its children in scope. /// /// # See Also - /// * [`Self::transform_down`] for the same transformation but in all children ignoring scope + /// * [`TreeNode::transform_down`] for the same transformation but in all children ignoring scope /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. /// * [Self::transform_down_up_in_scope] for a combined traversal with closures /// * [`Self::rewrite_in_scope`] for a combined traversal with a visitor @@ -645,7 +645,7 @@ pub trait ScopedTreeNode: TreeNode { /// (pre-order), and using `f_up` while traversing the tree bottom-up /// (post-order), limited to children in the same scope. /// - /// Same as [`Self::transform_down_up`] but limited to the same scope. + /// Same as [`TreeNode::transform_down_up`] but limited to the same scope. /// /// # See Also /// * [`Self::transform_up_in_scope`] for a bottom-up (post-order) traversal. diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 6354f1ce4e1ee..5becfed9d5b47 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -1851,7 +1851,7 @@ mod tests { let list_matches = any_match.slice(start_and_end[0] as usize, length as usize); - list_matches.true_count() == list_matches.len() as usize + list_matches.true_count() == list_matches.len() }) .collect::(); From e9200ac3c142291e6f449d063f9a6c2a153a5e75 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 20:38:31 +0300 Subject: [PATCH 12/17] rename --- datafusion/common/src/tree_node.rs | 264 ++++++++++++++--------------- 1 file changed, 132 insertions(+), 132 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index f1d23a17081de..03cee402de6d4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -1685,42 +1685,16 @@ pub(crate) mod tests { #[derive(Debug, Eq, Hash, PartialEq, Clone)] pub struct TestTreeNode { pub(crate) children: Vec>, - pub(crate) children_in_scope: Vec>, + pub(crate) children_in_same_scope: Vec>, pub(crate) data: T, } impl TestTreeNode { + /// Creates a node where all children are in the same scope. pub(crate) fn new(children: Vec>, data: T) -> Self { Self { + children_in_same_scope: children.clone(), children, - children_in_scope: vec![], - data, - } - } - - /// Creates a node where all children are in scope. - /// Automatically sets `children` = clone of `children_in_scope`. - pub(crate) fn new_scoped( - children_in_scope: Vec>, - data: T, - ) -> Self { - Self { - children: children_in_scope.clone(), - children_in_scope, - data, - } - } - - /// Creates a node with explicit `children` (all, in order) and - /// `children_in_scope` (the scoped subset). - pub(crate) fn new_mixed( - all_children: Vec>, - children_in_scope: Vec>, - data: T, - ) -> Self { - Self { - children: all_children, - children_in_scope, data, } } @@ -1728,7 +1702,7 @@ pub(crate) mod tests { pub(crate) fn new_leaf(data: T) -> Self { Self { children: vec![], - children_in_scope: vec![], + children_in_same_scope: vec![], data, } } @@ -1737,17 +1711,40 @@ pub(crate) mod tests { self.children.is_empty() } - /// Strip children_in_scope recursively - used to compare trees - /// in TreeNode tests where children_in_scope is not relevant. + /// Strip children_in_new_scope recursively - used to compare trees + /// in TreeNode tests where children_in_new_scope is not relevant. fn strip_scope(self) -> Self { Self { children: self.children.into_iter().map(|c| c.strip_scope()).collect(), - children_in_scope: vec![], + children_in_same_scope: vec![], data: self.data, } } } + impl TestTreeNode { + /// Creates a node with explicit `children` (all, in order). + /// `out_of_scope_children` are children that start a new scope + /// (i.e., NOT in the current node's scope). The remaining children + /// are computed as `children_in_same_scope`. + pub(crate) fn new_mixed( + all_children: Vec>, + out_of_scope_children: Vec>, + data: T, + ) -> Self { + let children_in_same_scope = all_children + .iter() + .filter(|c| !out_of_scope_children.contains(c)) + .cloned() + .collect(); + Self { + children: all_children, + children_in_same_scope, + data, + } + } + } + impl TreeNode for TestTreeNode { fn apply_children<'n, F: FnMut(&'n Self) -> Result>( &'n self, @@ -1778,16 +1775,16 @@ pub(crate) mod tests { &'n self, f: F, ) -> Result { - self.children_in_scope.apply_elements(f) + self.children_in_same_scope.apply_elements(f) } fn map_children_in_scope Result>>( self, f: F, ) -> Result> { - Ok(self.children_in_scope.map_elements(f)?.update_data( - |new_children_in_scope| Self { - children_in_scope: new_children_in_scope, + Ok(self.children_in_same_scope.map_elements(f)?.update_data( + |new_children| Self { + children_in_same_scope: new_children, ..self }, )) @@ -1810,35 +1807,36 @@ pub(crate) mod tests { } } - // J - // | - // I - // | - // F (mixed) - // / \ - // E (scoped) G - // | | - // C (mixed) H - // / \ - // B D (scoped) - // | - // A + // J + // | + // I + // | + // F (mixed) + // / \ + // E (new scope) G (same scope as F) + // | | + // C (mixed) H + // / \ + // B (Same scope as C) D (new scope) + // | + // A // - // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, D, A, G, H - // (new_scoped/new_mixed auto-add children_in_scope to children) + // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, A, D, G, H + // (new/new_mixed set both children and children_in_new_scope) // - // ScopedTreeNode (children_in_scope) traversal visits: J, I, F, E, C, D, A - // (skips B, G, H which are only in children, not children_in_scope) + // ScopedTreeNode (children_in_new_scope) traversal visits: J, I, F, G, H + // (skips E, C, B, A, D — E and D are in new scopes, the rest are + // unreachable via scoped traversal) fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new_scoped(vec![node_a], "d".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); let node_c = TestTreeNode::new_mixed( vec![node_b, node_d.clone()], vec![node_d], "c".to_string(), ); - let node_e = TestTreeNode::new_scoped(vec![node_c], "e".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); let node_g = TestTreeNode::new(vec![node_h], "g".to_string()); let node_f = TestTreeNode::new_mixed( @@ -1846,8 +1844,8 @@ pub(crate) mod tests { vec![node_e], "f".to_string(), ); - let node_i = TestTreeNode::new_scoped(vec![node_f], "i".to_string()); - TestTreeNode::new_scoped(vec![node_i], "j".to_string()) + let node_i = TestTreeNode::new(vec![node_f], "i".to_string()); + TestTreeNode::new(vec![node_i], "j".to_string()) } // Continue on all nodes @@ -1860,10 +1858,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", + "f_up(b)", + "f_down(d)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -1883,8 +1881,8 @@ pub(crate) mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); @@ -1899,8 +1897,8 @@ pub(crate) mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -1913,8 +1911,8 @@ pub(crate) mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_up(b)".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); @@ -1933,10 +1931,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", + "f_up(b)", + "f_down(d)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -1955,8 +1953,8 @@ pub(crate) mod tests { fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -1989,8 +1987,8 @@ pub(crate) mod tests { fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); @@ -2003,8 +2001,8 @@ pub(crate) mod tests { fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -2023,10 +2021,12 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", + "f_down(d)", + "f_up(d)", + "f_up(c)", + "f_up(e)", "f_down(g)", "f_down(h)", "f_up(h)", @@ -2042,10 +2042,10 @@ pub(crate) mod tests { fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = @@ -2056,10 +2056,10 @@ pub(crate) mod tests { fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); @@ -2076,10 +2076,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", + "f_up(b)", + "f_down(d)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -2114,8 +2114,6 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", ] .into_iter() @@ -2125,8 +2123,8 @@ pub(crate) mod tests { fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2138,8 +2136,8 @@ pub(crate) mod tests { fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2159,8 +2157,8 @@ pub(crate) mod tests { fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2172,8 +2170,8 @@ pub(crate) mod tests { fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new_leaf("b".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2192,8 +2190,6 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", ] @@ -2204,8 +2200,8 @@ pub(crate) mod tests { fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2217,8 +2213,8 @@ pub(crate) mod tests { fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_d = TestTreeNode::new_leaf("d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2237,10 +2233,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_up(b)", - "f_down(d)", "f_down(a)", "f_up(a)", + "f_up(b)", + "f_down(d)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -2252,8 +2248,8 @@ pub(crate) mod tests { fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); @@ -2266,8 +2262,8 @@ pub(crate) mod tests { fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); - let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); + let node_b = TestTreeNode::new(vec![node_a], "f_up(b)".to_string()); + let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2433,7 +2429,8 @@ pub(crate) mod tests { let mut rewriter = TestRewriter::new(Box::new($F_DOWN), Box::new($F_UP)); let actual = tree.rewrite(&mut rewriter)?; let actual_stripped = actual.update_data(|d| d.strip_scope()); - assert_eq!(actual_stripped, $EXPECTED_TREE); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2447,7 +2444,8 @@ pub(crate) mod tests { let tree = test_tree(); let actual = tree.transform_down_up($F_DOWN, $F_UP)?; let actual_stripped = actual.update_data(|d| d.strip_scope()); - assert_eq!(actual_stripped, $EXPECTED_TREE); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2461,7 +2459,8 @@ pub(crate) mod tests { let tree = test_tree(); let actual = tree.transform_down($F)?; let actual_stripped = actual.update_data(|d| d.strip_scope()); - assert_eq!(actual_stripped, $EXPECTED_TREE); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2475,7 +2474,8 @@ pub(crate) mod tests { let tree = test_tree(); let actual = tree.transform_up($F)?; let actual_stripped = actual.update_data(|d| d.strip_scope()); - assert_eq!(actual_stripped, $EXPECTED_TREE); + let expected_stripped = ($EXPECTED_TREE).update_data(|d| d.strip_scope()); + assert_eq!(actual_stripped, expected_stripped); Ok(()) } @@ -2914,40 +2914,40 @@ pub(crate) mod tests { let col_b = TestTreeNode::new_leaf("col_b".to_string()); let col_c = TestTreeNode::new_leaf("col_c".to_string()); - let plus = TestTreeNode::new_scoped(vec![col_a, col_b], "plus".to_string()); + let plus = TestTreeNode::new(vec![col_a, col_b], "plus".to_string()); // --- Innermost lambda scope: Lambda2 → inner_col --- let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); - let lambda2 = TestTreeNode::new_scoped(vec![inner_col], "lambda2".to_string()); + let lambda2 = TestTreeNode::new(vec![inner_col], "lambda2".to_string()); // --- Middle lambda scope: Lambda1 → list_col, cmp --- let list_col = TestTreeNode::new_leaf("list_col".to_string()); let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); - // cmp has idx_col in scope, Lambda2 out of scope + // cmp has idx_col in scope, Lambda2 out of scope (new scope) let cmp = TestTreeNode::new_mixed( vec![idx_col.clone(), lambda2.clone()], - vec![idx_col], + vec![lambda2], "cmp".to_string(), ); let lambda1 = - TestTreeNode::new_scoped(vec![list_col, cmp], "lambda1".to_string()); + TestTreeNode::new(vec![list_col, cmp], "lambda1".to_string()); // --- Outer scope --- - // gt has col_c in scope, Lambda1 out of scope + // gt has col_c in scope, Lambda1 out of scope (new scope) let gt = TestTreeNode::new_mixed( vec![col_c.clone(), lambda1.clone()], - vec![col_c], + vec![lambda1], "gt".to_string(), ); - let and = TestTreeNode::new_scoped(vec![plus, gt], "and".to_string()); + let and = TestTreeNode::new(vec![plus, gt], "and".to_string()); - TestTreeNode::new_scoped(vec![and], "root".to_string()) + TestTreeNode::new(vec![and], "root".to_string()) } - /// Collect all data reachable via children_in_scope (scoped DFS). + /// Collect all data reachable via children_in_new_scope (scoped DFS). fn collect_scoped_data(node: &TestTreeNode) -> Vec { let mut result = vec![node.data.clone()]; - for child in &node.children_in_scope { + for child in &node.children_in_same_scope { result.extend(collect_scoped_data(child)); } result @@ -2962,14 +2962,14 @@ pub(crate) mod tests { result } - // Scoped transform helpers that preserve both children and children_in_scope + // Scoped transform helpers that preserve both children and children_in_new_scope fn transform_yes_in_scope>( transformation_name: N, ) -> impl FnMut(TestTreeNode) -> Result>> { move |node| { Ok(Transformed::yes(TestTreeNode { children: node.children, - children_in_scope: node.children_in_scope, + children_in_same_scope: node.children_in_same_scope, data: format!("{}({})", transformation_name, node.data).into(), })) } @@ -2988,7 +2988,7 @@ pub(crate) mod tests { move |node| { let new_node = TestTreeNode { children: node.children, - children_in_scope: node.children_in_scope, + children_in_same_scope: node.children_in_same_scope, data: format!("{}({})", transformation_name, node.data).into(), }; Ok(if node.data == d { @@ -3515,7 +3515,7 @@ pub(crate) mod tests { /// Build the Lambda2 subtree (innermost scope: lambda2 → inner_col) fn build_lambda2() -> TestTreeNode { let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); - TestTreeNode::new_scoped(vec![inner_col], "lambda2".to_string()) + TestTreeNode::new(vec![inner_col], "lambda2".to_string()) } /// Build the Lambda1 subtree (middle scope: lambda1 → [list_col, cmp → idx_col]) @@ -3525,10 +3525,10 @@ pub(crate) mod tests { let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); let cmp = TestTreeNode::new_mixed( vec![idx_col.clone(), lambda2.clone()], - vec![idx_col], + vec![lambda2], "cmp".to_string(), ); - TestTreeNode::new_scoped(vec![list_col, cmp], "lambda1".to_string()) + TestTreeNode::new(vec![list_col, cmp], "lambda1".to_string()) } // --- Lambda1 scope (middle): 4 in-scope nodes --- @@ -3828,7 +3828,7 @@ pub(crate) mod tests { } #[test] - fn test_node_with_only_children_no_scope() -> Result<()> { + fn test_node_new_has_all_children_in_scope() -> Result<()> { let child = TestTreeNode::new_leaf("child".to_string()); let parent = TestTreeNode::new(vec![child], "parent".to_string()); @@ -3837,7 +3837,7 @@ pub(crate) mod tests { scoped_visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - assert_eq!(scoped_visits, vec!["parent"]); + assert_eq!(scoped_visits, vec!["parent", "child"]); let mut all_visits = vec![]; parent.apply(|n| { From e817a7e958be2b8190a24edefc19b568b0505d24 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 20:48:12 +0300 Subject: [PATCH 13/17] update --- datafusion/common/src/tree_node.rs | 115 +++++++++++++++-------------- 1 file changed, 58 insertions(+), 57 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 03cee402de6d4..752604eb315f0 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -1818,19 +1818,18 @@ pub(crate) mod tests { // C (mixed) H // / \ // B (Same scope as C) D (new scope) - // | - // A + // | + // A // - // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, A, D, G, H - // (new/new_mixed set both children and children_in_new_scope) + // TreeNode (children) traversal visits ALL nodes: J, I, F, E, C, B, D, A, G, H + // (new/new_mixed set both children and children_in_same_scope) // - // ScopedTreeNode (children_in_new_scope) traversal visits: J, I, F, G, H - // (skips E, C, B, A, D — E and D are in new scopes, the rest are - // unreachable via scoped traversal) + // ScopedTreeNode (children_in_same_scope) traversal visits: J, I, F, G, H + // (skips E which is out of F's scope; skips C, B, D, A which are under E) fn test_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new_mixed( vec![node_b, node_d.clone()], vec![node_d], @@ -1858,10 +1857,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_down(a)", - "f_up(a)", "f_up(b)", "f_down(d)", + "f_down(a)", + "f_up(a)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -1881,8 +1880,8 @@ pub(crate) mod tests { // Expected transformed tree after a combined traversal fn transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); @@ -1897,8 +1896,8 @@ pub(crate) mod tests { // Expected transformed tree after a top-down traversal fn transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("f_down(d)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -1911,8 +1910,8 @@ pub(crate) mod tests { // Expected transformed tree after a bottom-up traversal fn transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_up(b)".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); @@ -1931,10 +1930,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_down(a)", - "f_up(a)", "f_up(b)", "f_down(d)", + "f_down(a)", + "f_up(a)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -1953,8 +1952,8 @@ pub(crate) mod tests { fn f_down_jump_on_a_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("f_down(d)".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -1987,8 +1986,8 @@ pub(crate) mod tests { fn f_down_jump_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); @@ -2001,8 +2000,8 @@ pub(crate) mod tests { fn f_down_jump_on_e_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_down(h)".to_string()); @@ -2021,12 +2020,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", + "f_up(b)", + "f_down(d)", "f_down(a)", "f_up(a)", - "f_down(d)", - "f_up(d)", - "f_up(c)", - "f_up(e)", "f_down(g)", "f_down(h)", "f_up(h)", @@ -2042,10 +2039,10 @@ pub(crate) mod tests { fn f_up_jump_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("f_up(f_down(h))".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(f_down(g))".to_string()); let node_f = @@ -2056,10 +2053,10 @@ pub(crate) mod tests { fn f_up_jump_on_a_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); - let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); - let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); + let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); + let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("f_up(h)".to_string()); let node_g = TestTreeNode::new(vec![node_h], "f_up(g)".to_string()); let node_f = TestTreeNode::new(vec![node_e, node_g], "f_up(f)".to_string()); @@ -2076,10 +2073,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_down(a)", - "f_up(a)", "f_up(b)", "f_down(d)", + "f_down(a)", + "f_up(a)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -2114,6 +2111,8 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", + "f_up(b)", + "f_down(d)", "f_down(a)", ] .into_iter() @@ -2123,8 +2122,8 @@ pub(crate) mod tests { fn f_down_stop_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2136,8 +2135,8 @@ pub(crate) mod tests { fn f_down_stop_on_a_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_down(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("f_down(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2157,8 +2156,8 @@ pub(crate) mod tests { fn f_down_stop_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2170,8 +2169,8 @@ pub(crate) mod tests { fn f_down_stop_on_e_transformed_down_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("a".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("b".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2190,6 +2189,8 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", + "f_up(b)", + "f_down(d)", "f_down(a)", "f_up(a)", ] @@ -2200,8 +2201,8 @@ pub(crate) mod tests { fn f_up_stop_on_a_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_down(b)".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_down(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_down(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_down(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2213,8 +2214,8 @@ pub(crate) mod tests { fn f_up_stop_on_a_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "b".to_string()); - let node_d = TestTreeNode::new_leaf("d".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "d".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string()); let node_e = TestTreeNode::new(vec![node_c], "e".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); @@ -2233,10 +2234,10 @@ pub(crate) mod tests { "f_down(e)", "f_down(c)", "f_down(b)", - "f_down(a)", - "f_up(a)", "f_up(b)", "f_down(d)", + "f_down(a)", + "f_up(a)", "f_up(d)", "f_up(c)", "f_up(e)", @@ -2248,8 +2249,8 @@ pub(crate) mod tests { fn f_up_stop_on_e_transformed_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(f_down(a))".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_up(f_down(b))".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(f_down(d))".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(f_down(b))".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(f_down(d))".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(f_down(c))".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(f_down(e))".to_string()); @@ -2262,8 +2263,8 @@ pub(crate) mod tests { fn f_up_stop_on_e_transformed_up_tree() -> TestTreeNode { let node_a = TestTreeNode::new_leaf("f_up(a)".to_string()); - let node_b = TestTreeNode::new(vec![node_a], "f_up(b)".to_string()); - let node_d = TestTreeNode::new_leaf("f_up(d)".to_string()); + let node_b = TestTreeNode::new_leaf("f_up(b)".to_string()); + let node_d = TestTreeNode::new(vec![node_a], "f_up(d)".to_string()); let node_c = TestTreeNode::new(vec![node_b, node_d], "f_up(c)".to_string()); let node_e = TestTreeNode::new(vec![node_c], "f_up(e)".to_string()); let node_h = TestTreeNode::new_leaf("h".to_string()); From bfc9db99fecf1546000d3ce6062c198f7d2edbb4 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 20:54:56 +0300 Subject: [PATCH 14/17] update comment --- datafusion/common/src/tree_node.rs | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 752604eb315f0..0b052f177398a 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -2870,22 +2870,22 @@ pub(crate) mod tests { // a scope, plus some scopes with only 1 in-scope child. // // OUTER SCOPE (root traversal visits these): - // root → and → [plus, gt] - // plus → [col_a, col_b] - // gt → [col_c, Lambda1] ← Lambda1 is out of scope + // root -> and -> [plus, gt] + // plus -> [col_a, col_b] + // gt -> [col_c, Lambda1] <- Lambda1 is out of scope // // Scoped traversal: root, and, plus, col_a, col_b, gt, col_c // (7 nodes, Lambda1 not entered) // // LAMBDA1 SCOPE (Lambda1's own traversal): - // lambda1 → [list_col, cmp] - // cmp → [idx_col, Lambda2] ← Lambda2 is out of scope + // lambda1 -> [list_col, cmp] + // cmp -> [idx_col, Lambda2] <- Lambda2 is out of scope // // Scoped traversal: lambda1, list_col, cmp, idx_col // (4 nodes, Lambda2 not entered) // - // LAMBDA2 SCOPE (Lambda2's own traversal — only 1 in-scope child): - // lambda2 → [inner_col] + // LAMBDA2 SCOPE (Lambda2's own traversal -- only 1 in-scope child): + // lambda2 -> [inner_col] // // Scoped traversal: lambda2, inner_col // (2 nodes) @@ -2896,13 +2896,13 @@ pub(crate) mod tests { // / \ // plus gt (mixed: in_scope=[col_c], out=[Lambda1]) // / \ / \ - // col_a col_b col_c Lambda1 ← scope boundary + // col_a col_b col_c Lambda1 <- scope boundary // | // (in_scope=[list_col, cmp], out=[]) // / \ // list_col cmp (mixed: in_scope=[idx_col], out=[Lambda2]) // / \ - // idx_col Lambda2 ← nested scope boundary + // idx_col Lambda2 <- nested scope boundary // | // (in_scope=[inner_col]) // | @@ -2917,11 +2917,11 @@ pub(crate) mod tests { let plus = TestTreeNode::new(vec![col_a, col_b], "plus".to_string()); - // --- Innermost lambda scope: Lambda2 → inner_col --- + // --- Innermost lambda scope: Lambda2 -> inner_col --- let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); let lambda2 = TestTreeNode::new(vec![inner_col], "lambda2".to_string()); - // --- Middle lambda scope: Lambda1 → list_col, cmp --- + // --- Middle lambda scope: Lambda1 -> list_col, cmp --- let list_col = TestTreeNode::new_leaf("list_col".to_string()); let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); // cmp has idx_col in scope, Lambda2 out of scope (new scope) @@ -3513,13 +3513,13 @@ pub(crate) mod tests { // Each scope is tested with multiple traversal methods to ensure // no scope crossing occurs in any direction. - /// Build the Lambda2 subtree (innermost scope: lambda2 → inner_col) + /// Build the Lambda2 subtree (innermost scope: lambda2 -> inner_col) fn build_lambda2() -> TestTreeNode { let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); TestTreeNode::new(vec![inner_col], "lambda2".to_string()) } - /// Build the Lambda1 subtree (middle scope: lambda1 → [list_col, cmp → idx_col]) + /// Build the Lambda1 subtree (middle scope: lambda1 -> [list_col, cmp -> idx_col]) fn build_lambda1() -> TestTreeNode { let lambda2 = build_lambda2(); let list_col = TestTreeNode::new_leaf("list_col".to_string()); From b0b8d85fab086aeeb8344997c94a01a24b608e7c Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 21:27:13 +0300 Subject: [PATCH 15/17] renamed --- datafusion/common/src/tree_node.rs | 750 ++++++++++++++--------------- 1 file changed, 352 insertions(+), 398 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0b052f177398a..0bc0047b18f37 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -2862,87 +2862,62 @@ pub(crate) mod tests { item.visit(&mut visitor).unwrap(); } - // ===================================================================== - // ScopedTreeNode tests + // A + // | + // B + // / \ + // C F (mixed) + // / \ / \ + // D E G H (new scope) + // | + // / \ + // I J (mixed) + // / \ + // K L (new scope) + // | + // M // - // Models nested scope boundaries, like expressions with nested lambdas. - // Each scope has multiple nodes (>=3) to test sibling traversal within - // a scope, plus some scopes with only 1 in-scope child. + // ScopedTreeNode traversal from A: A, B, C, D, E, F, G + // (skips H, I, J, K, L, M -- H is a new scope) // - // OUTER SCOPE (root traversal visits these): - // root -> and -> [plus, gt] - // plus -> [col_a, col_b] - // gt -> [col_c, Lambda1] <- Lambda1 is out of scope + // ScopedTreeNode traversal from H: H, I, J, K + // (skips L, M -- L is a new scope) // - // Scoped traversal: root, and, plus, col_a, col_b, gt, col_c - // (7 nodes, Lambda1 not entered) - // - // LAMBDA1 SCOPE (Lambda1's own traversal): - // lambda1 -> [list_col, cmp] - // cmp -> [idx_col, Lambda2] <- Lambda2 is out of scope - // - // Scoped traversal: lambda1, list_col, cmp, idx_col - // (4 nodes, Lambda2 not entered) - // - // LAMBDA2 SCOPE (Lambda2's own traversal -- only 1 in-scope child): - // lambda2 -> [inner_col] - // - // Scoped traversal: lambda2, inner_col - // (2 nodes) - // - // root - // | (scoped) - // and (scoped: [plus, gt]) - // / \ - // plus gt (mixed: in_scope=[col_c], out=[Lambda1]) - // / \ / \ - // col_a col_b col_c Lambda1 <- scope boundary - // | - // (in_scope=[list_col, cmp], out=[]) - // / \ - // list_col cmp (mixed: in_scope=[idx_col], out=[Lambda2]) - // / \ - // idx_col Lambda2 <- nested scope boundary - // | - // (in_scope=[inner_col]) - // | - // inner_col - // ===================================================================== + // ScopedTreeNode traversal from L: L, M fn scoped_test_tree() -> TestTreeNode { // Leaves for outer scope - let col_a = TestTreeNode::new_leaf("col_a".to_string()); - let col_b = TestTreeNode::new_leaf("col_b".to_string()); - let col_c = TestTreeNode::new_leaf("col_c".to_string()); - - let plus = TestTreeNode::new(vec![col_a, col_b], "plus".to_string()); - - // --- Innermost lambda scope: Lambda2 -> inner_col --- - let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); - let lambda2 = TestTreeNode::new(vec![inner_col], "lambda2".to_string()); - - // --- Middle lambda scope: Lambda1 -> list_col, cmp --- - let list_col = TestTreeNode::new_leaf("list_col".to_string()); - let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); - // cmp has idx_col in scope, Lambda2 out of scope (new scope) - let cmp = TestTreeNode::new_mixed( - vec![idx_col.clone(), lambda2.clone()], - vec![lambda2], - "cmp".to_string(), + let d = TestTreeNode::new_leaf("d".to_string()); + let e = TestTreeNode::new_leaf("e".to_string()); + let g = TestTreeNode::new_leaf("g".to_string()); + + let c = TestTreeNode::new(vec![d, e], "c".to_string()); + + // --- Innermost scope: l -> m --- + let m = TestTreeNode::new_leaf("m".to_string()); + let l = TestTreeNode::new(vec![m], "l".to_string()); + + // --- Middle scope: h -> [i, j] --- + let i = TestTreeNode::new_leaf("i".to_string()); + let k = TestTreeNode::new_leaf("k".to_string()); + // j has k in scope, l out of scope (new scope) + let j = TestTreeNode::new_mixed( + vec![k.clone(), l.clone()], + vec![l], + "j".to_string(), ); - let lambda1 = - TestTreeNode::new(vec![list_col, cmp], "lambda1".to_string()); + let h = TestTreeNode::new(vec![i, j], "h".to_string()); // --- Outer scope --- - // gt has col_c in scope, Lambda1 out of scope (new scope) - let gt = TestTreeNode::new_mixed( - vec![col_c.clone(), lambda1.clone()], - vec![lambda1], - "gt".to_string(), + // f has g in scope, h out of scope (new scope) + let f = TestTreeNode::new_mixed( + vec![g.clone(), h.clone()], + vec![h], + "f".to_string(), ); - let and = TestTreeNode::new(vec![plus, gt], "and".to_string()); + let b = TestTreeNode::new(vec![c, f], "b".to_string()); - TestTreeNode::new(vec![and], "root".to_string()) + TestTreeNode::new(vec![b], "a".to_string()) } /// Collect all data reachable via children_in_new_scope (scoped DFS). @@ -3015,20 +2990,20 @@ pub(crate) mod tests { assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", - "f_up(col_a)", - "f_down(col_b)", - "f_up(col_b)", - "f_up(plus)", - "f_down(gt)", - "f_down(col_c)", - "f_up(col_c)", - "f_up(gt)", - "f_up(and)", - "f_up(root)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", + "f_up(d)", + "f_down(e)", + "f_up(e)", + "f_up(c)", + "f_down(f)", + "f_down(g)", + "f_up(g)", + "f_up(f)", + "f_up(b)", + "f_up(a)", ]) ); Ok(()) @@ -3040,14 +3015,7 @@ pub(crate) mod tests { let mut visitor = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); tree.visit_in_scope(&mut visitor)?; - let out_of_scope = [ - "lambda1", - "lambda2", - "list_col", - "cmp", - "idx_col", - "inner_col", - ]; + let out_of_scope = ["h", "l", "i", "j", "k", "m"]; for v in &visitor.visits { for name in &out_of_scope { assert!( @@ -3060,134 +3028,134 @@ pub(crate) mod tests { } #[test] - fn test_visit_in_scope_f_down_jump_on_plus() -> Result<()> { + fn test_visit_in_scope_f_down_jump_on_c() -> Result<()> { let tree = scoped_test_tree(); let mut visitor = TestVisitor::new( - Box::new(visit_event_on("plus", TreeNodeRecursion::Jump)), + Box::new(visit_event_on("c", TreeNodeRecursion::Jump)), Box::new(visit_continue), ); tree.visit_in_scope(&mut visitor)?; - // Jump on Plus: skip Plus's children, continue with sibling gt + // Jump on c: skip c's children, continue with sibling f assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_up(plus)", - "f_down(gt)", - "f_down(col_c)", - "f_up(col_c)", - "f_up(gt)", - "f_up(and)", - "f_up(root)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_up(c)", + "f_down(f)", + "f_down(g)", + "f_up(g)", + "f_up(f)", + "f_up(b)", + "f_up(a)", ]) ); Ok(()) } #[test] - fn test_visit_in_scope_f_down_stop_on_col_a() -> Result<()> { + fn test_visit_in_scope_f_down_stop_on_d() -> Result<()> { let tree = scoped_test_tree(); let mut visitor = TestVisitor::new( - Box::new(visit_event_on("col_a", TreeNodeRecursion::Stop)), + Box::new(visit_event_on("d", TreeNodeRecursion::Stop)), Box::new(visit_continue), ); tree.visit_in_scope(&mut visitor)?; assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", ]) ); Ok(()) } #[test] - fn test_visit_in_scope_f_up_jump_on_col_a() -> Result<()> { + fn test_visit_in_scope_f_up_jump_on_d() -> Result<()> { let tree = scoped_test_tree(); let mut visitor = TestVisitor::new( Box::new(visit_continue), - Box::new(visit_event_on("col_a", TreeNodeRecursion::Jump)), + Box::new(visit_event_on("d", TreeNodeRecursion::Jump)), ); tree.visit_in_scope(&mut visitor)?; - // Jump after f_up(col_a): continue with sibling col_b. - // col_b returns Continue, resetting the tnr, so plus's f_up IS called. + // Jump after f_up(d): continue with sibling e. + // e returns Continue, resetting the tnr, so c's f_up IS called. assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", - "f_up(col_a)", - "f_down(col_b)", - "f_up(col_b)", - "f_up(plus)", - "f_down(gt)", - "f_down(col_c)", - "f_up(col_c)", - "f_up(gt)", - "f_up(and)", - "f_up(root)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", + "f_up(d)", + "f_down(e)", + "f_up(e)", + "f_up(c)", + "f_down(f)", + "f_down(g)", + "f_up(g)", + "f_up(f)", + "f_up(b)", + "f_up(a)", ]) ); Ok(()) } #[test] - fn test_visit_in_scope_f_up_stop_on_col_a() -> Result<()> { + fn test_visit_in_scope_f_up_stop_on_d() -> Result<()> { let tree = scoped_test_tree(); let mut visitor = TestVisitor::new( Box::new(visit_continue), - Box::new(visit_event_on("col_a", TreeNodeRecursion::Stop)), + Box::new(visit_event_on("d", TreeNodeRecursion::Stop)), ); tree.visit_in_scope(&mut visitor)?; assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", - "f_up(col_a)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", + "f_up(d)", ]) ); Ok(()) } #[test] - fn test_visit_in_scope_f_up_jump_on_plus() -> Result<()> { + fn test_visit_in_scope_f_up_jump_on_c() -> Result<()> { let tree = scoped_test_tree(); let mut visitor = TestVisitor::new( Box::new(visit_continue), - Box::new(visit_event_on("plus", TreeNodeRecursion::Jump)), + Box::new(visit_event_on("c", TreeNodeRecursion::Jump)), ); tree.visit_in_scope(&mut visitor)?; - // Jump after f_up(plus): skip plus's parent f_up, continue with sibling gt. - // gt returns Continue which resets the tnr, so f_up(and) and f_up(root) are called. + // Jump after f_up(c): skip c's parent f_up, continue with sibling f. + // f returns Continue which resets the tnr, so f_up(b) and f_up(a) are called. assert_eq!( visitor.visits, s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", - "f_up(col_a)", - "f_down(col_b)", - "f_up(col_b)", - "f_up(plus)", - // gt is sibling of plus, so it gets visited - "f_down(gt)", - "f_down(col_c)", - "f_up(col_c)", - "f_up(gt)", - // gt returned Continue, resetting the tnr, so f_up(and) is called - "f_up(and)", - "f_up(root)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", + "f_up(d)", + "f_down(e)", + "f_up(e)", + "f_up(c)", + // f is sibling of c, so it gets visited + "f_down(f)", + "f_down(g)", + "f_up(g)", + "f_up(f)", + // f returned Continue, resetting the tnr, so f_up(b) is called + "f_up(b)", + "f_up(a)", ]) ); Ok(()) @@ -3205,7 +3173,7 @@ pub(crate) mod tests { })?; assert_eq!( visits, - s(&["root", "and", "plus", "col_a", "col_b", "gt", "col_c"]) + s(&["a", "b", "c", "d", "e", "f", "g"]) ); Ok(()) } @@ -3218,14 +3186,7 @@ pub(crate) mod tests { visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - let out_of_scope = [ - "lambda1", - "lambda2", - "list_col", - "cmp", - "idx_col", - "inner_col", - ]; + let out_of_scope = ["h", "l", "i", "j", "k", "m"]; for name in &out_of_scope { assert!( !visits.contains(&name.to_string()), @@ -3236,35 +3197,35 @@ pub(crate) mod tests { } #[test] - fn test_apply_in_scope_jump_on_plus() -> Result<()> { + fn test_apply_in_scope_jump_on_c() -> Result<()> { let tree = scoped_test_tree(); let mut visits = vec![]; tree.apply_in_scope(|n| { visits.push(n.data.clone()); - Ok(if n.data == "plus" { + Ok(if n.data == "c" { TreeNodeRecursion::Jump } else { TreeNodeRecursion::Continue }) })?; - // Jump on Plus skips its children, continues to sibling gt - assert_eq!(visits, s(&["root", "and", "plus", "gt", "col_c"])); + // Jump on c skips its children, continues to sibling f + assert_eq!(visits, s(&["a", "b", "c", "f", "g"])); Ok(()) } #[test] - fn test_apply_in_scope_stop_on_col_a() -> Result<()> { + fn test_apply_in_scope_stop_on_d() -> Result<()> { let tree = scoped_test_tree(); let mut visits = vec![]; tree.apply_in_scope(|n| { visits.push(n.data.clone()); - Ok(if n.data == "col_a" { + Ok(if n.data == "d" { TreeNodeRecursion::Stop } else { TreeNodeRecursion::Continue }) })?; - assert_eq!(visits, s(&["root", "and", "plus", "col_a"])); + assert_eq!(visits, s(&["a", "b", "c", "d"])); Ok(()) } @@ -3277,22 +3238,22 @@ pub(crate) mod tests { assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "f_down(col_a)", - "f_down(col_b)", - "f_down(gt)", - "f_down(col_c)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "f_down(d)", + "f_down(e)", + "f_down(f)", + "f_down(g)", ]) ); - // Lambda internals untouched in children path + // Scope h internals untouched in children path let children = collect_children_data(&result.data); - assert_eq!(children[0], "f_down(root)"); - assert!(children.contains(&"lambda1".to_string())); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"list_col".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert_eq!(children[0], "f_down(a)"); + assert!(children.contains(&"h".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"i".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } @@ -3301,14 +3262,7 @@ pub(crate) mod tests { let tree = scoped_test_tree(); let result = tree.transform_down_in_scope(transform_yes_in_scope("f_down"))?; let scoped = collect_scoped_data(&result.data); - let out_of_scope = [ - "lambda1", - "lambda2", - "list_col", - "cmp", - "idx_col", - "inner_col", - ]; + let out_of_scope = ["h", "l", "i", "j", "k", "m"]; for v in &scoped { for name in &out_of_scope { assert!( @@ -3321,24 +3275,24 @@ pub(crate) mod tests { } #[test] - fn test_transform_down_in_scope_jump_on_plus() -> Result<()> { + fn test_transform_down_in_scope_jump_on_c() -> Result<()> { let tree = scoped_test_tree(); let result = tree.transform_down_in_scope(transform_and_event_on_in_scope( "f_down", - "plus", + "c", TreeNodeRecursion::Jump, ))?; - // Plus is transformed but children skipped, gt still visited + // c is transformed but children skipped, f still visited assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_down(root)", - "f_down(and)", - "f_down(plus)", - "col_a", - "col_b", - "f_down(gt)", - "f_down(col_c)", + "f_down(a)", + "f_down(b)", + "f_down(c)", + "d", + "e", + "f_down(f)", + "f_down(g)", ]) ); Ok(()) @@ -3353,34 +3307,34 @@ pub(crate) mod tests { assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(root)", - "f_up(and)", - "f_up(plus)", - "f_up(col_a)", - "f_up(col_b)", - "f_up(gt)", - "f_up(col_c)", + "f_up(a)", + "f_up(b)", + "f_up(c)", + "f_up(d)", + "f_up(e)", + "f_up(f)", + "f_up(g)", ]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda1".to_string())); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"h".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } #[test] - fn test_transform_up_in_scope_stop_on_col_a() -> Result<()> { + fn test_transform_up_in_scope_stop_on_d() -> Result<()> { let tree = scoped_test_tree(); let result = tree.transform_up_in_scope(transform_and_event_on_in_scope( "f_up", - "col_a", + "d", TreeNodeRecursion::Stop, ))?; - // Stop on col_a: only col_a transformed, everything else untouched + // Stop on d: only d transformed, everything else untouched assert_eq!( collect_scoped_data(&result.data), - s(&["root", "and", "plus", "f_up(col_a)", "col_b", "gt", "col_c",]) + s(&["a", "b", "c", "f_up(d)", "e", "f", "g",]) ); Ok(()) } @@ -3397,18 +3351,18 @@ pub(crate) mod tests { assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", - "f_up(f_down(and))", - "f_up(f_down(plus))", - "f_up(f_down(col_a))", - "f_up(f_down(col_b))", - "f_up(f_down(gt))", - "f_up(f_down(col_c))", + "f_up(f_down(a))", + "f_up(f_down(b))", + "f_up(f_down(c))", + "f_up(f_down(d))", + "f_up(f_down(e))", + "f_up(f_down(f))", + "f_up(f_down(g))", ]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda1".to_string())); - assert!(children.contains(&"lambda2".to_string())); + assert!(children.contains(&"h".to_string())); + assert!(children.contains(&"l".to_string())); Ok(()) } @@ -3417,31 +3371,31 @@ pub(crate) mod tests { #[test] fn test_exists_in_scope_found_in_scope() -> Result<()> { let tree = scoped_test_tree(); - assert!(tree.exists_in_scope(|n| Ok(n.data == "root"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "and"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "plus"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "col_a"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "col_b"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "gt"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "col_c"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "a"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "b"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "c"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "d"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "e"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "f"))?); + assert!(tree.exists_in_scope(|n| Ok(n.data == "g"))?); Ok(()) } #[test] - fn test_exists_in_scope_not_found_lambda1_scope() -> Result<()> { + fn test_exists_in_scope_not_found_scope_h() -> Result<()> { let tree = scoped_test_tree(); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "lambda1"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "list_col"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "cmp"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "idx_col"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "h"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "i"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "j"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "k"))?); Ok(()) } #[test] - fn test_exists_in_scope_not_found_nested_lambda2_scope() -> Result<()> { + fn test_exists_in_scope_not_found_nested_scope_l() -> Result<()> { let tree = scoped_test_tree(); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "lambda2"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "inner_col"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "l"))?); + assert!(!tree.exists_in_scope(|n| Ok(n.data == "m"))?); Ok(()) } @@ -3465,45 +3419,45 @@ pub(crate) mod tests { assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", - "f_up(f_down(and))", - "f_up(f_down(plus))", - "f_up(f_down(col_a))", - "f_up(f_down(col_b))", - "f_up(f_down(gt))", - "f_up(f_down(col_c))", + "f_up(f_down(a))", + "f_up(f_down(b))", + "f_up(f_down(c))", + "f_up(f_down(d))", + "f_up(f_down(e))", + "f_up(f_down(f))", + "f_up(f_down(g))", ]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda1".to_string())); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"h".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } #[test] - fn test_rewrite_in_scope_f_down_jump_on_plus() -> Result<()> { + fn test_rewrite_in_scope_f_down_jump_on_c() -> Result<()> { let tree = scoped_test_tree(); let mut rewriter = TestRewriter::new( Box::new(transform_and_event_on_in_scope( "f_down", - "plus", + "c", TreeNodeRecursion::Jump, )), Box::new(transform_yes_in_scope("f_up")), ); let result = tree.rewrite_in_scope(&mut rewriter)?; - // Jump on Plus: children skipped, but sibling gt is visited + // Jump on c: children skipped, but sibling f is visited assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(f_down(root))", - "f_up(f_down(and))", - "f_up(f_down(plus))", - "col_a", - "col_b", - "f_up(f_down(gt))", - "f_up(f_down(col_c))", + "f_up(f_down(a))", + "f_up(f_down(b))", + "f_up(f_down(c))", + "d", + "e", + "f_up(f_down(f))", + "f_up(f_down(g))", ]) ); Ok(()) @@ -3513,272 +3467,272 @@ pub(crate) mod tests { // Each scope is tested with multiple traversal methods to ensure // no scope crossing occurs in any direction. - /// Build the Lambda2 subtree (innermost scope: lambda2 -> inner_col) - fn build_lambda2() -> TestTreeNode { - let inner_col = TestTreeNode::new_leaf("inner_col".to_string()); - TestTreeNode::new(vec![inner_col], "lambda2".to_string()) - } - - /// Build the Lambda1 subtree (middle scope: lambda1 -> [list_col, cmp -> idx_col]) - fn build_lambda1() -> TestTreeNode { - let lambda2 = build_lambda2(); - let list_col = TestTreeNode::new_leaf("list_col".to_string()); - let idx_col = TestTreeNode::new_leaf("idx_col".to_string()); - let cmp = TestTreeNode::new_mixed( - vec![idx_col.clone(), lambda2.clone()], - vec![lambda2], - "cmp".to_string(), + /// Build the scope_l subtree (innermost scope: l -> m) + fn build_scope_l() -> TestTreeNode { + let m = TestTreeNode::new_leaf("m".to_string()); + TestTreeNode::new(vec![m], "l".to_string()) + } + + /// Build the scope_h subtree (middle scope: h -> [i, j -> k]) + fn build_scope_h() -> TestTreeNode { + let l = build_scope_l(); + let i = TestTreeNode::new_leaf("i".to_string()); + let k = TestTreeNode::new_leaf("k".to_string()); + let j = TestTreeNode::new_mixed( + vec![k.clone(), l.clone()], + vec![l], + "j".to_string(), ); - TestTreeNode::new(vec![list_col, cmp], "lambda1".to_string()) + TestTreeNode::new(vec![i, j], "h".to_string()) } - // --- Lambda1 scope (middle): 4 in-scope nodes --- + // --- Scope h (middle): 4 in-scope nodes --- #[test] - fn test_lambda1_scope_apply() -> Result<()> { - let lambda1 = build_lambda1(); + fn test_scope_h_apply() -> Result<()> { + let h = build_scope_h(); let mut visits = vec![]; - lambda1.apply_in_scope(|n| { + h.apply_in_scope(|n| { visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - assert_eq!(visits, s(&["lambda1", "list_col", "cmp", "idx_col"])); - assert!(!visits.contains(&"lambda2".to_string())); - assert!(!visits.contains(&"inner_col".to_string())); + assert_eq!(visits, s(&["h", "i", "j", "k"])); + assert!(!visits.contains(&"l".to_string())); + assert!(!visits.contains(&"m".to_string())); Ok(()) } #[test] - fn test_lambda1_scope_visit() -> Result<()> { - let lambda1 = build_lambda1(); + fn test_scope_h_visit() -> Result<()> { + let h = build_scope_h(); let mut visitor = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - lambda1.visit_in_scope(&mut visitor)?; + h.visit_in_scope(&mut visitor)?; assert_eq!( visitor.visits, s(&[ - "f_down(lambda1)", - "f_down(list_col)", - "f_up(list_col)", - "f_down(cmp)", - "f_down(idx_col)", - "f_up(idx_col)", - "f_up(cmp)", - "f_up(lambda1)", + "f_down(h)", + "f_down(i)", + "f_up(i)", + "f_down(j)", + "f_down(k)", + "f_up(k)", + "f_up(j)", + "f_up(h)", ]) ); - // Must not enter lambda2 scope + // Must not enter scope l for v in &visitor.visits { assert!( - !v.contains("lambda2"), - "lambda1 visit must not enter lambda2: {v}" + !v.contains("(l)"), + "scope h visit must not enter scope l: {v}" ); assert!( - !v.contains("inner_col"), - "lambda1 visit must not enter lambda2: {v}" + !v.contains("(m)"), + "scope h visit must not enter scope l: {v}" ); } Ok(()) } #[test] - fn test_lambda1_scope_transform_down() -> Result<()> { - let lambda1 = build_lambda1(); - let result = lambda1.transform_down_in_scope(transform_yes_in_scope("TX"))?; + fn test_scope_h_transform_down() -> Result<()> { + let h = build_scope_h(); + let result = h.transform_down_in_scope(transform_yes_in_scope("TX"))?; assert_eq!( collect_scoped_data(&result.data), - s(&["TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)",]) + s(&["TX(h)", "TX(i)", "TX(j)", "TX(k)",]) ); - // Lambda2 untouched in children path + // Scope l untouched in children path let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } #[test] - fn test_lambda1_scope_transform_up() -> Result<()> { - let lambda1 = build_lambda1(); - let result = lambda1.transform_up_in_scope(transform_yes_in_scope("TX"))?; + fn test_scope_h_transform_up() -> Result<()> { + let h = build_scope_h(); + let result = h.transform_up_in_scope(transform_yes_in_scope("TX"))?; assert_eq!( collect_scoped_data(&result.data), - s(&["TX(lambda1)", "TX(list_col)", "TX(cmp)", "TX(idx_col)",]) + s(&["TX(h)", "TX(i)", "TX(j)", "TX(k)",]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } #[test] - fn test_lambda1_scope_exists() -> Result<()> { - let lambda1 = build_lambda1(); + fn test_scope_h_exists() -> Result<()> { + let h = build_scope_h(); // In scope - assert!(lambda1.exists_in_scope(|n| Ok(n.data == "lambda1"))?); - assert!(lambda1.exists_in_scope(|n| Ok(n.data == "list_col"))?); - assert!(lambda1.exists_in_scope(|n| Ok(n.data == "cmp"))?); - assert!(lambda1.exists_in_scope(|n| Ok(n.data == "idx_col"))?); - // Out of scope (lambda2's scope) - assert!(!lambda1.exists_in_scope(|n| Ok(n.data == "lambda2"))?); - assert!(!lambda1.exists_in_scope(|n| Ok(n.data == "inner_col"))?); + assert!(h.exists_in_scope(|n| Ok(n.data == "h"))?); + assert!(h.exists_in_scope(|n| Ok(n.data == "i"))?); + assert!(h.exists_in_scope(|n| Ok(n.data == "j"))?); + assert!(h.exists_in_scope(|n| Ok(n.data == "k"))?); + // Out of scope (scope l) + assert!(!h.exists_in_scope(|n| Ok(n.data == "l"))?); + assert!(!h.exists_in_scope(|n| Ok(n.data == "m"))?); Ok(()) } #[test] - fn test_lambda1_scope_rewrite() -> Result<()> { - let lambda1 = build_lambda1(); + fn test_scope_h_rewrite() -> Result<()> { + let h = build_scope_h(); let mut rewriter = TestRewriter::new( Box::new(transform_yes_in_scope("f_down")), Box::new(transform_yes_in_scope("f_up")), ); - let result = lambda1.rewrite_in_scope(&mut rewriter)?; + let result = h.rewrite_in_scope(&mut rewriter)?; assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda1))", - "f_up(f_down(list_col))", - "f_up(f_down(cmp))", - "f_up(f_down(idx_col))", + "f_up(f_down(h))", + "f_up(f_down(i))", + "f_up(f_down(j))", + "f_up(f_down(k))", ]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } #[test] - fn test_lambda1_scope_transform_down_up() -> Result<()> { - let lambda1 = build_lambda1(); - let result = lambda1.transform_down_up_in_scope( + fn test_scope_h_transform_down_up() -> Result<()> { + let h = build_scope_h(); + let result = h.transform_down_up_in_scope( transform_yes_in_scope("f_down"), transform_yes_in_scope("f_up"), )?; assert_eq!( collect_scoped_data(&result.data), s(&[ - "f_up(f_down(lambda1))", - "f_up(f_down(list_col))", - "f_up(f_down(cmp))", - "f_up(f_down(idx_col))", + "f_up(f_down(h))", + "f_up(f_down(i))", + "f_up(f_down(j))", + "f_up(f_down(k))", ]) ); let children = collect_children_data(&result.data); - assert!(children.contains(&"lambda2".to_string())); - assert!(children.contains(&"inner_col".to_string())); + assert!(children.contains(&"l".to_string())); + assert!(children.contains(&"m".to_string())); Ok(()) } - // --- Lambda2 scope (innermost): 2 in-scope nodes --- + // --- Scope l (innermost): 2 in-scope nodes --- #[test] - fn test_lambda2_scope_apply() -> Result<()> { - let lambda2 = build_lambda2(); + fn test_scope_l_apply() -> Result<()> { + let l = build_scope_l(); let mut visits = vec![]; - lambda2.apply_in_scope(|n| { + l.apply_in_scope(|n| { visits.push(n.data.clone()); Ok(TreeNodeRecursion::Continue) })?; - assert_eq!(visits, s(&["lambda2", "inner_col"])); + assert_eq!(visits, s(&["l", "m"])); Ok(()) } #[test] - fn test_lambda2_scope_visit() -> Result<()> { - let lambda2 = build_lambda2(); + fn test_scope_l_visit() -> Result<()> { + let l = build_scope_l(); let mut visitor = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - lambda2.visit_in_scope(&mut visitor)?; + l.visit_in_scope(&mut visitor)?; assert_eq!( visitor.visits, s(&[ - "f_down(lambda2)", - "f_down(inner_col)", - "f_up(inner_col)", - "f_up(lambda2)", + "f_down(l)", + "f_down(m)", + "f_up(m)", + "f_up(l)", ]) ); Ok(()) } #[test] - fn test_lambda2_scope_transform_down() -> Result<()> { - let lambda2 = build_lambda2(); - let result = lambda2.transform_down_in_scope(transform_yes_in_scope("TX"))?; + fn test_scope_l_transform_down() -> Result<()> { + let l = build_scope_l(); + let result = l.transform_down_in_scope(transform_yes_in_scope("TX"))?; assert_eq!( collect_scoped_data(&result.data), - s(&["TX(lambda2)", "TX(inner_col)",]) + s(&["TX(l)", "TX(m)",]) ); Ok(()) } #[test] - fn test_lambda2_scope_transform_up() -> Result<()> { - let lambda2 = build_lambda2(); - let result = lambda2.transform_up_in_scope(transform_yes_in_scope("TX"))?; + fn test_scope_l_transform_up() -> Result<()> { + let l = build_scope_l(); + let result = l.transform_up_in_scope(transform_yes_in_scope("TX"))?; assert_eq!( collect_scoped_data(&result.data), - s(&["TX(lambda2)", "TX(inner_col)",]) + s(&["TX(l)", "TX(m)",]) ); Ok(()) } #[test] - fn test_lambda2_scope_transform_down_up() -> Result<()> { - let lambda2 = build_lambda2(); - let result = lambda2.transform_down_up_in_scope( + fn test_scope_l_transform_down_up() -> Result<()> { + let l = build_scope_l(); + let result = l.transform_down_up_in_scope( transform_yes_in_scope("f_down"), transform_yes_in_scope("f_up"), )?; assert_eq!( collect_scoped_data(&result.data), - s(&["f_up(f_down(lambda2))", "f_up(f_down(inner_col))",]) + s(&["f_up(f_down(l))", "f_up(f_down(m))",]) ); Ok(()) } #[test] - fn test_lambda2_scope_rewrite() -> Result<()> { - let lambda2 = build_lambda2(); + fn test_scope_l_rewrite() -> Result<()> { + let l = build_scope_l(); let mut rewriter = TestRewriter::new( Box::new(transform_yes_in_scope("f_down")), Box::new(transform_yes_in_scope("f_up")), ); - let result = lambda2.rewrite_in_scope(&mut rewriter)?; + let result = l.rewrite_in_scope(&mut rewriter)?; assert_eq!( collect_scoped_data(&result.data), - s(&["f_up(f_down(lambda2))", "f_up(f_down(inner_col))",]) + s(&["f_up(f_down(l))", "f_up(f_down(m))",]) ); Ok(()) } #[test] - fn test_lambda2_scope_exists() -> Result<()> { - let lambda2 = build_lambda2(); - assert!(lambda2.exists_in_scope(|n| Ok(n.data == "lambda2"))?); - assert!(lambda2.exists_in_scope(|n| Ok(n.data == "inner_col"))?); - // Not in any scope from lambda2's perspective - assert!(!lambda2.exists_in_scope(|n| Ok(n.data == "lambda1"))?); - assert!(!lambda2.exists_in_scope(|n| Ok(n.data == "col_a"))?); + fn test_scope_l_exists() -> Result<()> { + let l = build_scope_l(); + assert!(l.exists_in_scope(|n| Ok(n.data == "l"))?); + assert!(l.exists_in_scope(|n| Ok(n.data == "m"))?); + // Not in any scope from l's perspective + assert!(!l.exists_in_scope(|n| Ok(n.data == "h"))?); + assert!(!l.exists_in_scope(|n| Ok(n.data == "d"))?); Ok(()) } // --- Outer scope: transform must not affect any inner scope --- #[test] - fn test_outer_scope_transform_does_not_affect_lambda1() -> Result<()> { + fn test_outer_scope_transform_does_not_affect_scope_h() -> Result<()> { let tree = scoped_test_tree(); let result = tree.transform_down_in_scope(transform_yes_in_scope("TX"))?; let all_data = collect_children_data(&result.data); - assert_eq!(all_data[0], "TX(root)"); - // Lambda1 scope completely untouched - assert!(all_data.contains(&"lambda1".to_string())); - assert!(all_data.contains(&"list_col".to_string())); - assert!(all_data.contains(&"cmp".to_string())); - assert!(all_data.contains(&"idx_col".to_string())); - // Lambda2 scope completely untouched - assert!(all_data.contains(&"lambda2".to_string())); - assert!(all_data.contains(&"inner_col".to_string())); + assert_eq!(all_data[0], "TX(a)"); + // Scope h completely untouched + assert!(all_data.contains(&"h".to_string())); + assert!(all_data.contains(&"i".to_string())); + assert!(all_data.contains(&"j".to_string())); + assert!(all_data.contains(&"k".to_string())); + // Scope l completely untouched + assert!(all_data.contains(&"l".to_string())); + assert!(all_data.contains(&"m".to_string())); Ok(()) } @@ -3790,12 +3744,12 @@ pub(crate) mod tests { transform_yes_in_scope("f_up"), )?; let all_data = collect_children_data(&result.data); - assert!(all_data.contains(&"lambda1".to_string())); - assert!(all_data.contains(&"list_col".to_string())); - assert!(all_data.contains(&"cmp".to_string())); - assert!(all_data.contains(&"idx_col".to_string())); - assert!(all_data.contains(&"lambda2".to_string())); - assert!(all_data.contains(&"inner_col".to_string())); + assert!(all_data.contains(&"h".to_string())); + assert!(all_data.contains(&"i".to_string())); + assert!(all_data.contains(&"j".to_string())); + assert!(all_data.contains(&"k".to_string())); + assert!(all_data.contains(&"l".to_string())); + assert!(all_data.contains(&"m".to_string())); Ok(()) } @@ -3804,13 +3758,13 @@ pub(crate) mod tests { let tree = scoped_test_tree(); let result = tree.transform_up_in_scope(transform_yes_in_scope("TX"))?; let all_data = collect_children_data(&result.data); - // Lambda1 + Lambda2 scopes completely untouched - assert!(all_data.contains(&"lambda1".to_string())); - assert!(all_data.contains(&"list_col".to_string())); - assert!(all_data.contains(&"cmp".to_string())); - assert!(all_data.contains(&"idx_col".to_string())); - assert!(all_data.contains(&"lambda2".to_string())); - assert!(all_data.contains(&"inner_col".to_string())); + // Scope h + scope l completely untouched + assert!(all_data.contains(&"h".to_string())); + assert!(all_data.contains(&"i".to_string())); + assert!(all_data.contains(&"j".to_string())); + assert!(all_data.contains(&"k".to_string())); + assert!(all_data.contains(&"l".to_string())); + assert!(all_data.contains(&"m".to_string())); Ok(()) } From 7b9fc250dc10ee28b0127c1d3b6eabc2bfcb01c5 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 21:55:06 +0300 Subject: [PATCH 16/17] simplify tree node tests by asserting that scoped result are the same as non scoped for that subtree --- datafusion/common/src/tree_node.rs | 1003 ++++------------------------ 1 file changed, 140 insertions(+), 863 deletions(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index 0bc0047b18f37..f46c7fcde7638 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -1782,12 +1782,13 @@ pub(crate) mod tests { self, f: F, ) -> Result> { - Ok(self.children_in_same_scope.map_elements(f)?.update_data( - |new_children| Self { + Ok(self + .children_in_same_scope + .map_elements(f)? + .update_data(|new_children| Self { children_in_same_scope: new_children, ..self - }, - )) + })) } } @@ -2422,6 +2423,20 @@ pub(crate) mod tests { } } + /// Like `transform_yes`, but preserves `children_in_same_scope` from the original node, + /// so scope boundaries are not lost during scoped traversal. + fn transform_yes_scoped + Clone>( + transformation_name: N, + ) -> impl FnMut(TestTreeNode) -> Result>> { + move |node| { + Ok(Transformed::yes(TestTreeNode { + children_in_same_scope: node.children_in_same_scope.clone(), + children: node.children, + data: format!("{}({})", transformation_name, node.data).into(), + })) + } + } + macro_rules! rewrite_test { ($NAME:ident, $F_DOWN:expr, $F_UP:expr, $EXPECTED_TREE:expr) => { #[test] @@ -2862,6 +2877,13 @@ pub(crate) mod tests { item.visit(&mut visitor).unwrap(); } + // ===================================================================== + // ScopedTreeNode tests + // + // The scoped tree has 3 nested scopes. Scoped traversal of each scope + // should produce the same result as non-scoped traversal of a tree + // containing only that scope's nodes. + // // A // | // B @@ -2878,49 +2900,43 @@ pub(crate) mod tests { // M // // ScopedTreeNode traversal from A: A, B, C, D, E, F, G - // (skips H, I, J, K, L, M -- H is a new scope) - // // ScopedTreeNode traversal from H: H, I, J, K - // (skips L, M -- L is a new scope) - // // ScopedTreeNode traversal from L: L, M + // ===================================================================== + /// Full tree with scope boundaries. fn scoped_test_tree() -> TestTreeNode { - // Leaves for outer scope let d = TestTreeNode::new_leaf("d".to_string()); let e = TestTreeNode::new_leaf("e".to_string()); let g = TestTreeNode::new_leaf("g".to_string()); - let c = TestTreeNode::new(vec![d, e], "c".to_string()); - // --- Innermost scope: l -> m --- let m = TestTreeNode::new_leaf("m".to_string()); let l = TestTreeNode::new(vec![m], "l".to_string()); - // --- Middle scope: h -> [i, j] --- let i = TestTreeNode::new_leaf("i".to_string()); let k = TestTreeNode::new_leaf("k".to_string()); - // j has k in scope, l out of scope (new scope) - let j = TestTreeNode::new_mixed( - vec![k.clone(), l.clone()], - vec![l], - "j".to_string(), - ); + let j = + TestTreeNode::new_mixed(vec![k.clone(), l.clone()], vec![l], "j".to_string()); let h = TestTreeNode::new(vec![i, j], "h".to_string()); - // --- Outer scope --- - // f has g in scope, h out of scope (new scope) - let f = TestTreeNode::new_mixed( - vec![g.clone(), h.clone()], - vec![h], - "f".to_string(), - ); + let f = + TestTreeNode::new_mixed(vec![g.clone(), h.clone()], vec![h], "f".to_string()); let b = TestTreeNode::new(vec![c, f], "b".to_string()); - TestTreeNode::new(vec![b], "a".to_string()) } - /// Collect all data reachable via children_in_new_scope (scoped DFS). + /// Build a non-scoped tree containing only the in-scope nodes, + /// by following `children_in_same_scope` recursively. + fn extract_scope_tree(node: &TestTreeNode) -> TestTreeNode { + let children: Vec<_> = node + .children_in_same_scope + .iter() + .map(extract_scope_tree) + .collect(); + TestTreeNode::new(children, node.data.clone()) + } + fn collect_scoped_data(node: &TestTreeNode) -> Vec { let mut result = vec![node.data.clone()]; for child in &node.children_in_same_scope { @@ -2929,7 +2945,6 @@ pub(crate) mod tests { result } - /// Collect all data reachable via children (non-scoped DFS). fn collect_children_data(node: &TestTreeNode) -> Vec { let mut result = vec![node.data.clone()]; for child in &node.children { @@ -2938,868 +2953,130 @@ pub(crate) mod tests { result } - // Scoped transform helpers that preserve both children and children_in_new_scope - fn transform_yes_in_scope>( - transformation_name: N, - ) -> impl FnMut(TestTreeNode) -> Result>> { - move |node| { - Ok(Transformed::yes(TestTreeNode { - children: node.children, - children_in_same_scope: node.children_in_same_scope, - data: format!("{}({})", transformation_name, node.data).into(), - })) - } - } - - fn transform_and_event_on_in_scope< - N: Display, - T: PartialEq + Display + From, - D: Into, - >( - transformation_name: N, - data: D, - event: TreeNodeRecursion, - ) -> impl FnMut(TestTreeNode) -> Result>> { - let d = data.into(); - move |node| { - let new_node = TestTreeNode { - children: node.children, - children_in_same_scope: node.children_in_same_scope, - data: format!("{}({})", transformation_name, node.data).into(), - }; - Ok(if node.data == d { - Transformed::new(new_node, true, event) - } else { - Transformed::yes(new_node) - }) - } - } - - fn s(v: &[&str]) -> Vec { - v.iter().map(|s| s.to_string()).collect() - } - - // === visit_in_scope === - - #[test] - fn test_visit_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = - TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - tree.visit_in_scope(&mut visitor)?; - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - "f_up(d)", - "f_down(e)", - "f_up(e)", - "f_up(c)", - "f_down(f)", - "f_down(g)", - "f_up(g)", - "f_up(f)", - "f_up(b)", - "f_up(a)", - ]) - ); - Ok(()) - } - - #[test] - fn test_visit_in_scope_does_not_enter_lambda() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = - TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - tree.visit_in_scope(&mut visitor)?; - let out_of_scope = ["h", "l", "i", "j", "k", "m"]; - for v in &visitor.visits { - for name in &out_of_scope { - assert!( - !v.contains(name), - "should not enter other scope ({name}): {v}" - ); - } + /// Collect references to every node in the tree (DFS through all children). + fn all_nodes(node: &TestTreeNode) -> Vec<&TestTreeNode> { + let mut result = vec![node]; + for child in &node.children { + result.extend(all_nodes(child)); } - Ok(()) - } - - #[test] - fn test_visit_in_scope_f_down_jump_on_c() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_event_on("c", TreeNodeRecursion::Jump)), - Box::new(visit_continue), - ); - tree.visit_in_scope(&mut visitor)?; - // Jump on c: skip c's children, continue with sibling f - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_up(c)", - "f_down(f)", - "f_down(g)", - "f_up(g)", - "f_up(f)", - "f_up(b)", - "f_up(a)", - ]) - ); - Ok(()) - } - - #[test] - fn test_visit_in_scope_f_down_stop_on_d() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_event_on("d", TreeNodeRecursion::Stop)), - Box::new(visit_continue), - ); - tree.visit_in_scope(&mut visitor)?; - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - ]) - ); - Ok(()) - } - - #[test] - fn test_visit_in_scope_f_up_jump_on_d() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_event_on("d", TreeNodeRecursion::Jump)), - ); - tree.visit_in_scope(&mut visitor)?; - // Jump after f_up(d): continue with sibling e. - // e returns Continue, resetting the tnr, so c's f_up IS called. - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - "f_up(d)", - "f_down(e)", - "f_up(e)", - "f_up(c)", - "f_down(f)", - "f_down(g)", - "f_up(g)", - "f_up(f)", - "f_up(b)", - "f_up(a)", - ]) - ); - Ok(()) - } - - #[test] - fn test_visit_in_scope_f_up_stop_on_d() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_event_on("d", TreeNodeRecursion::Stop)), - ); - tree.visit_in_scope(&mut visitor)?; - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - "f_up(d)", - ]) - ); - Ok(()) - } - - #[test] - fn test_visit_in_scope_f_up_jump_on_c() -> Result<()> { - let tree = scoped_test_tree(); - let mut visitor = TestVisitor::new( - Box::new(visit_continue), - Box::new(visit_event_on("c", TreeNodeRecursion::Jump)), - ); - tree.visit_in_scope(&mut visitor)?; - // Jump after f_up(c): skip c's parent f_up, continue with sibling f. - // f returns Continue which resets the tnr, so f_up(b) and f_up(a) are called. - assert_eq!( - visitor.visits, - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - "f_up(d)", - "f_down(e)", - "f_up(e)", - "f_up(c)", - // f is sibling of c, so it gets visited - "f_down(f)", - "f_down(g)", - "f_up(g)", - "f_up(f)", - // f returned Continue, resetting the tnr, so f_up(b) is called - "f_up(b)", - "f_up(a)", - ]) - ); - Ok(()) + result } - // === apply_in_scope === + /// For a given node, assert that all scoped traversal functions produce + /// the same result as their non-scoped counterparts on the in-scope subtree. + fn assert_all_scoped_traversals_match(scoped: &TestTreeNode) { + let equivalent = extract_scope_tree(scoped); - #[test] - fn test_apply_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let mut visits = vec![]; - tree.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!( - visits, - s(&["a", "b", "c", "d", "e", "f", "g"]) - ); - Ok(()) - } - - #[test] - fn test_apply_in_scope_does_not_enter_lambda() -> Result<()> { - let tree = scoped_test_tree(); - let mut visits = vec![]; - tree.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - let out_of_scope = ["h", "l", "i", "j", "k", "m"]; - for name in &out_of_scope { - assert!( - !visits.contains(&name.to_string()), - "{name} should not be visited" - ); - } - Ok(()) - } + // visit_in_scope == visit + let mut sv = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + scoped.visit_in_scope(&mut sv).unwrap(); + let mut ev = TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); + equivalent.visit(&mut ev).unwrap(); + assert_eq!(sv.visits, ev.visits, "visit mismatch for {}", scoped.data); - #[test] - fn test_apply_in_scope_jump_on_c() -> Result<()> { - let tree = scoped_test_tree(); - let mut visits = vec![]; - tree.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(if n.data == "c" { - TreeNodeRecursion::Jump - } else { - TreeNodeRecursion::Continue + // apply_in_scope == apply + let mut s_apply = vec![]; + scoped + .apply_in_scope(|n| { + s_apply.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) }) - })?; - // Jump on c skips its children, continues to sibling f - assert_eq!(visits, s(&["a", "b", "c", "f", "g"])); - Ok(()) - } - - #[test] - fn test_apply_in_scope_stop_on_d() -> Result<()> { - let tree = scoped_test_tree(); - let mut visits = vec![]; - tree.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(if n.data == "d" { - TreeNodeRecursion::Stop - } else { - TreeNodeRecursion::Continue + .unwrap(); + let mut e_apply = vec![]; + equivalent + .apply(|n| { + e_apply.push(n.data.clone()); + Ok(TreeNodeRecursion::Continue) }) - })?; - assert_eq!(visits, s(&["a", "b", "c", "d"])); - Ok(()) - } - - // === transform_down_in_scope === - - #[test] - fn test_transform_down_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope(transform_yes_in_scope("f_down"))?; - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "f_down(d)", - "f_down(e)", - "f_down(f)", - "f_down(g)", - ]) - ); - // Scope h internals untouched in children path - let children = collect_children_data(&result.data); - assert_eq!(children[0], "f_down(a)"); - assert!(children.contains(&"h".to_string())); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"i".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_transform_down_in_scope_does_not_transform_lambda_internals() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope(transform_yes_in_scope("f_down"))?; - let scoped = collect_scoped_data(&result.data); - let out_of_scope = ["h", "l", "i", "j", "k", "m"]; - for v in &scoped { - for name in &out_of_scope { - assert!( - !v.contains(name), - "{name} should not be in scoped data: {v}" - ); - } - } - Ok(()) - } - - #[test] - fn test_transform_down_in_scope_jump_on_c() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope(transform_and_event_on_in_scope( - "f_down", - "c", - TreeNodeRecursion::Jump, - ))?; - // c is transformed but children skipped, f still visited - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_down(a)", - "f_down(b)", - "f_down(c)", - "d", - "e", - "f_down(f)", - "f_down(g)", - ]) - ); - Ok(()) - } - - // === transform_up_in_scope === - - #[test] - fn test_transform_up_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope(transform_yes_in_scope("f_up"))?; - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(a)", - "f_up(b)", - "f_up(c)", - "f_up(d)", - "f_up(e)", - "f_up(f)", - "f_up(g)", - ]) - ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"h".to_string())); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_transform_up_in_scope_stop_on_d() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope(transform_and_event_on_in_scope( - "f_up", - "d", - TreeNodeRecursion::Stop, - ))?; - // Stop on d: only d transformed, everything else untouched - assert_eq!( - collect_scoped_data(&result.data), - s(&["a", "b", "c", "f_up(d)", "e", "f", "g",]) - ); - Ok(()) - } - - // === transform_down_up_in_scope === - - #[test] - fn test_transform_down_up_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_down_up_in_scope( - transform_yes_in_scope("f_down"), - transform_yes_in_scope("f_up"), - )?; - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(f_down(a))", - "f_up(f_down(b))", - "f_up(f_down(c))", - "f_up(f_down(d))", - "f_up(f_down(e))", - "f_up(f_down(f))", - "f_up(f_down(g))", - ]) - ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"h".to_string())); - assert!(children.contains(&"l".to_string())); - Ok(()) - } - - // === exists_in_scope === - - #[test] - fn test_exists_in_scope_found_in_scope() -> Result<()> { - let tree = scoped_test_tree(); - assert!(tree.exists_in_scope(|n| Ok(n.data == "a"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "b"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "c"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "d"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "e"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "f"))?); - assert!(tree.exists_in_scope(|n| Ok(n.data == "g"))?); - Ok(()) - } - - #[test] - fn test_exists_in_scope_not_found_scope_h() -> Result<()> { - let tree = scoped_test_tree(); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "h"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "i"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "j"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "k"))?); - Ok(()) - } - - #[test] - fn test_exists_in_scope_not_found_nested_scope_l() -> Result<()> { - let tree = scoped_test_tree(); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "l"))?); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "m"))?); - Ok(()) - } - - #[test] - fn test_exists_in_scope_not_found_nonexistent() -> Result<()> { - let tree = scoped_test_tree(); - assert!(!tree.exists_in_scope(|n| Ok(n.data == "zzz"))?); - Ok(()) - } - - // === rewrite_in_scope === - - #[test] - fn test_rewrite_in_scope_continue() -> Result<()> { - let tree = scoped_test_tree(); - let mut rewriter = TestRewriter::new( - Box::new(transform_yes_in_scope("f_down")), - Box::new(transform_yes_in_scope("f_up")), - ); - let result = tree.rewrite_in_scope(&mut rewriter)?; - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(f_down(a))", - "f_up(f_down(b))", - "f_up(f_down(c))", - "f_up(f_down(d))", - "f_up(f_down(e))", - "f_up(f_down(f))", - "f_up(f_down(g))", - ]) - ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"h".to_string())); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_rewrite_in_scope_f_down_jump_on_c() -> Result<()> { - let tree = scoped_test_tree(); - let mut rewriter = TestRewriter::new( - Box::new(transform_and_event_on_in_scope( - "f_down", - "c", - TreeNodeRecursion::Jump, - )), - Box::new(transform_yes_in_scope("f_up")), - ); - let result = tree.rewrite_in_scope(&mut rewriter)?; - // Jump on c: children skipped, but sibling f is visited - assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(f_down(a))", - "f_up(f_down(b))", - "f_up(f_down(c))", - "d", - "e", - "f_up(f_down(f))", - "f_up(f_down(g))", - ]) - ); - Ok(()) - } - - // === Nested scope isolation tests === - // Each scope is tested with multiple traversal methods to ensure - // no scope crossing occurs in any direction. - - /// Build the scope_l subtree (innermost scope: l -> m) - fn build_scope_l() -> TestTreeNode { - let m = TestTreeNode::new_leaf("m".to_string()); - TestTreeNode::new(vec![m], "l".to_string()) - } - - /// Build the scope_h subtree (middle scope: h -> [i, j -> k]) - fn build_scope_h() -> TestTreeNode { - let l = build_scope_l(); - let i = TestTreeNode::new_leaf("i".to_string()); - let k = TestTreeNode::new_leaf("k".to_string()); - let j = TestTreeNode::new_mixed( - vec![k.clone(), l.clone()], - vec![l], - "j".to_string(), - ); - TestTreeNode::new(vec![i, j], "h".to_string()) - } - - // --- Scope h (middle): 4 in-scope nodes --- - - #[test] - fn test_scope_h_apply() -> Result<()> { - let h = build_scope_h(); - let mut visits = vec![]; - h.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(visits, s(&["h", "i", "j", "k"])); - assert!(!visits.contains(&"l".to_string())); - assert!(!visits.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_scope_h_visit() -> Result<()> { - let h = build_scope_h(); - let mut visitor = - TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - h.visit_in_scope(&mut visitor)?; - assert_eq!( - visitor.visits, - s(&[ - "f_down(h)", - "f_down(i)", - "f_up(i)", - "f_down(j)", - "f_down(k)", - "f_up(k)", - "f_up(j)", - "f_up(h)", - ]) - ); - // Must not enter scope l - for v in &visitor.visits { - assert!( - !v.contains("(l)"), - "scope h visit must not enter scope l: {v}" - ); - assert!( - !v.contains("(m)"), - "scope h visit must not enter scope l: {v}" + .unwrap(); + assert_eq!(s_apply, e_apply, "apply mismatch for {}", scoped.data); + + // exists_in_scope == exists for each in-scope node + for name in &e_apply { + assert_eq!( + scoped.exists_in_scope(|n| Ok(&n.data == name)).unwrap(), + equivalent.exists(|n| Ok(&n.data == name)).unwrap(), + "exists mismatch for node {name} in scope {}", + scoped.data, ); } - Ok(()) - } - - #[test] - fn test_scope_h_transform_down() -> Result<()> { - let h = build_scope_h(); - let result = h.transform_down_in_scope(transform_yes_in_scope("TX"))?; - assert_eq!( - collect_scoped_data(&result.data), - s(&["TX(h)", "TX(i)", "TX(j)", "TX(k)",]) - ); - // Scope l untouched in children path - let children = collect_children_data(&result.data); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_scope_h_transform_up() -> Result<()> { - let h = build_scope_h(); - let result = h.transform_up_in_scope(transform_yes_in_scope("TX"))?; - assert_eq!( - collect_scoped_data(&result.data), - s(&["TX(h)", "TX(i)", "TX(j)", "TX(k)",]) - ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_scope_h_exists() -> Result<()> { - let h = build_scope_h(); - // In scope - assert!(h.exists_in_scope(|n| Ok(n.data == "h"))?); - assert!(h.exists_in_scope(|n| Ok(n.data == "i"))?); - assert!(h.exists_in_scope(|n| Ok(n.data == "j"))?); - assert!(h.exists_in_scope(|n| Ok(n.data == "k"))?); - // Out of scope (scope l) - assert!(!h.exists_in_scope(|n| Ok(n.data == "l"))?); - assert!(!h.exists_in_scope(|n| Ok(n.data == "m"))?); - Ok(()) - } - #[test] - fn test_scope_h_rewrite() -> Result<()> { - let h = build_scope_h(); - let mut rewriter = TestRewriter::new( - Box::new(transform_yes_in_scope("f_down")), - Box::new(transform_yes_in_scope("f_up")), - ); - let result = h.rewrite_in_scope(&mut rewriter)?; + // transform_down_in_scope == transform_down + let s_td = scoped + .clone() + .transform_down_in_scope(transform_yes_scoped("tx")) + .unwrap(); + let e_td = equivalent + .clone() + .transform_down(transform_yes("tx")) + .unwrap(); assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(f_down(h))", - "f_up(f_down(i))", - "f_up(f_down(j))", - "f_up(f_down(k))", - ]) + collect_scoped_data(&s_td.data), + collect_children_data(&e_td.data), + "transform_down mismatch for {}", + scoped.data, ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - #[test] - fn test_scope_h_transform_down_up() -> Result<()> { - let h = build_scope_h(); - let result = h.transform_down_up_in_scope( - transform_yes_in_scope("f_down"), - transform_yes_in_scope("f_up"), - )?; + // transform_up_in_scope == transform_up + let s_tu = scoped + .clone() + .transform_up_in_scope(transform_yes_scoped("tx")) + .unwrap(); + let e_tu = equivalent + .clone() + .transform_up(transform_yes("tx")) + .unwrap(); assert_eq!( - collect_scoped_data(&result.data), - s(&[ - "f_up(f_down(h))", - "f_up(f_down(i))", - "f_up(f_down(j))", - "f_up(f_down(k))", - ]) + collect_scoped_data(&s_tu.data), + collect_children_data(&e_tu.data), + "transform_up mismatch for {}", + scoped.data, ); - let children = collect_children_data(&result.data); - assert!(children.contains(&"l".to_string())); - assert!(children.contains(&"m".to_string())); - Ok(()) - } - - // --- Scope l (innermost): 2 in-scope nodes --- - #[test] - fn test_scope_l_apply() -> Result<()> { - let l = build_scope_l(); - let mut visits = vec![]; - l.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(visits, s(&["l", "m"])); - Ok(()) - } - - #[test] - fn test_scope_l_visit() -> Result<()> { - let l = build_scope_l(); - let mut visitor = - TestVisitor::new(Box::new(visit_continue), Box::new(visit_continue)); - l.visit_in_scope(&mut visitor)?; - assert_eq!( - visitor.visits, - s(&[ - "f_down(l)", - "f_down(m)", - "f_up(m)", - "f_up(l)", - ]) - ); - Ok(()) - } - - #[test] - fn test_scope_l_transform_down() -> Result<()> { - let l = build_scope_l(); - let result = l.transform_down_in_scope(transform_yes_in_scope("TX"))?; - assert_eq!( - collect_scoped_data(&result.data), - s(&["TX(l)", "TX(m)",]) - ); - Ok(()) - } - - #[test] - fn test_scope_l_transform_up() -> Result<()> { - let l = build_scope_l(); - let result = l.transform_up_in_scope(transform_yes_in_scope("TX"))?; + // transform_down_up_in_scope == transform_down_up + let s_tdu = scoped + .clone() + .transform_down_up_in_scope( + transform_yes_scoped("f_down"), + transform_yes_scoped("f_up"), + ) + .unwrap(); + let e_tdu = equivalent + .clone() + .transform_down_up(transform_yes("f_down"), transform_yes("f_up")) + .unwrap(); assert_eq!( - collect_scoped_data(&result.data), - s(&["TX(l)", "TX(m)",]) + collect_scoped_data(&s_tdu.data), + collect_children_data(&e_tdu.data), + "transform_down_up mismatch for {}", + scoped.data, ); - Ok(()) - } - #[test] - fn test_scope_l_transform_down_up() -> Result<()> { - let l = build_scope_l(); - let result = l.transform_down_up_in_scope( - transform_yes_in_scope("f_down"), - transform_yes_in_scope("f_up"), - )?; - assert_eq!( - collect_scoped_data(&result.data), - s(&["f_up(f_down(l))", "f_up(f_down(m))",]) + // rewrite_in_scope == rewrite + let mut sr = TestRewriter::new( + Box::new(transform_yes_scoped("f_down")), + Box::new(transform_yes_scoped("f_up")), ); - Ok(()) - } - - #[test] - fn test_scope_l_rewrite() -> Result<()> { - let l = build_scope_l(); - let mut rewriter = TestRewriter::new( - Box::new(transform_yes_in_scope("f_down")), - Box::new(transform_yes_in_scope("f_up")), + let s_rw = scoped.clone().rewrite_in_scope(&mut sr).unwrap(); + let mut er = TestRewriter::new( + Box::new(transform_yes("f_down")), + Box::new(transform_yes("f_up")), ); - let result = l.rewrite_in_scope(&mut rewriter)?; + let e_rw = equivalent.rewrite(&mut er).unwrap(); assert_eq!( - collect_scoped_data(&result.data), - s(&["f_up(f_down(l))", "f_up(f_down(m))",]) + collect_scoped_data(&s_rw.data), + collect_children_data(&e_rw.data), + "rewrite mismatch for {}", + scoped.data, ); - Ok(()) - } - - #[test] - fn test_scope_l_exists() -> Result<()> { - let l = build_scope_l(); - assert!(l.exists_in_scope(|n| Ok(n.data == "l"))?); - assert!(l.exists_in_scope(|n| Ok(n.data == "m"))?); - // Not in any scope from l's perspective - assert!(!l.exists_in_scope(|n| Ok(n.data == "h"))?); - assert!(!l.exists_in_scope(|n| Ok(n.data == "d"))?); - Ok(()) } - // --- Outer scope: transform must not affect any inner scope --- - #[test] - fn test_outer_scope_transform_does_not_affect_scope_h() -> Result<()> { + fn test_scoped_traversal_matches_non_scoped() -> Result<()> { let tree = scoped_test_tree(); - let result = tree.transform_down_in_scope(transform_yes_in_scope("TX"))?; - let all_data = collect_children_data(&result.data); - assert_eq!(all_data[0], "TX(a)"); - // Scope h completely untouched - assert!(all_data.contains(&"h".to_string())); - assert!(all_data.contains(&"i".to_string())); - assert!(all_data.contains(&"j".to_string())); - assert!(all_data.contains(&"k".to_string())); - // Scope l completely untouched - assert!(all_data.contains(&"l".to_string())); - assert!(all_data.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_outer_scope_transform_down_up_does_not_affect_inner_scopes() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_down_up_in_scope( - transform_yes_in_scope("f_down"), - transform_yes_in_scope("f_up"), - )?; - let all_data = collect_children_data(&result.data); - assert!(all_data.contains(&"h".to_string())); - assert!(all_data.contains(&"i".to_string())); - assert!(all_data.contains(&"j".to_string())); - assert!(all_data.contains(&"k".to_string())); - assert!(all_data.contains(&"l".to_string())); - assert!(all_data.contains(&"m".to_string())); - Ok(()) - } - - #[test] - fn test_outer_scope_transform_up_does_not_affect_inner_scopes() -> Result<()> { - let tree = scoped_test_tree(); - let result = tree.transform_up_in_scope(transform_yes_in_scope("TX"))?; - let all_data = collect_children_data(&result.data); - // Scope h + scope l completely untouched - assert!(all_data.contains(&"h".to_string())); - assert!(all_data.contains(&"i".to_string())); - assert!(all_data.contains(&"j".to_string())); - assert!(all_data.contains(&"k".to_string())); - assert!(all_data.contains(&"l".to_string())); - assert!(all_data.contains(&"m".to_string())); - Ok(()) - } - - // === Edge cases === - - #[test] - fn test_leaf_node_scoped_traversal() -> Result<()> { - let leaf = TestTreeNode::::new_leaf("leaf".to_string()); - let mut visits = vec![]; - leaf.apply_in_scope(|n| { - visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(visits, vec!["leaf"]); - Ok(()) - } - - #[test] - fn test_node_new_has_all_children_in_scope() -> Result<()> { - let child = TestTreeNode::new_leaf("child".to_string()); - let parent = TestTreeNode::new(vec![child], "parent".to_string()); - - let mut scoped_visits = vec![]; - parent.apply_in_scope(|n| { - scoped_visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(scoped_visits, vec!["parent", "child"]); - - let mut all_visits = vec![]; - parent.apply(|n| { - all_visits.push(n.data.clone()); - Ok(TreeNodeRecursion::Continue) - })?; - assert_eq!(all_visits, vec!["parent", "child"]); + for node in all_nodes(&tree) { + assert_all_scoped_traversals_match(node); + } Ok(()) } } From 11d7e7970c67603126bf2c73dec5dd67ab548509 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Sun, 5 Apr 2026 22:31:09 +0300 Subject: [PATCH 17/17] avoid clone --- datafusion/common/src/tree_node.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/datafusion/common/src/tree_node.rs b/datafusion/common/src/tree_node.rs index f46c7fcde7638..7dafd58160ae4 100644 --- a/datafusion/common/src/tree_node.rs +++ b/datafusion/common/src/tree_node.rs @@ -2868,7 +2868,13 @@ pub(crate) mod tests { fn test_large_tree() { let mut item = TestTreeNode::new_leaf("initial".to_string()); for i in 0..3000 { - item = TestTreeNode::new(vec![item], format!("parent-{i}")); + // Avoid TestTreeNode::new() here which clones children into + // children_in_same_scope - that would be O(n^2) for a deep chain. + item = TestTreeNode { + children: vec![item], + children_in_same_scope: vec![], + data: format!("parent-{i}"), + }; } let mut visitor =