Skip to content

Commit

Permalink
Merge c12d5a7 into 6be88b7
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 20, 2023
2 parents 6be88b7 + c12d5a7 commit 2c1f9e9
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 18 deletions.
12 changes: 6 additions & 6 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[extensions]
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"

[compat]
Compat = "3.37, 4"
LoopVectorization = "0.12"
Expand All @@ -27,9 +33,6 @@ SymbolicUtils = "0.19, ^1.0.5"
Zygote = "0.6"
julia = "1.6"

[extensions]
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Expand All @@ -40,6 +43,3 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils"]

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
74 changes: 72 additions & 2 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,20 @@ import Base: convert
#! format: off
if isdefined(Base, :get_extension)
using SymbolicUtils
import SymbolicUtils: istree, operation, arguments, similarterm, symtype, issym, arity, metadata, simplify
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap, mustfindfirst
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
import DynamicExpressions.SelfContainedEquationModule: SelfContainedNode
else
using ..SymbolicUtils
import ..SymbolicUtils: istree, operation, arguments, similarterm, symtype, issym, arity, metadata, simplify
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap, mustfindfirst
import ..DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
import ..DynamicExpressions.SelfContainedEquationModule: SelfContainedNode
end
#! format: on

Expand Down Expand Up @@ -282,4 +286,70 @@ function multiply_powers(
end
end

###############################################
# Direct Simplification Interface #############
###############################################

arity(x::SelfContainedNode) = x.tree.degree
istree(x::SelfContainedNode) = arity(x) > 0
symtype(::S) where {T,S<:SelfContainedNode{T}} = T
function operation(x::SelfContainedNode)
if arity(x) == 1
return x.operators.unaops[x.tree.op]
elseif arity(x) == 2
return x.operators.binops[x.tree.op]
else
error("Unexpected degree $(x.tree.degree).")
end
end
function unsorted_arguments(x::S) where {T,S<:SelfContainedNode{T}}
if arity(x) == 0
return Any[]
elseif arity(x) == 1
return Any[isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators)]
elseif arity(x) == 2
return Any[
isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators),
isconstant(x.tree.r) ? x.tree.r.val::T : S(x.tree.r, x.operators),
]
end
end
function arguments(x::S) where {T,S<:SelfContainedNode{T}}
return unsorted_arguments(x)
end
function similarterm(
t::S, f::F, args::AbstractArray, symtype=nothing; kws...
)::S where {T,S<:SelfContainedNode{T},F<:Function}
if length(args) > 2
l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
end
if length(args) == 1
op_index = mustfindfirst(f, t.operators.unaops)
new_node = Node(op_index, to_node(T, op_index, args[1]))
return S(new_node, t.operators)
elseif length(args) == 2
op_index = mustfindfirst(f, t.operators.binops)
new_node = if all(isconstant, args)
to_node(T, op_index, f(args...))
else
Node(op_index, [to_node(T, op_index, arg) for arg in args]...)
end
return S(new_node, t.operators)
else
error("Unexpected length $(length(args)).")
end
end

# Helper functions for interface
isconstant(x::SelfContainedNode) = isconstant(x.tree)
isconstant(x::Node) = x.degree == 0 && x.constant
isconstant(::Number) = true
to_node(::Type{T}, op_index, x::SelfContainedNode{T}) where {T} = Node(op_index, x.tree)
to_node(::Type{T}, op_index, x::Number) where {T} = Node(T; val=x)

function simplify(x::Node, operators::AbstractOperatorEnum, args...; kws...)
return simplify(SelfContainedNode(x, operators), args...; kws...)
end

end
2 changes: 2 additions & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ include("EvaluateEquationDerivative.jl")
include("EvaluationHelpers.jl")
include("SimplifyEquation.jl")
include("OperatorEnumConstruction.jl")
include("SelfContainedEquation.jl")
include("ExtensionInterface.jl")

import Requires: @init, @require
Expand All @@ -33,6 +34,7 @@ import Reexport: @reexport
eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
@reexport import .EvaluationHelpersModule
@reexport import .SelfContainedEquationModule: SelfContainedNode
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

#! format: off
Expand Down
5 changes: 5 additions & 0 deletions src/EquationUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ function set_constants!(tree::Node{T}, constants::AbstractVector{T}) where {T}
return nothing
end

Base.one(::Type{N}) where {T,N<:Node{T}} = Node(T; val=one(T))
Base.one(::N) where {N<:Node} = one(N)
Base.zero(::Type{N}) where {T,N<:Node{T}} = Node(T; val=zero(T))
Base.zero(::N) where {N<:Node} = zero(N)

## Assign index to nodes of a tree
# This will mirror a Node struct, rather
# than adding a new attribute to Node.
Expand Down
16 changes: 8 additions & 8 deletions src/OperatorEnum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ Defines an enum over operators, along with their derivatives.
- `diff_binops`: A tuple of Zygote-computed derivatives of the binary operators.
- `diff_unaops`: A tuple of Zygote-computed derivatives of the unary operators.
"""
struct OperatorEnum <: AbstractOperatorEnum
binops::Vector{Function}
unaops::Vector{Function}
diff_binops::Vector{Function}
diff_unaops::Vector{Function}
struct OperatorEnum{B,U,dB,dU} <: AbstractOperatorEnum
binops::B
unaops::U
diff_binops::dB
diff_unaops::dU
end

"""
Expand All @@ -29,9 +29,9 @@ Defines an enum over operators, along with their derivatives.
- `diff_binops`: A tuple of Zygote-computed derivatives of the binary operators.
- `diff_unaops`: A tuple of Zygote-computed derivatives of the unary operators.
"""
struct GenericOperatorEnum <: AbstractOperatorEnum
binops::Vector{Function}
unaops::Vector{Function}
struct GenericOperatorEnum{B,U} <: AbstractOperatorEnum
binops::B
unaops::U
end

end
6 changes: 4 additions & 2 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ function OperatorEnum(;
end

operators = OperatorEnum(
binary_operators, unary_operators, diff_binary_operators, diff_unary_operators
Tuple.((
binary_operators, unary_operators, diff_binary_operators, diff_unary_operators
))...,
)

if define_helper_functions
Expand Down Expand Up @@ -269,7 +271,7 @@ function GenericOperatorEnum(;
binary_operators = Function[op for op in binary_operators]
unary_operators = Function[op for op in unary_operators]

operators = GenericOperatorEnum(binary_operators, unary_operators)
operators = GenericOperatorEnum(Tuple.((binary_operators, unary_operators))...)

if define_helper_functions
@extend_operators_base operators
Expand Down
77 changes: 77 additions & 0 deletions src/SelfContainedEquation.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
module SelfContainedEquationModule

import ..OperatorEnumModule: AbstractOperatorEnum
import ..EquationModule: Node, string_tree
import ..UtilsModule: mustfindfirst

struct SelfContainedNode{T,OP<:AbstractOperatorEnum}
tree::Node{T}
operators::OP

function SelfContainedNode(
tree::Node{_T}, operators::_OP
) where {_T,_OP<:AbstractOperatorEnum}
return new{_T,_OP}(tree, operators)
end
function SelfContainedNode{_T,_OP}(
tree::Node{_T}, operators::_OP
) where {_T,_OP<:AbstractOperatorEnum}
return new{_T,_OP}(tree, operators)
end
end

Base.one(x::SelfContainedNode) = SelfContainedNode(one(x.tree), x.operators)
Base.zero(x::SelfContainedNode) = SelfContainedNode(zero(x.tree), x.operators)
Base.isequal(a::S, b::S) where {S<:SelfContainedNode} = isequal(a.tree, b.tree)
Base.hash(a::S) where {S<:SelfContainedNode} = hash((a.tree, a.operators))

function Base.promote(a::S, b::Number) where {TS,S<:SelfContainedNode{TS}}
n = Node(TS; val=b)
return (a, S(n, a.operators))
end
function Base.promote(a::S, b::N) where {TS,S<:SelfContainedNode{TS},TN,N<:Node{TN}}
T = promote_type(TS, TN)
n_a = convert(Node{T}, a.tree)
n_b = convert(Node{T}, b)
return (S(n_a, a.operators), S(n_b, a.operators))
end
Base.promote(a::S, b::S) where {S<:SelfContainedNode} = (a, b)
Base.promote(a::T, b::S) where {S<:SelfContainedNode,T} = reverse(promote(b, a))

function Base.show(io::IO, m::MIME"text/plain", x::SelfContainedNode)
print(io, "SelfContainedNode(\n")
print(io, " "^4, "tree=", string_tree(x.tree, x.operators), ",\n")
print(io, " "^4, "operators=", x.operators, "\n")
return print(io, ")")
end

for binop in (:(Base.:/), :(Base.:*), :(Base.:+), :(Base.:-), :(Base.:^))
@eval function $(binop)(a::S, b::S) where {S<:SelfContainedNode}
op_index = mustfindfirst($binop, a.operators.binops)
return S(Node(op_index, a.tree, b.tree), a.operators)
end
@eval function $(binop)(a::SelfContainedNode, b)
return $(binop)(promote(a, b)...)
end
@eval function $(binop)(a, b::SelfContainedNode)
return $(binop)(promote(a, b)...)
end
end
for unaop in (
:(Base.sin),
:(Base.cos),
:(Base.exp),
:(Base.tan),
:(Base.log),
:(Base.sqrt),
:(Base.:-),
)
@eval function $(unaop)(a::S) where {S<:SelfContainedNode}
op_index = mustfindfirst($unaop, a.operators.unaops)
return S(Node(op_index, a.tree), a.operators)
end
end

Base.:+(a::SelfContainedNode) = a

end
12 changes: 12 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,4 +166,16 @@ function deprecate_varmap(variable_names, varMap, func_name)
return variable_names
end

"""
mustfindfirst(el, container)
Find the index of the first element in `container` that is equal to `el`.
If no such element exists, throw an error.
"""
function mustfindfirst(el, container)::Integer
i = findfirst(==(el), container)
i === nothing && error("Could not find element $(el) in container $(container)")
return i
end

end

0 comments on commit 2c1f9e9

Please sign in to comment.