Skip to content

Commit

Permalink
Merge 498c551 into 71a7b58
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed May 26, 2023
2 parents 71a7b58 + 498c551 commit 80fdef2
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 14 deletions.
12 changes: 10 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -22,16 +22,24 @@ LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
PrecompileTools = "1"
Reexport = "1"
Requires = "1.0, 1.1, 1.2, 1.3"
SymbolicUtils = "0.19, ^1.0.5"
Zygote = "0.6"
julia = "1.6"

[extensions]
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays"]
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils"]

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
module InterfaceSymbolicUtilsModule
module DynamicExpressionsSymbolicUtilsExt

using SymbolicUtils
import ..EquationModule: Node, DEFAULT_NODE_TYPE
import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: isgood, isbad, @return_on_false
export node_to_symbolic, symbolic_to_node

import Base: convert
if isdefined(Base, :get_extension)
using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
else
using ..SymbolicUtils
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false
import ..DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
end

const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}

Expand Down Expand Up @@ -91,7 +103,7 @@ function findoperation(op, ops)
throw(error("Operation $(op) in expression not found in operations $(ops)!"))
end

function Base.convert(
function convert(
::typeof(SymbolicUtils.Symbolic),
tree::Node,
operators::AbstractOperatorEnum;
Expand All @@ -101,7 +113,7 @@ function Base.convert(
return node_to_symbolic(tree, operators; varMap=varMap, index_functions=index_functions)
end

function Base.convert(
function convert(
::typeof(Node),
x::Number,
operators::AbstractOperatorEnum;
Expand All @@ -110,7 +122,7 @@ function Base.convert(
return Node(; val=DEFAULT_NODE_TYPE(x))
end

function Base.convert(
function convert(
::typeof(Node),
expr::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
Expand Down Expand Up @@ -181,8 +193,10 @@ function node_to_symbolic(
end

function symbolic_to_node(
eqn::T, operators::AbstractOperatorEnum; varMap::Union{Array{String,1},Nothing}=nothing
)::Node where {T<:SymbolicUtils.Symbolic}
eqn::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
)::Node
return convert(Node, eqn, operators; varMap=varMap)
end

Expand Down
13 changes: 11 additions & 2 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@ include("EquationUtils.jl")
include("EvaluateEquation.jl")
include("EvaluateEquationDerivative.jl")
include("EvaluationHelpers.jl")
include("InterfaceSymbolicUtils.jl")
include("SimplifyEquation.jl")
include("OperatorEnumConstruction.jl")
include("ExtensionInterface.jl")

if !isdefined(Base, :get_extension)
using Requires
end
using Reexport
@reexport import .EquationModule:
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
Expand All @@ -30,9 +33,15 @@ using Reexport
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateEquationDerivativeModule:
eval_diff_tree_array, eval_grad_tree_array
@reexport import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

#! format: off
if !isdefined(Base, :get_extension)
@init @require SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/DynamicExpressionsSymbolicUtilsExt.jl")
end
#! format: on

include("deprecated.jl")

Expand Down
18 changes: 18 additions & 0 deletions src/ExtensionInterface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module ExtensionInterfaceModule

import ..EquationModule: Node, DEFAULT_NODE_TYPE
import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: isgood, isbad, @return_on_false

function node_to_symbolic(args...; kws...)
return error(
"Please load the `SymbolicUtils` package to use `node_to_symbolic(::Node, ::AbstractOperatorEnum; kws...)`.",
)
end
function symbolic_to_node(args...; kws...)
return error(
"Please load the `SymbolicUtils` package to use `symbolic_to_node(::Symbolic, ::AbstractOperatorEnum, kws...)`.",
)
end

end
1 change: 1 addition & 0 deletions test/test_symbolic_utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using SymbolicUtils
using DynamicExpressions
using Test
include("test_params.jl")
Expand Down

0 comments on commit 80fdef2

Please sign in to comment.