-
Notifications
You must be signed in to change notification settings - Fork 53
/
optimize.jl
97 lines (89 loc) · 4.01 KB
/
optimize.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
get_params_kwargs(::GPE; kwargs...) = delete!(Dict(kwargs), :lik)
get_params_kwargs(::GPA; kwargs...) = delete!(Dict(kwargs), :noise)
"""
optimize!(gp::GPBase, args...; kwargs...)
Optimise the hyperparameters of Gaussian process `gp` based on type II maximum likelihood estimation. This function performs gradient based optimisation using the Optim pacakge to which the user is referred to for further details.
# Keyword arguments:
* `domean::Bool`: Mean function hyperparameters should be optmized
* `kern::Bool`: Kernel function hyperparameters should be optmized
* `noise::Bool`: Observation noise hyperparameter should be optimized (GPE only)
* `lik::Bool`: Likelihood hyperparameters should be optimized (GPA only)
* `meanbounds`: [lowerbounds, upperbounds] for the mean hyperparameters
* `kernbounds`: [lowerbounds, upperbounds] for the kernel hyperparameters
* `noisebounds`: [lowerbound, upperbound] for the noise hyperparameter
* `args/kwargs`: Arguments and keyword arguments for the optimize function from the Optim package https://julianlsolvers.github.io/Optim.jl/stable/#user/config/
"""
function optimize!(gp::GPBase, args...; method = LBFGS(), domean::Bool = true, kern::Bool = true,
noise::Bool = true, lik::Bool = true,
meanbounds = nothing, kernbounds = nothing,
noisebounds = nothing, likbounds = nothing, kwargs...)
params_kwargs = get_params_kwargs(gp; domean=domean, kern=kern, noise=noise, lik=lik)
# println(params_kwargs)
func = get_optim_target(gp; params_kwargs...)
init = get_params(gp; params_kwargs...) # Initial hyperparameter values
if meanbounds == kernbounds == noisebounds == likbounds == nothing
results = optimize(func, init, args...; method=method, kwargs...) # Run optimizer
else
lb, ub = bounds(gp, noisebounds, meanbounds, kernbounds, likbounds;
domean = domean, kern = kern, noise = noise, lik = lik)
results = optimize(func.f, func.df, lb, ub, init, Fminbox(method), args...)
end
set_params!(gp, Optim.minimizer(results); params_kwargs...)
update_target!(gp)
return results
end
function get_optim_target(gp::GPBase; params_kwargs...)
function ltarget(hyp::AbstractVector)
prev = get_params(gp; params_kwargs...)
try
set_params!(gp, hyp; params_kwargs...)
update_target!(gp)
return -gp.target
catch err
# reset parameters to remove any NaNs
set_params!(gp, prev; params_kwargs...)
if !all(isfinite.(hyp))
println(err)
return Inf
elseif isa(err, ArgumentError)
println(err)
return Inf
elseif isa(err, LinearAlgebra.PosDefException)
println(err)
return Inf
else
throw(err)
end
end
end
function ltarget_and_dltarget!(grad::AbstractVector, hyp::AbstractVector)
prev = get_params(gp; params_kwargs...)
try
set_params!(gp, hyp; params_kwargs...)
update_target_and_dtarget!(gp; params_kwargs...)
grad[:] = -gp.dtarget
return -gp.target
catch err
# reset parameters to remove any NaNs
set_params!(gp, prev; params_kwargs...)
if !all(isfinite.(hyp))
println(err)
return Inf
elseif isa(err, ArgumentError)
println(err)
return Inf
elseif isa(err, LinearAlgebra.PosDefException)
println(err)
return Inf
else
throw(err)
end
end
end
function dltarget!(grad::AbstractVector, hyp::AbstractVector)
ltarget_and_dltarget!(grad::AbstractVector, hyp::AbstractVector)
end
xinit = get_params(gp; params_kwargs...)
func = OnceDifferentiable(ltarget, dltarget!, ltarget_and_dltarget!, xinit)
return func
end