diff --git a/src/objective_types/oncedifferentiable.jl b/src/objective_types/oncedifferentiable.jl index 9790ba3..9abab5a 100644 --- a/src/objective_types/oncedifferentiable.jl +++ b/src/objective_types/oncedifferentiable.jl @@ -32,7 +32,8 @@ function OnceDifferentiable(f, x_seed::AbstractArray{T}, F::Real, DF::AbstractArray, autodiff) where T - + # When here, at the constructor with positional autodiff, it should already + # be the case, that f is inplace. if typeof(f) <: Union{InplaceObjective, NotInplaceObjective} fF = make_f(f, x_seed, F) @@ -46,7 +47,10 @@ function OnceDifferentiable(f, x_seed::AbstractArray{T}, # Figure out which Val-type to use for DiffEqDiffTools based on our # symbol interface. fdtype = diffeqdiff_fdtype(autodiff) - gcache = DiffEqDiffTools.GradientCache(x_seed, x_seed, fdtype) + df_array_spec = DF + x_array_spec = x_seed + return_spec = typeof(F) + gcache = DiffEqDiffTools.GradientCache(df_array_spec, x_array_spec, fdtype, return_spec) function g!(storage, x) DiffEqDiffTools.finite_difference_gradient!(storage, f, x, gcache) @@ -86,36 +90,66 @@ function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray, end OnceDifferentiable(f, x, F, alloc_DF(x, F), :forward, chunk) end -function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray, DF::AbstractArray, - autodiff::Symbol , chunk::ForwardDiff.Chunk = ForwardDiff.Chunk(x)) +function OnceDifferentiable(f, x_seed::AbstractArray, F::AbstractArray, DF::AbstractArray, + autodiff::Symbol , chunk::ForwardDiff.Chunk = ForwardDiff.Chunk(x_seed)) if typeof(f) <: Union{InplaceObjective, NotInplaceObjective} - fF = make_f(f, x, F) - dfF = make_df(f, x, F) - fdfF = make_fdf(f, x, F) - return OnceDifferentiable(fF, dfF, fdfF, x, F, DF) + fF = make_f(f, x_seed, F) + dfF = make_df(f, x_seed, F) + fdfF = make_fdf(f, x_seed, F) + return OnceDifferentiable(fF, dfF, fdfF, x_seed, F, DF) else if is_finitediff(autodiff) - # Figure out which Val-type to use for DiffEqDiffTools based on our # symbol interface. fdtype = diffeqdiff_fdtype(autodiff) + # Apparently only the third input is aliased. + j_diffeqdiff_cache = DiffEqDiffTools.JacobianCache(similar(x_seed), similar(F), similar(F), fdtype) + if autodiff == :finiteforward + # These copies can be done away with if we add a keyword for + # reusing arrays instead for overwriting them. + Fx = copy(F) + DF = copy(DF) + + x_f, x_df = x_of_nans(x_seed), x_of_nans(x_seed) + f_calls, j_calls = [0,], [0,] + function j_finiteforward!(J, x) + # Exploit the possibility that it might be that x_f == x + # then we don't have to call f again. + + # if at least one element of x_f is different from x, update + if any(x_f .!= x) + Fx = similar(Fx) + f(Fx, x) + f_calls .+= 1 + end + + DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache, Fx) + end + function fj_finiteforward!(F, J, x) + f(F, x) + DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache, F) + end + + + return OnceDifferentiable(f, j_finiteforward!, fj_finiteforward!, copy(F), copy(DF), x_f, x_df, f_calls, j_calls) + end - central_cache = DiffEqDiffTools.JacobianCache(similar(x), similar(F), similar(F), fdtype) function fj_finitediff!(F, J, x) f(F, x) - DiffEqDiffTools.finite_difference_jacobian!(J, f, x, central_cache) + DiffEqDiffTools.finite_difference_jacobian!(J, f, x, j_diffeqdiff_cache) F end function j_finitediff!(J, x) F_cache = similar(F) fj_finitediff!(F_cache, J, x) end - return OnceDifferentiable(f, j_finitediff!, fj_finitediff!, x, F, DF) + + return OnceDifferentiable(f, j_finitediff!, fj_finitediff!, x_seed, F, DF) elseif is_forwarddiff(autodiff) - jac_cfg = ForwardDiff.JacobianConfig(f, F, x, chunk) - ForwardDiff.checktag(jac_cfg, f, x) + jac_cfg = ForwardDiff.JacobianConfig(f, F, x_seed, chunk) + ForwardDiff.checktag(jac_cfg, f, x_seed) F2 = copy(F) function j_forwarddiff!(J, x) @@ -127,7 +161,7 @@ function OnceDifferentiable(f, x::AbstractArray, F::AbstractArray, DF::AbstractA DiffResults.value(jac_res) end - return OnceDifferentiable(f, j_forwarddiff!, fj_forwarddiff!, x, F, DF) + return OnceDifferentiable(f, j_forwarddiff!, fj_forwarddiff!, x_seed, F, DF) else error("The autodiff value $(autodiff) is not supported. Use :finite or :forward.") end