Skip to content

Commit

Permalink
Merge 2e9b45d into abeb7fa
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Jun 19, 2023
2 parents abeb7fa + 2e9b45d commit db532a7
Show file tree
Hide file tree
Showing 7 changed files with 106 additions and 53 deletions.
18 changes: 15 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,19 @@ jobs:
julia --color=yes --project=. coverage.jl
shell: bash
- name: Coveralls
uses: coverallsapp/github-action@master
uses: coverallsapp/github-action@v2
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
path-to-lcov: coverage-lcov.info
parallel: true
path-to-lcov: lcov.info
flag-name: julia-${{ matrix.julia-version }}-${{ matrix.os }}-${{ github.event_name }}

coveralls:
name: Indicate completion to coveralls
runs-on: ubuntu-latest
needs: test
steps:
- name: Finish
uses: coverallsapp/github-action@v2
with:
parallel-finished: true

5 changes: 3 additions & 2 deletions coverage.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
using Coverage
# process '*.cov' files
coverage = process_folder() # defaults to src/; alternatively, supply the folder name as argument
push!(coverage, process_folder("ext")...)

LCOV.writefile("coverage-lcov.info", coverage)
LCOV.writefile("lcov.info", coverage)

# process '*.info' files
coverage = merge_coverage_counts(
coverage,
filter!(
let prefixes = (joinpath(pwd(), "src", ""),)
let prefixes = (joinpath(pwd(), "src", ""), joinpath(pwd(), "ext", ""))
c -> any(p -> startswith(c.filename, p), prefixes)
end,
LCOV.readfolder("test"),
Expand Down
67 changes: 41 additions & 26 deletions ext/DynamicExpressionsSymbolicUtilsExt.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
module DynamicExpressionsSymbolicUtilsExt

import Base: convert
#! format: off
if isdefined(Base, :get_extension)
using SymbolicUtils
import DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false
import DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
else
using ..SymbolicUtils
import ..DynamicExpressions.EquationModule: Node, DEFAULT_NODE_TYPE
import ..DynamicExpressions.OperatorEnumModule: AbstractOperatorEnum
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false
import ..DynamicExpressions.UtilsModule: isgood, isbad, @return_on_false, deprecate_varmap
import ..DynamicExpressions.ExtensionInterfaceModule: node_to_symbolic, symbolic_to_node
end
#! format: on

const SYMBOLIC_UTILS_TYPES = Union{<:Number,SymbolicUtils.Symbolic{<:Number}}

Expand Down Expand Up @@ -77,8 +79,11 @@ function split_eq(
op,
args,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)
variable_names = deprecate_varmap(variable_names, varMap, :split_eq)
!(op (sum, prod, +, *)) && throw(error("Unsupported operation $op in expression!"))
if Symbol(op) == Symbol(sum)
ind = findoperation(+, operators.binops)
Expand All @@ -89,8 +94,8 @@ function split_eq(
end
return Node(
ind,
convert(Node, args[1], operators; varMap=varMap),
convert(Node, op(args[2:end]...), operators; varMap=varMap),
convert(Node, args[1], operators; variable_names=variable_names),
convert(Node, op(args[2:end]...), operators; variable_names=variable_names),
)
end

Expand All @@ -105,30 +110,31 @@ function convert(
::typeof(SymbolicUtils.Symbolic),
tree::Node,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
# Deprecated:
varMap=nothing,
)
return node_to_symbolic(tree, operators; varMap=varMap, index_functions=index_functions)
variable_names = deprecate_varmap(variable_names, varMap, :convert)
return node_to_symbolic(
tree, operators; variable_names=variable_names, index_functions=index_functions
)
end

function convert(
::typeof(Node),
x::Number,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
)
function convert(::typeof(Node), x::Number, operators::AbstractOperatorEnum; kws...)
return Node(; val=DEFAULT_NODE_TYPE(x))
end

function convert(
::typeof(Node),
expr::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
)
variable_names = deprecate_varmap(variable_names, nothing, :convert)
if !SymbolicUtils.istree(expr)
varMap === nothing && return Node(String(expr.name))
return Node(String(expr.name), varMap)
variable_names === nothing && return Node(String(expr.name))
return Node(String(expr.name), variable_names)
end

# First, we remove integer powers:
Expand All @@ -140,19 +146,21 @@ function convert(
op = convert_to_function(SymbolicUtils.operation(expr), operators)
args = SymbolicUtils.arguments(expr)

length(args) > 2 && return split_eq(op, args, operators; varMap=varMap)
length(args) > 2 && return split_eq(op, args, operators; variable_names=variable_names)
ind = if length(args) == 2
findoperation(op, operators.binops)
else
findoperation(op, operators.unaops)
end

return Node(ind, map(x -> convert(Node, x, operators; varMap=varMap), args)...)
return Node(
ind, map(x -> convert(Node, x, operators; variable_names=variable_names), args)...
)
end

"""
node_to_symbolic(tree::Node, operators::AbstractOperatorEnum;
varMap::Union{Array{String, 1}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
index_functions::Bool=false)
The interface to SymbolicUtils.jl. Passing a tree to this function
Expand All @@ -162,7 +170,7 @@ will generate a symbolic equation in SymbolicUtils.jl format.
- `tree::Node`: The equation to convert.
- `operators::AbstractOperatorEnum`: OperatorEnum, which contains the operators used in the equation.
- `varMap::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: What variable names to use for
each feature. Default is [x1, x2, x3, ...].
- `index_functions::Bool=false`: Whether to generate special names for the
operators, which then allows one to convert back to a `Node` format
Expand All @@ -172,19 +180,23 @@ will generate a symbolic equation in SymbolicUtils.jl format.
function node_to_symbolic(
tree::Node,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
index_functions::Bool=false,
# Deprecated:
varMap=nothing,
)
variable_names = deprecate_varmap(variable_names, varMap, :node_to_symbolic)
expr = subs_bad(parse_tree_to_eqs(tree, operators, index_functions))
# Check for NaN and Inf
@assert isgood(expr) "The recovered equation contains NaN or Inf."
# Return if no varMap is given
varMap === nothing && return expr
# Return if no variable_names is given
variable_names === nothing && return expr
# Create a substitution tuple
subs = Dict(
[
SymbolicUtils.Sym{LiteralReal}(Symbol("x$(i)")) =>
SymbolicUtils.Sym{LiteralReal}(Symbol(varMap[i])) for i in 1:length(varMap)
SymbolicUtils.Sym{LiteralReal}(Symbol(variable_names[i])) for
i in 1:length(variable_names)
]...,
)
return substitute(expr, subs)
Expand All @@ -193,9 +205,12 @@ end
function symbolic_to_node(
eqn::SymbolicUtils.Symbolic,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated:
varMap=nothing,
)::Node
return convert(Node, eqn, operators; varMap=varMap)
variable_names = deprecate_varmap(variable_names, varMap, :symbolic_to_node)
return convert(Node, eqn, operators; variable_names=variable_names)
end

function multiply_powers(eqn::Number)::Tuple{SYMBOLIC_UTILS_TYPES,Bool}
Expand Down
54 changes: 35 additions & 19 deletions src/Equation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module EquationModule

import ..OperatorEnumModule: AbstractOperatorEnum
import ..UtilsModule: @memoize_on, @with_memoize
import ..UtilsModule: @memoize_on, @with_memoize, deprecate_varmap

const DEFAULT_NODE_TYPE = Float32

Expand Down Expand Up @@ -144,14 +144,14 @@ Create a variable node, using the format `"x1"` to mean feature 1
Node(var_string::String) = Node(; feature=parse(Int, var_string[2:end]))

"""
Node(var_string::String, varMap::Array{String, 1})
Node(var_string::String, variable_names::Array{String, 1})
Create a variable node, using a user-passed format
"""
function Node(var_string::String, varMap::Array{String,1})
function Node(var_string::String, variable_names::Array{String,1})
return Node(;
feature=[
i for (i, _variable) in enumerate(varMap) if _variable == var_string
i for (i, _variable) in enumerate(variable_names) if _variable == var_string
][1]::Int,
)
end
Expand Down Expand Up @@ -199,20 +199,23 @@ function string_op(
tree::Node,
operators::AbstractOperatorEnum;
bracketed::Bool=false,
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)::String where {F}
variable_names = deprecate_varmap(variable_names, varMap, :string_op)
op_name = get_op_name(string(op))
if op_name in ["+", "-", "*", "/", "^"]
l = string_tree(tree.l, operators; bracketed=false, varMap=varMap)
r = string_tree(tree.r, operators; bracketed=false, varMap=varMap)
l = string_tree(tree.l, operators; bracketed=false, variable_names=variable_names)
r = string_tree(tree.r, operators; bracketed=false, variable_names=variable_names)
if bracketed
return "$l $op_name $r"
else
return "($l $op_name $r)"
end
else
l = string_tree(tree.l, operators; bracketed=true, varMap=varMap)
r = string_tree(tree.r, operators; bracketed=true, varMap=varMap)
l = string_tree(tree.l, operators; bracketed=true, variable_names=variable_names)
r = string_tree(tree.r, operators; bracketed=true, variable_names=variable_names)
return "$op_name($l, $r)"
end
end
Expand All @@ -224,31 +227,38 @@ Convert an equation to a string.
# Arguments
- `varMap::Union{Array{String, 1}, Nothing}=nothing`: what variables
- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables
to print for each feature.
"""
function string_tree(
tree::Node{T},
operators::AbstractOperatorEnum;
bracketed::Bool=false,
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)::String where {T}
variable_names = deprecate_varmap(variable_names, varMap, :string_tree)
if tree.degree == 0
if tree.constant
return string_constant(tree.val::T; bracketed=bracketed)
else
if varMap === nothing
if variable_names === nothing
return "x$(tree.feature)"
else
return varMap[tree.feature]
return variable_names[tree.feature]
end
end
elseif tree.degree == 1
op_name = get_op_name(string(operators.unaops[tree.op]))
return "$(op_name)($(string_tree(tree.l, operators, bracketed=true, varMap=varMap)))"
return "$(op_name)($(string_tree(tree.l, operators, bracketed=true, variable_names=variable_names)))"
else
return string_op(
operators.binops[tree.op], tree, operators; bracketed=bracketed, varMap=varMap
operators.binops[tree.op],
tree,
operators;
bracketed=bracketed,
variable_names=variable_names,
)
end
end
Expand All @@ -267,17 +277,23 @@ function print_tree(
io::IO,
tree::Node,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)
return println(io, string_tree(tree, operators; varMap=varMap))
variable_names = deprecate_varmap(variable_names, varMap, :print_tree)
return println(io, string_tree(tree, operators; variable_names=variable_names))
end

function print_tree(
tree::Node,
operators::AbstractOperatorEnum;
varMap::Union{Array{String,1},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
# Deprecated
varMap=nothing,
)
return println(string_tree(tree, operators; varMap=varMap))
variable_names = deprecate_varmap(variable_names, varMap, :print_tree)
return println(string_tree(tree, operators; variable_names=variable_names))
end

end
9 changes: 9 additions & 0 deletions src/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,4 +157,13 @@ end

@inline fill_similar(value, array, args...) = fill!(similar(array, args...), value)

function deprecate_varmap(variable_names, varMap, func_name)
if varMap !== nothing
Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name)
@assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`"
variable_names = varMap
end
return variable_names
end

end
2 changes: 1 addition & 1 deletion test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ true_s = "((sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0)"

# TODO: Next, we test that custom varMaps work:

s = string_tree(tree, operators; varMap=["v1", "v2", "v3"])
s = string_tree(tree, operators; variable_names=["v1", "v2", "v3"])
true_s = "((sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0)"
@test s == true_s

Expand Down
4 changes: 2 additions & 2 deletions test/test_symbolic_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ operators = OperatorEnum(;
)
tree = Node(5, (Node(; val=3.0) * Node(1, Node("x1")))^2.0, Node(; val=-1.2))

eqn = node_to_symbolic(tree, operators; varMap=["energy"], index_functions=true)
eqn = node_to_symbolic(tree, operators; variable_names=["energy"], index_functions=true)
@test string(eqn) == "greater(safe_pow(3.0_inv(energy), 2.0), -1.2)"

tree2 = symbolic_to_node(eqn, operators; varMap=["energy"])
tree2 = symbolic_to_node(eqn, operators; variable_names=["energy"])
@test string_tree(tree, operators) == string_tree(tree2, operators)

0 comments on commit db532a7

Please sign in to comment.