Skip to content

Commit

Permalink
fix forward over reverse
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jun 28, 2019
1 parent 54b16c2 commit 26a3fc0
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 56 deletions.
8 changes: 4 additions & 4 deletions src/differentiation/jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,19 +137,19 @@ function autonum_hesvec(f,x,v)
end

function autoback_hesvec!(du,f,x,v,
cache2 = ForwardDiff.Dual{DeivVecTag}.(x, v),
cache3 = ForwardDiff.Dual{DeivVecTag}.(x, v))
cache2 = ForwardDiff.Dual{Nothing}.(x, v),
cache3 = ForwardDiff.Dual{Nothing}.(x, v))
g = let f=f
g = (dx,x) -> dx .= first(Zygote.gradient(f,x))
end
cache2 .= Dual{DeivVecTag}.(x, v)
cache2 .= Dual{Nothing}.(x, v)
g(cache3,cache2)
du .= partials.(cache3, 1)
end

function autoback_hesvec(f,x,v)
g = (x) -> first(Zygote.gradient(f,x))
partials.(g(Dual{DeivVecTag}.(x, v)), 1)
ForwardDiff.partials.(g(ForwardDiff.Dual{Nothing}.(x, v)), 1)
end

function num_hesvecgrad!(du,g,x,v,
Expand Down
102 changes: 50 additions & 52 deletions test/test_jaches_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@ f(u) = A*u
x = rand(300)
v = rand(300)
du = similar(x)

g(u) = sum(abs2,u)
function h(x)
DiffEqDiffTools.finite_difference_gradient(g,x)
end
function h(dx,x)
DiffEqDiffTools.finite_difference_gradient!(dx,g,x)
end

cache1 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
Expand All @@ -19,43 +25,36 @@ cache2 = ForwardDiff.Dual{SparseDiffTools.DeivVecTag}.(x, v)
@test auto_jacvec!(du, f, x, v, cache1, cache2) ForwardDiff.jacobian(f,similar(x),x)*v
@test auto_jacvec(f, x, v) ForwardDiff.jacobian(f,similar(x),x)*v

f(u) = sum(u.^2)
@test num_hesvec!(du, f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvec!(du, f, x, v, similar(v), similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvec!(du, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test num_hesvec!(du, g, x, v, similar(v), similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test num_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2

@test numauto_hesvec!(du, f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numauto_hesvec!(du, f, x, v, ForwardDiff.GradientConfig(f,x), similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numauto_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numauto_hesvec!(du, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test numauto_hesvec!(du, g, x, v, ForwardDiff.GradientConfig(g,x), similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test numauto_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

@test autonum_hesvec!(du, f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test autonum_hesvec!(du, f, x, v, similar(v), cache1, cache2) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test autonum_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test autonum_hesvec!(du, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test autonum_hesvec!(du, g, x, v, similar(v), cache1, cache2) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test autonum_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

@test numback_hesvec!(du, f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numback_hesvec!(du, f, x, v, similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numback_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test numback_hesvec!(du, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test numback_hesvec!(du, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test numback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

@test_broken autoback_hesvec!(du, f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test_broken autoback_hesvec!(du, f, x, v, similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-8
@test_broken autoback_hesvec(f, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-8
cache3 = ForwardDiff.Dual{Nothing}.(x, v)
cache4 = ForwardDiff.Dual{Nothing}.(x, v)
@test autoback_hesvec!(du, g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test autoback_hesvec!(du, g, x, v, cache3, cache4) ForwardDiff.hessian(g,x)*v rtol=1e-8
@test autoback_hesvec(g, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-8

function g(x)
DiffEqDiffTools.finite_difference_gradient(f,x)
end
function g(dx,x)
DiffEqDiffTools.finite_difference_gradient!(dx,f,x)
end
@test num_hesvecgrad!(du, g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvecgrad!(du, g, x, v, similar(v), similar(v)) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvecgrad(g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test num_hesvecgrad!(du, h, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test num_hesvecgrad!(du, h, x, v, similar(v), similar(v)) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test num_hesvecgrad(h, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2

@test auto_hesvecgrad!(du, g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test auto_hesvecgrad!(du, g, x, v, cache1, cache2) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test auto_hesvecgrad(g, x, v) ForwardDiff.hessian(f,x)*v rtol=1e-2
@test auto_hesvecgrad!(du, h, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test auto_hesvecgrad!(du, h, x, v, cache1, cache2) ForwardDiff.hessian(g,x)*v rtol=1e-2
@test auto_hesvecgrad(h, x, v) ForwardDiff.hessian(g,x)*v rtol=1e-2

f(du,u) = mul!(du,A,u)
f(u) = A*u
L = JacVec(f,x)
@test L*x auto_jacvec(f, x, x)
@test L*v auto_jacvec(f, x, v)
Expand All @@ -75,42 +74,41 @@ L.u .= v
out = similar(v)
gmres!(out, L, v)

f(u) = sum(u.^2)
x = rand(300)
v = rand(300)
L = HesVec(f,x,autodiff=false)
@test L*x num_hesvec(f, x, x)
@test L*v num_hesvec(f, x, v)
@test mul!(du,L,v) num_hesvec(f, x, v) rtol=1e-2
L = HesVec(g,x,autodiff=false)
@test L*x num_hesvec(g, x, x)
@test L*v num_hesvec(g, x, v)
@test mul!(du,L,v) num_hesvec(g, x, v) rtol=1e-2
L.u .= v
@test mul!(du,L,v) num_hesvec(f, v, v) rtol=1e-2
@test mul!(du,L,v) num_hesvec(g, v, v) rtol=1e-2

L = HesVec(f,x)
@test L*x numauto_hesvec(f, x, x)
@test L*v numauto_hesvec(f, x, v)
@test mul!(du,L,v) numauto_hesvec(f, x, v) rtol=1e-8
L = HesVec(g,x)
@test L*x numauto_hesvec(g, x, x)
@test L*v numauto_hesvec(g, x, v)
@test mul!(du,L,v) numauto_hesvec(g, x, v) rtol=1e-8
L.u .= v
@test mul!(du,L,v) numauto_hesvec(f, v, v) rtol=1e-8
@test mul!(du,L,v) numauto_hesvec(g, v, v) rtol=1e-8

### Integration test with IterativeSolvers
out = similar(v)
gmres!(out, L, v)

x = rand(300)
v = rand(300)
L = HesVecGrad(g,x,autodiff=false)
@test L*x num_hesvec(f, x, x)
@test L*v num_hesvec(f, x, v)
@test mul!(du,L,v) num_hesvec(f, x, v) rtol=1e-2
L = HesVecGrad(h,x,autodiff=false)
@test L*x num_hesvec(g, x, x)
@test L*v num_hesvec(g, x, v)
@test mul!(du,L,v) num_hesvec(g, x, v) rtol=1e-2
L.u .= v
@test mul!(du,L,v) num_hesvec(f, v, v) rtol=1e-2
@test mul!(du,L,v) num_hesvec(g, v, v) rtol=1e-2

L = HesVecGrad(g,x,autodiff=true)
@test L*x autonum_hesvec(f, x, x)
@test L*v numauto_hesvec(f, x, v)
@test mul!(du,L,v) numauto_hesvec(f, x, v) rtol=1e-8
L = HesVecGrad(h,x,autodiff=true)
@test L*x autonum_hesvec(g, x, x)
@test L*v numauto_hesvec(g, x, v)
@test mul!(du,L,v) numauto_hesvec(g, x, v) rtol=1e-8
L.u .= v
@test mul!(du,L,v) numauto_hesvec(f, v, v) rtol=1e-8
@test mul!(du,L,v) numauto_hesvec(g, v, v) rtol=1e-8

### Integration test with IterativeSolvers
out = similar(v)
Expand Down

0 comments on commit 26a3fc0

Please sign in to comment.