Skip to content

Commit

Permalink
add reference visitor APIs
Browse files Browse the repository at this point in the history
  • Loading branch information
peter-toth committed May 22, 2024
1 parent e893a2e commit 8711fe2
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 11 deletions.
174 changes: 167 additions & 7 deletions datafusion/common/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,19 @@ pub trait TreeNode: Sized {
.visit_parent(|| visitor.f_up(self))
}

/// Similar to [`TreeNode::visit()`], but the lifetimes of the [`TreeNode`] references
/// passed to [`TreeNodeRefVisitor::f_down()`] and [`TreeNodeRefVisitor::f_up()`]
/// methods match the lifetime of the original root [`TreeNode`] reference.
fn visit_ref<'n, V: TreeNodeRefVisitor<'n, Node = Self>>(
&'n self,
visitor: &mut V,
) -> Result<TreeNodeRecursion> {
visitor
.f_down(self)?
.visit_children(|| self.apply_children_ref(|c| c.visit_ref(visitor)))?
.visit_parent(|| visitor.f_up(self))
}

/// Rewrite the tree node with a [`TreeNodeRewriter`], performing a
/// depth-first walk of the node and its children.
///
Expand Down Expand Up @@ -204,6 +217,27 @@ pub trait TreeNode: Sized {
apply_impl(self, &mut f)
}

/// Similar to [`TreeNode::apply()`], but the lifetime of the [`TreeNode`] references
/// passed to the `f` closures match the lifetime of the original root [`TreeNode`]
/// reference.
fn apply_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
mut f: F,
) -> Result<TreeNodeRecursion> {
fn apply_ref_impl<
'n,
N: TreeNode,
F: FnMut(&'n N) -> Result<TreeNodeRecursion>,
>(
node: &'n N,
f: &mut F,
) -> Result<TreeNodeRecursion> {
f(node)?.visit_children(|| node.apply_children_ref(|c| apply_ref_impl(c, f)))
}

apply_ref_impl(self, &mut f)
}

/// Recursively rewrite the node's children and then the node using `f`
/// (a bottom-up post-order traversal).
///
Expand Down Expand Up @@ -430,6 +464,17 @@ pub trait TreeNode: Sized {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
f: F,
) -> Result<TreeNodeRecursion> {
// The default implementation is the stricter `apply_children_ref()`
self.apply_children_ref(f)
}

/// Similar to [`TreeNode::apply_children()`], but the lifetime of the [`TreeNode`]
/// references passed to the `f` closures match the lifetime of the original root
/// [`TreeNode`] reference.
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion>;

/// Low-level API used to implement other APIs.
Expand Down Expand Up @@ -483,6 +528,26 @@ pub trait TreeNodeVisitor: Sized {
}
}

/// Similar to [`TreeNodeVisitor`], but the lifetimes of the [`TreeNode`] references
/// passed to [`TreeNodeRefVisitor::f_down()`] and [`TreeNodeRefVisitor::f_up()`] methods
/// match the lifetime of the original root [`TreeNode`] reference.
pub trait TreeNodeRefVisitor<'n>: Sized {
/// The node type which is visitable.
type Node: TreeNode;

/// Invoked while traversing down the tree, before any children are visited.
/// Default implementation continues the recursion.
fn f_down(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}

/// Invoked while traversing up the tree after children are visited. Default
/// implementation continues the recursion.
fn f_up(&mut self, _node: &'n Self::Node) -> Result<TreeNodeRecursion> {
Ok(TreeNodeRecursion::Continue)
}
}

/// A [Visitor](https://en.wikipedia.org/wiki/Visitor_pattern) for recursively
/// rewriting [`TreeNode`]s via [`TreeNode::rewrite`].
///
Expand Down Expand Up @@ -857,6 +922,10 @@ pub trait DynTreeNode {
/// Returns all children of the specified `TreeNode`.
fn arc_children(&self) -> Vec<Arc<Self>>;

fn children(&self) -> Vec<&Arc<Self>> {
panic!("DynTreeNode::children is not implemented yet")
}

/// Constructs a new node with the specified children.
fn with_new_arc_children(
&self,
Expand All @@ -875,6 +944,13 @@ impl<T: DynTreeNode + ?Sized> TreeNode for Arc<T> {
self.arc_children().iter().apply_until_stop(f)
}

fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
}

fn map_children<F: FnMut(Self) -> Result<Transformed<Self>>>(
self,
f: F,
Expand Down Expand Up @@ -913,8 +989,8 @@ pub trait ConcreteTreeNode: Sized {
}

impl<T: ConcreteTreeNode> TreeNode for T {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children().into_iter().apply_until_stop(f)
Expand All @@ -938,15 +1014,16 @@ impl<T: ConcreteTreeNode> TreeNode for T {

#[cfg(test)]
mod tests {
use std::collections::HashMap;
use std::fmt::Display;

use crate::tree_node::{
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRewriter,
TreeNodeVisitor,
Transformed, TreeNode, TreeNodeIterator, TreeNodeRecursion, TreeNodeRefVisitor,
TreeNodeRewriter, TreeNodeVisitor,
};
use crate::Result;

#[derive(PartialEq, Debug)]
#[derive(Debug, Eq, Hash, PartialEq)]
struct TestTreeNode<T> {
children: Vec<TestTreeNode<T>>,
data: T,
Expand All @@ -959,8 +1036,8 @@ mod tests {
}

impl<T> TreeNode for TestTreeNode<T> {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.children.iter().apply_until_stop(f)
Expand Down Expand Up @@ -1912,4 +1989,87 @@ mod tests {
TreeNodeRecursion::Stop
)
);

// F
// / | \
// / | \
// E C A
// | / \
// C B D
// / \ |
// B D A
// |
// A
#[test]
fn test_apply_ref() -> Result<()> {
let node_a = TestTreeNode::new(vec![], "a".to_string());
let node_b = TestTreeNode::new(vec![], "b".to_string());
let node_d = TestTreeNode::new(vec![node_a], "d".to_string());
let node_c = TestTreeNode::new(vec![node_b, node_d], "c".to_string());
let node_e = TestTreeNode::new(vec![node_c], "e".to_string());
let node_a_2 = TestTreeNode::new(vec![], "a".to_string());
let node_b_2 = TestTreeNode::new(vec![], "b".to_string());
let node_d_2 = TestTreeNode::new(vec![node_a_2], "d".to_string());
let node_c_2 = TestTreeNode::new(vec![node_b_2, node_d_2], "c".to_string());
let node_a_3 = TestTreeNode::new(vec![], "a".to_string());
let tree = TestTreeNode::new(vec![node_e, node_c_2, node_a_3], "f".to_string());

let node_f_ref = &tree;
let node_e_ref = &node_f_ref.children[0];
let node_c_ref = &node_e_ref.children[0];
let node_b_ref = &node_c_ref.children[0];
let node_d_ref = &node_c_ref.children[1];
let node_a_ref = &node_d_ref.children[0];

let mut m: HashMap<&TestTreeNode<String>, usize> = HashMap::new();
tree.apply_ref(|e| {
*m.entry(e).or_insert(0) += 1;
Ok(TreeNodeRecursion::Continue)
})?;

let expected = HashMap::from([
(node_f_ref, 1),
(node_e_ref, 1),
(node_c_ref, 2),
(node_d_ref, 2),
(node_b_ref, 2),
(node_a_ref, 3),
]);
assert_eq!(m, expected);

struct TestVisitor<'n> {
m: HashMap<&'n TestTreeNode<String>, (usize, usize)>,
}

impl<'n> TreeNodeRefVisitor<'n> for TestVisitor<'n> {
type Node = TestTreeNode<String>;

fn f_down(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (down_count, _) = self.m.entry(node).or_insert((0, 0));
*down_count += 1;
Ok(TreeNodeRecursion::Continue)
}

fn f_up(&mut self, node: &'n Self::Node) -> Result<TreeNodeRecursion> {
let (_, up_count) = self.m.entry(node).or_insert((0, 0));
*up_count += 1;
Ok(TreeNodeRecursion::Continue)
}
}

let mut visitor = TestVisitor { m: HashMap::new() };
tree.visit_ref(&mut visitor)?;

let expected = HashMap::from([
(node_f_ref, (1, 1)),
(node_e_ref, (1, 1)),
(node_c_ref, (2, 2)),
(node_d_ref, (2, 2)),
(node_b_ref, (2, 2)),
(node_a_ref, (3, 3)),
]);
assert_eq!(visitor.m, expected);

Ok(())
}
}
4 changes: 2 additions & 2 deletions datafusion/expr/src/logical_plan/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ use datafusion_common::{
};

impl TreeNode for LogicalPlan {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
self.inputs().into_iter().apply_until_stop(f)
Expand Down
4 changes: 2 additions & 2 deletions datafusion/expr/src/tree_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ use datafusion_common::tree_node::{
use datafusion_common::{map_until_stop_and_collect, Result};

impl TreeNode for Expr {
fn apply_children<F: FnMut(&Self) -> Result<TreeNodeRecursion>>(
&self,
fn apply_children_ref<'n, F: FnMut(&'n Self) -> Result<TreeNodeRecursion>>(
&'n self,
f: F,
) -> Result<TreeNodeRecursion> {
let children = match self {
Expand Down

0 comments on commit 8711fe2

Please sign in to comment.