Skip to content

Commit

Permalink
Merge pull request #937 from huanglangwen/colorjac
Browse files Browse the repository at this point in the history
replace ForwardDiff.jacobian with SparseDiffTools.forwarddiff_color_jacobian
  • Loading branch information
ChrisRackauckas committed Nov 13, 2019
2 parents 0e1fa24 + 92696c9 commit 75ff4b6
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 68 deletions.
3 changes: 2 additions & 1 deletion src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ module OrdinaryDiffEq

import DiffEqBase: calculate_residuals, calculate_residuals!, unwrap_cache, @tight_loop_macros, islinear

import SparseDiffTools: forwarddiff_color_jacobian!, ForwardColorJacCache
import SparseDiffTools
import SparseDiffTools: forwarddiff_color_jacobian!, forwarddiff_color_jacobian, ForwardColorJacCache, default_chunk_size, getsize

using MacroTools

Expand Down
122 changes: 61 additions & 61 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,24 @@ function derivative(f, x::Union{Number,AbstractArray{<:Number}},
end
end

function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if alg_autodiff(alg)
if DiffEqBase.has_colorvec(integrator.f)
J,tmp = jacobian_autodiff(f, x, integrator)
else
J,tmp = jacobian_autodiff(f, x)
end
else
if DiffEqBase.has_colorvec(integrator.f)
J,tmp = jacobian_finitediff(f, x, alg.diff_type, integrator, integrator.f.colorvec)
else
J,tmp = jacobian_finitediff(f, x, alg.diff_type, integrator)
end
end
integrator.destats.nf += tmp
J
end

jacobian_autodiff(f, x) = (ForwardDiff.derivative(f,x),1)
jacobian_autodiff(f, x::AbstractArray) = (ForwardDiff.jacobian(f, x),1)
function jacobian_autodiff(f, x::AbstractArray, integrator)
colorvec=integrator.f.colorvec
jac=integrator.f.jac_prototype
J=jac isa SparseMatrixCSC ? similar(jac) : fill(0.,size(jac))
(forwarddiff_color_jacobian!(J,f,x,colorvec=colorvec,sparsity=jac),1)
jacobian_autodiff(f, x, odefun) = (ForwardDiff.derivative(f,x),1)
function jacobian_autodiff(f, x::AbstractArray, odefun)
if DiffEqBase.has_colorvec(odefun)
colorvec = odefun.colorvec
sparsity = odefun.jac_prototype
jac_prototype = nothing
else
colorvec = 1:length(x)
sparsity = nothing
jac_prototype = odefun.jac_prototype
end
maxcolor = maximum(colorvec)
chunksize = getsize(default_chunk_size(maxcolor))
num_of_chunks = Int(ceil(maxcolor / chunksize))
(forwarddiff_color_jacobian(f,x,colorvec = colorvec, sparsity = sparsity,
jac_prototype = jac_prototype),
num_of_chunks)
end
#jacobian_autodiff(f, x::AbstractArray, colorvec) = (ForwardDiff.jacobian(f, x, colorvec = colorvec),1)

function _nfcount(N,diff_type)
if diff_type==Val{:complex}
Expand All @@ -77,52 +66,56 @@ function _nfcount(N,diff_type)
tmp
end

jacobian_finitediff(f, x, diff_type, integrator) =
(DiffEqDiffTools.finite_difference_derivative(f, x, diff_type, eltype(x), dir = diffdir(integrator)),2)
jacobian_finitediff(f, x::AbstractArray, diff_type, integrator) =
(DiffEqDiffTools.finite_difference_jacobian(f, x, diff_type, eltype(x), Val{false},
dir = diffdir(integrator)),_nfcount(length(x),diff_type))
jacobian_finitediff(f, x::AbstractArray, diff_type, integrator, colorvec) =
(DiffEqDiffTools.finite_difference_jacobian(f, x, diff_type, eltype(x), Val{false},
dir = diffdir(integrator), colorvec = colorvec, sparsity = integrator.f.jac_prototype),_nfcount(maximum(colorvec),diff_type))
jacobian_finitediff(f, x, diff_type, dir, colorvec, sparsity, jac_prototype) =
(DiffEqDiffTools.finite_difference_derivative(f, x, diff_type, eltype(x), dir = dir),2)
jacobian_finitediff(f, x::AbstractArray, diff_type, dir, colorvec, sparsity, jac_prototype) =
(DiffEqDiffTools.finite_difference_jacobian(f, x, diff_type, eltype(x), diff_type==Val{:forward} ? f(x) : similar(x),
dir = dir, colorvec = colorvec, sparsity = sparsity, jac_prototype = jac_prototype),_nfcount(maximum(colorvec),diff_type))

function jacobian(f, x, integrator)
alg = unwrap_alg(integrator, true)
local tmp
if alg_autodiff(alg)
J, tmp = jacobian_autodiff(f, x, integrator.f)
else
if DiffEqBase.has_colorvec(integrator.f)
colorvec = integrator.f.colorvec
sparsity = integrator.f.jac_prototype
jac_prototype = nothing
else
colorvec = 1:length(x)
sparsity = nothing
jac_prototype = integrator.f.jac_prototype
end
dir = diffdir(integrator)
J, tmp = jacobian_finitediff(f, x, alg.diff_type, dir, colorvec, sparsity, jac_prototype)
end
integrator.destats.nf += tmp
J
end

jacobian_finitediff_forward!(J,f,x,jac_config,forwardcache,integrator)=(DiffEqDiffTools.finite_difference_jacobian!(J,f,x,jac_config,forwardcache,dir=diffdir(integrator));length(x))
jacobian_finitediff_forward!(J,f,x,jac_config,forwardcache,integrator,colorvec)=
(DiffEqDiffTools.finite_difference_jacobian!(J,f,x,jac_config,forwardcache,
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.jac_prototype);maximum(colorvec))
jacobian_finitediff!(J,f,x,jac_config,integrator)=(DiffEqDiffTools.finite_difference_jacobian!(J,f,x,jac_config,dir=diffdir(integrator));2*length(x))
jacobian_finitediff!(J,f,x,jac_config,integrator,colorvec)=
(DiffEqDiffTools.finite_difference_jacobian!(J,f,x,jac_config,
dir=diffdir(integrator),colorvec=colorvec,sparsity=integrator.f.jac_prototype);2*maximum(colorvec))
jacobian_autodiff!(J,f,x,jac_config)=forwarddiff_color_jacobian!(J,f,x,jac_config)#J::SparseMatrixCSC
jacobian_autodiff!(J,f,fx,x,jac_config)=ForwardDiff.jacobian!(J,f,fx,x,jac_config)

function jacobian!(J::AbstractMatrix{<:Number}, f, x::AbstractArray{<:Number}, fx::AbstractArray{<:Number}, integrator::DiffEqBase.DEIntegrator, jac_config)
alg = unwrap_alg(integrator, true)
if alg_autodiff(alg)
if DiffEqBase.has_colorvec(integrator.f)
jacobian_autodiff!(J,f,x,jac_config)
else
jacobian_autodiff!(J,f,fx,x,jac_config)
end
forwarddiff_color_jacobian!(J,f,x,jac_config)
integrator.destats.nf += 1
else
colorvec = DiffEqBase.has_colorvec(integrator.f) ? integrator.f.colorvec : 1:length(x)
isforward = alg.diff_type === Val{:forward}
if isforward
forwardcache = get_tmp_cache(integrator, alg, unwrap_cache(integrator, true))[2]
f(forwardcache, x)
integrator.destats.nf += 1
if DiffEqBase.has_colorvec(integrator.f)
tmp=jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, integrator, integrator.f.colorvec)
else
tmp=jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, integrator)
end
tmp=jacobian_finitediff_forward!(J, f, x, jac_config, forwardcache, integrator, colorvec)
else # not forward difference
if DiffEqBase.has_colorvec(integrator.f)
tmp=jacobian_finitediff!(J, f, x, jac_config, integrator, integrator.f.colorvec)
else
tmp=jacobian_finitediff!(J, f, x, jac_config, integrator)
end
tmp=jacobian_finitediff!(J, f, x, jac_config, integrator, colorvec)
end
integrator.destats.nf += tmp
end
Expand All @@ -133,10 +126,13 @@ function DiffEqBase.build_jac_config(alg::Union{OrdinaryDiffEqAlgorithm,DAEAlgor
if !DiffEqBase.has_jac(f)
if alg_autodiff(alg)
if DiffEqBase.has_colorvec(f)
jac_config = ForwardColorJacCache(uf,uprev,colorvec=f.colorvec,sparsity=f.jac_prototype)
colorvec = f.colorvec
sparsity = f.jac_prototype
else
jac_config = ForwardDiff.JacobianConfig(uf,du1,uprev,ForwardDiff.Chunk{determine_chunksize(u,alg)}())
colorvec = 1:length(uprev)
sparsity = nothing
end
jac_config = ForwardColorJacCache(uf,uprev,colorvec=colorvec,sparsity=sparsity)
else
if alg.diff_type != Val{:complex}
jac_config = DiffEqDiffTools.JacobianCache(tmp,du1,du2,alg.diff_type)
Expand All @@ -152,10 +148,14 @@ end

get_chunksize(jac_config::ForwardDiff.JacobianConfig{T,V,N,D}) where {T,V,N,D} = N

function DiffEqBase.resize_jac_config!(jac_config::ForwardDiff.JacobianConfig, i)
for j in eachindex(jac_config.duals)
resize!(jac_config.duals[j], i)
end
function DiffEqBase.resize_jac_config!(jac_config::SparseDiffTools.ForwardColorJacCache, i)
resize!(jac_config.fx, i)
resize!(jac_config.dx, i)
resize!(jac_config.t, i)
resize!(jac_config.p, i)
jac_config.p .= SparseDiffTools.adapt.(typeof(jac_config.dx),
SparseDiffTools.generate_chunked_partials(jac_config.dx,
1:length(jac_config.dx),Val(ForwardDiff.npartials(jac_config.t[1]))))
jac_config
end

Expand Down
12 changes: 8 additions & 4 deletions test/integrators/resize_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ resize!(i, 5)
@test size(i.cache.nlsolver.cache.J) == (5,5)
@test size(i.cache.nlsolver.cache.W) == (5,5)
@test length(i.cache.nlsolver.cache.du1) == 5
@test length(i.cache.nlsolver.cache.jac_config.duals[1]) == 5
@test length(i.cache.nlsolver.cache.jac_config.duals[2]) == 5
@test length(i.cache.nlsolver.cache.jac_config.fx) == 5
@test length(i.cache.nlsolver.cache.jac_config.dx) == 5
@test length(i.cache.nlsolver.cache.jac_config.t) == 5
@test length(i.cache.nlsolver.cache.jac_config.p) == 5
@test length(i.cache.nlsolver.cache.weight) == 5
solve!(i)

Expand Down Expand Up @@ -75,8 +77,10 @@ resize!(i, 5)
@test size(i.cache.J) == (5, 5)
@test size(i.cache.W) == (5, 5)
@test length(i.cache.linsolve_tmp) == 5
@test length(i.cache.jac_config.duals[1]) == 5
@test length(i.cache.jac_config.duals[2]) == 5
@test length(i.cache.jac_config.fx) == 5
@test length(i.cache.jac_config.dx) == 5
@test length(i.cache.jac_config.t) == 5
@test length(i.cache.jac_config.p) == 5
solve!(i)

i = init(prob, Rosenbrock23(;autodiff=false))
Expand Down
4 changes: 2 additions & 2 deletions test/interface/data_array_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ end

tstop =[tstop1;tstop2]
sol = solve(prob,Tsit5(),callback = cbs, tstops=tstop)
sol = solve(prob,Rodas4(),callback = cbs, tstops=tstop)
sol = solve(prob,Kvaerno3(),callback = cbs, tstops=tstop)
@test_broken sol = solve(prob,Rodas4(),callback = cbs, tstops=tstop)
@test_broken sol = solve(prob,Kvaerno3(),callback = cbs, tstops=tstop)
@test_broken sol = solve(prob,Rodas4(autodiff=false),callback = cbs, tstops=tstop)
@test_broken sol = solve(prob,Kvaerno3(autodiff=false),callback = cbs, tstops=tstop)
end

0 comments on commit 75ff4b6

Please sign in to comment.