Skip to content

Commit 4ccbcaf

Browse files
authored
Merge 455e80f into cab1143
2 parents cab1143 + 455e80f commit 4ccbcaf

29 files changed

+860
-285
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DynamicExpressions"
22
uuid = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
33
authors = ["MilesCranmer <miles.cranmer@gmail.com>"]
4-
version = "0.18.5"
4+
version = "0.18.6"
55

66
[deps]
77
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
@@ -14,7 +14,6 @@ PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
1414
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1515
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1616
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"
17-
TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe"
1817

1918
[weakdeps]
2019
Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e"
@@ -43,7 +42,6 @@ PackageExtensionCompat = "1"
4342
PrecompileTools = "1"
4443
Reexport = "1"
4544
SymbolicUtils = "0.19, ^1.0.5, 2"
46-
TestItems = "0.1"
4745
Zygote = "0.6"
4846
julia = "1.6"
4947

benchmark/benchmarks.jl

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
using DynamicExpressions, BenchmarkTools, Random
22

33
# Trigger extensions:
4-
using LoopVectorization
5-
using Bumper
6-
using StrideArrays
7-
using Zygote
4+
using LoopVectorization, Bumper, StrideArrays, Zygote
85

96
if PACKAGE_VERSION < v"0.14.0"
107
@eval using DynamicExpressions: Node as GraphNode
@@ -18,6 +15,14 @@ else
1815
@eval using DynamicExpressions.NodeUtilsModule: is_constant
1916
end
2017

18+
if PACKAGE_VERSION < v"0.18.6"
19+
@eval using DynamicExpressions:
20+
index_constants as index_constant_nodes,
21+
count_constants as count_constant_nodes,
22+
get_constants as get_scalar_constants,
23+
set_constants! as set_scalar_constants!
24+
end
25+
2126
include("../test/tree_gen_utils.jl")
2227

2328
const SUITE = BenchmarkGroup()
@@ -113,15 +118,16 @@ end
113118
PACKAGE_VERSION < v"0.14.0" && return :(copy_node(t; preserve_sharing=preserve_sharing))
114119
return :(copy_node(t)) # Assume type used to infer sharing
115120
end
116-
@generated function get_set_constants!(tree::N) where {T,N<:AbstractExpressionNode{T}}
117-
if !(@isdefined set_constants!)
118-
return :(set_constants(tree, get_constants(tree)))
119-
elseif hasmethod(set_constants!, Tuple{N, Vector{T}})
120-
return :(set_constants!(tree, get_constants(tree)))
121+
@generated function get_set_constants!(tree::N) where {N}
122+
T = eltype(N)
123+
if !(@isdefined set_scalar_constants!)
124+
return :(set_scalar_constants(tree, get_scalar_constants(tree)))
125+
elseif hasmethod(set_scalar_constants!, Tuple{N, Vector{T}})
126+
return :(set_scalar_constants!(tree, get_scalar_constants(tree)))
121127
else
122128
return quote
123-
let (x, refs) = get_constants(tree)
124-
set_constants!(tree, x, refs)
129+
let (x, refs) = get_scalar_constants(tree)
130+
set_scalar_constants!(tree, x, refs)
125131
end
126132
end
127133
end
@@ -141,12 +147,12 @@ function benchmark_utilities()
141147
:combine_operators,
142148
:count_nodes,
143149
:count_depth,
144-
:count_constants,
150+
:count_constant_nodes,
145151
:has_constants,
146152
:has_operators,
147153
:is_constant,
148154
:get_set_constants!,
149-
:index_constants,
155+
:index_constant_nodes,
150156
:string_tree,
151157
:hash,
152158
)
@@ -157,9 +163,9 @@ function benchmark_utilities()
157163
[
158164
:simplify_tree,
159165
:count_nodes,
160-
:count_constants,
166+
:count_constant_nodes,
161167
:get_set_constants!,
162-
:index_constants,
168+
:index_constant_nodes,
163169
:string_tree,
164170
],
165171
)
@@ -207,7 +213,8 @@ function benchmark_utilities()
207213
setup=(
208214
ntrees=100;
209215
n=20;
210-
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32)) for _ in 1:ntrees]
216+
rng=Random.MersenneTwister(0);
217+
trees=[$preprocess(gen_random_tree_fixed_size(n, $operators, 5, Float32, Node, rng)) for _ in 1:ntrees]
211218
)
212219
)
213220
#! format: on
@@ -216,6 +223,37 @@ function benchmark_utilities()
216223
end
217224
end
218225

226+
# Additional methods
227+
@static if PACKAGE_VERSION >= v"0.18.0"
228+
suite["get_set_constants_parametric"] = @benchmarkable(
229+
[get_set_constants!(ex) for ex in exs],
230+
seconds = 10.0,
231+
setup = (
232+
operators = $operators;
233+
ntrees = 100;
234+
n = 20;
235+
n_features = 5;
236+
n_params = 3;
237+
n_param_classes = 10;
238+
rng = Random.MersenneTwister(0);
239+
exs = [
240+
let tree = gen_random_tree_fixed_size(
241+
n, operators, n_features, Float32, ParametricNode, rng
242+
)
243+
ex = ParametricExpression(
244+
tree;
245+
operators,
246+
variable_names=map(i -> "x$i", 1:n_features),
247+
parameters=randn(rng, Float32, n_params, n_param_classes),
248+
parameter_names=map(i -> "p$i", 1:n_params),
249+
)
250+
ex
251+
end for _ in 1:ntrees
252+
]
253+
)
254+
)
255+
end
256+
219257
return suite
220258
end
221259

ext/DynamicExpressionsOptimExt.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using DynamicExpressions:
55
AbstractExpressionNode,
66
filter_map,
77
eval_tree_array,
8-
get_constants,
9-
set_constants!
8+
get_scalar_constants,
9+
set_scalar_constants!,
10+
get_number_type
1011
using Compat: @inline
1112

1213
import Optim: Optim, OptimizationResults, NLSolversBase
@@ -44,9 +45,14 @@ function wrap_func(
4445
function wrapped_f(args::Vararg{Any,M}) where {M}
4546
first_args = args[begin:(end - 1)]
4647
x = args[end]
47-
set_constants!(tree, x, refs)
48+
set_scalar_constants!(tree, x, refs)
4849
return @inline(f(first_args..., tree))
4950
end
51+
# without first args, it looks like this
52+
# function wrapped_f(x)
53+
# set_scalar_constants!(tree, x, refs)
54+
# return @inline(f(tree))
55+
# end
5056
return wrapped_f
5157
end
5258
function wrap_func(
@@ -100,7 +106,8 @@ function Optim.optimize(
100106
if make_copy
101107
tree = copy(tree)
102108
end
103-
x0, refs = get_constants(tree)
109+
110+
x0, refs = get_scalar_constants(tree)
104111
if !isnothing(h!)
105112
throw(
106113
ArgumentError(
@@ -117,7 +124,7 @@ function Optim.optimize(
117124
)
118125
end
119126
minimizer = Optim.minimizer(base_res)
120-
set_constants!(tree, minimizer, refs)
127+
set_scalar_constants!(tree, minimizer, refs)
121128
return ExpressionOptimizationResults(base_res, tree)
122129
end
123130

src/DynamicExpressions.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DispatchDoctor: @stable, @unstable
44

55
@stable default_mode = "disable" begin
66
include("Utils.jl")
7+
include("ValueInterface.jl")
78
include("ExtensionInterface.jl")
89
include("OperatorEnum.jl")
910
include("Node.jl")
@@ -25,6 +26,13 @@ import PackageExtensionCompat: @require_extensions
2526
import Reexport: @reexport
2627
macro ignore(args...) end
2728

29+
import .ValueInterfaceModule:
30+
is_valid,
31+
is_valid_array,
32+
get_number_type,
33+
pack_scalar_constants!,
34+
unpack_scalar_constants,
35+
ValueInterface
2836
@reexport import .NodeModule:
2937
AbstractNode,
3038
AbstractExpressionNode,
@@ -47,14 +55,15 @@ import .NodeModule:
4755
branch_equal
4856
@reexport import .NodeUtilsModule:
4957
count_nodes,
50-
count_constants,
58+
count_constant_nodes,
5159
count_depth,
5260
NodeIndex,
53-
index_constants,
61+
index_constant_nodes,
5462
has_operators,
5563
has_constants,
56-
get_constants,
57-
set_constants!
64+
count_scalar_constants,
65+
get_scalar_constants,
66+
set_scalar_constants!
5867
@reexport import .StringsModule: string_tree, print_tree
5968
@reexport import .OperatorEnumModule: AbstractOperatorEnum
6069
@reexport import .OperatorEnumConstructionModule:

0 commit comments

Comments
 (0)