diff --git a/Project.toml b/Project.toml index 1b6d68117..bcf074c9f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "0.8.7" +version = "0.9.0" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" @@ -26,7 +26,7 @@ Optim = "0.19, 1.1" Pkg = "1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" -SymbolicUtils = "0.6" +SymbolicUtils = "0.19" Zygote = "0.6" julia = "1.5" diff --git a/docs/src/api.md b/docs/src/api.md index e785e941d..eabb78f04 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,7 +66,6 @@ evalTreeArray(tree::Node, cX::AbstractMatrix{T}, options::Options) where {T<:Rea ```@docs node_to_symbolic(tree::Node, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, index_functions::Bool=false) ``` diff --git a/src/CustomSymbolicUtilsSimplification.jl b/src/CustomSymbolicUtilsSimplification.jl deleted file mode 100644 index 175f2610f..000000000 --- a/src/CustomSymbolicUtilsSimplification.jl +++ /dev/null @@ -1,174 +0,0 @@ -module CustomSymbolicUtilsSimplificationModule - -using SymbolicUtils -using SymbolicUtils: Chain, If, RestartedChain, IfElse, Postwalk, Fixpoint, @ordered_acrule, isnotflat, flatten_term, needs_sorting, sort_args, is_literal_number, hasrepeats, merge_repeats, _isone, _iszero, _isinteger, istree, symtype, is_operation, has_trig, polynormalize -import ..CoreModule: Options -import ..InterfaceSymbolicUtilsModule: SYMBOLIC_UTILS_TYPES -import ..UtilsModule: isgood, @return_on_false - -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} - return eqn, true -end - -function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:SymbolicUtils.Term{<:Number}} - args = SymbolicUtils.arguments(eqn) - nargs = length(args) - if nargs == 1 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - return op(l), true - elseif op == ^ - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - n = args[2] - if typeof(n) <: Int - if n == 1 - return l, true - elseif n == -1 - return 1.0 / l, true - elseif n > 1 - return reduce(*, [l for i=1:n]), true - elseif n < -1 - return reduce(/, vcat([1], [l for i=1:abs(n)])), true - else - return 1.0, true - end - else - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - return l ^ r, true - end - elseif nargs == 2 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - @return_on_false isgood(r) eqn - return op(l, r), true - else - # return mapreduce(multiply_powers, op, args) - # ## reduce(op, map(multiply_powers, args)) - out = map(multiply_powers, args) #vector of tuples - for i=1:size(out, 1) - @return_on_false out[i][2] eqn - @return_on_false isgood(out[i][1]) eqn - end - cumulator = out[1][1] - for i=2:size(out, 1) - cumulator = op(cumulator, out[i][1]) - @return_on_false isgood(cumulator) eqn - end - return cumulator, true - end -end - -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:SymbolicUtils.Term{<:Number}} - op = SymbolicUtils.operation(eqn) - return multiply_powers(eqn, op) -end - -# Operators required for each rule: -function get_simplifier(binops::A, unaops::B) where {A,B} - PLUS_RULES = [ - rule for (required_ops, rule) in [ - ((+,), @rule(~x::isnotflat(+) => flatten_term(+, ~x))), - ((+,), @rule(~x::needs_sorting(+) => sort_args(+, ~x))), - ((+,), @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b)), - ((+,), @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))), - ((+,), @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))), - ((+,), @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))), - ((+,), @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x))), - ((+,), @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x))), - ((+,), @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...))), - ((+,), @ordered_acrule((~z::_iszero + ~x) => ~x)), - ((+,), @rule(+(~x) => ~x))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - TIMES_RULES = [ - rule for (required_ops, rule) in [ - ((*,), @rule(~x::isnotflat(*) => flatten_term(*, ~x))), - ((*,), @rule(~x::needs_sorting(*) => sort_args(*, ~x))), - ((*,), @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b)), - ((*,), @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...))), - ((*,), @acrule((~y)^(~n) * ~y => (~y)^(~n+1))), - ((*,), @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m))), - ((*,), @ordered_acrule((~z::_isone * ~x) => ~x)), - ((*,), @ordered_acrule((~z::_iszero * ~x) => ~z)), - ((*,), @rule(*(~x) => ~x))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - POW_RULES =[ - rule for (required_ops, rule) in [ - ((*,), @rule(^(*(~~x), ~y::_isinteger) => *(map(a->SymbolicUtils.pow(a, ~y), ~~x)...))), - ((*,), @rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q)))), - ((*,), @rule(^(~x, ~z::_iszero) => 1)), - ((*,), @rule(^(~x, ~z::_isone) => ~x)), - ((*, /,), @rule(inv(~x) => ~x ^ -1))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - ASSORTED_RULES =[ - rule for (required_ops, rule) in [ - ((), @rule(identity(~x) => ~x)), - ((*,), @rule(-(~x) => -1*~x)), - ((*, -,), @rule(-(~x, ~y) => ~x + -1(~y))), - ((/, *,), @rule(~x / ~y => ~x * SymbolicUtils.pow(~y, -1))), - ((), @rule(one(~x) => one(symtype(~x)))), - ((), @rule(zero(~x) => zero(symtype(~x))))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - TRIG_RULES = [ - rule for (required_ops, rule) in [ - ((sin, cos, *, +,), @acrule(sin(~x)^2 + cos(~x)^2 => one(~x))), - ((sin, cos, *, +,), @acrule(sin(~x)^2 + -1 => cos(~x)^2)), - ((sin, cos, *, +,), @acrule(cos(~x)^2 + -1 => sin(~x)^2)), - ((tan, sec, *, +,), @acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x))), - ((tan, sec, *, +,), @acrule(tan(~x)^2 + 1 => sec(~x)^2)), - ((tan, sec, *, +,), @acrule(sec(~x)^2 + -1 => tan(~x)^2)), - ((cot, csc, *, +,), @acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x))), - ((cot, csc, *, +,), @acrule(cot(~x)^2 + 1 => csc(~x)^2)), - ((cot, csc, *, +,), @acrule(csc(~x)^2 + -1 => cot(~x)^2))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - function number_simplifier() - rule_tree = [If(istree, Chain(ASSORTED_RULES)), - If(is_operation(+), - Chain(PLUS_RULES)), - If(is_operation(*), - Chain(TIMES_RULES)), - If(is_operation(^), - Chain(POW_RULES))] |> RestartedChain - - rule_tree - end - trig_simplifier(;kw...) = Chain(TRIG_RULES) - function default_simplifier(; kw...) - IfElse(has_trig, - Postwalk(Chain((number_simplifier(), - trig_simplifier())), - ; kw...), - Postwalk(number_simplifier()) - ; kw...) - end - # reduce overhead of simplify by defining these as constant - serial_simplifier = If(istree, Fixpoint(default_simplifier())) - serial_polynormal_simplifier = If(istree, - Fixpoint(Chain((polynormalize, - Fixpoint(default_simplifier()))))) - return serial_polynormal_simplifier -end - -function custom_simplify(init_eqn::T, options::Options)::Tuple{SYMBOLIC_UTILS_TYPES, Bool} where {T<:SYMBOLIC_UTILS_TYPES} - if !istree(init_eqn) #simplifier will return nothing if not a tree. - return init_eqn, false - end - simplifier = get_simplifier(options.binops, options.unaops) - eqn = simplifier(init_eqn)::SYMBOLIC_UTILS_TYPES #simplify(eqn, polynorm=true) - - # Remove power laws - return multiply_powers(eqn::SYMBOLIC_UTILS_TYPES) -end - -end \ No newline at end of file diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index b9ecb3b06..a15802826 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -4,12 +4,106 @@ using SymbolicUtils import ..CoreModule: CONST_TYPE, Node, Options import ..UtilsModule: isgood, isbad, @return_on_false -const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Sym{<:Number},SymbolicUtils.Term{<:Number}} +const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}} + +const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /) + +isgood(x::SymbolicUtils.Symbolic) = SymbolicUtils.istree(x) ? all(isgood.([SymbolicUtils.operation(x);SymbolicUtils.arguments(x)])) : true +subs_bad(x) = isgood(x) ? x : Inf + +function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=false) + if tree.degree == 0 + # Return constant if needed + tree.constant && return subs_bad(tree.val) + return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)")) + end + # Collect the next children + children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,) + # Get the operation + op = tree.degree > 1 ? options.binops[tree.op] : options.unaops[tree.op] + # Create an N tuple of Numbers for each argument + dtypes = map(x->Number, 1:tree.degree) + # + if !(op ∈ SUPPORTED_OPS) && index_functions + op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...}, Number}}(Symbol(op)) + end + + return subs_bad(op(map(x->parse_tree_to_eqs(x, options, index_functions), children)...)) +end + +# For operators which are indexed, we need to convert them back +# using the string: +function convert_to_function(x::SymbolicUtils.Sym{SymbolicUtils.FnType{T, Number}}, options::Options) where {T <: Tuple} + degree = length(T.types) + if degree == 1 + ind = findoperation(x.name, options.unaops) + return options.unaops[ind] + elseif degree == 2 + ind = findoperation(x.name, options.binops) + return options.binops[ind] + else + throw(AssertionError("Function $(String(x.name)) has degree > 2 !")) + end +end + +# For normal operators, simply return the function itself: +convert_to_function(x, options::Options) = x + + +# Split equation +function split_eq(op, args, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + !(op ∈ (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!")) + if Symbol(op) == Symbol(sum) + ind = findoperation(+, options.binops) + elseif Symbol(op) == Symbol(prod) + ind = findoperation(*, options.binops) + else + ind = findoperation(op, options.binops) + end + return Node(ind, convert(Node, args[1], options; varMap=varMap), convert(Node, op(args[2:end]...), options; varMap=varMap)) +end + +function findoperation(op, ops) + for (i,oi) in enumerate(ops) + Symbol(oi) == Symbol(op) && return i + end + throw(error("Operation $(op) in expression not found in operations $(ops)!")) +end + +function Base.convert(::typeof(SymbolicUtils.Symbolic), tree::Node, options::Options; + varMap::Union{Array{String, 1}, Nothing}=nothing, + index_functions::Bool=false) + node_to_symbolic(tree, options; varMap=varMap, index_functions=index_functions) +end + +function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + return Node(CONST_TYPE(x)) +end + +function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + if !SymbolicUtils.istree(expr) + varMap === nothing && return Node(String(expr.name)) + return Node(String(expr.name), varMap) + end + + # First, we remove integer powers: + y, good_return = multiply_powers(expr) + if good_return + expr = y + end + + op = convert_to_function(SymbolicUtils.operation(expr), options) + args = SymbolicUtils.arguments(expr) + + length(args) > 2 && return split_eq(op, args, options; varMap=varMap) + ind = length(args) == 2 ? findoperation(op, options.binops) : findoperation(op, options.unaops) + + return Node(ind, map(x->convert(Node, x, options, varMap=varMap), args)...) +end """ node_to_symbolic(tree::Node, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, index_functions::Bool=false) The interface to SymbolicUtils.jl. Passing a tree to this function @@ -21,212 +115,101 @@ will generate a symbolic equation in SymbolicUtils.jl format. - `options::Options`: Options, which contains the operators used in the equation. - `varMap::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for each feature. Default is [x1, x2, x3, ...]. -- `evaluate_functions::Bool=false`: Whether to evaluate the operators, or - leave them as symbolic. - `index_functions::Bool=false`: Whether to generate special names for the operators, which then allows one to convert back to a `Node` format using `symbolic_to_node`. + (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). """ -function node_to_symbolic(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false - )::SYMBOLIC_UTILS_TYPES - if tree.degree == 0 - if tree.constant - return tree.val - else - if varMap == nothing - return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")) - else - return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])) - end - end - elseif tree.degree == 1 - left_side = node_to_symbolic(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - op = options.unaops[tree.op] - if (op in (cos, sin, exp, cot, tan, csc, sec)) || evaluate_functions - return op(left_side) - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol("_unaop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol(op)) - end - return dummy_op(left_side) - end - else - left_side = node_to_symbolic(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - right_side = node_to_symbolic(tree.r, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - op = options.binops[tree.op] - if (op in (+, -, *, /)) || evaluate_functions - return op(left_side, right_side) - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol("_binop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol(op)) - end - return dummy_op(left_side, right_side) - end - end -end - -function node_to_symbolic_safe(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false - )::Tuple{SYMBOLIC_UTILS_TYPES,Bool} - if tree.degree == 0 - if tree.constant - return tree.val, true - else - if varMap == nothing - return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")), true - else - return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])), true - end - end - elseif tree.degree == 1 - left_side, complete = node_to_symbolic_safe(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete Inf - @return_on_false isgood(left_side) Inf - op = options.unaops[tree.op] - if (op in (cos, sin, exp, cot, tan, csc, sec)) || evaluate_functions - out = op(left_side) - @return_on_false isgood(out) Inf - return out, true - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol("_unaop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol(op)) - end - out = dummy_op(left_side) - #TODO: Can probably delete this check: - @return_on_false isgood(out) Inf - return out, true - end - else - left_side, complete = node_to_symbolic_safe(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete Inf - @return_on_false isgood(left_side) Inf - right_side, complete2 = node_to_symbolic_safe(tree.r, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete2 Inf - @return_on_false isgood(right_side) Inf - op = options.binops[tree.op] - if (op in (+, -, *, /)) || evaluate_functions - out = op(left_side, right_side) - @return_on_false isgood(out) Inf - return out, true - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol("_binop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol(op)) - end - out = dummy_op(left_side, right_side) - # Can probably delete this check TODO - @return_on_false isgood(out) Inf - return out, true - end - end +function node_to_symbolic(tree::Node, options::Options; + varMap::Union{Array{String, 1}, Nothing}=nothing, + index_functions::Bool=false) + expr = subs_bad(parse_tree_to_eqs(tree, options, index_functions)) + # Check for NaN and Inf + @assert isgood(expr) "The recovered equation contains NaN or Inf." + # Return if no varMap is given + varMap === nothing && return expr + # Create a substitution tuple + subs = Dict( + [SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) => SymbolicUtils.Sym{LiteralReal}(Symbol(varMap[i])) for i in 1:length(varMap)]... + ) + return substitute(expr, subs) end -# Just constant function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:Number} - return Node(convert(CONST_TYPE, eqn)) -end + varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Symbolic} -# Just variable -function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Sym{<:Number}} - return Node(varMap_to_index(eqn.name, varMap)) + convert(Node, eqn, options; varMap=varMap) end -function _multiarg_split(op_idx::Int, eqn::Array{Any, 1}, - options::Options, varMap::Union{Array{String, 1}, Nothing} - )::Node - if length(eqn) == 2 - return Node(op_idx, - symbolic_to_node(eqn[1], options, varMap=varMap), - symbolic_to_node(eqn[2], options, varMap=varMap)) - elseif length(eqn) == 3 - return Node(op_idx, - symbolic_to_node(eqn[1], options, varMap=varMap), - _multiarg_split(op_idx, eqn[2:3], options, varMap)) - else - # Minimize depth: - split_point = round(Int, length(eqn) // 2) - return Node(op_idx, - _multiarg_split(op_idx, eqn[1:split_point], options, varMap), - _multiarg_split(op_idx, eqn[split_point+1:end], options, varMap)) - end -end +# function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) +# function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) -# Equation: -function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing - )::Node where {T<:SymbolicUtils.Term{<:Number}} - args = SymbolicUtils.arguments(eqn) - l = symbolic_to_node(args[1], options, varMap=varMap) - nargs = length(args) - op = SymbolicUtils.operation(eqn) - if nargs == 1 - op_idx = unaop_to_index(op, options) - return Node(op_idx, l) - else - op_idx = binop_to_index(op, options) - if nargs == 2 - r = symbolic_to_node(args[2], options, varMap=varMap) - return Node(op_idx, l, r) - else - # TODO: Assert operator is +, * - return _multiarg_split(op_idx, args, options, varMap) - end - end +function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} + return eqn, true end -function unaop_to_index(op::F, options::Options)::Int where {F<:SymbolicUtils.Sym} - # In format _unaop1 - parse(Int, string(op.name)[7:end]) +function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} + if !SymbolicUtils.istree(eqn) + return eqn, true + end + op = SymbolicUtils.operation(eqn) + return multiply_powers(eqn, op) end -function binop_to_index(op::F, options::Options)::Int where {F<:SymbolicUtils.Sym} - # In format _binop1 - parse(Int, string(op.name)[7:end]) -end -function unaop_to_index(op::F, options::Options)::Int where {F<:Function} - for i=1:options.nuna - if op == options.unaops[i] - return i +function multiply_powers(eqn::SymbolicUtils.Symbolic, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F} + args = SymbolicUtils.arguments(eqn) + nargs = length(args) + if nargs == 1 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + return op(l), true + elseif op == ^ + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + n = args[2] + if typeof(n) <: Int + if n == 1 + return l, true + elseif n == -1 + return 1.0 / l, true + elseif n > 1 + return reduce(*, [l for i=1:n]), true + elseif n < -1 + return reduce(/, vcat([1], [l for i=1:abs(n)])), true + else + return 1.0, true + end + else + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + return l ^ r, true end - end - error("Operator $(op) in simplified expression not found in options $(options.unaops)!") -end - -function binop_to_index(op::F, options::Options)::Int where {F<:Function} - for i=1:options.nbin - if op == options.binops[i] - return i + elseif nargs == 2 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + @return_on_false isgood(r) eqn + return op(l, r), true + else + # return mapreduce(multiply_powers, op, args) + # ## reduce(op, map(multiply_powers, args)) + out = map(multiply_powers, args) #vector of tuples + for i=1:size(out, 1) + @return_on_false out[i][2] eqn + @return_on_false isgood(out[i][1]) eqn end - end - error("Operator $(op) in simplified expression not found in options $(options.binops)!") -end - -function varMap_to_index(var::Symbol, varMap::Array{String, 1})::Int - str = string(var) - for i=1:length(varMap) - if str == varMap[i] - return i + cumulator = out[1][1] + for i=2:size(out, 1) + cumulator = op(cumulator, out[i][1]) + @return_on_false isgood(cumulator) eqn end - end -end - -function varMap_to_index(var::Symbol, varMap::Nothing)::Int - return parse(Int, string(var)[2:end]) + return cumulator, true + end end end diff --git a/src/Mutate.jl b/src/Mutate.jl index bc3669c06..c7b6c96c0 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -6,7 +6,7 @@ import ..LossFunctionsModule: scoreFunc, scoreFuncBatch import ..CheckConstraintsModule: check_constraints import ..PopMemberModule: PopMember import ..MutationFunctionsModule: genRandomTreeFixedSize, mutateConstant, mutateOperator, appendRandomOp, prependRandomOp, insertRandomOp, deleteRandomOp, crossoverTrees -import ..SimplifyEquationModule: simplifyTree, combineOperators, simplifyWithSymbolicUtils +import ..SimplifyEquationModule: simplifyTree, combineOperators import ..RecorderModule: @recorder # Go through one simulated options.annealing mutation cycle @@ -95,14 +95,7 @@ function nextGeneration(dataset::Dataset{T}, elseif mutationChoice < cweights[6] tree = simplifyTree(tree, options) # Sometimes we simplify tree tree = combineOperators(tree, options) # See if repeated constants at outer levels - # SymbolicUtils is quite slow, so only rarely - # do we use it for simplification. - if rand() < 0.01 && options.use_symbolic_utils - tree = simplifyWithSymbolicUtils(tree, options, curmaxsize) - @recorder tmp_recorder["type"] = "full_simplify" - else - @recorder tmp_recorder["type"] = "partial_simplify" - end + @recorder tmp_recorder["type"] = "partial_simplify" mutation_accepted = true return PopMember(tree, beforeScore, beforeLoss, parent=parent_ref), mutation_accepted diff --git a/src/Options.jl b/src/Options.jl index 01c3c1217..065051003 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -281,7 +281,6 @@ function Options(; probPickFirst=0.86f0, earlyStopCondition::Union{Function, AbstractFloat, Nothing}=nothing, stateReturn::Bool=false, - use_symbolic_utils::Bool=false, timeout_in_seconds=nothing, skip_mutation_failures::Bool=true, enable_autodiff::Bool=false, @@ -477,7 +476,7 @@ function Options(; earlyStopCondition = (loss, complexity) -> loss < stopping_point end - options = Options{typeof(binary_operators),typeof(unary_operators), typeof(diff_binary_operators), typeof(diff_unary_operators), typeof(loss)}(binary_operators, unary_operators, diff_binary_operators, diff_unary_operators, bin_constraints, una_constraints, ns, parsimony, alpha, maxsize, maxdepth, fast_cycle, migration, hofMigration, fractionReplacedHof, shouldOptimizeConstants, hofFile, npopulations, perturbationFactor, annealing, batching, batchSize, mutationWeights, crossoverProbability, warmupMaxsizeBy, useFrequency, useFrequencyInTournament, npop, ncyclesperiteration, fractionReplaced, topn, verbosity, probNegate, nuna, nbin, seed, loss, progress, terminal_width, optimizer_algorithm, optimize_probability, optimizer_nrestarts, optimizer_iterations, recorder, recorder_file, probPickFirst, earlyStopCondition, stateReturn, use_symbolic_utils, timeout_in_seconds, skip_mutation_failures, enable_autodiff, nested_constraints) + options = Options{typeof(binary_operators),typeof(unary_operators), typeof(diff_binary_operators), typeof(diff_unary_operators), typeof(loss)}(binary_operators, unary_operators, diff_binary_operators, diff_unary_operators, bin_constraints, una_constraints, ns, parsimony, alpha, maxsize, maxdepth, fast_cycle, migration, hofMigration, fractionReplacedHof, shouldOptimizeConstants, hofFile, npopulations, perturbationFactor, annealing, batching, batchSize, mutationWeights, crossoverProbability, warmupMaxsizeBy, useFrequency, useFrequencyInTournament, npop, ncyclesperiteration, fractionReplaced, topn, verbosity, probNegate, nuna, nbin, seed, loss, progress, terminal_width, optimizer_algorithm, optimize_probability, optimizer_nrestarts, optimizer_iterations, recorder, recorder_file, probPickFirst, earlyStopCondition, stateReturn, timeout_in_seconds, skip_mutation_failures, enable_autodiff, nested_constraints) @eval begin Base.print(io::IO, tree::Node) = print(io, stringTree(tree, $options)) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index c04ffde04..9e711d79b 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -52,7 +52,6 @@ struct Options{A,B,dA,dB,C<:Union{SupervisedLoss,Function}} probPickFirst::Float32 earlyStopCondition::Union{Function, Nothing} stateReturn::Bool - use_symbolic_utils::Bool timeout_in_seconds::Union{Float64, Nothing} skip_mutation_failures::Bool enable_autodiff::Bool @@ -82,7 +81,7 @@ Base.print(io::IO, options::Options) = print(io, """Options( # Speed Tweaks: batching=$(options.batching), batchSize=$(options.batchSize), fast_cycle=$(options.fast_cycle), # Logistics: - hofFile=$(options.hofFile), verbosity=$(options.verbosity), seed=$(options.seed), progress=$(options.progress), use_symbolic_utils=$(options.use_symbolic_utils), + hofFile=$(options.hofFile), verbosity=$(options.verbosity), seed=$(options.seed), progress=$(options.progress), # Early Exit: earlyStopCondition=$(options.earlyStopCondition), timeout_in_seconds=$(options.timeout_in_seconds), )""") diff --git a/src/SimplifyEquation.jl b/src/SimplifyEquation.jl index 4145d5e01..e08b26697 100644 --- a/src/SimplifyEquation.jl +++ b/src/SimplifyEquation.jl @@ -2,8 +2,6 @@ module SimplifyEquationModule import ..CoreModule: CONST_TYPE, Node, copyNode, Options import ..EquationUtilsModule: countNodes -import ..CustomSymbolicUtilsSimplificationModule: custom_simplify -import ..InterfaceSymbolicUtilsModule: node_to_symbolic_safe, symbolic_to_node import ..CheckConstraintsModule: check_constraints import ..UtilsModule: isbad, isgood @@ -131,32 +129,4 @@ function simplifyTree(tree::Node, options::Options)::Node return tree end - -# Expensive but powerful simplify using SymbolicUtils -function simplifyWithSymbolicUtils(tree::Node, options::Options, curmaxsize::Int)::Node - if !(((+) in options.binops) && ((*) in options.binops)) - return tree - end - init_node = copyNode(tree) - init_size = countNodes(tree) - symbolic_util_form, complete = node_to_symbolic_safe(tree, options, index_functions=true) - if !complete - return init_node - end - eqn_form, complete2 = custom_simplify(symbolic_util_form, options) - if !complete2 - return init_node - end - final_node = symbolic_to_node(eqn_form, options) - final_size = countNodes(tree) - did_simplification_improve = (final_size <= init_size) && (check_constraints(final_node, options, curmaxsize)) - output = did_simplification_improve ? final_node : init_node - - return output -end - -function simplifyWithSymbolicUtils(tree::Node, options::Options)::Node - simplifyWithSymbolicUtils(tree, options, options.maxsize) -end - end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index b2d78b5bb..1f1ceb61a 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -3,7 +3,7 @@ module SingleIterationModule import ..CoreModule: Options, Dataset, RecordType, stringTree import ..EquationUtilsModule: countNodes import ..UtilsModule: debug -import ..SimplifyEquationModule: simplifyTree, combineOperators, simplifyWithSymbolicUtils +import ..SimplifyEquationModule: simplifyTree, combineOperators import ..PopMemberModule: copyPopMember import ..PopulationModule: Population, finalizeScores, bestSubPop import ..HallOfFameModule: HallOfFame @@ -55,9 +55,6 @@ function OptimizeAndSimplifyPopulation( @inbounds @simd for j=1:pop.n pop.members[j].tree = simplifyTree(pop.members[j].tree, options) pop.members[j].tree = combineOperators(pop.members[j].tree, options) - if options.use_symbolic_utils - pop.members[j].tree = simplifyWithSymbolicUtils(pop.members[j].tree, options, curmaxsize) - end if rand() < options.optimize_probability && options.shouldOptimizeConstants pop.members[j] = optimizeConstants(dataset, baseline, pop.members[j], options) end diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 925e0e34e..afc40edaa 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -21,8 +21,6 @@ export Population, differentiableEvalTreeArray, node_to_symbolic, symbolic_to_node, - custom_simplify, - simplifyWithSymbolicUtils, combineOperators, genRandomTree, genRandomTreeFixedSize, @@ -76,7 +74,6 @@ include("ConstantOptimization.jl") include("Population.jl") include("HallOfFame.jl") include("InterfaceSymbolicUtils.jl") -include("CustomSymbolicUtilsSimplification.jl") include("SimplifyEquation.jl") include("Mutate.jl") include("RegularizedEvolution.jl") @@ -96,8 +93,7 @@ import .PopulationModule: Population, bestSubPop, record_population, bestOfSampl import .HallOfFameModule: HallOfFame, calculateParetoFrontier, string_dominating_pareto_curve import .SingleIterationModule: SRCycle, OptimizeAndSimplifyPopulation import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node -import .CustomSymbolicUtilsSimplificationModule: custom_simplify -import .SimplifyEquationModule: simplifyWithSymbolicUtils, combineOperators, simplifyTree +import .SimplifyEquationModule: combineOperators, simplifyTree import .ProgressBarsModule: ProgressBar, set_multiline_postfix import .RecorderModule: @recorder, find_iteration_from_record diff --git a/test/full.jl b/test/full.jl index 3e9786653..ac79219a6 100644 --- a/test/full.jl +++ b/test/full.jl @@ -107,7 +107,7 @@ for i=0:5 # Always assume multi for dom in dominating best = dom[end] - eqn = node_to_symbolic(best.tree, options, evaluate_functions=true) + eqn = node_to_symbolic(best.tree, options) local x4 = SymbolicUtils.Sym{Real}(Symbol("x4")) true_eqn = 2*cos(x4) @@ -145,7 +145,7 @@ dominating = calculateParetoFrontier(X, y, hallOfFame, options) best = dominating[end] eqn = node_to_symbolic(best.tree, options; - evaluate_functions=true, varMap=varMap) + varMap=varMap) t4 = SymbolicUtils.Sym{Real}(Symbol("t4")) true_eqn = 2*cos(t4) @@ -170,7 +170,7 @@ dominating = calculateParetoFrontier(X, y, hallOfFame, options) best = dominating[end] printTree(best.tree, options) eqn = node_to_symbolic(best.tree, options; - evaluate_functions=true, varMap=varMap) + varMap=varMap) residual = simplify(eqn - true_eqn) + t4 * 1e-10 @test best.loss < maximum_residual / 10 diff --git a/test/test_params.jl b/test/test_params.jl index 3acf29e9b..0af0c1721 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -48,7 +48,6 @@ default_params = ( probPickFirst=1.0, earlyStopCondition=nothing, stateReturn=false, - use_symbolic_utils=false, timeout_in_seconds=nothing, skip_mutation_failures=false, ) \ No newline at end of file diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 76cfbdb09..df1eca04a 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,27 +1,70 @@ include("test_params.jl") using SymbolicRegression, Test +import SymbolicUtils: simplify, Symbolic +import Random: MersenneTwister binary_operators = (+, -, /, *) index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] -options = Options(; default_params..., binary_operators=binary_operators) +options = Options(binary_operators=binary_operators) tree = Node("x1") + Node("x1") # Should simplify to 2*x1: -eqn = node_to_symbolic(tree, options; index_functions=true) -eqn2 = custom_simplify(eqn, options) +eqn = convert(Symbolic, tree, options) +eqn2 = simplify(eqn) +# Should correctly simplify to 2 x1: +# (although it might use 2(x1^1)) +@test occursin("2", "$(repr(eqn2)[1])") -@test occursin("2", "$(eqn2[1])") +# Let's convert back the simplified version. +# This should remove the ^ operator: +tree = convert(Node, eqn2, options) +# Make sure one of the nodes is now 2.0: +@test (tree.l.constant ? tree.l : tree.r).val == 2 +# Make sure the other node is x1: +@test (!tree.l.constant ? tree.l : tree.r).feature == 1 -# Repeat test with simplifyWithSymbolicUtils: -simple_tree = simplifyWithSymbolicUtils(tree, options, 5) +# Finally, let's try converting a product, and ensure +# that SymbolicUtils does not convert it to a power: +tree = Node("x1") * Node("x1") +eqn = convert(Symbolic, tree, options) +@test repr(eqn) == "x1*x1" +# Test converting back: +tree_copy = convert(Node, eqn, options) +@test repr(tree_copy) == "(x1 * x1)" -# Check that the first operator is *, for 2 * x1: -@test simple_tree.op == index_of_mult +# Let's test a much more complex function, +# with custom operators, and unary operators: +x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") +pow_abs(x, y) = abs(x) ^ y +custom_cos(x) = cos(x)^2 -tree = Node("x1") * Node("x1") + Node("x1") * Node("x1") +# Define for Node (usually these are done internally to Options) +pow_abs(l::Node, r::Node)::Node = (l.constant && r.constant) ? Node(pow_abs(l.val, r.val)::AbstractFloat) : Node(5, l, r) +pow_abs(l::Node, r::AbstractFloat)::Node = l.constant ? Node(pow_abs(l.val, r)::AbstractFloat) : Node(5, l, r) +pow_abs(l::AbstractFloat, r::Node)::Node = r.constant ? Node(pow_abs(l, r.val)::AbstractFloat) : Node(5, l, r) +custom_cos(x::Node)::Node = x.constant ? Node(custom_cos(x.val)::AbstractFloat) : Node(1, x) -# Should not convert this to power: -@test !occursin("^", stringTree(simplifyWithSymbolicUtils(tree, options), options)) \ No newline at end of file +options = Options(; + binary_operators=(+, *, -, /, pow_abs), + unary_operators=(custom_cos, exp, sin), +) +tree = (((x2 + x2) * ((-0.5982493 / pow_abs(x1, x2)) / -0.54734415)) + (sin(custom_cos(sin(1.2926733 - 1.6606787) / sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426)) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2))) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262)))) +# We use `index_functions` to avoid converting the custom operators into the primitives. +eqn = convert(Symbolic, tree, options, index_functions=true) + +tree_copy = convert(Node, eqn, options) +tree_copy2 = convert(Node, simplify(eqn), options) +# Too difficult to check the representation, so we check by evaluation: +N = 100 +X = rand(MersenneTwister(0), 3, N) .+ 0.1 +output1, flag1 = evalTreeArray(tree, X, options) +output2, flag2 = evalTreeArray(tree_copy, X, options) +output3, flag3 = evalTreeArray(tree_copy2, X, options) + +@test isapprox(output1, output2, atol=1e-4 * sqrt(N)) +# Simplified equation may give a different answer due to rounding errors, +# so we weaken the requirement: +@test isapprox(output1, output3, atol=1e-2 * sqrt(N)) \ No newline at end of file