-
Notifications
You must be signed in to change notification settings - Fork 70
/
ConstantOptimization.jl
95 lines (86 loc) · 3.41 KB
/
ConstantOptimization.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
module ConstantOptimizationModule
using LineSearches: LineSearches
using Optim: Optim
import DynamicExpressions: Node, count_constants
import ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE
import ..UtilsModule: get_birth_order
import ..LossFunctionsModule: score_func, eval_loss
import ..PopMemberModule: PopMember
# Proxy function for optimization
function opt_func(
x, dataset::Dataset{T,L}, tree, constant_nodes, options
) where {T<:DATA_TYPE,L<:LOSS_TYPE}
_set_constants!(x, constant_nodes)
# TODO(mcranmer): This should use score_func batching.
loss = eval_loss(tree, dataset, options)
return loss::L
end
function _set_constants!(x::AbstractArray{T}, constant_nodes) where {T}
for (xi, node) in zip(x, constant_nodes)
node.val::T = xi
end
return nothing
end
# Use Nelder-Mead to optimize the constants in an equation
function optimize_constants(
dataset::Dataset{T,L}, member::PopMember{T,L}, options::Options
)::Tuple{PopMember{T,L},Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE}
nconst = count_constants(member.tree)
nconst == 0 && return (member, 0.0)
if T <: Complex
# TODO: Make this more general. Also, do we even need Newton here at all??
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options
)
elseif nconst == 1
algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options
)
else
if options.optimizer_algorithm == "NelderMead"
algorithm = Optim.NelderMead(; linesearch=LineSearches.BackTracking())
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options
)
elseif options.optimizer_algorithm == "BFGS"
algorithm = Optim.BFGS(; linesearch=LineSearches.BackTracking())#order=3))
return _optimize_constants(
dataset, member, options, algorithm, options.optimizer_options
)
else
error("Optimization function not implemented.")
end
end
end
function _optimize_constants(
dataset, member::PopMember{T,L}, options, algorithm, optimizer_options
)::Tuple{PopMember{T,L},Float64} where {T,L}
tree = member.tree
constant_nodes = filter(t -> t.degree == 0 && t.constant, tree)
x0 = [n.val::T for n in constant_nodes]
f(x) = opt_func(x, dataset, tree, constant_nodes, options)
result = Optim.optimize(f, x0, algorithm, optimizer_options)
num_evals = 0.0
num_evals += result.f_calls
# Try other initial conditions:
for i in 1:(options.optimizer_nrestarts)
new_start = x0 .* (T(1) .+ T(1//2) * randn(T, size(x0, 1)))
tmpresult = Optim.optimize(f, new_start, algorithm, optimizer_options)
num_evals += tmpresult.f_calls
if tmpresult.minimum < result.minimum
result = tmpresult
end
end
if Optim.converged(result)
_set_constants!(result.minimizer, constant_nodes)
member.score, member.loss = score_func(dataset, tree, options)
num_evals += 1
member.birth = get_birth_order(; deterministic=options.deterministic)
else
_set_constants!(x0, constant_nodes)
end
return member, num_evals
end
end