Skip to content

Commit

Permalink
Merge pull request #56 from SymbolicML/MilesCranmer/issue14
Browse files Browse the repository at this point in the history
Graph-like expressions
  • Loading branch information
MilesCranmer committed Dec 19, 2023
2 parents f23ed22 + f8f9678 commit 8109f9c
Show file tree
Hide file tree
Showing 28 changed files with 1,620 additions and 820 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.13.1"
version = "0.14.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -24,6 +24,7 @@ DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"

[compat]
Aqua = "0.7"
Compat = "3.37, 4"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
Expand Down
47 changes: 37 additions & 10 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant
using Zygote
if PACKAGE_VERSION < v"0.14.0"
@eval using DynamicExpressions: Node as GraphNode
else
@eval using DynamicExpressions: GraphNode
end

include("benchmark_utils.jl")

Expand Down Expand Up @@ -66,13 +71,15 @@ end

# These macros make the benchmarks work on older versions:
#! format: off
@generated function _convert(::Type{N}, t; preserve_sharing) where {N<:Node}
@generated function _convert(::Type{N}, t; preserve_sharing) where {N}
PACKAGE_VERSION < v"0.7.0" && return :(convert(N, t))
return :(convert(N, t; preserve_sharing=preserve_sharing))
PACKAGE_VERSION < v"0.14.0" && return :(convert(N, t; preserve_sharing=preserve_sharing))
return :(convert(N, t)) # Assume type used to infer sharing
end
@generated function _copy_node(t; preserve_sharing)
PACKAGE_VERSION < v"0.7.0" && return :(copy_node(t; preserve_topology=preserve_sharing))
return :(copy_node(t; preserve_sharing=preserve_sharing))
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
return :(copy_node(t)) # Assume type used to infer sharing
end
@generated function get_set_constants!(tree)
!(@isdefined set_constants!) && return :(set_constants(tree, get_constants(tree)))
Expand Down Expand Up @@ -101,13 +108,36 @@ function benchmark_utilities()
:index_constants,
:string_tree,
)
has_both_modes = [:copy, :convert]
if PACKAGE_VERSION >= v"0.14.0"
append!(
has_both_modes,
[
:simplify_tree,
:count_nodes,
:count_constants,
:get_set_constants!,
:index_constants,
:string_tree,
],
)
end

operators = OperatorEnum(; binary_operators=[+, -, /, *], unary_operators=[cos, exp])
for func_k in all_funcs
suite[func_k] = let s = BenchmarkGroup()
for k in (:break_sharing, :preserve_sharing)
has_both_modes = func_k in (:copy, :convert)
k == :preserve_sharing && !has_both_modes && continue
for k in (
if func_k in has_both_modes
[:break_sharing, :preserve_sharing]
else
[:break_sharing]
end
)
preprocess = if k == :preserve_sharing && PACKAGE_VERSION >= v"0.14.0"
tree -> GraphNode(tree)
else
identity
end

f = if func_k == :copy
tree -> _copy_node(tree; preserve_sharing=(k == :preserve_sharing))
Expand All @@ -132,12 +162,9 @@ function benchmark_utilities()
setup=(
ntrees=100;
n=20;
trees=[gen_random_tree_fixed_size(n, $operators, 5, Float32) for _ in 1:ntrees]
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
)
)
if !has_both_modes
s = s[k]
end
#! format: on
end
s
Expand Down
85 changes: 70 additions & 15 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,7 @@ Equations are specified as binary trees with the `Node` type, defined
as follows:

```@docs
Node{T}
```

There are a variety of constructors for `Node` objects, including:

```@docs
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
Node(op::Integer, l::Node)
Node(op::Integer, l::Node, r::Node)
Node(var_string::String)
Node
```

When you create an `Options` object, the operators
Expand All @@ -69,23 +60,87 @@ When using these node constructors, types will automatically be promoted.
You can convert the type of a node using `convert`:

```@docs
convert(::Type{Node{T1}}, tree::Node{T2}) where {T1, T2}
convert(::Type{AbstractExpressionNode{T1}}, tree::AbstractExpressionNode{T2}) where {T1, T2}
```

You can set a `tree` (in-place) with `set_node!`:

```@docs
set_node!(tree::Node{T}, new_tree::Node{T}) where {T}
set_node!
```

You can create a copy of a node with `copy_node`:

```@docs
copy_node(tree::Node)
copy_node
```

## Graph-Like Equations

You can describe an equation as a *graph* rather than a tree
by using the `GraphNode` type:

```@docs
GraphNode{T}
```

This makes it so you can have multiple parents for a given node,
and share parts of an expression. For example:

```julia
julia> operators = OperatorEnum(;
binary_operators=[+, -, *], unary_operators=[cos, sin, exp]
);

julia> x1, x2 = GraphNode(feature=1), GraphNode(feature=2)
(x1, x2)

julia> y = sin(x1) + 1.5
sin(x1) + 1.5

julia> z = exp(y) + y
exp(sin(x1) + 1.5) + {(sin(x1) + 1.5)}
```

Here, the curly braces `{}` indicate that the node
is shared by another (or more) parent node.

This means that we only need to change it once
to have changes propagate across the expression:

```julia
julia> y.r.val *= 0.9
1.35

julia> z
exp(sin(x1) + 1.35) + {(sin(x1) + 1.35)}
```

This also means there are fewer nodes to describe an expression:

```julia
julia> length(z)
6

julia> length(convert(Node, z))
10
```

where we have converted the `GraphNode` to a `Node` type,
which breaks shared connections into separate nodes.

## Abstract Types

Both the `Node` and `GraphNode` types are subtypes of the abstract type:

```@docs
AbstractExpressionNode{T}
```

There is also an abstract type `AbstractNode` which is a supertype of `Node`:
which can be used to create additional expression-like types.
The supertype of this abstract type is the `AbstractNode` type,
which is more generic but does not have all of the same methods:

```@docs
AbstractNode
AbstractNode{T}
```
57 changes: 33 additions & 24 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module DynamicExpressionsSymbolicUtilsExt

using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.EquationModule:
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
Expand All @@ -19,14 +20,17 @@ end
subs_bad(x) = isgood(x) ? x : Inf

function parse_tree_to_eqs(
tree::Node{T}, operators::AbstractOperatorEnum, index_functions::Bool=false
tree::AbstractExpressionNode{T},
operators::AbstractOperatorEnum,
index_functions::Bool=false,
) where {T}
if tree.degree == 0
# Return constant if needed
tree.constant && return subs_bad(tree.val::T)
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
end
# Collect the next children
# TODO: Type instability!
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
# Get the operation
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
Expand Down Expand Up @@ -66,11 +70,12 @@ convert_to_function(x, operators::AbstractOperatorEnum) = x
function split_eq(
op,
args,
operators::AbstractOperatorEnum;
operators::AbstractOperatorEnum,
::Type{N}=Node;
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, varMap, :split_eq)
!(op (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!"))
if Symbol(op) == Symbol(sum)
Expand All @@ -80,10 +85,10 @@ function split_eq(
else
ind = findoperation(op, operators.binops)
end
return Node(
return constructorof(N)(
ind,
convert(Node, args[1], operators; variable_names=variable_names),
convert(Node, op(args[2:end]...), operators; variable_names=variable_names),
convert(N, args[1], operators; variable_names=variable_names),
convert(N, op(args[2:end]...), operators; variable_names=variable_names),
)
end

Expand All @@ -96,7 +101,7 @@ end

function Base.convert(
::typeof(SymbolicUtils.Symbolic),
tree::Node,
tree::AbstractExpressionNode,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
Expand All @@ -109,20 +114,22 @@ function Base.convert(
)
end

function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
return Node(; val=DEFAULT_NODE_TYPE(x))
function Base.convert(
::Type{N}, x::Number, operators::AbstractOperatorEnum; kws...
) where {N<:AbstractExpressionNode}
return constructorof(N)(; val=DEFAULT_NODE_TYPE(x))
end

function Base.convert(
::typeof(Node),
::Type{N},
expr::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
)
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, nothing, :convert)
if !SymbolicUtils.istree(expr)
variable_names === nothing && return Node(String(expr.name))
return Node(String(expr.name), variable_names)
variable_names === nothing && return constructorof(N)(String(expr.name))
return constructorof(N)(String(expr.name), variable_names)
end

# First, we remove integer powers:
Expand All @@ -134,20 +141,21 @@ function Base.convert(
op = convert_to_function(SymbolicUtils.operation(expr), operators)
args = SymbolicUtils.arguments(expr)

length(args) > 2 && return split_eq(op, args, operators; variable_names=variable_names)
length(args) > 2 &&
return split_eq(op, args, operators, N; variable_names=variable_names)
ind = if length(args) == 2
findoperation(op, operators.binops)
else
findoperation(op, operators.unaops)
end

return Node(
ind, map(x -> convert(Node, x, operators; variable_names=variable_names), args)...
return constructorof(N)(
ind, map(x -> convert(N, x, operators; variable_names=variable_names), args)...
)
end

"""
node_to_symbolic(tree::Node, operators::AbstractOperatorEnum;
node_to_symbolic(tree::AbstractExpressionNode, operators::AbstractOperatorEnum;
variable_names::Union{Array{String, 1}, Nothing}=nothing,
index_functions::Bool=false)
Expand All @@ -156,17 +164,17 @@ will generate a symbolic equation in SymbolicUtils.jl format.
## Arguments
- `tree::Node`: The equation to convert.
- `tree::AbstractExpressionNode`: The equation to convert.
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
each feature. Default is [x1, x2, x3, ...].
- `index_functions::Bool=false`: Whether to generate special names for the
operators, which then allows one to convert back to a `Node` format
operators, which then allows one to convert back to a `AbstractExpressionNode` format
using `symbolic_to_node`.
(CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84).
"""
function node_to_symbolic(
tree::Node,
tree::AbstractExpressionNode,
operators::AbstractOperatorEnum;
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
Expand All @@ -192,13 +200,14 @@ end

function symbolic_to_node(
eqn::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
operators::AbstractOperatorEnum,
::Type{N}=Node;
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)::Node
) where {N<:AbstractExpressionNode}
variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node)
return convert(Node, eqn, operators; variable_names=variable_names)
return convert(N, eqn, operators; variable_names=variable_names)
end

function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
Expand Down
5 changes: 4 additions & 1 deletion src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
AbstractNode,
AbstractExpressionNode,
GraphNode,
Node,
string_tree,
print_tree,
copy_node,
set_node!,
tree_mapreduce,
filter_map
import .EquationModule: constructorof, preserve_sharing
@reexport import .EquationUtilsModule:
count_nodes,
count_constants,
Expand All @@ -38,7 +41,7 @@ import Reexport: @reexport
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateEquationDerivativeModule:
eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

Expand Down
Loading

2 comments on commit 8109f9c

@MilesCranmer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/97376

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" 8109f9c93c877d89274a9b1b5a6a6b19bf4e4e02
git push origin v0.14.0

Please sign in to comment.