Skip to content

Commit

Permalink
Merge pull request #50 from SymbolicML/lighter-types
Browse files Browse the repository at this point in the history
Switch to UInt8/UInt16 for Node fields
  • Loading branch information
MilesCranmer committed Aug 12, 2023
2 parents 5836ba2 + e79953b commit cd1fc0c
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 63 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.12.3"
version = "0.13.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
18 changes: 8 additions & 10 deletions benchmark/benchmark_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@ function random_node(tree::Node{T})::Node{T} where {T}
if tree.degree == 0
return tree
end
b = 0
c = 0
if tree.degree >= 1
b = count_nodes(tree.l)
end
if tree.degree == 2
c = count_nodes(tree.r)
b = count_nodes(tree.l)
c = if tree.degree == 2
count_nodes(tree.r)
else
0
end

i = rand(1:(1 + b + c))
Expand All @@ -27,7 +25,7 @@ function random_node(tree::Node{T})::Node{T} where {T}
return random_node(tree.r)
end

function make_random_leaf(nfeatures::Int, ::Type{T})::Node{T} where {T}
function make_random_leaf(nfeatures::Integer, ::Type{T})::Node{T} where {T}
if rand() > 0.5
return Node(; val=randn(T))
else
Expand All @@ -37,7 +35,7 @@ end

# Add a random unary/binary operation to the end of a tree
function append_random_op(
tree::Node{T}, operators, nfeatures::Int; makeNewBinOp::Union{Bool,Nothing}=nothing
tree::Node{T}, operators, nfeatures::Integer; makeNewBinOp::Union{Bool,Nothing}=nothing
)::Node{T} where {T}
nuna = length(operators.unaops)
nbin = length(operators.binops)
Expand Down Expand Up @@ -66,7 +64,7 @@ function append_random_op(
end

function gen_random_tree_fixed_size(
node_count::Int, operators, nfeatures::Int, ::Type{T}
node_count::Integer, operators, nfeatures::Integer, ::Type{T}
)::Node{T} where {T}
tree = make_random_leaf(nfeatures, T)
cur_size = count_nodes(tree)
Expand Down
2 changes: 1 addition & 1 deletion docs/src/eval.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ all variables (or, all constants). Both use forward-mode automatic, but use
`Zygote.jl` to compute derivatives of each operator, so this is very efficient.

```@docs
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int) where {T<:Number}
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer) where {T<:Number}
eval_grad_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum; turbo::Bool=false, variable::Bool=false) where {T<:Number}
```

Expand Down
4 changes: 2 additions & 2 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ There are a variety of constructors for `Node` objects, including:

```@docs
Node(::Type{T}; val=nothing, feature::Integer=nothing) where {T}
Node(op::Int, l::Node)
Node(op::Int, l::Node, r::Node)
Node(op::Integer, l::Node)
Node(op::Integer, l::Node, r::Node)
Node(var_string::String)
```

Expand Down
6 changes: 3 additions & 3 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ function parse_tree_to_eqs(
return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)"))
end
# Collect the next children
children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,)
children = tree.degree == 2 ? (tree.l, tree.r) : (tree.l,)
# Get the operation
op = tree.degree > 1 ? operators.binops[tree.op] : operators.unaops[tree.op]
op = tree.degree == 2 ? operators.binops[tree.op] : operators.unaops[tree.op]
# Create an N tuple of Numbers for each argument
dtypes = map(x -> Number, 1:(tree.degree))
#
Expand Down Expand Up @@ -228,7 +228,7 @@ function multiply_powers(
@return_on_false complete eqn
@return_on_false isgood(l) eqn
n = args[2]
if typeof(n) <: Int
if typeof(n) <: Integer
if n == 1
return l, true
elseif n == -1
Expand Down
45 changes: 21 additions & 24 deletions src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap

const DEFAULT_NODE_TYPE = Float32

#! format: off
"""
Node{T}
Expand All @@ -15,16 +16,16 @@ nodes, you can evaluate or print a given expression.
# Fields
- `degree::Int`: Degree of the node. 0 for constants, 1 for
- `degree::UInt8`: Degree of the node. 0 for constants, 1 for
unary operators, 2 for binary operators.
- `constant::Bool`: Whether the node is a constant.
- `val::T`: Value of the node. If `degree==0`, and `constant==true`,
this is the value of the constant. It has a type specified by the
overall type of the `Node` (e.g., `Float64`).
- `feature::Int` (optional): Index of the feature to use in the
- `feature::UInt16`: Index of the feature to use in the
case of a feature node. Only used if `degree==0` and `constant==false`.
Only defined if `degree == 0 && constant == false`.
- `op::Int`: If `degree==1`, this is the index of the operator
- `op::UInt8`: If `degree==1`, this is the index of the operator
in `operators.unaops`. If `degree==2`, this is the index of the
operator in `operators.binops`. In other words, this is an enum
of the operators, and is dependent on the specific `OperatorEnum`
Expand All @@ -36,36 +37,32 @@ nodes, you can evaluate or print a given expression.
argument to the binary operator.
"""
mutable struct Node{T}
degree::Int # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
degree::UInt8 # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::Union{T,Nothing} # If is a constant, this stores the actual value
# ------------------- (possibly undefined below)
feature::Int # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::Int # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
feature::UInt16 # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::UInt8 # If operator, this is the index of the operator in operators.binops, or operators.unaops
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
r::Node{T} # Right child node. Only defined for degree=2.

#################
## Constructors:
#################
Node(d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
Node(::Type{_T}, d::Int, c::Bool, v::_T) where {_T} = new{_T}(d, c, v)
Node(::Type{_T}, d::Int, c::Bool, v::Nothing, f::Int) where {_T} = new{_T}(d, c, v, f)
function Node(d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}) where {_T}
return new{_T}(d, c, v, f, o, l)
end
function Node(
d::Int, c::Bool, v::Nothing, f::Int, o::Int, l::Node{_T}, r::Node{_T}
) where {_T}
return new{_T}(d, c, v, f, o, l, r)
end
Node(d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::_T) where {_T} = new{_T}(UInt8(d), c, v)
Node(::Type{_T}, d::Integer, c::Bool, v::Nothing, f::Integer) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f))
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l)
Node(d::Integer, c::Bool, v::Nothing, f::Integer, o::Integer, l::Node{_T}, r::Node{_T}) where {_T} = new{_T}(UInt8(d), c, v, UInt16(f), UInt8(o), l, r)

end
################################################################################
#! format: on

include("base.jl")

"""
Node([::Type{T}]; val=nothing, feature::Int=nothing) where {T}
Node([::Type{T}]; val=nothing, feature::Union{Integer,Nothing}=nothing) where {T}
Create a leaf node: either a constant, or a variable.
Expand Down Expand Up @@ -115,18 +112,18 @@ function Node(
end

"""
Node(op::Int, l::Node)
Node(op::Integer, l::Node)
Apply unary operator `op` (enumerating over the order given) to `Node` `l`
"""
Node(op::Int, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)
Node(op::Integer, l::Node{T}) where {T} = Node(1, false, nothing, 0, op, l)

"""
Node(op::Int, l::Node, r::Node)
Node(op::Integer, l::Node, r::Node)
Apply binary operator `op` (enumerating over the order given) to `Node`s `l` and `r`
"""
function Node(op::Int, l::Node{T1}, r::Node{T2}) where {T1,T2}
function Node(op::Integer, l::Node{T1}, r::Node{T2}) where {T1,T2}
# Get highest type:
if T1 != T2
T = promote_type(T1, T2)
Expand All @@ -141,7 +138,7 @@ end
Create a variable node, using the format `"x1"` to mean feature 1
"""
Node(var_string::String) = Node(; feature=parse(Int, var_string[2:end]))
Node(var_string::String) = Node(; feature=parse(UInt16, var_string[2:end]))

"""
Node(var_string::String, variable_names::Array{String, 1})
Expand Down Expand Up @@ -261,7 +258,7 @@ Convert an equation to a string.
# Keyword Arguments
- `bracketed`: (optional) whether to put brackets around the outside.
- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::Int, variable_names)`.
- `f_variable`: (optional) function to convert a variable to a string, of the form `(feature::UInt8, variable_names)`.
- `f_constant`: (optional) function to convert a constant to a string, of the form `(val, bracketed::Bool)`
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: (optional) what variables to print for each feature.
"""
Expand Down
12 changes: 6 additions & 6 deletions src/EquationUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@ has_constants(tree::Node) = any(is_node_constant, tree)
Check if a tree has any operators.
"""
has_operators(tree::Node) = tree.degree !== 0
has_operators(tree::Node) = tree.degree != 0

"""
is_constant(tree::Node)::Bool
Check if an expression is a constant numerical value, or
whether it depends on input features.
"""
is_constant(tree::Node) = all(t -> t.degree !== 0 || t.constant, tree)
is_constant(tree::Node) = all(t -> t.degree != 0 || t.constant, tree)

"""
get_constants(tree::Node{T})::Vector{T} where {T}
Expand Down Expand Up @@ -92,25 +92,25 @@ end
# This will mirror a Node struct, rather
# than adding a new attribute to Node.
mutable struct NodeIndex
constant_index::Int # Index of this constant (if a constant exists here)
constant_index::UInt16 # Index of this constant (if a constant exists here)
l::NodeIndex
r::NodeIndex

NodeIndex() = new()
end

function index_constants(tree::Node)::NodeIndex
return index_constants(tree, 0)
return index_constants(tree, UInt16(0))
end

function index_constants(tree::Node, left_index::Int)::NodeIndex
function index_constants(tree::Node, left_index)::NodeIndex
index_tree = NodeIndex()
index_constants!(tree, index_tree, left_index)
return index_tree
end

# Count how many constants to the left of this node, and put them in a tree
function index_constants!(tree::Node, index_tree::NodeIndex, left_index::Int)
function index_constants!(tree::Node, index_tree::NodeIndex, left_index)
if tree.degree == 0
if tree.constant
index_tree.constant_index = left_index + 1
Expand Down
18 changes: 9 additions & 9 deletions src/EvaluateEquationDerivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function assert_autodiff_enabled(operators::OperatorEnum)
end

"""
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Int; turbo::Bool=false)
eval_diff_tree_array(tree::Node{T}, cX::AbstractMatrix{T}, operators::OperatorEnum, direction::Integer; turbo::Bool=false)
Compute the forward derivative of an expression, using a similar
structure and optimization to eval_tree_array. `direction` is the index of a particular
Expand All @@ -31,7 +31,7 @@ respect to `x1`.
- `cX::AbstractMatrix{T}`: The data matrix, with each column being a data point.
- `operators::OperatorEnum`: The operators used to create the `tree`. Note that `operators.enable_autodiff`
must be `true`. This is needed to create the derivative operations.
- `direction::Int`: The index of the variable to take the derivative with respect to.
- `direction::Integer`: The index of the variable to take the derivative with respect to.
- `turbo::Bool`: Use `LoopVectorization.@turbo` for faster evaluation.
# Returns
Expand All @@ -43,7 +43,7 @@ function eval_diff_tree_array(
tree::Node{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
direction::Int;
direction::Integer;
turbo::Bool=false,
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
assert_autodiff_enabled(operators)
Expand All @@ -57,7 +57,7 @@ function eval_diff_tree_array(
tree::Node{T1},
cX::AbstractMatrix{T2},
operators::OperatorEnum,
direction::Int;
direction::Integer;
turbo::Bool=false,
) where {T1<:Number,T2<:Number}
T = promote_type(T1, T2)
Expand All @@ -71,7 +71,7 @@ function _eval_diff_tree_array(
tree::Node{T},
cX::AbstractMatrix{T},
operators::OperatorEnum,
direction::Int,
direction::Integer,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,turbo}
evaluation, derivative, complete = if tree.degree == 0
Expand Down Expand Up @@ -102,7 +102,7 @@ function _eval_diff_tree_array(
end

function diff_deg0_eval(
tree::Node{T}, cX::AbstractMatrix{T}, direction::Int
tree::Node{T}, cX::AbstractMatrix{T}, direction::Integer
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number}
const_part = deg0_eval(tree, cX)[1]
derivative_part = if ((!tree.constant) && tree.feature == direction)
Expand All @@ -119,7 +119,7 @@ function diff_deg1_eval(
op::F,
diff_op::dF,
operators::OperatorEnum,
direction::Int,
direction::Integer,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
Expand All @@ -144,7 +144,7 @@ function diff_deg2_eval(
op::F,
diff_op::dF,
operators::OperatorEnum,
direction::Int,
direction::Integer,
::Val{turbo},
)::Tuple{AbstractVector{T},AbstractVector{T},Bool} where {T<:Number,F,dF,turbo}
(cumulator, dcumulator, complete) = _eval_diff_tree_array(
Expand Down Expand Up @@ -200,7 +200,7 @@ function eval_grad_tree_array(
)::Tuple{AbstractVector{T},AbstractMatrix{T},Bool} where {T<:Number}
assert_autodiff_enabled(operators)
n_gradients = variable ? size(cX, 1) : count_constants(tree)
index_tree = index_constants(tree, 0)
index_tree = index_constants(tree, UInt16(0))
return eval_grad_tree_array(
tree,
Val(n_gradients),
Expand Down
6 changes: 3 additions & 3 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import ..EvaluateEquationDerivativeModule: eval_grad_tree_array, _zygote_gradien
import ..EvaluationHelpersModule: _grad_evaluator

"""Used to set a default value for `operators` for ease of use."""
@enum AvailableOperatorTypes begin
@enum AvailableOperatorTypes::UInt8 begin
IsNothing
IsOperatorEnum
IsGenericOperatorEnum
Expand All @@ -19,8 +19,8 @@ end

const LATEST_OPERATORS = Ref{Union{Nothing,AbstractOperatorEnum}}(nothing)
const LATEST_OPERATORS_TYPE = Ref{AvailableOperatorTypes}(IsNothing)
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,Int}()
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,Int}()
const LATEST_UNARY_OPERATOR_MAPPING = Dict{Function,fieldtype(Node{Float64}, :op)}()
const LATEST_BINARY_OPERATOR_MAPPING = Dict{Function,fieldtype(Node{Float64}, :op)}()
const ALREADY_DEFINED_UNARY_OPERATORS = (;
operator_enum=Dict{Function,Bool}(), generic_operator_enum=Dict{Function,Bool}()
)
Expand Down
4 changes: 2 additions & 2 deletions src/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ julia> tree_mapreduce(t -> 1, (p, c...) -> p + max(c...), tree) # compute depth
5
julia> tree_mapreduce(vcat, tree) do t
t.degree == 2 ? [t.op] : Int[]
t.degree == 2 ? [t.op] : UInt8[]
end # Get list of binary operators used. (regular mapreduce also works)
2-element Vector{Int64}:
2-element Vector{UInt8}:
1
2
Expand Down
5 changes: 3 additions & 2 deletions test/test_base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,9 @@ end
@test length(unique(map(objectid, copy_node(tree; preserve_sharing=true)))) == 24 - 3
map(t -> (t.degree == 0 && t.constant) ? (t.val *= 2) : nothing, ctree)
@test sum(t -> t.val, filter(t -> t.degree == 0 && t.constant, ctree)) == 11.6 * 2
@test typeof(map(t -> t.degree, ctree, Int)) == Vector{Int}
@test first(map(t -> t.degree, ctree, Int)) == 2
local T = fieldtype(typeof(ctree), :degree)
@test typeof(map(t -> t.degree, ctree, T)) == Vector{T}
@test first(map(t -> t.degree, ctree, T)) == 2
end

@testset "in" begin
Expand Down

2 comments on commit cd1fc0c

@MilesCranmer
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89512

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.0 -m "<description of version>" cd1fc0c66559608b96244285a3bc89e4f4e770ff
git push origin v0.13.0

Please sign in to comment.