Skip to content

Commit

Permalink
Merge 082d5b9 into 2fa8ecd
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Aug 2, 2023
2 parents 2fa8ecd + 082d5b9 commit 5f18eb4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicExpressions"
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
version = "0.12.0"
version = "0.12.1"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand Down
2 changes: 1 addition & 1 deletion src/Equation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ function string_constant(val, bracketed::Bool)
end

function string_variable(feature, variable_names)
if variable_names === nothing
if variable_names === nothing || feature > lastindex(variable_names)
return "x" * string(feature)
else
return variable_names[feature]
Expand Down
11 changes: 8 additions & 3 deletions src/OperatorEnumConstruction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,22 @@ const ALREADY_DEFINED_UNARY_OPERATORS = (;
const ALREADY_DEFINED_BINARY_OPERATORS = (;
operator_enum=Dict{Function,Bool}(), generic_operator_enum=Dict{Function,Bool}()
)
const LATEST_VARIABLE_NAMES = Ref{Vector{String}}(String[])

function Base.show(io::IO, tree::Node)
latest_operators_type = LATEST_OPERATORS_TYPE.x
if latest_operators_type == IsNothing
return print(io, string_tree(tree))
return print(io, string_tree(tree; variable_names=LATEST_VARIABLE_NAMES.x))
elseif latest_operators_type == IsOperatorEnum
latest_operators = LATEST_OPERATORS.x::OperatorEnum
return print(io, string_tree(tree, latest_operators))
return print(
io, string_tree(tree, latest_operators; variable_names=LATEST_VARIABLE_NAMES.x)
)
else
latest_operators = LATEST_OPERATORS.x::GenericOperatorEnum
return print(io, string_tree(tree, latest_operators))
return print(
io, string_tree(tree, latest_operators; variable_names=LATEST_VARIABLE_NAMES.x)
)
end
end
function (tree::Node)(X; kws...)
Expand Down
17 changes: 17 additions & 0 deletions test/test_print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,20 @@ end
f_constant(val::Float64, args...) = string(round(val; digits=4))
@test string_tree(tree, operators; f_constant=f_constant) == "((x1 * x1) + 3.1416)"
end

@testset "Test variable names" begin
operators = OperatorEnum(; binary_operators=[+, *, /, -], unary_operators=[cos, sin])
@extend_operators operators
x1, x2, x3 = [Node(Float64; feature=i) for i in 1:3]
DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x = [
"k1", "k2", "k3"
]
tree = x1 * x2 + x3
@test string(tree) == "((k1 * k2) + k3)"
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
@test string(tree) == "((x1 * x2) + x3)"
# Check if we can pass the wrong number of variable names:
DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x = ["k1"]
@test string(tree) == "((k1 * x2) + x3)"
empty!(DynamicExpressions.OperatorEnumConstructionModule.LATEST_VARIABLE_NAMES.x)
end

0 comments on commit 5f18eb4

Please sign in to comment.