Skip to content

Commit

Permalink
update differentiation misc
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 25, 2017
1 parent c4c0fed commit d424fcd
Showing 1 changed file with 21 additions and 17 deletions.
38 changes: 21 additions & 17 deletions src/misc_utils.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
struct DiffCache{T, S}
du::Vector{T}
dual_du::Vector{S}
struct DiffEqNLSolveTag end

immutable DiffCache{T<:AbstractArray, S<:AbstractArray}
du::T
dual_du::S
end

Base.@pure function DiffCache(T, length, ::Type{Val{chunk_size}}) where chunk_size
DiffCache(zeros(T, length), zeros(Dual{:DiffEqNLSolve,T,chunk_size}, length))
Base.@pure function DiffCache{chunk_size}(T, size, ::Type{Val{chunk_size}})
DiffCache(zeros(T, size...), zeros(Dual{typeof(ForwardDiff.Tag(Void,T)),T,chunk_size}, size...))
end

Base.@pure DiffCache(u::AbstractArray) = DiffCache(eltype(u),length(u),Val{ForwardDiff.pickchunksize(length(u))})
Base.@pure DiffCache(u::AbstractArray,nlsolve) = DiffCache(eltype(u),length(u),Val{get_chunksize(nlsolve)})
Base.@pure DiffCache(u::AbstractArray,T::Type{Val{CS}}) where {CS} = DiffCache(eltype(u),length(u),T)
Base.@pure DiffCache(u::AbstractArray) = DiffCache(eltype(u),size(u),Val{ForwardDiff.pickchunksize(length(u))})
Base.@pure DiffCache(u::AbstractArray,nlsolve) = DiffCache(eltype(u),size(u),Val{get_chunksize(nlsolve)})
Base.@pure DiffCache{CS}(u::AbstractArray,T::Type{Val{CS}}) = DiffCache(eltype(u),size(u),T)

get_du(dc::DiffCache, ::Type{T}) where {T<:Dual} = dc.dual_du
get_du{T<:Dual}(dc::DiffCache, ::Type{T}) = dc.dual_du
get_du(dc::DiffCache, T) = dc.du

# Default nlsolve behavior, should move to DiffEqDiffTools.jl
Expand All @@ -25,14 +27,14 @@ Base.@pure function determine_chunksize(u,CS)
end
end

function autodiff_setup(f!, initial_x::Vector,chunk_size::Type{Val{CS}}) where CS
function autodiff_setup{CS}(f!, initial_x::Vector,chunk_size::Type{Val{CS}})

permf! = (fx, x) -> f!(x, fx)

fx2 = copy(initial_x)
jac_cfg = ForwardDiff.JacobianConfig(:DiffEqNLSolve, initial_x, ForwardDiff.Chunk{CS}())
jac_cfg = ForwardDiff.JacobianConfig(nothing,
initial_x, initial_x,
ForwardDiff.Chunk{CS}())
g! = (x, gx) -> ForwardDiff.jacobian!(gx, permf!, fx2, x, jac_cfg)

fg! = (x, fx, gx) -> begin
jac_res = DiffBase.DiffResult(fx, gx)
ForwardDiff.jacobian!(jac_res, permf!, fx2, x, jac_cfg)
Expand All @@ -46,16 +48,18 @@ function non_autodiff_setup(f!, initial_x::Vector)
DifferentiableMultivariateFunction(f!)
end

struct NLSOLVEJL_SETUP{CS,AD} end
immutable NLSOLVEJL_SETUP{CS,AD} end
Base.@pure NLSOLVEJL_SETUP(;chunk_size=0,autodiff=true) = NLSOLVEJL_SETUP{chunk_size,autodiff}()
(p::NLSOLVEJL_SETUP)(f,u0) = (res=NLsolve.nlsolve(f,u0); res.zero)
function (p::NLSOLVEJL_SETUP{CS,AD}){CS,AD}(::Type{Val{:init}},f,u0_prototype)
if AD
return non_autodiff_setup(f,u0_prototype)
return autodiff_setup(f,u0_prototype,Val{determine_chunksize(u0_prototype,CS)})
else
return autodiff_setup(f,u0_prototype,Val{determine_chunksize(initial_x,CS)})
return non_autodiff_setup(f,u0_prototype)
end
end

get_chunksize(x) = 0
get_chunksize(x::NLSOLVEJL_SETUP{CS,AD}) where {CS,AD} = CS
get_chunksize{CS,AD}(x::NLSOLVEJL_SETUP{CS,AD}) = CS

export NLSOLVEJL_SETUP

0 comments on commit d424fcd

Please sign in to comment.