Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: module names to match struct names #72

Merged
merged 5 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Zygote"]
test = ["Test", "SafeTestsets", "Aqua", "Bumper", "Enzyme", "ForwardDiff", "LinearAlgebra", "LoopVectorization", "Optim", "SpecialFunctions", "StaticArrays", "SymbolicUtils", "Suppressor", "Zygote"]
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ DynamicExpressions.jl is the backbone of [SymbolicRegression.jl](https://github.

A dynamic expression is a snippet of code that can change throughout runtime - compilation is not possible! **DynamicExpressions.jl does the following:**
1. Defines an enum over user-specified operators.
2. Using this enum, it defines a [very lightweight and type-stable data structure](https://symbolicml.org/DynamicExpressions.jl/dev/types/#DynamicExpressions.EquationModule.Node) for arbitrary expressions.
2. Using this enum, it defines a [very lightweight and type-stable data structure](https://symbolicml.org/DynamicExpressions.jl/dev/types/#DynamicExpressions.NodeModule.Node) for arbitrary expressions.
3. It then generates specialized [evaluation kernels](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquation.jl#L29-L66) for the space of potential operators.
4. It also generates kernels for the [first-order derivatives](https://github.com/SymbolicML/DynamicExpressions.jl/blob/fe8e6dfa160d12485fb77c226d22776dd6ed697a/src/EvaluateEquationDerivative.jl#L139-L175), using [Zygote.jl](https://github.com/FluxML/Zygote.jl).
5. DynamicExpressions.jl can also operate on arbitrary other types (vectors, tensors, symbols, strings, or even unions) - see last part below.
Expand Down
7 changes: 6 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using DynamicExpressions, BenchmarkTools, Random
using DynamicExpressions.EquationUtilsModule: is_constant

# Trigger extensions:
using LoopVectorization
Expand All @@ -13,6 +12,12 @@ else
@eval using DynamicExpressions: GraphNode
end

if PACKAGE_VERSION < v"0.17.0"
@eval using DynamicExpressions.EquationUtilsModule: is_constant
else
@eval using DynamicExpressions.NodeUtilsModule: is_constant
end

include("../test/tree_gen_utils.jl")

const SUITE = BenchmarkGroup()
Expand Down
4 changes: 2 additions & 2 deletions ext/DynamicExpressionsLoopVectorizationExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ module DynamicExpressionsLoopVectorizationExt
using LoopVectorization: @turbo
using DynamicExpressions: AbstractExpressionNode
using DynamicExpressions.UtilsModule: ResultOk, fill_similar
using DynamicExpressions.EvaluateEquationModule: @return_on_check
import DynamicExpressions.EvaluateEquationModule:
using DynamicExpressions.EvaluateModule: @return_on_check
import DynamicExpressions.EvaluateModule:
deg1_eval,
deg2_eval,
deg1_l2_ll0_lr0_eval,
Expand Down
2 changes: 1 addition & 1 deletion ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module DynamicExpressionsSymbolicUtilsExt

using SymbolicUtils
import DynamicExpressions.EquationModule:
import DynamicExpressions.NodeModule:
AbstractExpressionNode, Node, constructorof, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, deprecate_varmap
Expand Down
23 changes: 11 additions & 12 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@ module DynamicExpressions
include("Utils.jl")
include("ExtensionInterface.jl")
include("OperatorEnum.jl")
include("Equation.jl")
include("EquationUtils.jl")
include("Node.jl")
include("NodeUtils.jl")
include("Strings.jl")
include("EvaluateEquation.jl")
include("EvaluateEquationDerivative.jl")
include("Evaluate.jl")
include("EvaluateDerivative.jl")
include("EvaluationHelpers.jl")
include("SimplifyEquation.jl")
include("Simplify.jl")
include("OperatorEnumConstruction.jl")
include("Random.jl")

import PackageExtensionCompat: @require_extensions
import Reexport: @reexport
@reexport import .EquationModule:
@reexport import .NodeModule:
AbstractNode,
AbstractExpressionNode,
GraphNode,
Expand All @@ -25,8 +25,8 @@ import Reexport: @reexport
tree_mapreduce,
filter_map,
filter_map!
import .EquationModule: constructorof, preserve_sharing
@reexport import .EquationUtilsModule:
import .NodeModule: constructorof, preserve_sharing
@reexport import .NodeUtilsModule:
count_nodes,
count_constants,
count_depth,
Expand All @@ -40,10 +40,9 @@ import .EquationModule: constructorof, preserve_sharing
@reexport import .OperatorEnumModule: AbstractOperatorEnum
@reexport import .OperatorEnumConstructionModule:
OperatorEnum, GenericOperatorEnum, @extend_operators, set_default_variable_names!
@reexport import .EvaluateEquationModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateEquationDerivativeModule:
eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyEquationModule: combine_operators, simplify_tree!
@reexport import .EvaluateModule: eval_tree_array, differentiable_eval_tree_array
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .SimplifyModule: combine_operators, simplify_tree!
@reexport import .EvaluationHelpersModule
@reexport import .ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
@reexport import .RandomModule: NodeSampler
Expand Down
6 changes: 3 additions & 3 deletions src/EvaluateEquation.jl → src/Evaluate.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module EvaluateEquationModule
module EvaluateModule

import ..EquationModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof
import ..StringsModule: string_tree
import ..OperatorEnumModule: OperatorEnum, GenericOperatorEnum
import ..UtilsModule: is_bad_array, fill_similar, counttuple, ResultOk
import ..EquationUtilsModule: is_constant
import ..NodeUtilsModule: is_constant
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded

const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
module EvaluateEquationDerivativeModule
module EvaluateDerivativeModule

import ..EquationModule: AbstractExpressionNode, constructorof
import ..NodeModule: AbstractExpressionNode, constructorof
import ..OperatorEnumModule: OperatorEnum
import ..UtilsModule: is_bad_array, fill_similar, ResultOk2
import ..EquationUtilsModule: count_constants, index_constants, NodeIndex
import ..EvaluateEquationModule:
deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..NodeUtilsModule: count_constants, index_constants, NodeIndex
import ..EvaluateModule: deg0_eval, get_nuna, get_nbin, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..ExtensionInterfaceModule: _zygote_gradient

"""
Expand Down
6 changes: 3 additions & 3 deletions src/EvaluationHelpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ module EvaluationHelpersModule

import Base: adjoint
import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: AbstractExpressionNode
import ..EvaluateEquationModule: eval_tree_array
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array
import ..NodeModule: AbstractExpressionNode
import ..EvaluateModule: eval_tree_array
import ..EvaluateDerivativeModule: eval_grad_tree_array

# Evaluation:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/Equation.jl → src/Node.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module EquationModule
module NodeModule

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap, Undefined
Expand Down
4 changes: 2 additions & 2 deletions src/EquationUtils.jl → src/NodeUtils.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module EquationUtilsModule
module NodeUtilsModule

import Compat: Returns
import ..EquationModule:
import ..NodeModule:
AbstractNode,
AbstractExpressionNode,
Node,
Expand Down
14 changes: 7 additions & 7 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
module OperatorEnumConstructionModule

import ..OperatorEnumModule: AbstractOperatorEnum, OperatorEnum, GenericOperatorEnum
import ..EquationModule: Node, GraphNode, AbstractExpressionNode, constructorof
import ..NodeModule: Node, GraphNode, AbstractExpressionNode, constructorof
import ..StringsModule: string_tree
import ..EvaluateEquationModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluateModule: eval_tree_array, OPERATOR_LIMIT_BEFORE_SLOWDOWN
import ..EvaluateDerivativeModule: eval_grad_tree_array, _zygote_gradient
import ..EvaluationHelpersModule: _grad_evaluator

"""Used to set a default value for `operators` for ease of use."""
Expand Down Expand Up @@ -110,8 +110,8 @@ function _extend_unary_operator(f::Symbol, type_requirements, internal)
@gensym _constructorof _AbstractExpressionNode
quote
if $$internal
import ..EquationModule.constructorof as $_constructorof
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
import ..NodeModule.constructorof as $_constructorof
import ..NodeModule.AbstractExpressionNode as $_AbstractExpressionNode
else
using DynamicExpressions:
constructorof as $_constructorof,
Expand All @@ -137,8 +137,8 @@ function _extend_binary_operator(f::Symbol, type_requirements, build_converters,
@gensym _constructorof _AbstractExpressionNode
quote
if $$internal
import ..EquationModule.constructorof as $_constructorof
import ..EquationModule.AbstractExpressionNode as $_AbstractExpressionNode
import ..NodeModule.constructorof as $_constructorof
import ..NodeModule.AbstractExpressionNode as $_AbstractExpressionNode
else
using DynamicExpressions:
constructorof as $_constructorof,
Expand Down
2 changes: 1 addition & 1 deletion src/Random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module RandomModule
import Compat: Returns, @inline
import Random: AbstractRNG
import Base: rand
import ..EquationModule: AbstractNode, tree_mapreduce, filter_map
import ..NodeModule: AbstractNode, tree_mapreduce, filter_map

"""
NodeSampler(; tree, filter::Function=Returns(true), weighting::Union{Nothing,Function}=nothing, break_sharing::Val=Val(false))
Expand Down
6 changes: 3 additions & 3 deletions src/SimplifyEquation.jl → src/Simplify.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module SimplifyEquationModule
module SimplifyModule

import ..EquationModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node!
import ..EquationUtilsModule: tree_mapreduce, is_node_constant
import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set_node!
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: isbad, isgood

Expand Down
2 changes: 1 addition & 1 deletion src/Strings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module StringsModule

using ..UtilsModule: deprecate_varmap
using ..OperatorEnumModule: AbstractOperatorEnum
using ..EquationModule: AbstractExpressionNode, tree_mapreduce
using ..NodeModule: AbstractExpressionNode, tree_mapreduce

const OP_NAMES = Base.ImmutableDict(
"safe_log" => "log",
Expand Down
8 changes: 7 additions & 1 deletion src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import Base: @deprecate
import .EquationModule: Node, GraphNode
import .NodeModule: Node, GraphNode

@deprecate set_constants set_constants!
@deprecate simplify_tree simplify_tree!
Expand Down Expand Up @@ -69,3 +69,9 @@ for N in (:Node, :GraphNode)
end
end
end

Base.@deprecate_binding EquationModule NodeModule
Base.@deprecate_binding EquationUtilsModule NodeUtilsModule
Base.@deprecate_binding EvaluateEquationModule EvaluateModule
Base.@deprecate_binding EvaluateEquationDerivativeModule EvaluateDerivativeModule
Base.@deprecate_binding SimplifyEquationModule SimplifyModule
9 changes: 9 additions & 0 deletions test/test_deprecations.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using DynamicExpressions
using Test
using Zygote
using Suppressor: @capture_err

operators = OperatorEnum(; binary_operators=[+, -, *, /], unary_operators=[cos, sin])
x1, x2 = Node{Float64}(; feature=1), Node{Float64}(; feature=2)
Expand Down Expand Up @@ -43,3 +44,11 @@ if VERSION >= v"1.9"
@assert (n.op == 1 && n.l === x1 && n.r === x2)
)
end

# Test deprecated modules
logs = @capture_err begin
@eval using DynamicExpressions.EquationModule
end
@test contains(logs, "DynamicExpressions.EquationModule is deprecated,")

DynamicExpressions.EquationModule.Node === DynamicExpressions.NodeModule.Node
6 changes: 3 additions & 3 deletions test/test_evaluation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,23 +93,23 @@ end
@test repr(tree) == "cos(cos(3.0))"
tree = convert(Node{T}, tree)
truth = cos(cos(T(3.0f0)))
@test DynamicExpressions.EvaluateEquationModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, Val(turbo)).x[1] ≈
@test DynamicExpressions.EvaluateModule.deg1_l1_ll0_eval(tree, [zero(T)]', cos, cos, Val(turbo)).x[1] ≈
truth

# op(<constant>, <constant>)
tree = Node(1, Node(; val=3.0f0), Node(; val=4.0f0))
@test repr(tree) == "3.0 + 4.0"
tree = convert(Node{T}, tree)
truth = T(3.0f0) + T(4.0f0)
@test DynamicExpressions.EvaluateEquationModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), Val(turbo)).x[1] ≈
@test DynamicExpressions.EvaluateModule.deg2_l0_r0_eval(tree, [zero(T)]', (+), Val(turbo)).x[1] ≈
truth

# op(op(<constant>, <constant>))
tree = Node(1, Node(1, Node(; val=3.0f0), Node(; val=4.0f0)))
@test repr(tree) == "cos(3.0 + 4.0)"
tree = convert(Node{T}, tree)
truth = cos(T(3.0f0) + T(4.0f0))
@test DynamicExpressions.EvaluateEquationModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), Val(turbo)).x[1] ≈
@test DynamicExpressions.EvaluateModule.deg1_l2_ll0_lr0_eval(tree, [zero(T)]', cos, (+), Val(turbo)).x[1] ≈
truth

# Test for presence of NaNs:
Expand Down
4 changes: 2 additions & 2 deletions test/test_simplification.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ function Base.:≈(a::String, b::String)
return a == b
end

simplify_tree! = DynamicExpressions.SimplifyEquationModule.simplify_tree!
combine_operators = DynamicExpressions.SimplifyEquationModule.combine_operators
simplify_tree! = DynamicExpressions.SimplifyModule.simplify_tree!
combine_operators = DynamicExpressions.SimplifyModule.combine_operators

binary_operators = (+, -, /, *)

Expand Down
Loading