Skip to content

Commit

Permalink
Merge pull request #53 from SymbolicML/abstract-node
Browse files Browse the repository at this point in the history
Create `AbstractNode` super type
  • Loading branch information
MilesCranmer committed Aug 28, 2023
2 parents 65184f9 + 4f621ec commit 4638851
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 48 deletions.
6 changes: 6 additions & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,9 @@ You can create a copy of a node with `copy_node`:
```@docs
copy_node(tree::Node)
```

There is also an abstract type `AbstractNode` which is a supertype of `Node`:

```@docs
AbstractNode
```
9 changes: 8 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,14 @@ include("ExtensionInterface.jl")
import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
AbstractNode,
Node,
string_tree,
print_tree,
copy_node,
set_node!,
tree_mapreduce,
filter_map
@reexport import .EquationUtilsModule:
count_nodes,
count_constants,
Expand Down
20 changes: 19 additions & 1 deletion src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,24 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap

const DEFAULT_NODE_TYPE = Float32

"""
AbstractNode
Abstract type for binary trees. Must have the following fields:
- `degree::Integer`: Degree of the node. Either 0, 1, or 2. If 1,
then `l` needs to be defined as the left child. If 2,
then `r` also needs to be defined as the right child.
- `l::AbstractNode`: Left child of the current node. Should only be
defined if `degree >= 1`; otherwise, leave it undefined (see the
the constructors of `Node{T}` for an example).
Don't use `nothing` to represent an undefined value
as it will incur a large performance penalty.
- `r::AbstractNode`: Right child of the current node. Should only
be defined if `degree == 2`.
"""
abstract type AbstractNode end

#! format: off
"""
Node{T}
Expand Down Expand Up @@ -36,7 +54,7 @@ nodes, you can evaluate or print a given expression.
Same type as the parent node. This is to be passed as the right
argument to the binary operator.
"""
mutable struct Node{T}
mutable struct Node{T} <: AbstractNode
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::Union{T,Nothing} # If is a constant, this stores the actual value
Expand Down
10 changes: 5 additions & 5 deletions src/EquationUtils.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
module EquationUtilsModule

import Compat: Returns
import ..EquationModule: Node, copy_node, tree_mapreduce, any, filter_map
import ..EquationModule: AbstractNode, Node, copy_node, tree_mapreduce, any, filter_map

"""
count_nodes(tree::Node{T})::Int where {T}
count_nodes(tree::AbstractNode)::Int
Count the number of nodes in the tree.
"""
count_nodes(tree::Node) = tree_mapreduce(_ -> 1, +, tree)
count_nodes(tree::AbstractNode) = tree_mapreduce(_ -> 1, +, tree)
# This code is given as an example. Normally we could just use sum(Returns(1), tree).

"""
count_depth(tree::Node{T})::Int where {T}
count_depth(tree::AbstractNode)::Int
Compute the max depth of the tree.
"""
function count_depth(tree::Node)
function count_depth(tree::AbstractNode)
return tree_mapreduce(Returns(1), (p, child...) -> p + max(child...), tree)
end

Expand Down
92 changes: 51 additions & 41 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ import Compat: @inline, Returns
import ..UtilsModule: @memoize_on, @with_memoize

"""
tree_mapreduce(f::Function, op::Function, tree::Node, result_type::Type=Nothing)
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::Node, result_type::Type=Nothing)
tree_mapreduce(f::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
tree_mapreduce(f_leaf::Function, f_branch::Function, op::Function, tree::AbstractNode, result_type::Type=Nothing)
Map a function over a tree and aggregate the result using an operator `op`.
`op` should be defined with inputs `(parent, child...) ->` so that it can aggregate
Expand Down Expand Up @@ -66,23 +66,27 @@ end # Get list of constants. (regular mapreduce also works)
```
"""
function tree_mapreduce(
f::F, op::G, tree::N, result_type::Type{RT}=Nothing; preserve_sharing::Bool=false
) where {T,N<:Node{T},F<:Function,G<:Function,RT}
f::F,
op::G,
tree::AbstractNode,
result_type::Type{RT}=Nothing;
preserve_sharing::Bool=false,
) where {F<:Function,G<:Function,RT}
return tree_mapreduce(f, f, op, tree, result_type; preserve_sharing)
end
function tree_mapreduce(
f_leaf::F1,
f_branch::F2,
op::G,
tree::N,
tree::AbstractNode,
result_type::Type{RT}=Nothing;
preserve_sharing::Bool=false,
) where {T,N<:Node{T},F1<:Function,F2<:Function,G<:Function,RT}
) where {F1<:Function,F2<:Function,G<:Function,RT}

# Trick taken from here:
# https://discourse.julialang.org/t/recursive-inner-functions-a-thousand-times-slower/85604/5
# to speed up recursive closure
@memoize_on t function inner(inner, t::Node)
@memoize_on t function inner(inner, t)
if t.degree == 0
return @inline(f_leaf(t))
elseif t.degree == 1
Expand All @@ -97,19 +101,19 @@ function tree_mapreduce(
throw(ArgumentError("Need to specify `result_type` if you use `preserve_sharing`."))

if preserve_sharing && RT != Nothing
return @with_memoize inner(inner, tree) IdDict{N,RT}()
return @with_memoize inner(inner, tree) IdDict{typeof(tree),RT}()
else
return inner(inner, tree)
end
end

"""
any(f::Function, tree::Node)
any(f::Function, tree::AbstractNode)
Reduce a flag function over a tree, returning `true` if the function returns `true` for any node.
By using this instead of tree_mapreduce, we can take advantage of early exits.
"""
function any(f::F, tree::Node) where {F<:Function}
function any(f::F, tree::AbstractNode) where {F<:Function}
if tree.degree == 0
return @inline(f(tree))::Bool
elseif tree.degree == 1
Expand All @@ -119,19 +123,25 @@ function any(f::F, tree::Node) where {F<:Function}
end
end

function Base.:(==)(a::Node{T1}, b::Node{T2})::Bool where {T1,T2}
function Base.:(==)(a::AbstractNode, b::AbstractNode)::Bool
(degree = a.degree) != b.degree && return false
if degree == 0
(constant = a.constant) != b.constant && return false
if constant
return a.val::T1 == b.val::T2
else
return a.feature == b.feature
end
return isequal_deg0(a, b)
elseif degree == 1
return a.op == b.op && a.l == b.l
return isequal_deg1(a, b) && a.l == b.l
else
return isequal_deg2(a, b) && a.l == b.l && a.r == b.r
end
end

@inline isequal_deg1(a::Node, b::Node) = a.op == b.op
@inline isequal_deg2(a::Node, b::Node) = a.op == b.op
@inline function isequal_deg0(a::Node{T1}, b::Node{T2}) where {T1,T2}
(constant = a.constant) != b.constant && return false
if constant
return a.val::T1 == b.val::T2
else
return a.op == b.op && a.l == b.l && a.r == b.r
return a.feature == b.feature
end
end

Expand All @@ -144,33 +154,33 @@ end
Apply a function to each node in a tree.
"""
function foreach(f::Function, tree::Node)
function foreach(f::Function, tree::AbstractNode)
return tree_mapreduce(t -> (@inline(f(t)); nothing), Returns(nothing), tree)
end

"""
filter_map(filter_fnc::Function, map_fnc::Function, tree::Node, result_type::Type)
filter_map(filter_fnc::Function, map_fnc::Function, tree::AbstractNode, result_type::Type)
A faster equivalent to `map(map_fnc, filter(filter_fnc, tree))`
that avoids the intermediate allocation. However, using this requires
specifying the `result_type` of `map_fnc` so the resultant array can
be preallocated.
"""
function filter_map(
filter_fnc::F, map_fnc::G, tree::Node, result_type::Type{GT}
filter_fnc::F, map_fnc::G, tree::AbstractNode, result_type::Type{GT}
) where {F<:Function,G<:Function,GT}
stack = Array{GT}(undef, count(filter_fnc, tree))
filter_map!(filter_fnc, map_fnc, stack, tree)
return stack::Vector{GT}
end

"""
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::Node)
filter_map!(filter_fnc::Function, map_fnc::Function, stack::Vector{GT}, tree::AbstractNode)
Equivalent to `filter_map`, but stores the results in a preallocated array.
"""
function filter_map!(
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::Node
filter_fnc::Function, map_fnc::Function, destination::Vector{GT}, tree::AbstractNode
) where {GT}
pointer = Ref(0)
foreach(tree) do t
Expand All @@ -183,49 +193,49 @@ function filter_map!(
end

"""
filter(f::Function, tree::Node)
filter(f::Function, tree::AbstractNode)
Filter nodes of a tree, returning a flat array of the nodes for which the function returns `true`.
"""
function filter(f::F, tree::Node{T}) where {F<:Function,T}
return filter_map(f, identity, tree, Node{T})
function filter(f::F, tree::AbstractNode) where {F<:Function}
return filter_map(f, identity, tree, typeof(tree))
end

collect(tree::Node) = filter(Returns(true), tree)
collect(tree::AbstractNode) = filter(Returns(true), tree)

"""
map(f::Function, tree::Node, result_type::Type{RT}=Nothing)
map(f::Function, tree::AbstractNode, result_type::Type{RT}=Nothing)
Map a function over a tree and return a flat array of the results in depth-first order.
Pre-specifying the `result_type` of the function can be used to avoid extra allocations,
"""
function map(f::F, tree::Node, result_type::Type{RT}=Nothing) where {F<:Function,RT}
function map(f::F, tree::AbstractNode, result_type::Type{RT}=Nothing) where {F<:Function,RT}
if RT == Nothing
return f.(collect(tree))
else
return filter_map(Returns(true), f, tree, result_type)
end
end

function count(f::F, tree::Node; init=0) where {F<:Function}
function count(f::F, tree::AbstractNode; init=0) where {F<:Function}
return tree_mapreduce(t -> @inline(f(t)) ? 1 : 0, +, tree) + init
end

function sum(f::F, tree::Node; init=0) where {F<:Function}
function sum(f::F, tree::AbstractNode; init=0) where {F<:Function}
return tree_mapreduce(f, +, tree) + init
end

all(f::F, tree::Node) where {F<:Function} = !any(t -> !@inline(f(t)), tree)
all(f::F, tree::AbstractNode) where {F<:Function} = !any(t -> !@inline(f(t)), tree)

function mapreduce(f::F, op::G, tree::Node) where {F<:Function,G<:Function}
function mapreduce(f::F, op::G, tree::AbstractNode) where {F<:Function,G<:Function}
return tree_mapreduce(f, (n...) -> reduce(op, n), tree)
end

isempty(::Node) = false
iterate(root::Node) = (root, collect(root)[(begin + 1):end])
iterate(::Node, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
in(item, tree::Node) = any(t -> t == item, tree)
length(tree::Node) = sum(Returns(1), tree)
isempty(::AbstractNode) = false
iterate(root::AbstractNode) = (root, collect(root)[(begin + 1):end])
iterate(::AbstractNode, stack) = isempty(stack) ? nothing : (popfirst!(stack), stack)
in(item, tree::AbstractNode) = any(t -> t == item, tree)
length(tree::AbstractNode) = sum(Returns(1), tree)
function hash(tree::Node{T}) where {T}
return tree_mapreduce(
t -> t.constant ? hash((0, t.val::T)) : hash((1, t.feature)),
Expand Down Expand Up @@ -299,11 +309,11 @@ end

for func in (:reduce, :foldl, :foldr, :mapfoldl, :mapfoldr)
@eval begin
function $func(f, tree::Node; kws...)
function $func(f, tree::AbstractNode; kws...)
throw(
error(
string($func) *
" not implemented for Node. Use `tree_mapreduce` instead.",
" not implemented for AbstractNode. Use `tree_mapreduce` instead.",
),
)
end
Expand Down
37 changes: 37 additions & 0 deletions test/test_custom_node_type.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using DynamicExpressions
using Test

mutable struct MyCustomNode{A,B} <: AbstractNode
degree::Int
val1::A
val2::B
l::MyCustomNode{A,B}
r::MyCustomNode{A,B}

MyCustomNode(val1, val2) = new{typeof(val1),typeof(val2)}(0, val1, val2)
MyCustomNode(val1, val2, l) = new{typeof(val1),typeof(val2)}(1, val1, val2, l)
MyCustomNode(val1, val2, l, r) = new{typeof(val1),typeof(val2)}(2, val1, val2, l, r)
end

node1 = MyCustomNode(1.0, 2)

@test typeof(node1) == MyCustomNode{Float64,Int}
@test node1.degree == 0
@test count_depth(node1) == 1
@test count_nodes(node1) == 1

node2 = MyCustomNode(1.5, 3, node1)

@test typeof(node2) == MyCustomNode{Float64,Int}
@test node2.degree == 1
@test node2.l.degree == 0
@test count_depth(node2) == 2
@test count_nodes(node2) == 2

node2 = MyCustomNode(1.5, 3, node1, node1)

@test count_depth(node2) == 2
@test count_nodes(node2) == 3
@test sum(t -> t.val1, node2) == 1.5 + 1.0 + 1.0
@test sum(t -> t.val2, node2) == 3 + 2 + 2
@test count(t -> t.degree == 0, node2) == 2
4 changes: 4 additions & 0 deletions test/unittest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,7 @@ end
@safetestset "Test helpers break upon redefining" begin
include("test_safe_helpers.jl")
end

@safetestset "Test custom node type" begin
include("test_custom_node_type.jl")
end

0 comments on commit 4638851

Please sign in to comment.