Skip to content

Commit

Permalink
add test for non-vector inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
Langwen Huang committed Nov 3, 2019
1 parent dacd688 commit 778c332
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/jacobians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,25 @@ function finite_difference_jacobian!(J::AbstractMatrix,
x::AbstractArray{<:Number},
fdtype :: Type{T1}=Val{:forward},
returntype :: Type{T2}=eltype(x),
f_in :: Union{T2,Nothing}=nothing;
f_in :: Union{AbstractArray{<:T2},Nothing}=nothing;
relstep=default_relstep(fdtype, eltype(x)),
absstep=relstep,
colorvec = 1:length(x),
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing) where {T1,T2,T3}
cache = JacobianCache(x, fdtype, returntype)
finite_difference_jacobian!(J, f, x, cache, f_in; relstep=relstep, absstep=absstep, colorvec=colorvec, sparsity=sparsity)
sparsity = ArrayInterface.has_sparsestruct(J) ? J : nothing) where {T1,T2}
if f_in isa Nothing && fdtype == Val{:forward}
if size(J,1) == length(x)
fx = zero(x)
else
fx = zeros(returntype,size(J,1))
end
f(fx,x)
cache = JacobianCache(x, fx, fdtype, returntype)
elseif f_in isa Nothing
cache = JacobianCache(x, fdtype, returntype)
else
cache = JacobianCache(x, f_in, fdtype, returntype)
end
finite_difference_jacobian!(J, f, x, cache, cache.fx; relstep=relstep, absstep=absstep, colorvec=colorvec, sparsity=sparsity)
end

function finite_difference_jacobian!(
Expand Down
14 changes: 14 additions & 0 deletions test/finitedifftests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,20 @@ end
@test err_func(test_iipJac(J_ref, iipf, x, central_cache), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, Val{:central}), J_ref) < 1e-8
end

# Non vector input
x = rand(2,2)
oopf(x) = x
iipf(fx,x) = (fx.=x)
J_ref = Matrix{Float64}(I,4,4)
@time @testset "Jacobian for non-vector inputs" begin
@test err_func(DiffEqDiffTools.finite_difference_jacobian(oopf, x, Val{:forward}), J_ref) < 1e-8
@test err_func(DiffEqDiffTools.finite_difference_jacobian(oopf, x, Val{:central}), J_ref) < 1e-8
@test err_func(DiffEqDiffTools.finite_difference_jacobian(oopf, x, Val{:complex}), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, Val{:forward}, eltype(x), iipf(similar(x),x)), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, Val{:central}, eltype(x), iipf(similar(x),x)), J_ref) < 1e-8
@test err_func(test_iipJac(J_ref, iipf, x, Val{:complex}, eltype(x), iipf(similar(x),x)), J_ref) < 1e-8
end
# Hessian tests

f(x) = sin(x[1]) + cos(x[2])
Expand Down

0 comments on commit 778c332

Please sign in to comment.