Skip to content

Commit

Permalink
Merge pull request #72 from SymbolicML/refactor-modules
Browse files Browse the repository at this point in the history
refactor!: module names to match struct names
  • Loading branch information
MilesCranmer committed Apr 28, 2024
2 parents 27b6199 + 841f240 commit f68d92e
Show file tree
Hide file tree
Showing 19 changed files with 69 additions and 50 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Suppressor", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DynamicExpressions.jl is the backbone of [SymbolicRegression.jl](https://github.

A dynamic expression is a snippet of code that can change throughout runtime - compilation is not possible! **DynamicExpressions.jl does the following:**
1. Defines an enum over user-specified operators.
2. Using this enum, it defines a [very lightweight and type-stable data structure](https://symbolicml.org/DynamicExpressions.jl/dev/types/#DynamicExpressions.EquationModule.Node) for arbitrary expressions.
2. Using this enum, it defines a [very lightweight and type-stable data structure](https://symbolicml.org/DynamicExpressions.jl/dev/types/#DynamicExpressions.NodeModule.Node) for arbitrary expressions.
3. It then generates specialized [evaluation kernels](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquation.jl#L29-L66) for the space of potential operators.
4. It also generates kernels for the [first-order derivatives](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquationDerivative.jl#L139-L175), using [Zygote.jl](https://github.com/FluxML/Zygote.jl).
5. DynamicExpressions.jl can also operate on arbitrary other types (vectors, tensors, symbols, strings, or even unions) - see last part below.
Expand Down
7 changes: 6 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant

# Trigger extensions:
using LoopVectorization
Expand All @@ -13,6 +12,12 @@ else
@eval using DynamicExpressions: GraphNode
end

if PACKAGE_VERSION < v"0.17.0"
@eval using DynamicExpressions.EquationUtilsModule: is_constant
else
@eval using DynamicExpressions.NodeUtilsModule: is_constant
end

include("../test/tree_gen_utils.jl")

const SUITE = BenchmarkGroup()
Expand Down
4 changes: 2 additions & 2 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module DynamicExpressionsLoopVectorizationExt
using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateEquationModule: @return_on_check
import DynamicExpressions.EvaluateEquationModule:
using DynamicExpressions.EvaluateModule: @return_on_check
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
deg1_l2_ll0_lr0_eval,
Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DynamicExpressionsSymbolicUtilsExt

using SymbolicUtils
import DynamicExpressions.EquationModule:
import DynamicExpressions.NodeModule:
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, deprecate_varmap
Expand Down
23 changes: 11 additions & 12 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ module DynamicExpressions
include("Utils.jl")
include("ExtensionInterface.jl")
include("OperatorEnum.jl")
include("Equation.jl")
include("EquationUtils.jl")
include("Node.jl")
include("NodeUtils.jl")
include("Strings.jl")
include("EvaluateEquation.jl")
include("EvaluateEquationDerivative.jl")
include("Evaluate.jl")
include("EvaluateDerivative.jl")
include("EvaluationHelpers.jl")
include("SimplifyEquation.jl")
include("Simplify.jl")
include("OperatorEnumConstruction.jl")
include("Random.jl")

import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
@reexport import .NodeModule:
AbstractNode,
AbstractExpressionNode,
GraphNode,
Expand All @@ -25,8 +25,8 @@ import Reexport: @reexport
tree_mapreduce,
filter_map,
filter_map!
import .EquationModule: constructorof, preserve_sharing
@reexport import .EquationUtilsModule:
import .NodeModule: constructorof, preserve_sharing
@reexport import .NodeUtilsModule:
count_nodes,
count_constants,
count_depth,
Expand All @@ -40,10 +40,9 @@ import .EquationModule: constructorof, preserve_sharing
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateEquationDerivativeModule:
eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyModule: combine_operators, simplify_tree!
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
@reexport import .RandomModule: NodeSampler
Expand Down
6 changes: 3 additions & 3 deletions src/EvaluateEquation.jl → src/Evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module EvaluateEquationModule
module EvaluateModule

import ..EquationModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof
import ..StringsModule: string_tree
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: is_bad_array, fill_similar, counttuple, ResultOk
import ..EquationUtilsModule: is_constant
import ..NodeUtilsModule: is_constant
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded

const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module EvaluateEquationDerivativeModule
module EvaluateDerivativeModule

import ..EquationModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof
import ..OperatorEnumModule: OperatorEnum
import ..UtilsModule: is_bad_array, fill_similar, ResultOk2
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
import ..EvaluateEquationModule:
deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..NodeUtilsModule: count_constants, index_constants, NodeIndex
import ..EvaluateModule: deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..ExtensionInterfaceModule: _zygote_gradient

"""
Expand Down
6 changes: 3 additions & 3 deletions src/EvaluationHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ module EvaluationHelpersModule

import Base: adjoint
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: AbstractExpressionNode
import ..EvaluateEquationModule: eval_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
import ..NodeModule: AbstractExpressionNode
import ..EvaluateModule: eval_tree_array
import ..EvaluateDerivativeModule: eval_grad_tree_array

# Evaluation:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/Equation.jl → src/Node.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module EquationModule
module NodeModule

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
Expand Down
4 changes: 2 additions & 2 deletions src/EquationUtils.jl → src/NodeUtils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module EquationUtilsModule
module NodeUtilsModule

import Compat: Returns
import ..EquationModule:
import ..NodeModule:
AbstractNode,
AbstractExpressionNode,
Node,
Expand Down
14 changes: 7 additions & 7 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module OperatorEnumConstructionModule

import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: Node, GraphNode, AbstractExpressionNode, constructorof
import ..NodeModule: Node, GraphNode, AbstractExpressionNode, constructorof
import ..StringsModule: string_tree
import ..EvaluateEquationModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluateModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..EvaluateDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluationHelpersModule: _grad_evaluator

"""Used to set a default value for `operators` for ease of use."""
Expand Down Expand Up @@ -110,8 +110,8 @@ function _extend_unary_operator(f::Symbol, type_requirements, internal)
@gensym _constructorof _AbstractExpressionNode
quote
if $$internal
import ..EquationModule.constructorof as $_constructorof
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
import ..NodeModule.constructorof as $_constructorof
import ..NodeModule.AbstractExpressionNode as $_AbstractExpressionNode
else
using DynamicExpressions:
constructorof as $_constructorof,
Expand All @@ -137,8 +137,8 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters,
@gensym _constructorof _AbstractExpressionNode
quote
if $$internal
import ..EquationModule.constructorof as $_constructorof
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
import ..NodeModule.constructorof as $_constructorof
import ..NodeModule.AbstractExpressionNode as $_AbstractExpressionNode
else
using DynamicExpressions:
constructorof as $_constructorof,
Expand Down
2 changes: 1 addition & 1 deletion src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module RandomModule
import Compat: Returns, @inline
import Random: AbstractRNG
import Base: rand
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map
import ..NodeModule: AbstractNode, tree_mapreduce, filter_map

"""
NodeSampler(; tree, filter::Function=Returns(true), weighting::Union{Nothing,Function}=nothing, break_sharing::Val=Val(false))
Expand Down
6 changes: 3 additions & 3 deletions src/SimplifyEquation.jl → src/Simplify.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module SimplifyEquationModule
module SimplifyModule

import ..EquationModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node!
import ..EquationUtilsModule: tree_mapreduce, is_node_constant
import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node!
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: isbad, isgood

Expand Down
2 changes: 1 addition & 1 deletion src/Strings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module StringsModule

using ..UtilsModule: deprecate_varmap
using ..OperatorEnumModule: AbstractOperatorEnum
using ..EquationModule: AbstractExpressionNode, tree_mapreduce
using ..NodeModule: AbstractExpressionNode, tree_mapreduce

const OP_NAMES = Base.ImmutableDict(
"safe_log" => "log",
Expand Down
8 changes: 7 additions & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Base: @deprecate
import .EquationModule: Node, GraphNode
import .NodeModule: Node, GraphNode

@deprecate set_constants set_constants!
@deprecate simplify_tree simplify_tree!
Expand Down Expand Up @@ -69,3 +69,9 @@ for N in (:Node, :GraphNode)
end
end
end

Base.@deprecate_binding EquationModule NodeModule
Base.@deprecate_binding EquationUtilsModule NodeUtilsModule
Base.@deprecate_binding EvaluateEquationModule EvaluateModule
Base.@deprecate_binding EvaluateEquationDerivativeModule EvaluateDerivativeModule
Base.@deprecate_binding SimplifyEquationModule SimplifyModule
9 changes: 9 additions & 0 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DynamicExpressions
using Test
using Zygote
using Suppressor: @capture_err

operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
x1, x2 = Node{Float64}(; feature=1), Node{Float64}(; feature=2)
Expand Down Expand Up @@ -43,3 +44,11 @@ if VERSION >= v"1.9"
@assert (n.op == 1 && n.l === x1 && n.r === x2)
)
end

# Test deprecated modules
logs = @capture_err begin
@eval using DynamicExpressions.EquationModule
end
@test contains(logs, "DynamicExpressions.EquationModule is deprecated,")

DynamicExpressions.EquationModule.Node === DynamicExpressions.NodeModule.Node
6 changes: 3 additions & 3 deletions test/test_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ end
@test repr(tree) == "cos(cos(3.0))"
tree = convert(Node{T}, tree)
truth = cos(cos(T(3.0f0)))
@test DynamicExpressions.EvaluateEquationModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, Val(turbo)).x[1]
@test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, Val(turbo)).x[1]
truth

# op(<constant>, <constant>)
tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0))
@test repr(tree) == "3.0 + 4.0"
tree = convert(Node{T}, tree)
truth = T(3.0f0) + T(4.0f0)
@test DynamicExpressions.EvaluateEquationModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), Val(turbo)).x[1]
@test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), Val(turbo)).x[1]
truth

# op(op(<constant>, <constant>))
tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0)))
@test repr(tree) == "cos(3.0 + 4.0)"
tree = convert(Node{T}, tree)
truth = cos(T(3.0f0) + T(4.0f0))
@test DynamicExpressions.EvaluateEquationModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), Val(turbo)).x[1]
@test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), Val(turbo)).x[1]
truth

# Test for presence of NaNs:
Expand Down
4 changes: 2 additions & 2 deletions test/test_simplification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ function Base.:≈(a::String, b::String)
return a == b
end

simplify_tree! = DynamicExpressions.SimplifyEquationModule.simplify_tree!
combine_operators = DynamicExpressions.SimplifyEquationModule.combine_operators
simplify_tree! = DynamicExpressions.SimplifyModule.simplify_tree!
combine_operators = DynamicExpressions.SimplifyModule.combine_operators

binary_operators = (+, -, /, *)

Expand Down

0 comments on commit f68d92e

Please sign in to comment.