From b7d33d14d774281516df179902fc07fd835cdda3 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Fri, 8 Apr 2022 05:26:14 -0400 Subject: [PATCH 01/20] Bump SymbolicUtils.jl to 0.19 I'm having some trouble isolating the issue due to FromFile.jl eval. But it seems like it's: ```julia using Metatheory rhs = :(~(x^-1)) Metatheory.Syntax.makeconsequent(rhs) [9:50 PM] julia> Metatheory.Syntax.makeconsequent(rhs) ERROR: Error when parsing right hand side Stacktrace: [1] error(s::String) @ Base .\error.jl:33 [2] makeconsequent(expr::Expr) @ Metatheory.Syntax C:\Users\accou\.julia\dev\Metatheory\src\Syntax.jl:64 [3] top-level scope @ REPL[2]:1 ``` @shashi --- Project.toml | 2 +- src/CustomSymbolicUtilsSimplification.jl | 18 +++++++++++++++--- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 081371b3..4dd08f5f 100644 --- a/Project.toml +++ b/Project.toml @@ -26,7 +26,7 @@ Optim = "0.19, 1.1" Pkg = "1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" -SymbolicUtils = "0.6" +SymbolicUtils = "0.19" julia = "1.5" [extras] diff --git a/src/CustomSymbolicUtilsSimplification.jl b/src/CustomSymbolicUtilsSimplification.jl index ffe6aa26..56d1b557 100644 --- a/src/CustomSymbolicUtilsSimplification.jl +++ b/src/CustomSymbolicUtilsSimplification.jl @@ -1,10 +1,22 @@ using FromFile using SymbolicUtils -using SymbolicUtils: Chain, If, RestartedChain, IfElse, Postwalk, Fixpoint, @ordered_acrule, isnotflat, flatten_term, needs_sorting, sort_args, is_literal_number, hasrepeats, merge_repeats, _isone, _iszero, _isinteger, istree, symtype, is_operation, has_trig, polynormalize +using SymbolicUtils: Chain, If, RestartedChain, IfElse, Postwalk, Fixpoint, @ordered_acrule, isnotflat, flatten_term, needs_sorting, sort_args, is_literal_number, hasrepeats, merge_repeats, _isone, _iszero, _isinteger, istree, symtype, is_operation, expand, operation, arguments @from "Core.jl" import Options @from "InterfaceSymbolicUtils.jl" import SYMBOLIC_UTILS_TYPES @from "Utils.jl" import isgood, @return_on_false +function has_trig(term) + !istree(term) && return false + fns = (sin, cos, tan, cot, sec, csc, exp) + op = operation(term) + + if Base.@nany 7 i -> fns[i] === op + return true + else + return any(has_trig, arguments(term)) + end +end + function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} return eqn, true end @@ -154,8 +166,8 @@ function get_simplifier(binops::A, unaops::B) where {A,B} # reduce overhead of simplify by defining these as constant serial_simplifier = If(istree, Fixpoint(default_simplifier())) serial_polynormal_simplifier = If(istree, - Fixpoint(Chain((polynormalize, - Fixpoint(default_simplifier()))))) + Fixpoint(Chain((expand, + Fixpoint(default_simplifier()))))) return serial_polynormal_simplifier end From df539aefd48571d6cf16b524fbdbf0c7829af944 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:20:14 -0400 Subject: [PATCH 02/20] Begin removal of SymbolicUtils-based simplification --- src/Mutate.jl | 9 +-------- src/Options.jl | 3 +-- src/OptionsStruct.jl | 1 - src/SingleIteration.jl | 3 --- 4 files changed, 2 insertions(+), 14 deletions(-) diff --git a/src/Mutate.jl b/src/Mutate.jl index 59e1568b..8b06ce60 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -94,14 +94,7 @@ function nextGeneration(dataset::Dataset{T}, elseif mutationChoice < cweights[6] tree = simplifyTree(tree, options) # Sometimes we simplify tree tree = combineOperators(tree, options) # See if repeated constants at outer levels - # SymbolicUtils is quite slow, so only rarely - # do we use it for simplification. - if rand() < 0.01 && options.use_symbolic_utils - tree = simplifyWithSymbolicUtils(tree, options, curmaxsize) - @recorder tmp_recorder["type"] = "full_simplify" - else - @recorder tmp_recorder["type"] = "partial_simplify" - end + @recorder tmp_recorder["type"] = "partial_simplify" mutation_accepted = true return PopMember(tree, beforeScore, beforeLoss, parent=parent_ref), mutation_accepted diff --git a/src/Options.jl b/src/Options.jl index 5b47d429..c8addd75 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -269,7 +269,6 @@ function Options(; probPickFirst=0.86f0, earlyStopCondition::Union{Function, Float32, Nothing}=nothing, stateReturn::Bool=false, - use_symbolic_utils::Bool=false, timeout_in_seconds=nothing, skip_mutation_failures::Bool=true, ) where {nuna,nbin} @@ -361,7 +360,7 @@ function Options(; earlyStopCondition = (loss, complexity) -> loss < earlyStopCondition end - options = Options{typeof(binary_operators),typeof(unary_operators), typeof(loss)}(binary_operators, unary_operators, bin_constraints, una_constraints, ns, parsimony, alpha, maxsize, maxdepth, fast_cycle, migration, hofMigration, fractionReplacedHof, shouldOptimizeConstants, hofFile, npopulations, perturbationFactor, annealing, batching, batchSize, mutationWeights, crossoverProbability, warmupMaxsizeBy, useFrequency, useFrequencyInTournament, npop, ncyclesperiteration, fractionReplaced, topn, verbosity, probNegate, nuna, nbin, seed, loss, progress, terminal_width, optimizer_algorithm, optimize_probability, optimizer_nrestarts, optimizer_iterations, recorder, recorder_file, probPickFirst, earlyStopCondition, stateReturn, use_symbolic_utils, timeout_in_seconds, skip_mutation_failures) + options = Options{typeof(binary_operators),typeof(unary_operators), typeof(loss)}(binary_operators, unary_operators, bin_constraints, una_constraints, ns, parsimony, alpha, maxsize, maxdepth, fast_cycle, migration, hofMigration, fractionReplacedHof, shouldOptimizeConstants, hofFile, npopulations, perturbationFactor, annealing, batching, batchSize, mutationWeights, crossoverProbability, warmupMaxsizeBy, useFrequency, useFrequencyInTournament, npop, ncyclesperiteration, fractionReplaced, topn, verbosity, probNegate, nuna, nbin, seed, loss, progress, terminal_width, optimizer_algorithm, optimize_probability, optimizer_nrestarts, optimizer_iterations, recorder, recorder_file, probPickFirst, earlyStopCondition, stateReturn, timeout_in_seconds, skip_mutation_failures) @eval begin Base.print(io::IO, tree::Node) = print(io, stringTree(tree, $options)) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 53ff1e8e..0d39baf1 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -48,7 +48,6 @@ struct Options{A,B,C<:Union{SupervisedLoss,Function}} probPickFirst::Float32 earlyStopCondition::Union{Function, Nothing} stateReturn::Bool - use_symbolic_utils::Bool timeout_in_seconds::Union{Float64, Nothing} skip_mutation_failures::Bool diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index ff58aca0..cc7abd03 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -54,9 +54,6 @@ function OptimizeAndSimplifyPopulation( @inbounds @simd for j=1:pop.n pop.members[j].tree = simplifyTree(pop.members[j].tree, options) pop.members[j].tree = combineOperators(pop.members[j].tree, options) - if options.use_symbolic_utils - pop.members[j].tree = simplifyWithSymbolicUtils(pop.members[j].tree, options, curmaxsize) - end if rand() < options.optimize_probability && options.shouldOptimizeConstants pop.members[j] = optimizeConstants(dataset, baseline, pop.members[j], options) end From 74debe18c420126a92ec91616d9a92db0e9be925 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:23:34 -0400 Subject: [PATCH 03/20] Remove remaining use of use_symbolic_utils --- src/OptionsStruct.jl | 2 +- test/test_params.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index 957d9853..9e711d79 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -81,7 +81,7 @@ Base.print(io::IO, options::Options) = print(io, """Options( # Speed Tweaks: batching=$(options.batching), batchSize=$(options.batchSize), fast_cycle=$(options.fast_cycle), # Logistics: - hofFile=$(options.hofFile), verbosity=$(options.verbosity), seed=$(options.seed), progress=$(options.progress), use_symbolic_utils=$(options.use_symbolic_utils), + hofFile=$(options.hofFile), verbosity=$(options.verbosity), seed=$(options.seed), progress=$(options.progress), # Early Exit: earlyStopCondition=$(options.earlyStopCondition), timeout_in_seconds=$(options.timeout_in_seconds), )""") diff --git a/test/test_params.jl b/test/test_params.jl index 3acf29e9..0af0c172 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -48,7 +48,6 @@ default_params = ( probPickFirst=1.0, earlyStopCondition=nothing, stateReturn=false, - use_symbolic_utils=false, timeout_in_seconds=nothing, skip_mutation_failures=false, ) \ No newline at end of file From e6258ff36959de996cbe4ebc91436ffa59d86451 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:33:46 -0400 Subject: [PATCH 04/20] Remove entire SymbolicUtils simplification interface --- src/CustomSymbolicUtilsSimplification.jl | 186 ----------------------- src/InterfaceSymbolicUtils.jl | 29 ++-- src/Mutate.jl | 2 +- src/SimplifyEquation.jl | 29 ---- src/SingleIteration.jl | 2 +- src/SymbolicRegression.jl | 6 +- test/test_simplification.jl | 17 +-- 7 files changed, 19 insertions(+), 252 deletions(-) delete mode 100644 src/CustomSymbolicUtilsSimplification.jl diff --git a/src/CustomSymbolicUtilsSimplification.jl b/src/CustomSymbolicUtilsSimplification.jl deleted file mode 100644 index b213e299..00000000 --- a/src/CustomSymbolicUtilsSimplification.jl +++ /dev/null @@ -1,186 +0,0 @@ -module CustomSymbolicUtilsSimplificationModule - -using SymbolicUtils -using SymbolicUtils: Chain, If, RestartedChain, IfElse, Postwalk, Fixpoint, @ordered_acrule, isnotflat, flatten_term, needs_sorting, sort_args, is_literal_number, hasrepeats, merge_repeats, _isone, _iszero, _isinteger, istree, symtype, is_operation, expand, operation, arguments -import ..CoreModule: Options -import ..InterfaceSymbolicUtilsModule: SYMBOLIC_UTILS_TYPES -import ..UtilsModule: isgood, @return_on_false - -function has_trig(term) - !istree(term) && return false - fns = (sin, cos, tan, cot, sec, csc, exp) - op = operation(term) - - if Base.@nany 7 i -> fns[i] === op - return true - else - return any(has_trig, arguments(term)) - end -end - -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} - return eqn, true -end - -function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:SymbolicUtils.Term{<:Number}} - args = SymbolicUtils.arguments(eqn) - nargs = length(args) - if nargs == 1 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - return op(l), true - elseif op == ^ - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - n = args[2] - if typeof(n) <: Int - if n == 1 - return l, true - elseif n == -1 - return 1.0 / l, true - elseif n > 1 - return reduce(*, [l for i=1:n]), true - elseif n < -1 - return reduce(/, vcat([1], [l for i=1:abs(n)])), true - else - return 1.0, true - end - else - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - return l ^ r, true - end - elseif nargs == 2 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - @return_on_false isgood(r) eqn - return op(l, r), true - else - # return mapreduce(multiply_powers, op, args) - # ## reduce(op, map(multiply_powers, args)) - out = map(multiply_powers, args) #vector of tuples - for i=1:size(out, 1) - @return_on_false out[i][2] eqn - @return_on_false isgood(out[i][1]) eqn - end - cumulator = out[1][1] - for i=2:size(out, 1) - cumulator = op(cumulator, out[i][1]) - @return_on_false isgood(cumulator) eqn - end - return cumulator, true - end -end - -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:SymbolicUtils.Term{<:Number}} - op = SymbolicUtils.operation(eqn) - return multiply_powers(eqn, op) -end - -# Operators required for each rule: -function get_simplifier(binops::A, unaops::B) where {A,B} - PLUS_RULES = [ - rule for (required_ops, rule) in [ - ((+,), @rule(~x::isnotflat(+) => flatten_term(+, ~x))), - ((+,), @rule(~x::needs_sorting(+) => sort_args(+, ~x))), - ((+,), @ordered_acrule(~a::is_literal_number + ~b::is_literal_number => ~a + ~b)), - ((+,), @acrule(*(~~x) + *(~β, ~~x) => *(1 + ~β, (~~x)...))), - ((+,), @acrule(*(~α, ~~x) + *(~β, ~~x) => *(~α + ~β, (~~x)...))), - ((+,), @acrule(*(~~x, ~α) + *(~~x, ~β) => *(~α + ~β, (~~x)...))), - ((+,), @acrule(~x + *(~β, ~x) => *(1 + ~β, ~x))), - ((+,), @acrule(*(~α::is_literal_number, ~x) + ~x => *(~α + 1, ~x))), - ((+,), @rule(+(~~x::hasrepeats) => +(merge_repeats(*, ~~x)...))), - ((+,), @ordered_acrule((~z::_iszero + ~x) => ~x)), - ((+,), @rule(+(~x) => ~x))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - TIMES_RULES = [ - rule for (required_ops, rule) in [ - ((*,), @rule(~x::isnotflat(*) => flatten_term(*, ~x))), - ((*,), @rule(~x::needs_sorting(*) => sort_args(*, ~x))), - ((*,), @ordered_acrule(~a::is_literal_number * ~b::is_literal_number => ~a * ~b)), - ((*,), @rule(*(~~x::hasrepeats) => *(merge_repeats(^, ~~x)...))), - ((*,), @acrule((~y)^(~n) * ~y => (~y)^(~n+1))), - ((*,), @ordered_acrule((~x)^(~n) * (~x)^(~m) => (~x)^(~n + ~m))), - ((*,), @ordered_acrule((~z::_isone * ~x) => ~x)), - ((*,), @ordered_acrule((~z::_iszero * ~x) => ~z)), - ((*,), @rule(*(~x) => ~x))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - POW_RULES =[ - rule for (required_ops, rule) in [ - ((*,), @rule(^(*(~~x), ~y::_isinteger) => *(map(a->SymbolicUtils.pow(a, ~y), ~~x)...))), - ((*,), @rule((((~x)^(~p::_isinteger))^(~q::_isinteger)) => (~x)^((~p)*(~q)))), - ((*,), @rule(^(~x, ~z::_iszero) => 1)), - ((*,), @rule(^(~x, ~z::_isone) => ~x)), - ((*, /,), @rule(inv(~x) => ~x ^ -1))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - ASSORTED_RULES =[ - rule for (required_ops, rule) in [ - ((), @rule(identity(~x) => ~x)), - ((*,), @rule(-(~x) => -1*~x)), - ((*, -,), @rule(-(~x, ~y) => ~x + -1(~y))), - ((/, *,), @rule(~x / ~y => ~x * SymbolicUtils.pow(~y, -1))), - ((), @rule(one(~x) => one(symtype(~x)))), - ((), @rule(zero(~x) => zero(symtype(~x))))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - TRIG_RULES = [ - rule for (required_ops, rule) in [ - ((sin, cos, *, +,), @acrule(sin(~x)^2 + cos(~x)^2 => one(~x))), - ((sin, cos, *, +,), @acrule(sin(~x)^2 + -1 => cos(~x)^2)), - ((sin, cos, *, +,), @acrule(cos(~x)^2 + -1 => sin(~x)^2)), - ((tan, sec, *, +,), @acrule(tan(~x)^2 + -1*sec(~x)^2 => one(~x))), - ((tan, sec, *, +,), @acrule(tan(~x)^2 + 1 => sec(~x)^2)), - ((tan, sec, *, +,), @acrule(sec(~x)^2 + -1 => tan(~x)^2)), - ((cot, csc, *, +,), @acrule(cot(~x)^2 + -1*csc(~x)^2 => one(~x))), - ((cot, csc, *, +,), @acrule(cot(~x)^2 + 1 => csc(~x)^2)), - ((cot, csc, *, +,), @acrule(csc(~x)^2 + -1 => cot(~x)^2))] - if all([(op in binops || op in unaops) for op in required_ops]) - ] - function number_simplifier() - rule_tree = [If(istree, Chain(ASSORTED_RULES)), - If(is_operation(+), - Chain(PLUS_RULES)), - If(is_operation(*), - Chain(TIMES_RULES)), - If(is_operation(^), - Chain(POW_RULES))] |> RestartedChain - - rule_tree - end - trig_simplifier(;kw...) = Chain(TRIG_RULES) - function default_simplifier(; kw...) - IfElse(has_trig, - Postwalk(Chain((number_simplifier(), - trig_simplifier())), - ; kw...), - Postwalk(number_simplifier()) - ; kw...) - end - # reduce overhead of simplify by defining these as constant - serial_simplifier = If(istree, Fixpoint(default_simplifier())) - serial_polynormal_simplifier = If(istree, - Fixpoint(Chain((expand, - Fixpoint(default_simplifier()))))) - return serial_polynormal_simplifier -end - -function custom_simplify(init_eqn::T, options::Options)::Tuple{SYMBOLIC_UTILS_TYPES, Bool} where {T<:SYMBOLIC_UTILS_TYPES} - if !istree(init_eqn) #simplifier will return nothing if not a tree. - return init_eqn, false - end - simplifier = get_simplifier(options.binops, options.unaops) - eqn = simplifier(init_eqn)::SYMBOLIC_UTILS_TYPES #simplify(eqn, polynorm=true) - - # Remove power laws - return multiply_powers(eqn::SYMBOLIC_UTILS_TYPES) -end - -end \ No newline at end of file diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index b9ecb3b0..0f87098b 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -4,7 +4,6 @@ using SymbolicUtils import ..CoreModule: CONST_TYPE, Node, Options import ..UtilsModule: isgood, isbad, @return_on_false -const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Sym{<:Number},SymbolicUtils.Term{<:Number}} """ node_to_symbolic(tree::Node, options::Options; @@ -28,15 +27,14 @@ will generate a symbolic equation in SymbolicUtils.jl format. using `symbolic_to_node`. """ function node_to_symbolic(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false - )::SYMBOLIC_UTILS_TYPES + varMap::Union{Array{String, 1}, Nothing}=nothing, + evaluate_functions::Bool=false, + index_functions::Bool=false) if tree.degree == 0 if tree.constant return tree.val else - if varMap == nothing + if varMap === nothing return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")) else return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])) @@ -73,15 +71,14 @@ function node_to_symbolic(tree::Node, options::Options; end function node_to_symbolic_safe(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false - )::Tuple{SYMBOLIC_UTILS_TYPES,Bool} + varMap::Union{Array{String, 1}, Nothing}=nothing, + evaluate_functions::Bool=false, + index_functions::Bool=false) if tree.degree == 0 if tree.constant return tree.val, true else - if varMap == nothing + if varMap === nothing return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")), true else return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])), true @@ -135,19 +132,18 @@ end # Just constant function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:Number} + varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:Number} return Node(convert(CONST_TYPE, eqn)) end # Just variable function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Sym{<:Number}} + varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Sym{<:Number}} return Node(varMap_to_index(eqn.name, varMap)) end function _multiarg_split(op_idx::Int, eqn::Array{Any, 1}, - options::Options, varMap::Union{Array{String, 1}, Nothing} - )::Node + options::Options, varMap::Union{Array{String, 1}, Nothing})::Node if length(eqn) == 2 return Node(op_idx, symbolic_to_node(eqn[1], options, varMap=varMap), @@ -167,8 +163,7 @@ end # Equation: function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing - )::Node where {T<:SymbolicUtils.Term{<:Number}} + varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Term{<:Number}} args = SymbolicUtils.arguments(eqn) l = symbolic_to_node(args[1], options, varMap=varMap) nargs = length(args) diff --git a/src/Mutate.jl b/src/Mutate.jl index c90e6643..c7b6c96c 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -6,7 +6,7 @@ import ..LossFunctionsModule: scoreFunc, scoreFuncBatch import ..CheckConstraintsModule: check_constraints import ..PopMemberModule: PopMember import ..MutationFunctionsModule: genRandomTreeFixedSize, mutateConstant, mutateOperator, appendRandomOp, prependRandomOp, insertRandomOp, deleteRandomOp, crossoverTrees -import ..SimplifyEquationModule: simplifyTree, combineOperators, simplifyWithSymbolicUtils +import ..SimplifyEquationModule: simplifyTree, combineOperators import ..RecorderModule: @recorder # Go through one simulated options.annealing mutation cycle diff --git a/src/SimplifyEquation.jl b/src/SimplifyEquation.jl index 4145d5e0..3f75c980 100644 --- a/src/SimplifyEquation.jl +++ b/src/SimplifyEquation.jl @@ -2,7 +2,6 @@ module SimplifyEquationModule import ..CoreModule: CONST_TYPE, Node, copyNode, Options import ..EquationUtilsModule: countNodes -import ..CustomSymbolicUtilsSimplificationModule: custom_simplify import ..InterfaceSymbolicUtilsModule: node_to_symbolic_safe, symbolic_to_node import ..CheckConstraintsModule: check_constraints import ..UtilsModule: isbad, isgood @@ -131,32 +130,4 @@ function simplifyTree(tree::Node, options::Options)::Node return tree end - -# Expensive but powerful simplify using SymbolicUtils -function simplifyWithSymbolicUtils(tree::Node, options::Options, curmaxsize::Int)::Node - if !(((+) in options.binops) && ((*) in options.binops)) - return tree - end - init_node = copyNode(tree) - init_size = countNodes(tree) - symbolic_util_form, complete = node_to_symbolic_safe(tree, options, index_functions=true) - if !complete - return init_node - end - eqn_form, complete2 = custom_simplify(symbolic_util_form, options) - if !complete2 - return init_node - end - final_node = symbolic_to_node(eqn_form, options) - final_size = countNodes(tree) - did_simplification_improve = (final_size <= init_size) && (check_constraints(final_node, options, curmaxsize)) - output = did_simplification_improve ? final_node : init_node - - return output -end - -function simplifyWithSymbolicUtils(tree::Node, options::Options)::Node - simplifyWithSymbolicUtils(tree, options, options.maxsize) -end - end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 7a4f36c7..1f1ceb61 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -3,7 +3,7 @@ module SingleIterationModule import ..CoreModule: Options, Dataset, RecordType, stringTree import ..EquationUtilsModule: countNodes import ..UtilsModule: debug -import ..SimplifyEquationModule: simplifyTree, combineOperators, simplifyWithSymbolicUtils +import ..SimplifyEquationModule: simplifyTree, combineOperators import ..PopMemberModule: copyPopMember import ..PopulationModule: Population, finalizeScores, bestSubPop import ..HallOfFameModule: HallOfFame diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index b4d3babe..2ec1c508 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -21,8 +21,6 @@ export Population, differentiableEvalTreeArray, node_to_symbolic, symbolic_to_node, - custom_simplify, - simplifyWithSymbolicUtils, combineOperators, genRandomTree, genRandomTreeFixedSize, @@ -76,7 +74,6 @@ include("ConstantOptimization.jl") include("Population.jl") include("HallOfFame.jl") include("InterfaceSymbolicUtils.jl") -include("CustomSymbolicUtilsSimplification.jl") include("SimplifyEquation.jl") include("Mutate.jl") include("RegularizedEvolution.jl") @@ -96,8 +93,7 @@ import .PopulationModule: Population, bestSubPop, record_population, bestOfSampl import .HallOfFameModule: HallOfFame, calculateParetoFrontier, string_dominating_pareto_curve import .SingleIterationModule: SRCycle, OptimizeAndSimplifyPopulation import .InterfaceSymbolicUtilsModule: node_to_symbolic, symbolic_to_node -import .CustomSymbolicUtilsSimplificationModule: custom_simplify -import .SimplifyEquationModule: simplifyWithSymbolicUtils, combineOperators, simplifyTree +import .SimplifyEquationModule: combineOperators, simplifyTree import .ProgressBarsModule: ProgressBar, set_multiline_postfix import .RecorderModule: @recorder, find_iteration_from_record diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 76cfbdb0..76744637 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -10,18 +10,9 @@ options = Options(; default_params..., binary_operators=binary_operators) tree = Node("x1") + Node("x1") # Should simplify to 2*x1: -eqn = node_to_symbolic(tree, options; index_functions=true) -eqn2 = custom_simplify(eqn, options) - -@test occursin("2", "$(eqn2[1])") - -# Repeat test with simplifyWithSymbolicUtils: -simple_tree = simplifyWithSymbolicUtils(tree, options, 5) +import SymbolicUtils: simplify -# Check that the first operator is *, for 2 * x1: -@test simple_tree.op == index_of_mult - -tree = Node("x1") * Node("x1") + Node("x1") * Node("x1") +eqn = node_to_symbolic(tree, options; index_functions=true) +eqn2 = simplify(eqn, options) -# Should not convert this to power: -@test !occursin("^", stringTree(simplifyWithSymbolicUtils(tree, options), options)) \ No newline at end of file +@test occursin("2", "$(eqn2[1])") \ No newline at end of file From 51b7da4d95e34f87c663f9ac6d0834e6cb23f541 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:38:13 -0400 Subject: [PATCH 05/20] Revert InterfaceSymbolicUtils.jl to post-0.6 SymbolicUtils.jl --- src/InterfaceSymbolicUtils.jl | 346 ++++++++++++++++------------------ 1 file changed, 164 insertions(+), 182 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 0f87098b..33f90418 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -4,6 +4,98 @@ using SymbolicUtils import ..CoreModule: CONST_TYPE, Node, Options import ..UtilsModule: isgood, isbad, @return_on_false +const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}} + +const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /) + +isgood(x::SymbolicUtils.Symbolic) = SymbolicUtils.istree(x) ? all(isgood.([SymbolicUtils.operation(x);SymbolicUtils.arguments(x)])) : true +subs_bad(x) = isgood(x) ? x : Inf + +function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=false, evaluate_functions::Bool=false) + if tree.degree == 0 + # Return constant if needed + tree.constant && return subs_bad(tree.val) + return SymbolicUtils.Sym{Number}(Symbol("x$(tree.feature)")) + end + # Collect the next children + children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,) + # Get the operation + op = tree.degree > 1 ? options.binops[tree.op] : options.unaops[tree.op] + # Create an N tuple of Numbers for each argument + dtypes = map(x->Number, 1:tree.degree) + # + if !(op ∈ SUPPORTED_OPS) && index_functions + op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...}, Number}}(Symbol(op)) + end + + return subs_bad(op(map(x->parse_tree_to_eqs(x, options, index_functions, evaluate_functions), children)...)) +end + +## Convert symbolic function back +convert_to_function(x, args...) = x + +function convert_to_function(x::SymbolicUtils.Sym{SymbolicUtils.FnType{T, Number}}, options::Options) where {T <: Tuple} + degree = length(T.types) + if degree == 1 + ind = findoperation(x.name, options.unaops) + return options.unaops[ind] + elseif degree == 2 + ind = findoperation(x.name, options.binops) + return options.binops[ind] + else + throw(AssertionError("Function $(String(x.name)) has degree > 2 !")) + end +end + +# Split equation +function split_eq(op, args, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + !(op ∈ (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!")) + if Symbol(op) == Symbol(sum) + ind = findoperation(+, options.binops) + elseif Symbol(op) == Symbol(prod) + ind = findoperation(*, options.binops) + else + ind = findoperation(op, options.binops) + end + return Node(ind, convert(Node, args[1], options; varMap=varMap), convert(Node, op(args[2:end]...), options; varMap=varMap)) +end + +function findoperation(op, ops) + for (i,oi) in enumerate(ops) + Symbol(oi) == Symbol(op) && return i + end + throw(error("Operation $(op) in expression not found in operations $(ops)!")) +end + +function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + return Node(CONST_TYPE(x)) +end + +function Base.convert(::typeof(Node), x::Symbol, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + varMap === nothing && return Node(String(x)) + return Node(String(x), varMap) +end + +function Base.convert(::typeof(Node), x::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + if !SymbolicUtils.istree(x) + varMap === nothing && return Node(String(x.name)) + return Node(String(x.name), varMap) + end + + # First, we remove integer powers: + y, good_return = multiply_powers(x) + if good_return + x = y + end + + op = convert_to_function(SymbolicUtils.operation(x)) + args = SymbolicUtils.arguments(x) + + length(args) > 2 && return split_eq(op, args, options; varMap=varMap) + ind = length(args) == 2 ? findoperation(op, options.binops) : findoperation(op, options.unaops) + + return Node(ind, map(x->convert(Node, x, options, varMap=varMap), args)...) +end """ node_to_symbolic(tree::Node, options::Options; @@ -25,203 +117,93 @@ will generate a symbolic equation in SymbolicUtils.jl format. - `index_functions::Bool=false`: Whether to generate special names for the operators, which then allows one to convert back to a `Node` format using `symbolic_to_node`. + (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). """ -function node_to_symbolic(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false) - if tree.degree == 0 - if tree.constant - return tree.val - else - if varMap === nothing - return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")) - else - return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])) - end - end - elseif tree.degree == 1 - left_side = node_to_symbolic(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - op = options.unaops[tree.op] - if (op in (cos, sin, exp, cot, tan, csc, sec)) || evaluate_functions - return op(left_side) - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol("_unaop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol(op)) - end - return dummy_op(left_side) - end - else - left_side = node_to_symbolic(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - right_side = node_to_symbolic(tree.r, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - op = options.binops[tree.op] - if (op in (+, -, *, /)) || evaluate_functions - return op(left_side, right_side) - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol("_binop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol(op)) - end - return dummy_op(left_side, right_side) - end - end -end - -function node_to_symbolic_safe(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=false) - if tree.degree == 0 - if tree.constant - return tree.val, true - else - if varMap === nothing - return SymbolicUtils.Sym{Real}(Symbol("x$(tree.feature)")), true - else - return SymbolicUtils.Sym{Real}(Symbol(varMap[tree.feature])), true - end - end - elseif tree.degree == 1 - left_side, complete = node_to_symbolic_safe(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete Inf - @return_on_false isgood(left_side) Inf - op = options.unaops[tree.op] - if (op in (cos, sin, exp, cot, tan, csc, sec)) || evaluate_functions - out = op(left_side) - @return_on_false isgood(out) Inf - return out, true - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol("_unaop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number}, Real}}(Symbol(op)) - end - out = dummy_op(left_side) - #TODO: Can probably delete this check: - @return_on_false isgood(out) Inf - return out, true - end - else - left_side, complete = node_to_symbolic_safe(tree.l, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete Inf - @return_on_false isgood(left_side) Inf - right_side, complete2 = node_to_symbolic_safe(tree.r, options, varMap=varMap, evaluate_functions=evaluate_functions, index_functions=index_functions) - @return_on_false complete2 Inf - @return_on_false isgood(right_side) Inf - op = options.binops[tree.op] - if (op in (+, -, *, /)) || evaluate_functions - out = op(left_side, right_side) - @return_on_false isgood(out) Inf - return out, true - else - if index_functions - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol("_binop$(tree.op)")) - else - dummy_op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{Number,Number}, Real}}(Symbol(op)) - end - out = dummy_op(left_side, right_side) - # Can probably delete this check TODO - @return_on_false isgood(out) Inf - return out, true - end - end +function node_to_symbolic(tree::Node, options::Options; + varMap::Union{Array{String, 1}, Nothing}=nothing, + evaluate_functions::Bool=false, + index_functions::Bool=true + ) + expr = subs_bad(parse_tree_to_eqs(tree, options, index_functions, evaluate_functions)) + # Check for NaN and Inf + @assert isgood(expr) "The recovered equation contains NaN or Inf." + # Return if no varMap is given + varMap === nothing && return expr + # Create a substitution tuple + subs = Dict( + [SymbolicUtils.Sym{Number}(Symbol("x$(i)")) => SymbolicUtils.Sym{Number}(Symbol(varMap[i])) for i in 1:length(varMap)]... + ) + return substitute(expr, subs) end -# Just constant function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:Number} - return Node(convert(CONST_TYPE, eqn)) -end + varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Symbolic} -# Just variable -function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Sym{<:Number}} - return Node(varMap_to_index(eqn.name, varMap)) + convert(Node, eqn, options; varMap=varMap) end -function _multiarg_split(op_idx::Int, eqn::Array{Any, 1}, - options::Options, varMap::Union{Array{String, 1}, Nothing})::Node - if length(eqn) == 2 - return Node(op_idx, - symbolic_to_node(eqn[1], options, varMap=varMap), - symbolic_to_node(eqn[2], options, varMap=varMap)) - elseif length(eqn) == 3 - return Node(op_idx, - symbolic_to_node(eqn[1], options, varMap=varMap), - _multiarg_split(op_idx, eqn[2:3], options, varMap)) - else - # Minimize depth: - split_point = round(Int, length(eqn) // 2) - return Node(op_idx, - _multiarg_split(op_idx, eqn[1:split_point], options, varMap), - _multiarg_split(op_idx, eqn[split_point+1:end], options, varMap)) - end +function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} + return eqn, true end -# Equation: -function symbolic_to_node(eqn::T, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing)::Node where {T<:SymbolicUtils.Term{<:Number}} - args = SymbolicUtils.arguments(eqn) - l = symbolic_to_node(args[1], options, varMap=varMap) - nargs = length(args) - op = SymbolicUtils.operation(eqn) - if nargs == 1 - op_idx = unaop_to_index(op, options) - return Node(op_idx, l) - else - op_idx = binop_to_index(op, options) - if nargs == 2 - r = symbolic_to_node(args[2], options, varMap=varMap) - return Node(op_idx, l, r) +function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} + args = SymbolicUtils.arguments(eqn) + nargs = length(args) + if nargs == 1 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + return op(l), true + elseif op == ^ + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + n = args[2] + if typeof(n) <: Int + if n == 1 + return l, true + elseif n == -1 + return 1.0 / l, true + elseif n > 1 + return reduce(*, [l for i=1:n]), true + elseif n < -1 + return reduce(/, vcat([1], [l for i=1:abs(n)])), true + else + return 1.0, true + end else - # TODO: Assert operator is +, * - return _multiarg_split(op_idx, args, options, varMap) + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + return l ^ r, true end - end -end - -function unaop_to_index(op::F, options::Options)::Int where {F<:SymbolicUtils.Sym} - # In format _unaop1 - parse(Int, string(op.name)[7:end]) -end - -function binop_to_index(op::F, options::Options)::Int where {F<:SymbolicUtils.Sym} - # In format _binop1 - parse(Int, string(op.name)[7:end]) -end - -function unaop_to_index(op::F, options::Options)::Int where {F<:Function} - for i=1:options.nuna - if op == options.unaops[i] - return i - end - end - error("Operator $(op) in simplified expression not found in options $(options.unaops)!") -end - -function binop_to_index(op::F, options::Options)::Int where {F<:Function} - for i=1:options.nbin - if op == options.binops[i] - return i + elseif nargs == 2 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + @return_on_false isgood(r) eqn + return op(l, r), true + else + # return mapreduce(multiply_powers, op, args) + # ## reduce(op, map(multiply_powers, args)) + out = map(multiply_powers, args) #vector of tuples + for i=1:size(out, 1) + @return_on_false out[i][2] eqn + @return_on_false isgood(out[i][1]) eqn end - end - error("Operator $(op) in simplified expression not found in options $(options.binops)!") -end - -function varMap_to_index(var::Symbol, varMap::Array{String, 1})::Int - str = string(var) - for i=1:length(varMap) - if str == varMap[i] - return i + cumulator = out[1][1] + for i=2:size(out, 1) + cumulator = op(cumulator, out[i][1]) + @return_on_false isgood(cumulator) eqn end - end + return cumulator, true + end end -function varMap_to_index(var::Symbol, varMap::Nothing)::Int - return parse(Int, string(var)[2:end]) +function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} + op = SymbolicUtils.operation(eqn) + return multiply_powers(eqn, op) end end From b59cb3f304368f72b07b01577264ae2f749ab443 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:41:48 -0400 Subject: [PATCH 06/20] Remove unused simplification import --- src/SimplifyEquation.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/SimplifyEquation.jl b/src/SimplifyEquation.jl index 3f75c980..e08b2669 100644 --- a/src/SimplifyEquation.jl +++ b/src/SimplifyEquation.jl @@ -2,7 +2,6 @@ module SimplifyEquationModule import ..CoreModule: CONST_TYPE, Node, copyNode, Options import ..EquationUtilsModule: countNodes -import ..InterfaceSymbolicUtilsModule: node_to_symbolic_safe, symbolic_to_node import ..CheckConstraintsModule: check_constraints import ..UtilsModule: isbad, isgood From b42b15ce8552dca57859583e345e489a50530268 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:51:09 -0400 Subject: [PATCH 07/20] Remove unused evaluate functions argument --- src/InterfaceSymbolicUtils.jl | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 33f90418..c5f3c658 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -11,7 +11,7 @@ const SUPPORTED_OPS = (cos, sin, exp, cot, tan, csc, sec, +, -, *, /) isgood(x::SymbolicUtils.Symbolic) = SymbolicUtils.istree(x) ? all(isgood.([SymbolicUtils.operation(x);SymbolicUtils.arguments(x)])) : true subs_bad(x) = isgood(x) ? x : Inf -function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=false, evaluate_functions::Bool=false) +function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=false) if tree.degree == 0 # Return constant if needed tree.constant && return subs_bad(tree.val) @@ -28,7 +28,7 @@ function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=f op = SymbolicUtils.Sym{(SymbolicUtils.FnType){Tuple{dtypes...}, Number}}(Symbol(op)) end - return subs_bad(op(map(x->parse_tree_to_eqs(x, options, index_functions, evaluate_functions), children)...)) + return subs_bad(op(map(x->parse_tree_to_eqs(x, options, index_functions), children)...)) end ## Convert symbolic function back @@ -100,7 +100,6 @@ end """ node_to_symbolic(tree::Node, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, index_functions::Bool=false) The interface to SymbolicUtils.jl. Passing a tree to this function @@ -112,19 +111,15 @@ will generate a symbolic equation in SymbolicUtils.jl format. - `options::Options`: Options, which contains the operators used in the equation. - `varMap::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for each feature. Default is [x1, x2, x3, ...]. -- `evaluate_functions::Bool=false`: Whether to evaluate the operators, or - leave them as symbolic. - `index_functions::Bool=false`: Whether to generate special names for the operators, which then allows one to convert back to a `Node` format using `symbolic_to_node`. (CURRENTLY UNAVAILABLE - See https://github.com/MilesCranmer/SymbolicRegression.jl/pull/84). """ function node_to_symbolic(tree::Node, options::Options; - varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, - index_functions::Bool=true - ) - expr = subs_bad(parse_tree_to_eqs(tree, options, index_functions, evaluate_functions)) + varMap::Union{Array{String, 1}, Nothing}=nothing, + index_functions::Bool=true) + expr = subs_bad(parse_tree_to_eqs(tree, options, index_functions)) # Check for NaN and Inf @assert isgood(expr) "The recovered equation contains NaN or Inf." # Return if no varMap is given From 96383a8b7b01e58c4e96bb1824745bb328338812 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 00:53:45 -0400 Subject: [PATCH 08/20] Make index_functions default false --- src/InterfaceSymbolicUtils.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index c5f3c658..2d15d4ab 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -118,7 +118,7 @@ will generate a symbolic equation in SymbolicUtils.jl format. """ function node_to_symbolic(tree::Node, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing, - index_functions::Bool=true) + index_functions::Bool=false) expr = subs_bad(parse_tree_to_eqs(tree, options, index_functions)) # Check for NaN and Inf @assert isgood(expr) "The recovered equation contains NaN or Inf." From fa0df71f32819c93248b11a43e8a0a5f829c4c28 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 01:25:41 -0400 Subject: [PATCH 09/20] Fix SymbolicUtils conversion test --- test/test_simplification.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 76744637..13ca6964 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -12,7 +12,6 @@ tree = Node("x1") + Node("x1") # Should simplify to 2*x1: import SymbolicUtils: simplify -eqn = node_to_symbolic(tree, options; index_functions=true) -eqn2 = simplify(eqn, options) - -@test occursin("2", "$(eqn2[1])") \ No newline at end of file +eqn = node_to_symbolic(tree, options; index_functions=false) +eqn2 = simplify(eqn) +@test occursin("2", "$(eqn2[1])") From 824461af6273303910615768e14cc6851706cb7b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 21:14:27 -0400 Subject: [PATCH 10/20] Add convert function from Node to Symbolic --- src/InterfaceSymbolicUtils.jl | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 2d15d4ab..9f00bb44 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -67,6 +67,12 @@ function findoperation(op, ops) throw(error("Operation $(op) in expression not found in operations $(ops)!")) end +function Base.convert(::typeof(SymbolicUtils.Symbolic), tree::Node, options::Options; + varMap::Union{Array{String, 1}, Nothing}=nothing, + index_functions::Bool=false) + node_to_symbolic(tree, options; varMap=varMap, index_functions=index_functions) +end + function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) return Node(CONST_TYPE(x)) end @@ -76,20 +82,20 @@ function Base.convert(::typeof(Node), x::Symbol, options::Options; varMap::Union return Node(String(x), varMap) end -function Base.convert(::typeof(Node), x::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) - if !SymbolicUtils.istree(x) - varMap === nothing && return Node(String(x.name)) - return Node(String(x.name), varMap) +function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + if !SymbolicUtils.istree(expr) + varMap === nothing && return Node(String(expr.name)) + return Node(String(expr.name), varMap) end # First, we remove integer powers: - y, good_return = multiply_powers(x) + y, good_return = multiply_powers(expr) if good_return - x = y + expr = y end - op = convert_to_function(SymbolicUtils.operation(x)) - args = SymbolicUtils.arguments(x) + op = convert_to_function(SymbolicUtils.operation(expr)) + args = SymbolicUtils.arguments(expr) length(args) > 2 && return split_eq(op, args, options; varMap=varMap) ind = length(args) == 2 ? findoperation(op, options.binops) : findoperation(op, options.unaops) From f7846b677776a171a3becfeeb290e801c01b7e6b Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 21:27:34 -0400 Subject: [PATCH 11/20] Replace String with repr --- test/test_simplification.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 13ca6964..41516426 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -14,4 +14,5 @@ import SymbolicUtils: simplify eqn = node_to_symbolic(tree, options; index_functions=false) eqn2 = simplify(eqn) -@test occursin("2", "$(eqn2[1])") +# Should correctly simplify to 2 x1: +@test occursin("2", "$(repr(eqn2)[1])") From 5133bcba63f7e33fc7cce446f21e4d6bdbbda52c Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 21:42:38 -0400 Subject: [PATCH 12/20] Fix API calls --- docs/src/api.md | 1 - test/full.jl | 6 +++--- test/test_simplification.jl | 4 ++-- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/docs/src/api.md b/docs/src/api.md index e785e941..eabb78f0 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -66,7 +66,6 @@ evalTreeArray(tree::Node, cX::AbstractMatrix{T}, options::Options) where {T<:Rea ```@docs node_to_symbolic(tree::Node, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing, - evaluate_functions::Bool=false, index_functions::Bool=false) ``` diff --git a/test/full.jl b/test/full.jl index 3e978665..ac79219a 100644 --- a/test/full.jl +++ b/test/full.jl @@ -107,7 +107,7 @@ for i=0:5 # Always assume multi for dom in dominating best = dom[end] - eqn = node_to_symbolic(best.tree, options, evaluate_functions=true) + eqn = node_to_symbolic(best.tree, options) local x4 = SymbolicUtils.Sym{Real}(Symbol("x4")) true_eqn = 2*cos(x4) @@ -145,7 +145,7 @@ dominating = calculateParetoFrontier(X, y, hallOfFame, options) best = dominating[end] eqn = node_to_symbolic(best.tree, options; - evaluate_functions=true, varMap=varMap) + varMap=varMap) t4 = SymbolicUtils.Sym{Real}(Symbol("t4")) true_eqn = 2*cos(t4) @@ -170,7 +170,7 @@ dominating = calculateParetoFrontier(X, y, hallOfFame, options) best = dominating[end] printTree(best.tree, options) eqn = node_to_symbolic(best.tree, options; - evaluate_functions=true, varMap=varMap) + varMap=varMap) residual = simplify(eqn - true_eqn) + t4 * 1e-10 @test best.loss < maximum_residual / 10 diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 41516426..136b74e7 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -10,9 +10,9 @@ options = Options(; default_params..., binary_operators=binary_operators) tree = Node("x1") + Node("x1") # Should simplify to 2*x1: -import SymbolicUtils: simplify +import SymbolicUtils: simplify, Symbolic -eqn = node_to_symbolic(tree, options; index_functions=false) +eqn = convert(Symbolic, tree, options) eqn2 = simplify(eqn) # Should correctly simplify to 2 x1: @test occursin("2", "$(repr(eqn2)[1])") From a1ec1b24ae52f4651c7ac74ef1498ece11cec663 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 22:36:10 -0400 Subject: [PATCH 13/20] Use LiteralReal in SymbolicUtils export --- src/InterfaceSymbolicUtils.jl | 4 ++-- test/test_simplification.jl | 16 ++++++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 9f00bb44..1d9db543 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -15,7 +15,7 @@ function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=f if tree.degree == 0 # Return constant if needed tree.constant && return subs_bad(tree.val) - return SymbolicUtils.Sym{Number}(Symbol("x$(tree.feature)")) + return SymbolicUtils.Sym{LiteralReal}(Symbol("x$(tree.feature)")) end # Collect the next children children = tree.degree >= 2 ? (tree.l, tree.r) : (tree.l,) @@ -132,7 +132,7 @@ function node_to_symbolic(tree::Node, options::Options; varMap === nothing && return expr # Create a substitution tuple subs = Dict( - [SymbolicUtils.Sym{Number}(Symbol("x$(i)")) => SymbolicUtils.Sym{Number}(Symbol(varMap[i])) for i in 1:length(varMap)]... + [SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) => SymbolicUtils.Sym{LiteralReal}(Symbol(varMap[i])) for i in 1:length(varMap)]... ) return substitute(expr, subs) end diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 136b74e7..2ed6656a 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -16,3 +16,19 @@ eqn = convert(Symbolic, tree, options) eqn2 = simplify(eqn) # Should correctly simplify to 2 x1: @test occursin("2", "$(repr(eqn2)[1])") + +# Let's convert back: +tree = convert(Node, eqn2, options) +# Make sure one of the nodes is now 2.0: +@test (tree.l.constant ? tree.l : tree.r).val == 2 +# Make sure the other node is x1: +@test (!tree.l.constant ? tree.l : tree.r).feature == 1 + +# Finally, let's try simplifying a product, and ensure +# that SymbolicUtils does not convert it to a power: +tree = Node("x1") * Node("x1") +eqn = convert(Symbolic, tree, options) +@test repr(eqn) == "x1*x1" +# Test converting back: +tree_copy = convert(Node, eqn, options) +@test repr(tree_copy) == "(x1 * x1)" From aa96cb033abc84d19cb2ec233b686d7de32e1ac1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 22:52:23 -0400 Subject: [PATCH 14/20] Remove unnecessary conversion function --- src/InterfaceSymbolicUtils.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 1d9db543..93912a43 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -77,11 +77,6 @@ function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union return Node(CONST_TYPE(x)) end -function Base.convert(::typeof(Node), x::Symbol, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) - varMap === nothing && return Node(String(x)) - return Node(String(x), varMap) -end - function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) if !SymbolicUtils.istree(expr) varMap === nothing && return Node(String(expr.name)) From 22caeab9847963d420fe6a379887aa532ae73e36 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 22:55:14 -0400 Subject: [PATCH 15/20] Remove unused multiply_powers function --- src/InterfaceSymbolicUtils.jl | 69 ----------------------------------- 1 file changed, 69 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 93912a43..0f3f2f81 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -83,12 +83,6 @@ function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Opt return Node(String(expr.name), varMap) end - # First, we remove integer powers: - y, good_return = multiply_powers(expr) - if good_return - expr = y - end - op = convert_to_function(SymbolicUtils.operation(expr)) args = SymbolicUtils.arguments(expr) @@ -138,68 +132,5 @@ function symbolic_to_node(eqn::T, options::Options; convert(Node, eqn, options; varMap=varMap) end -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} - return eqn, true -end - -function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} - args = SymbolicUtils.arguments(eqn) - nargs = length(args) - if nargs == 1 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - return op(l), true - elseif op == ^ - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - n = args[2] - if typeof(n) <: Int - if n == 1 - return l, true - elseif n == -1 - return 1.0 / l, true - elseif n > 1 - return reduce(*, [l for i=1:n]), true - elseif n < -1 - return reduce(/, vcat([1], [l for i=1:abs(n)])), true - else - return 1.0, true - end - else - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - return l ^ r, true - end - elseif nargs == 2 - l, complete = multiply_powers(args[1]) - @return_on_false complete eqn - @return_on_false isgood(l) eqn - r, complete2 = multiply_powers(args[2]) - @return_on_false complete2 eqn - @return_on_false isgood(r) eqn - return op(l, r), true - else - # return mapreduce(multiply_powers, op, args) - # ## reduce(op, map(multiply_powers, args)) - out = map(multiply_powers, args) #vector of tuples - for i=1:size(out, 1) - @return_on_false out[i][2] eqn - @return_on_false isgood(out[i][1]) eqn - end - cumulator = out[1][1] - for i=2:size(out, 1) - cumulator = op(cumulator, out[i][1]) - @return_on_false isgood(cumulator) eqn - end - return cumulator, true - end -end - -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} - op = SymbolicUtils.operation(eqn) - return multiply_powers(eqn, op) -end end From ae5115070e7a7107fce561d5698eb2a4221ca573 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 23:10:59 -0400 Subject: [PATCH 16/20] Add more complex tests of simplification --- test/test_simplification.jl | 33 +++++++++++++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 2ed6656a..979d0ebd 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,5 +1,7 @@ include("test_params.jl") using SymbolicRegression, Test +import SymbolicUtils: simplify, Symbolic +import Random: MersenneTwister binary_operators = (+, -, /, *) @@ -10,8 +12,6 @@ options = Options(; default_params..., binary_operators=binary_operators) tree = Node("x1") + Node("x1") # Should simplify to 2*x1: -import SymbolicUtils: simplify, Symbolic - eqn = convert(Symbolic, tree, options) eqn2 = simplify(eqn) # Should correctly simplify to 2 x1: @@ -32,3 +32,32 @@ eqn = convert(Symbolic, tree, options) # Test converting back: tree_copy = convert(Node, eqn, options) @test repr(tree_copy) == "(x1 * x1)" + +# Let's test a much more complex function, +# with custom operators, and unary operators: +x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") +pow_abs(x, y) = abs(x) ^ y +custom_cos(x) = cos(x)^2 + +# Define for Node (usually these are done internally to Options) +pow_abs(l::Node, r::Node)::Node = (l.constant && r.constant) ? Node(pow_abs(l.val, r.val)::AbstractFloat) : Node(5, l, r) +pow_abs(l::Node, r::AbstractFloat)::Node = l.constant ? Node(pow_abs(l.val, r)::AbstractFloat) : Node(5, l, r) +pow_abs(l::AbstractFloat, r::Node)::Node = r.constant ? Node(pow_abs(l, r.val)::AbstractFloat) : Node(5, l, r) +custom_cos(x::Node)::Node = x.constant ? Node(custom_cos(x.val)::AbstractFloat) : Node(1, x) + +options = Options(; + binary_operators=(+, *, -, /, pow_abs), + unary_operators=(custom_cos, exp, sin), +) +tree = (((x2 + x2) * ((-0.5982493 / pow_abs(x1, x2)) / -0.54734415)) + (sin(custom_cos(sin(1.2926733 - 1.6606787) / sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426)) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2))) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262)))) +# We use `index_functions` to avoid converting the custom operators into the primitives. +eqn = convert(Symbolic, tree, options, index_functions=true) + +tree_copy = convert(Node, eqn, options) +# Too difficult to check the representation, so we check by evaluation: +N = 100 +X = rand(MersenneTwister(0), 3, N) +output1, flag1 = evalTreeArray(tree, X, options) +output2, flag2 = evalTreeArray(tree_copy, X, options) + +isapprox(output1, output2, atol=1e-5 * N) \ No newline at end of file From 67f617d176b9c2e71bf3b8292ec937e49977b59e Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 23:33:15 -0400 Subject: [PATCH 17/20] Bring back multiply_powers in SymbolicUtils conversion --- src/InterfaceSymbolicUtils.jl | 69 +++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 0f3f2f81..93912a43 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -83,6 +83,12 @@ function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Opt return Node(String(expr.name), varMap) end + # First, we remove integer powers: + y, good_return = multiply_powers(expr) + if good_return + expr = y + end + op = convert_to_function(SymbolicUtils.operation(expr)) args = SymbolicUtils.arguments(expr) @@ -132,5 +138,68 @@ function symbolic_to_node(eqn::T, options::Options; convert(Node, eqn, options; varMap=varMap) end +function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} + return eqn, true +end + +function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} + args = SymbolicUtils.arguments(eqn) + nargs = length(args) + if nargs == 1 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + return op(l), true + elseif op == ^ + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + n = args[2] + if typeof(n) <: Int + if n == 1 + return l, true + elseif n == -1 + return 1.0 / l, true + elseif n > 1 + return reduce(*, [l for i=1:n]), true + elseif n < -1 + return reduce(/, vcat([1], [l for i=1:abs(n)])), true + else + return 1.0, true + end + else + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + return l ^ r, true + end + elseif nargs == 2 + l, complete = multiply_powers(args[1]) + @return_on_false complete eqn + @return_on_false isgood(l) eqn + r, complete2 = multiply_powers(args[2]) + @return_on_false complete2 eqn + @return_on_false isgood(r) eqn + return op(l, r), true + else + # return mapreduce(multiply_powers, op, args) + # ## reduce(op, map(multiply_powers, args)) + out = map(multiply_powers, args) #vector of tuples + for i=1:size(out, 1) + @return_on_false out[i][2] eqn + @return_on_false isgood(out[i][1]) eqn + end + cumulator = out[1][1] + for i=2:size(out, 1) + cumulator = op(cumulator, out[i][1]) + @return_on_false isgood(cumulator) eqn + end + return cumulator, true + end +end + +function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} + op = SymbolicUtils.operation(eqn) + return multiply_powers(eqn, op) +end end From 1a2536c168ff7e3d37615a5bfae85128827b56a8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 1 May 2022 23:47:16 -0400 Subject: [PATCH 18/20] Fix multiply_powers for new SymbolicUtils API --- src/InterfaceSymbolicUtils.jl | 23 +++++++++++++++-------- test/test_simplification.jl | 17 ++++++++++++----- 2 files changed, 27 insertions(+), 13 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 93912a43..67791c8b 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -138,11 +138,23 @@ function symbolic_to_node(eqn::T, options::Options; convert(Node, eqn, options; varMap=varMap) end -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{<:Number,SymbolicUtils.Sym{<:Number}}} - return eqn, true +# function Base.convert(::typeof(Node), x::Number, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) +# function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) + +function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} + return eqn, true +end + +function multiply_powers(eqn::SymbolicUtils.Symbolic)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} + if !SymbolicUtils.istree(eqn) + return eqn, true + end + op = SymbolicUtils.operation(eqn) + return multiply_powers(eqn, op) end -function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F,T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} + +function multiply_powers(eqn::SymbolicUtils.Symbolic, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {F} args = SymbolicUtils.arguments(eqn) nargs = length(args) if nargs == 1 @@ -197,9 +209,4 @@ function multiply_powers(eqn::T, op::F)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where end end -function multiply_powers(eqn::T)::Tuple{SYMBOLIC_UTILS_TYPES,Bool} where {T<:Union{SymbolicUtils.Term{<:Number},SymbolicUtils.Symbolic{<:Number}}} - op = SymbolicUtils.operation(eqn) - return multiply_powers(eqn, op) -end - end diff --git a/test/test_simplification.jl b/test/test_simplification.jl index 979d0ebd..df1eca04 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -7,7 +7,7 @@ binary_operators = (+, -, /, *) index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] -options = Options(; default_params..., binary_operators=binary_operators) +options = Options(binary_operators=binary_operators) tree = Node("x1") + Node("x1") @@ -15,16 +15,18 @@ tree = Node("x1") + Node("x1") eqn = convert(Symbolic, tree, options) eqn2 = simplify(eqn) # Should correctly simplify to 2 x1: +# (although it might use 2(x1^1)) @test occursin("2", "$(repr(eqn2)[1])") -# Let's convert back: +# Let's convert back the simplified version. +# This should remove the ^ operator: tree = convert(Node, eqn2, options) # Make sure one of the nodes is now 2.0: @test (tree.l.constant ? tree.l : tree.r).val == 2 # Make sure the other node is x1: @test (!tree.l.constant ? tree.l : tree.r).feature == 1 -# Finally, let's try simplifying a product, and ensure +# Finally, let's try converting a product, and ensure # that SymbolicUtils does not convert it to a power: tree = Node("x1") * Node("x1") eqn = convert(Symbolic, tree, options) @@ -54,10 +56,15 @@ tree = (((x2 + x2) * ((-0.5982493 / pow_abs(x1, x2)) / -0.54734415)) + (sin(cust eqn = convert(Symbolic, tree, options, index_functions=true) tree_copy = convert(Node, eqn, options) +tree_copy2 = convert(Node, simplify(eqn), options) # Too difficult to check the representation, so we check by evaluation: N = 100 -X = rand(MersenneTwister(0), 3, N) +X = rand(MersenneTwister(0), 3, N) .+ 0.1 output1, flag1 = evalTreeArray(tree, X, options) output2, flag2 = evalTreeArray(tree_copy, X, options) +output3, flag3 = evalTreeArray(tree_copy2, X, options) -isapprox(output1, output2, atol=1e-5 * N) \ No newline at end of file +@test isapprox(output1, output2, atol=1e-4 * sqrt(N)) +# Simplified equation may give a different answer due to rounding errors, +# so we weaken the requirement: +@test isapprox(output1, output3, atol=1e-2 * sqrt(N)) \ No newline at end of file From 6579b8777b738d1fa6b83430f70f061d79736ab1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 2 May 2022 00:28:35 -0400 Subject: [PATCH 19/20] Clean up error in using convert_to_function --- src/InterfaceSymbolicUtils.jl | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/InterfaceSymbolicUtils.jl b/src/InterfaceSymbolicUtils.jl index 67791c8b..a1580282 100644 --- a/src/InterfaceSymbolicUtils.jl +++ b/src/InterfaceSymbolicUtils.jl @@ -31,9 +31,8 @@ function parse_tree_to_eqs(tree::Node, options::Options, index_functions::Bool=f return subs_bad(op(map(x->parse_tree_to_eqs(x, options, index_functions), children)...)) end -## Convert symbolic function back -convert_to_function(x, args...) = x - +# For operators which are indexed, we need to convert them back +# using the string: function convert_to_function(x::SymbolicUtils.Sym{SymbolicUtils.FnType{T, Number}}, options::Options) where {T <: Tuple} degree = length(T.types) if degree == 1 @@ -47,6 +46,10 @@ function convert_to_function(x::SymbolicUtils.Sym{SymbolicUtils.FnType{T, Number end end +# For normal operators, simply return the function itself: +convert_to_function(x, options::Options) = x + + # Split equation function split_eq(op, args, options::Options; varMap::Union{Array{String, 1}, Nothing}=nothing) !(op ∈ (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!")) @@ -89,7 +92,7 @@ function Base.convert(::typeof(Node), expr::SymbolicUtils.Symbolic, options::Opt expr = y end - op = convert_to_function(SymbolicUtils.operation(expr)) + op = convert_to_function(SymbolicUtils.operation(expr), options) args = SymbolicUtils.arguments(expr) length(args) > 2 && return split_eq(op, args, options; varMap=varMap) From 1e57fbb9bb834976405255190471e15e3f4693b1 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Mon, 2 May 2022 00:28:51 -0400 Subject: [PATCH 20/20] Bump overall version to 0.9.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 440bebff..bcf074c9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "0.8.7" +version = "0.9.0" [deps] Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"