Skip to content

Commit

Permalink
use stricter references in apply() and visit()
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed May 22, 2024
1 parent 8711fe2 commit 7200e30
Show file tree
Hide file tree
Showing 83 changed files with 209 additions and 267 deletions.
2 changes: 1 addition & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ impl ExecutionPlan for CustomExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![]
}

Expand Down
125 changes: 26 additions & 99 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,26 +123,13 @@ pub trait TreeNode: Sized {
/// TreeNodeVisitor::f_up(ChildNode2)
/// TreeNodeVisitor::f_up(ParentNode)
/// ```
fn visit<V: TreeNodeVisitor<Node = Self>>(
&self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
.f_down(self)?
.visit_children(|| self.apply_children(|c| c.visit(visitor)))?
.visit_parent(|| visitor.f_up(self))
}

/// Similar to [`TreeNode::visit()`], but the lifetimes of the [`TreeNode`] references
/// passed to [`TreeNodeRefVisitor::f_down()`] and [`TreeNodeRefVisitor::f_up()`]
/// methods match the lifetime of the original root [`TreeNode`] reference.
fn visit_ref<'n, V: TreeNodeRefVisitor<'n, Node = Self>>(
fn visit<'n, V: TreeNodeVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
.f_down(self)?
.visit_children(|| self.apply_children_ref(|c| c.visit_ref(visitor)))?
.visit_children(|| self.apply_children(|c| c.visit(visitor)))?
.visit_parent(|| visitor.f_up(self))
}

Expand Down Expand Up @@ -203,39 +190,18 @@ pub trait TreeNode: Sized {
/// # See Also
/// * [`Self::transform_down`] for the equivalent transformation API.
/// * [`Self::visit`] for both top-down and bottom up traversal.
fn apply<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_impl<N: TreeNode, F: FnMut(&N) -> Result<TreeNodeRecursion>>(
node: &N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}

apply_impl(self, &mut f)
}

/// Similar to [`TreeNode::apply()`], but the lifetime of the [`TreeNode`] references
/// passed to the `f` closures match the lifetime of the original root [`TreeNode`]
/// reference.
fn apply_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
fn apply<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_ref_impl<
'n,
N: TreeNode,
F: FnMut(&'n N) -> Result<TreeNodeRecursion>,
>(
fn apply_impl<'n, N: TreeNode, F: FnMut(&'n N) -> Result<TreeNodeRecursion>>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children_ref(|c| apply_ref_impl(c, f)))
f(node)?.visit_children(|| node.apply_children(|c| apply_impl(c, f)))
}

apply_ref_impl(self, &mut f)
apply_impl(self, &mut f)
}

/// Recursively rewrite the node's children and then the node using `f`
Expand Down Expand Up @@ -461,18 +427,7 @@ pub trait TreeNode: Sized {
///
/// Description: Apply `f` to inspect node's children (but not the node
/// itself).
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion> {
// The default implementation is the stricter `apply_children_ref()`
self.apply_children_ref(f)
}

/// Similar to [`TreeNode::apply_children()`], but the lifetime of the [`TreeNode`]
/// references passed to the `f` closures match the lifetime of the original root
/// [`TreeNode`] reference.
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion>;
Expand Down Expand Up @@ -511,27 +466,7 @@ pub trait TreeNode: Sized {
///
/// # See Also:
/// * [`TreeNode::rewrite`] to rewrite owned `TreeNode`s
pub trait TreeNodeVisitor: Sized {
/// The node type which is visitable.
type Node: TreeNode;

/// Invoked while traversing down the tree, before any children are visited.
/// Default implementation continues the recursion.
fn f_down(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}

/// Invoked while traversing up the tree after children are visited. Default
/// implementation continues the recursion.
fn f_up(&mut self, _node: &Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

/// Similar to [`TreeNodeVisitor`], but the lifetimes of the [`TreeNode`] references
/// passed to [`TreeNodeRefVisitor::f_down()`] and [`TreeNodeRefVisitor::f_up()`] methods
/// match the lifetime of the original root [`TreeNode`] reference.
pub trait TreeNodeRefVisitor<'n>: Sized {
pub trait TreeNodeVisitor<'n>: Sized {
/// The node type which is visitable.
type Node: TreeNode;

Expand Down Expand Up @@ -920,11 +855,7 @@ impl<T> TransformedResult<T> for Result<Transformed<T>> {
/// its related `Arc<dyn T>` will automatically implement [`TreeNode`].
pub trait DynTreeNode {
/// Returns all children of the specified `TreeNode`.
fn arc_children(&self) -> Vec<Arc<Self>>;

fn children(&self) -> Vec<&Arc<Self>> {
panic!("DynTreeNode::children is not implemented yet")
}
fn arc_children(&self) -> Vec<&Arc<Self>>;

/// Constructs a new node with the specified children.
fn with_new_arc_children(
Expand All @@ -937,18 +868,11 @@ pub trait DynTreeNode {
/// Blanket implementation for any `Arc<T>` where `T` implements [`DynTreeNode`]
/// (such as [`Arc<dyn PhysicalExpr>`]).
impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion> {
self.arc_children().iter().apply_until_stop(f)
}

fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
self.arc_children().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
Expand All @@ -957,7 +881,10 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
) -> Result<Transformed<Self>> {
let children = self.arc_children();
if !children.is_empty() {
let new_children = children.into_iter().map_until_stop_and_collect(f)?;
let new_children = children
.into_iter()
.cloned()
.map_until_stop_and_collect(f)?;
// Propagate up `new_children.transformed` and `new_children.tnr`
// along with the node containing transformed children.
if new_children.transformed {
Expand Down Expand Up @@ -989,7 +916,7 @@ pub trait ConcreteTreeNode: Sized {
}

impl<T: ConcreteTreeNode> TreeNode for T {
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
Expand Down Expand Up @@ -1018,8 +945,8 @@ mod tests {
use std::fmt::Display;

use crate::tree_node::{
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRefVisitor,
TreeNodeRewriter, TreeNodeVisitor,
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
};
use crate::Result;

Expand All @@ -1036,7 +963,7 @@ mod tests {
}

impl<T> TreeNode for TestTreeNode<T> {
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
fn apply_children<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
Expand Down Expand Up @@ -1536,15 +1463,15 @@ mod tests {
}
}

impl<T: Display> TreeNodeVisitor for TestVisitor<T> {
impl<'n, T: Display> TreeNodeVisitor<'n> for TestVisitor<T> {
type Node = TestTreeNode<T>;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_down({})", node.data));
(*self.f_down)(node)
}

fn f_up(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
self.visits.push(format!("f_up({})", node.data));
(*self.f_up)(node)
}
Expand Down Expand Up @@ -2001,7 +1928,7 @@ mod tests {
// |
// A
#[test]
fn test_apply_ref() -> Result<()> {
fn test_apply_and_visit_references() -> Result<()> {
let node_a = TestTreeNode::new(vec![], "a".to_string());
let node_b = TestTreeNode::new(vec![], "b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
Expand All @@ -2022,7 +1949,7 @@ mod tests {
let node_a_ref = &node_d_ref.children[0];

let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
tree.apply_ref(|e| {
tree.apply(|e| {
*m.entry(e).or_insert(0) += 1;
Ok(TreeNodeRecursion::Continue)
})?;
Expand All @@ -2041,7 +1968,7 @@ mod tests {
m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
}

impl<'n> TreeNodeRefVisitor<'n> for TestVisitor<'n> {
impl<'n> TreeNodeVisitor<'n> for TestVisitor<'n> {
type Node = TestTreeNode<String>;

fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Expand All @@ -2058,7 +1985,7 @@ mod tests {
}

let mut visitor = TestVisitor { m: HashMap::new() };
tree.visit_ref(&mut visitor)?;
tree.visit(&mut visitor)?;

let expected = HashMap::from([
(node_f_ref, (1, 1)),
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/arrow_file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ impl ExecutionPlan for ArrowExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/avro.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ impl ExecutionPlan for AvroExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/csv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ impl ExecutionPlan for CsvExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/datasource/physical_plan/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ impl ExecutionPlan for NdJsonExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
Vec::new()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ impl ExecutionPlan for ParquetExec {
&self.cache
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
// this is a leaf node and has no children
vec![]
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2441,10 +2441,10 @@ impl<'a> BadPlanVisitor<'a> {
}
}

impl<'a> TreeNodeVisitor for BadPlanVisitor<'a> {
impl<'n, 'a> TreeNodeVisitor<'n> for BadPlanVisitor<'a> {
type Node = LogicalPlan;

fn f_down(&mut self, node: &Self::Node) -> Result<TreeNodeRecursion> {
fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
match node {
LogicalPlan::Ddl(ddl) if !self.options.allow_ddl => {
plan_err!("DDL not supported: {}", ddl.name())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ fn take_optimizable(node: &dyn ExecutionPlan) -> Option<Arc<dyn ExecutionPlan>>
return Some(child);
}
}
if let [ref childrens_child] = child.children().as_slice() {
if let [childrens_child] = child.children().as_slice() {
child = Arc::clone(childrens_child);
} else {
break;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1375,8 +1375,8 @@ pub(crate) mod tests {
vec![false]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

// model that it requires the output ordering of its input
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_optimizer/enforce_sorting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ fn remove_corresponding_sort_from_sub_plan(
// Replace with variants that do not preserve order.
if is_sort_preserving_merge(&node.plan) {
node.children = node.children.swap_remove(0).children;
node.plan = node.plan.children().swap_remove(0);
node.plan = node.plan.children().swap_remove(0).clone();
} else if let Some(repartition) =
node.plan.as_any().downcast_ref::<RepartitionExec>()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ impl LimitedDistinctAggregation {
let mut is_global_limit = false;
if let Some(local_limit) = plan.as_any().downcast_ref::<LocalLimitExec>() {
limit = local_limit.fetch();
children = local_limit.children();
children = local_limit.children().into_iter().cloned().collect();
} else if let Some(global_limit) = plan.as_any().downcast_ref::<GlobalLimitExec>()
{
global_fetch = global_limit.fetch();
global_fetch?;
global_skip = global_limit.skip();
// the aggregate must read at least fetch+skip number of rows
limit = global_fetch.unwrap() + global_skip;
children = global_limit.children();
children = global_limit.children().into_iter().cloned().collect();
is_global_limit = true
} else {
return None;
Expand Down
6 changes: 3 additions & 3 deletions datafusion/core/src/physical_optimizer/output_requirements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ impl ExecutionPlan for OutputRequirementExec {
vec![true]
}

fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
vec![self.input.clone()]
fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
vec![&self.input]
}

fn required_input_ordering(&self) -> Vec<Option<Vec<PhysicalSortRequirement>>> {
Expand Down Expand Up @@ -273,7 +273,7 @@ fn require_top_ordering_helper(
// When an operator requires an ordering, any `SortExec` below can not
// be responsible for (i.e. the originator of) the global ordering.
let (new_child, is_changed) =
require_top_ordering_helper(children.swap_remove(0))?;
require_top_ordering_helper(children.swap_remove(0).clone())?;
Ok((plan.with_new_children(vec![new_child])?, is_changed))
} else {
// Stop searching, there is no global ordering desired for the query.
Expand Down

0 comments on commit 7200e30

Please sign in to comment.