Skip to content

Commit

Permalink
Clean up symbolic utils interface
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 20, 2023
1 parent 0c74a74 commit c12d5a7
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 27 deletions.
75 changes: 52 additions & 23 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@ import Base: convert
#! format: off
if isdefined(Base, :get_extension)
using SymbolicUtils
import SymbolicUtils: istree, operation, arguments, similarterm, symtype
import SymbolicUtils: istree, operation, arguments, similarterm, symtype, issym, arity, metadata, simplify
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap, mustfindfirst
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
import DynamicExpressions.SelfContainedEquationModule: SelfContainedNode
else
using ..SymbolicUtils
import ..SymbolicUtils: istree, operation, arguments, similarterm, symtype
import ..SymbolicUtils: istree, operation, arguments, similarterm, symtype, issym, arity, metadata, simplify
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap, mustfindfirst
Expand Down Expand Up @@ -286,41 +286,70 @@ function multiply_powers(
end
end

#########################
# Interface #############
#########################
###############################################
# Direct Simplification Interface #############
###############################################

istree(x::SelfContainedNode) = x.tree.degree > 0
arity(x::SelfContainedNode) = x.tree.degree
istree(x::SelfContainedNode) = arity(x) > 0
symtype(::S) where {T,S<:SelfContainedNode{T}} = T
function operation(x::SelfContainedNode)
if x.tree.degree == 1
if arity(x) == 1
return x.operators.unaops[x.tree.op]
else # x.tree.degree == 2
elseif arity(x) == 2
return x.operators.binops[x.tree.op]
else
error("Unexpected degree $(x.tree.degree).")
end
end
function arguments(x::S) where {S<:SelfContainedNode}
if x.tree.degree == 1
return [S(x.tree.l, x.operators)]
else # x.tree.degree == 2
return [S(x.tree.l, x.operators), S(x.tree.r, x.operators)]
function unsorted_arguments(x::S) where {T,S<:SelfContainedNode{T}}
if arity(x) == 0
return Any[]
elseif arity(x) == 1
return Any[isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators)]
elseif arity(x) == 2
return Any[
isconstant(x.tree.l) ? x.tree.l.val::T : S(x.tree.l, x.operators),
isconstant(x.tree.r) ? x.tree.r.val::T : S(x.tree.r, x.operators),
]
end
end
function similarterm(t::S, f, args, symtype=nothing) where {S<:SelfContainedNode}
if length(args) == 0
error("Unexpected input.")
elseif length(args) == 1
op_index = mustfindfirst(f, t.operators.unaops)::Integer
new_node = Node(op_index, only(args).tree)
function arguments(x::S) where {T,S<:SelfContainedNode{T}}
return unsorted_arguments(x)
end
function similarterm(
t::S, f::F, args::AbstractArray, symtype=nothing; kws...
)::S where {T,S<:SelfContainedNode{T},F<:Function}
if length(args) > 2
l = similarterm(t, f, args[begin:(begin + 1)], symtype; kws...)
return similarterm(t, f, [l, args[(begin + 2):end]...], symtype; kws...)
end
if length(args) == 1
op_index = mustfindfirst(f, t.operators.unaops)
new_node = Node(op_index, to_node(T, op_index, args[1]))
return S(new_node, t.operators)
elseif length(args) == 2
op_index = mustfindfirst(f, t.operators.binops)::Integer
new_node = Node(op_index, args[1].tree, args[2].tree)
op_index = mustfindfirst(f, t.operators.binops)
new_node = if all(isconstant, args)
to_node(T, op_index, f(args...))
else
Node(op_index, [to_node(T, op_index, arg) for arg in args]...)
end
return S(new_node, t.operators)
else
l = similarterm(t, f, args[begin:(begin + 1)], symtype)
return similarterm(t, f, [l, args[(begin + 2):end]...], symtype)
error("Unexpected length $(length(args)).")
end
end

# Helper functions for interface
isconstant(x::SelfContainedNode) = isconstant(x.tree)
isconstant(x::Node) = x.degree == 0 && x.constant
isconstant(::Number) = true
to_node(::Type{T}, op_index, x::SelfContainedNode{T}) where {T} = Node(op_index, x.tree)
to_node(::Type{T}, op_index, x::Number) where {T} = Node(T; val=x)

function simplify(x::Node, operators::AbstractOperatorEnum, args...; kws...)
return simplify(SelfContainedNode(x, operators), args...; kws...)
end

end
10 changes: 6 additions & 4 deletions src/SelfContainedEquation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,21 @@ struct SelfContainedNode{T,OP<:AbstractOperatorEnum}
operators::OP

function SelfContainedNode(
tree::N, operators::_OP
) where {_T,N<:Node{_T},_OP<:AbstractOperatorEnum}
tree::Node{_T}, operators::_OP
) where {_T,_OP<:AbstractOperatorEnum}
return new{_T,_OP}(tree, operators)
end
function SelfContainedNode{_T,_OP}(
tree::N, operators::_OP
) where {_T,N<:Node{_T},_OP<:AbstractOperatorEnum}
tree::Node{_T}, operators::_OP
) where {_T,_OP<:AbstractOperatorEnum}
return new{_T,_OP}(tree, operators)
end
end

Base.one(x::SelfContainedNode) = SelfContainedNode(one(x.tree), x.operators)
Base.zero(x::SelfContainedNode) = SelfContainedNode(zero(x.tree), x.operators)
Base.isequal(a::S, b::S) where {S<:SelfContainedNode} = isequal(a.tree, b.tree)
Base.hash(a::S) where {S<:SelfContainedNode} = hash((a.tree, a.operators))

function Base.promote(a::S, b::Number) where {TS,S<:SelfContainedNode{TS}}
n = Node(TS; val=b)
Expand Down

0 comments on commit c12d5a7

Please sign in to comment.