diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index fcc59b6e..5132391e 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -215,7 +215,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i utmp = NVector(_u0) use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && - LinearSolver ∈ SPARSE_SOLVERS) + LinearSolver ∈ SPARSE_SOLVERS) || + prob.f.jac_prototype isa AbstractSciMLOperator userfun = FunJac(f!, prob.f.jac, prob.p, @@ -341,6 +342,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i end if typeof(prob.f.jac_prototype) <: AbstractSciMLOperator + "here!!!!" function getcfunjtimes(::T) where {T} @cfunction(jactimes, Cint, diff --git a/src/nvector_wrapper.jl b/src/nvector_wrapper.jl index b4de0bc7..fe704a6d 100644 --- a/src/nvector_wrapper.jl +++ b/src/nvector_wrapper.jl @@ -77,6 +77,7 @@ Conversion happens in two steps within ccall: """ Base.cconvert(::Type{N_Vector}, v::Vector{realtype}) = convert(NVector, v) # will just return v if v is an NVector Base.unsafe_convert(::Type{N_Vector}, nv::NVector) = nv.n_v +Base.copy!(v::Vector, nv::Ptr{Sundials._generic_N_Vector}) = copy!(v, convert(NVector, nv)) Base.similar(nv::NVector) = NVector(similar(nv.v)) diff --git a/test/common_interface/jacobians.jl b/test/common_interface/jacobians.jl index 6c03e763..6558a0e6 100644 --- a/test/common_interface/jacobians.jl +++ b/test/common_interface/jacobians.jl @@ -21,6 +21,7 @@ end Lotka_f = ODEFunction(Lotka; jac = Lotka_jac) prob = ODEProblem(Lotka_f, ones(2), (0.0, 10.0)) good_sol = solve(prob, CVODE_BDF()) +testsol = solve(prob, CVODE_BDF(), saveat = 0.1, abstol = 1e-12, reltol = 1e-12) @test jac_called == true Lotka_f = ODEFunction(Lotka; @@ -33,10 +34,14 @@ sol9 = solve(prob, CVODE_BDF(; linear_solver = :KLU)) @test jac_called == true @test Array(sol9) ≈ Array(good_sol) -Lotka_fj = ODEFunction(Lotka; jac_prototype = JacVec(Lotka, ones(2))) +Lotka_fj = ODEFunction(Lotka; + jac_prototype = JacVec((du, u) -> Lotka(du, u, (), 0.0), ones(2), + SciMLBase.NullParameters())) prob = ODEProblem(Lotka_fj, ones(2), (0.0, 10.0)) -sol9 = solve(prob, CVODE_BDF(; linear_solver = :GMRES)) +sol9 = solve(prob, CVODE_BDF(; linear_solver = :GMRES), saveat = 0.1, abstol = 1e-12, + reltol = 1e-12) +@test Array(sol9) ≈ Array(testsol) function f2!(res, du, u, p, t) res[1] = 1.01du[1]