Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move cache variables out into own types.
- Loading branch information
Showing
5 changed files
with
193 additions
and
188 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,62 +1,56 @@ | ||
function _unchecked_value!(obj, x) | ||
obj.f_calls .+= 1 | ||
copy!(obj.last_x_f, x) | ||
obj.f_x = obj.f(real_to_complex(obj, x)) | ||
end | ||
function value(obj, x) | ||
if x != obj.last_x_f | ||
obj.f_calls .+= 1 | ||
return obj.f(real_to_complex(obj,x)) | ||
function value(cache, obj::AbstractObjective, x) | ||
if x != cache.last_x_f | ||
cache.f_calls .+= 1 | ||
return obj.f(real_to_complex(cache, x)) | ||
end | ||
obj.f_x | ||
cache.f_x | ||
end | ||
function value!(obj, x) | ||
if x != obj.last_x_f | ||
_unchecked_value!(obj, x) | ||
function value!(cache, obj::AbstractObjective, x) | ||
if x != cache.last_x_f | ||
cache.f_calls .+= 1 | ||
copy!(cache.last_x_f, x) | ||
cache.f_x = obj.f(real_to_complex(cache, x)) | ||
end | ||
obj.f_x | ||
cache.f_x | ||
end | ||
|
||
|
||
function _unchecked_gradient!(obj, x) | ||
obj.g_calls .+= 1 | ||
copy!(obj.last_x_g, x) | ||
obj.g!(real_to_complex(obj, obj.g), real_to_complex(obj, x)) | ||
end | ||
function gradient!(obj::AbstractObjective, x) | ||
if x != obj.last_x_g | ||
_unchecked_gradient!(obj, x) | ||
function gradient!(cache, obj::AbstractObjective, x) | ||
if x != cache.last_x_g | ||
cache.g_calls .+= 1 | ||
copy!(cache.last_x_g, x) | ||
obj.g!(real_to_complex(cache, cache.g), real_to_complex(cache, x)) | ||
end | ||
end | ||
|
||
function value_gradient!(obj::AbstractObjective, x) | ||
if x != obj.last_x_f && x != obj.last_x_g | ||
obj.f_calls .+= 1 | ||
obj.g_calls .+= 1 | ||
copy!(obj.last_x_f, x) | ||
copy!(obj.last_x_g, x) | ||
obj.f_x = obj.fg!(real_to_complex(obj, obj.g), real_to_complex(obj, x)) | ||
elseif x != obj.last_x_f | ||
_unchecked_value!(obj, x) | ||
elseif x != obj.last_x_g | ||
_unchecked_gradient!(obj, x) | ||
function value_gradient!(cache, obj::AbstractObjective, x) | ||
if x != cache.last_x_f && x != cache.last_x_g | ||
cache.f_calls .+= 1 | ||
cache.g_calls .+= 1 | ||
copy!(cache.last_x_f, x) | ||
copy!(cache.last_x_g, x) | ||
cache.f_x = obj.fg!(real_to_complex(cache, cache.g), real_to_complex(cache, x)) | ||
elseif x != cache.last_x_f | ||
cache.f_calls .+= 1 | ||
copy!(cache.last_x_f, x) | ||
cache.f_x = obj.f(real_to_complex(cache, x)) | ||
elseif x != cache.last_x_g | ||
cache.g_calls .+= 1 | ||
copy!(cache.last_x_g, x) | ||
obj.g!(real_to_complex(cache, cache.g), real_to_complex(cache, x)) | ||
end | ||
obj.f_x | ||
cache.f_x | ||
end | ||
|
||
function _unchecked_hessian!(obj::AbstractObjective, x) | ||
obj.h_calls .+= 1 | ||
copy!(obj.last_x_h, x) | ||
obj.h!(obj.H, x) | ||
end | ||
function hessian!(obj::AbstractObjective, x) | ||
if x != obj.last_x_h | ||
_unchecked_hessian!(obj, x) | ||
function hessian!(cache, obj::AbstractObjective, x) | ||
if x != cache.last_x_h | ||
cache.h_calls .+= 1 | ||
copy!(cache.last_x_h, x) | ||
obj.h!(cache.H, x) | ||
end | ||
end | ||
|
||
# Getters are without ! and accept only an objective and index or just an objective | ||
value(obj::AbstractObjective) = obj.f_x | ||
gradient(obj::AbstractObjective) = obj.g | ||
gradient(obj::AbstractObjective, i::Integer) = obj.g[i] | ||
hessian(obj::AbstractObjective) = obj.H | ||
# Getters are without ! and accept only an objective cache and index or just an objective cache | ||
value(cache::AbstractObjectiveCache) = cache.f_x | ||
gradient(cache::AbstractObjectiveCache) = cache.g | ||
gradient(cache::AbstractObjectiveCache, i::Integer) = cache.g[i] | ||
hessian(cache::AbstractObjectiveCache) = cache.H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,109 +1,118 @@ | ||
abstract type AbstractObjective end | ||
real_to_complex(d::AbstractObjective, x) = iscomplex(d) ? real_to_complex(x) : x | ||
complex_to_real(d::AbstractObjective, x) = iscomplex(d) ? complex_to_real(x) : x | ||
abstract type AbstractObjectiveCache end | ||
real_to_complex(c::AbstractObjectiveCache, x) = iscomplex(c) ? real_to_complex(x) : x | ||
complex_to_real(c::AbstractObjectiveCache, x) = iscomplex(c) ? complex_to_real(x) : x | ||
|
||
# Used for objectives and solvers where no gradient is available/exists | ||
mutable struct NonDifferentiable{T,A<:AbstractArray{T},Tcplx} <: AbstractObjective where {T<:Real, | ||
Tcplx<:Union{Val{true},Val{false}} #if true, must convert back on every f call | ||
} | ||
struct NonDifferentiable <: AbstractObjective | ||
f | ||
end | ||
|
||
mutable struct NonDifferentiableCache{T,A<:AbstractArray{T},Tcplx} <: AbstractObjectiveCache where {T<:Real, | ||
Tcplx<:Union{Val{true},Val{false}} # true is complex x; must convert back on every f call | ||
} | ||
f_x::T | ||
last_x_f::A | ||
f_calls::Vector{Int} | ||
end | ||
iscomplex(obj::NonDifferentiable{T,A,Val{true}}) where {T,A} = true | ||
iscomplex(obj::NonDifferentiable{T,A,Val{false}}) where {T,A} = false | ||
NonDifferentiable(f,f_x::T, last_x_f::AbstractArray{T}, f_calls::Vector{Int}) where {T} = NonDifferentiable{T,typeof(last_x_f),Val{false}}(f,f_x,last_x_f,f_calls) #compatibility with old constructor | ||
|
||
function NonDifferentiable(f, x_seed::AbstractArray) | ||
iscomplex = eltype(x_seed) <: Complex | ||
if iscomplex | ||
iscomplex(obj::NonDifferentiableCache{T,A,Val{true}}) where {T,A} = true | ||
iscomplex(obj::NonDifferentiableCache{T,A,Val{false}}) where {T,A} = false | ||
function NonDifferentiableCache(f, x_seed::AbstractArray) | ||
x_complex = eltype(x_seed) <: Complex | ||
if x_complex | ||
x_seed = complex_to_real(x_seed) | ||
end | ||
NonDifferentiable{eltype(x_seed),typeof(x_seed),Val{iscomplex}}(f, f(x_seed), copy(x_seed), [1]) | ||
NonDifferentiableCache{eltype(x_seed),typeof(x_seed),Val{x_complex}}(f(x_seed), copy(x_seed), [1]) | ||
end | ||
|
||
# Used for objectives and solvers where the gradient is available/exists | ||
mutable struct OnceDifferentiable{T, Tgrad, A<:AbstractArray{T}, Tcplx} <: AbstractObjective where {T<:Real, Tgrad, Tcplx<:Union{Val{true},Val{false}}} | ||
struct OnceDifferentiable <: AbstractObjective | ||
f | ||
g! | ||
fg! | ||
end | ||
# Automatically create the fg! helper function if only f and g! is provided | ||
function OnceDifferentiable(f, g!) | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
return OnceDifferentiable(f, g!, fg!) | ||
end | ||
|
||
mutable struct OnceDifferentiableCache{T, Tgrad, A<:AbstractArray{T}, Tcplx} <: AbstractObjectiveCache where {T<:Real, Tgrad, Tcplx<:Union{Val{true},Val{false}}} | ||
f_x::T | ||
g::Tgrad | ||
last_x_f::A | ||
last_x_g::A | ||
f_calls::Vector{Int} | ||
g_calls::Vector{Int} | ||
end | ||
iscomplex(obj::OnceDifferentiable{T,Tgrad,A,Val{true}}) where {T,Tgrad,A} = true | ||
iscomplex(obj::OnceDifferentiable{T,Tgrad,A,Val{false}}) where {T,Tgrad,A} = false | ||
OnceDifferentiable(f,g!,fg!,f_x::T, g::Tgrad, last_x_f::A, last_x_g::A, f_calls::Vector{Int}, g_calls::Vector{Int}) where {T, Tgrad, A<:AbstractArray{T}} = OnceDifferentiable{T,Tgrad,A,Val{false}}(f,g!,fg!,f_x, g, last_x_f, last_x_g, f_calls, g_calls) #compatibility with old constructor | ||
|
||
# The user friendly/short form OnceDifferentiable constructor | ||
function OnceDifferentiable(f, g!, fg!, x_seed::AbstractArray) | ||
iscomplex = eltype(x_seed) <: Complex | ||
iscomplex(obj::OnceDifferentiableCache{T,Tgrad,A,Val{true}}) where {T,Tgrad,A} = true | ||
iscomplex(obj::OnceDifferentiableCache{T,Tgrad,A,Val{false}}) where {T,Tgrad,A} = false | ||
function OnceDifferentiableCache(f, g!, x_seed::AbstractArray) | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
OnceDifferentiableCache(f, g!, fg!, x_seed) | ||
end | ||
function OnceDifferentiableCache(f, g!, fg!, x_seed::AbstractArray) | ||
x_complex = eltype(x_seed) <: Complex | ||
g = similar(x_seed) | ||
f_val = fg!(g, x_seed) | ||
|
||
if iscomplex | ||
if x_complex | ||
x_seed = complex_to_real(x_seed) | ||
g = complex_to_real(g) | ||
end | ||
OnceDifferentiable{eltype(x_seed),typeof(g),typeof(x_seed),Val{iscomplex}}(f, g!, fg!, f_val, g, copy(x_seed), copy(x_seed), [1], [1]) | ||
end | ||
# Automatically create the fg! helper function if only f and g! is provided | ||
function OnceDifferentiable(f, g!, x_seed::AbstractArray) | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
return OnceDifferentiable(f, g!, fg!, x_seed) | ||
OnceDifferentiableCache{eltype(x_seed),typeof(g),typeof(x_seed),Val{x_complex}}(f_val, g, copy(x_seed), copy(x_seed), [1], [1]) | ||
end | ||
|
||
# Used for objectives and solvers where the gradient and Hessian is available/exists | ||
mutable struct TwiceDifferentiable{T<:Real,Tgrad,A<:AbstractArray{T}} <: AbstractObjective | ||
struct TwiceDifferentiable <: AbstractObjective | ||
f | ||
g! | ||
fg! | ||
h! | ||
end | ||
# Automatically create the fg! helper function if only f, g! and h! is provided | ||
function TwiceDifferentiable(f, g!, h!) | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
return TwiceDifferentiable(f, g!, fg!, h!) | ||
end | ||
|
||
mutable struct TwiceDifferentiableCache{T<:Real,Tgrad,Thess,A<:AbstractArray{T}} <: AbstractObjectiveCache | ||
f_x::T | ||
g::Tgrad | ||
H::Matrix{T} | ||
H::Thess | ||
last_x_f::A | ||
last_x_g::A | ||
last_x_h::A | ||
f_calls::Vector{Int} | ||
g_calls::Vector{Int} | ||
h_calls::Vector{Int} | ||
end | ||
iscomplex(obj::TwiceDifferentiable) = false | ||
# The user friendly/short form TwiceDifferentiable constructor | ||
function TwiceDifferentiable(td::TwiceDifferentiable, x::AbstractArray) | ||
value_gradient!(td, x) | ||
hessian!(td, x) | ||
td | ||
iscomplex(obj::TwiceDifferentiableCache) = false | ||
# The user friendly/short form TwiceDifferentiableCache constructor | ||
function TwiceDifferentiableCache(f, g!, h!, x_seed::AbstractArray) | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
TwiceDifferentiableCache(f, g!, fg!, h!, x_seed) | ||
end | ||
|
||
function TwiceDifferentiable(f, g!, fg!, h!, x_seed::AbstractArray{T}) where T | ||
function TwiceDifferentiableCache(f, g!, fg!, h!, x_seed::AbstractArray{T}) where T | ||
n_x = length(x_seed) | ||
|
||
g = similar(x_seed) | ||
H = Array{T}(n_x, n_x) | ||
|
||
f_val = fg!(g, x_seed) | ||
h!(H, x_seed) | ||
|
||
TwiceDifferentiable(f, g!, fg!, h!, f_val, | ||
g, H, copy(x_seed), | ||
copy(x_seed), copy(x_seed), [1], [1], [1]) | ||
end | ||
# Automatically create the fg! helper function if only f, g! and h! is provided | ||
function TwiceDifferentiable(f, | ||
g!, | ||
h!, | ||
x_seed::AbstractArray{T}) where T | ||
function fg!(storage, x) | ||
g!(storage, x) | ||
return f(x) | ||
end | ||
return TwiceDifferentiable(f, g!, fg!, h!, x_seed) | ||
end | ||
TwiceDifferentiableCache(f_val, g, H, copy(x_seed), copy(x_seed), copy(x_seed), [1], [1], [1]) | ||
end |
Oops, something went wrong.