Skip to content

Commit

Permalink
Merge 61df8b8 into 6be88b7
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jul 22, 2023
2 parents 6be88b7 + 61df8b8 commit 4c1fee9
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 75 deletions.
22 changes: 12 additions & 10 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,45 +1,47 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.10.1"
version = "0.11.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"
DynamicExpressionsZygoteExt = "Zygote"

[compat]
Compat = "3.37, 4"
LoopVectorization = "0.12"
MacroTools = "0.4, 0.5"
PackageExtensionCompat = "1"
PrecompileTools = "1"
Reexport = "1"
Requires = "1.0, 1.1, 1.2, 1.3"
SymbolicUtils = "0.19, ^1.0.5"
Zygote = "0.6"
julia = "1.6"

[extensions]
DynamicExpressionsSymbolicUtilsExt = "SymbolicUtils"

[extras]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils"]

[weakdeps]
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
test = ["Test", "SafeTestsets", "SpecialFunctions", "ForwardDiff", "StaticArrays", "SymbolicUtils", "Zygote"]
1 change: 1 addition & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1 change: 1 addition & 0 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant
using Zygote

include("benchmark_utils.jl")

Expand Down
28 changes: 8 additions & 20 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,12 @@
module DynamicExpressionsSymbolicUtilsExt

import Base: convert
#! format: off
if isdefined(Base, :get_extension)
using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
else
using ..SymbolicUtils
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import ..DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
end
#! format: on
using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}

const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /)

function isgood(x::SymbolicUtils.Symbolic)
Expand Down Expand Up @@ -106,7 +94,7 @@ function findoperation(op, ops)
throw(error("Operation $(op) in expression not found in operations $(ops)!"))
end

function convert(
function Base.convert(
::typeof(SymbolicUtils.Symbolic),
tree::Node,
operators::AbstractOperatorEnum;
Expand All @@ -121,11 +109,11 @@ function convert(
)
end

function convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
function Base.convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
return Node(; val=DEFAULT_NODE_TYPE(x))
end

function convert(
function Base.convert(
::typeof(Node),
expr::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
Expand Down
9 changes: 9 additions & 0 deletions ext/DynamicExpressionsZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module DynamicExpressionsZygoteExt

import Zygote: gradient
import DynamicExpressions.EvaluateEquationDerivativeModule: _zygote_gradient

_zygote_gradient(op::F, ::Val{1}) where {F} = x -> gradient(op, x)[1]
_zygote_gradient(op::F, ::Val{2}) where {F} = (x, y) -> gradient(op, x, y)

end
8 changes: 3 additions & 5 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ include("SimplifyEquation.jl")
include("OperatorEnumConstruction.jl")
include("ExtensionInterface.jl")

import Requires: @init, @require
import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
Node, string_tree, print_tree, copy_node, set_node!, tree_mapreduce, filter_map
Expand All @@ -35,11 +35,9 @@ import Reexport: @reexport
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node

#! format: off
if !isdefined(Base, :get_extension)
@init @require SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" include("../ext/DynamicExpressionsSymbolicUtilsExt.jl")
function __init__()
@require_extensions
end
#! format: on

include("deprecated.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/EvaluateEquationDerivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import ..UtilsModule: @return_on_false2, @maybe_turbo, is_bad_array, fill_simila
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
import ..EvaluateEquationModule: deg0_eval

_zygote_gradient(args...) = error("Please load the Zygote.jl package.")

function assert_autodiff_enabled(operators::OperatorEnum)
if length(operators.diff_binops) == 0 && length(operators.diff_unaops) == 0
error(
Expand Down
9 changes: 3 additions & 6 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
module OperatorEnumConstructionModule

import Zygote: gradient
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: string_tree, Node
import ..EvaluateEquationModule: eval_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluationHelpersModule: _grad_evaluator

function create_evaluation_helpers!(operators::OperatorEnum)
Expand Down Expand Up @@ -223,12 +222,10 @@ function OperatorEnum(;

if enable_autodiff
for op in binary_operators
diff_op(x, y) = gradient(op, x, y)
push!(diff_binary_operators, diff_op)
push!(diff_binary_operators, _zygote_gradient(op, Val(2)))
end
for op in unary_operators
diff_op(x) = gradient(op, x)[1]
push!(diff_unary_operators, diff_op)
push!(diff_unary_operators, _zygote_gradient(op, Val(1)))
end
end

Expand Down
35 changes: 1 addition & 34 deletions src/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,55 +30,32 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types

X = rand(T, 3, 10)
operators = OperatorEnum(;
binary_operators=binops,
unary_operators=unaops,
define_helper_functions=false,
enable_autodiff=true,
binary_operators=binops, unary_operators=unaops, define_helper_functions=false
)
x = Node(T; feature=1)
c = Node(T; val=one(T))

# Trivial:
for l in (x, c)
@ignore_domain_error eval_tree_array(l, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end
end

# Binary operators
for i in eachindex(binops), l in (x, c), r in (x, c)
tree = Node(i, l, r)
tree = convert(Node{T}, tree)
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end
end

# Unary operators
for j in eachindex(unaops), k in eachindex(unaops), l in (x, c)
tree = Node(j, l)
tree = convert(Node{T}, tree)
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end

tree = Node(j, Node(k, l))
tree = convert(Node{T}, tree)
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end
end

# Both operators
Expand All @@ -91,20 +68,10 @@ function test_all_combinations(; binary_operators, unary_operators, turbo, types
tree = Node(i, Node(j1, l), Node(j2, r))
tree = convert(Node{T}, tree)
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end

tree = Node(j1, Node(i, l, r))
tree = convert(Node{T}, tree)
@ignore_domain_error eval_tree_array(tree, X, operators; turbo=use_turbo)
for variable in (true, false)
@ignore_domain_error eval_grad_tree_array(
l, X, operators; turbo=use_turbo, variable
)
end
end
end
return nothing
Expand Down
1 change: 1 addition & 0 deletions test/test_base.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DynamicExpressions
using Random
using Test
using Zygote

operators = OperatorEnum(;
binary_operators=[+, -, *, /], unary_operators=[cos, sin], enable_autodiff=true
Expand Down
1 change: 1 addition & 0 deletions test/test_container_preserved.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DynamicExpressions
using StaticArrays
using Test
using Zygote

@testset "StaticArrays type preserved" begin
for T in (Float32, Float64)
Expand Down

0 comments on commit 4c1fee9

Please sign in to comment.