diff --git a/src/common_interface/function_types.jl b/src/common_interface/function_types.jl index e6b09213..4dc04eb2 100644 --- a/src/common_interface/function_types.jl +++ b/src/common_interface/function_types.jl @@ -1,5 +1,6 @@ abstract type AbstractFunJac{J2} end -mutable struct FunJac{F, F2, J, P, M, J2, uType, uType2, Prec, PS} <: AbstractFunJac{J2} +mutable struct FunJac{N, F, F2, J, P, M, J2, Prec, PS, + TResid <: Union{Nothing, Array{Float64, N}}} <: AbstractFunJac{J2} fun::F fun2::F2 jac::J @@ -8,9 +9,9 @@ mutable struct FunJac{F, F2, J, P, M, J2, uType, uType2, Prec, PS} <: AbstractFu jac_prototype::J2 prec::Prec psetup::PS - u::uType - du::uType - resid::uType2 + u::Array{Float64, N} + du::Array{Float64, N} + resid::TResid end function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du) FunJac(fun, nothing, jac, p, m, @@ -25,20 +26,20 @@ function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du, resid) du, resid) end -function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), - length(funjac.du)) +function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N} + funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) + funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du), + size(funjac.du)) _du = funjac.du _u = funjac.u funjac.fun(_du, _u, funjac.p, t) return CV_SUCCESS end -function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), - length(funjac.du)) +function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N} + funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) + funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du), + size(funjac.du)) _du = funjac.du _u = funjac.u funjac.fun2(_du, _u, funjac.p, t) @@ -79,14 +80,15 @@ function cvodejac(t::realtype, return CV_SUCCESS end -function idasolfun(t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector, funjac::FunJac) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) +function idasolfun(t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector, + funjac::FunJac{N}) where {N} + funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), - length(funjac.du)) + funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du), + size(funjac.du)) _du = funjac.du - funjac.resid = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(resid), - length(funjac.resid)) + funjac.resid = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(resid), + size(funjac.resid)) _resid = funjac.resid funjac.fun(_resid, _du, _u, funjac.p, t) return IDA_SUCCESS @@ -102,10 +104,11 @@ function idajac(t::realtype, tmp1::N_Vector, tmp2::N_Vector, tmp3::N_Vector) - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) + N = ndims(funjac.u) + funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), - length(funjac.du)) + funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du), + size(funjac.du)) _du = funjac.du funjac.jac(convert(Matrix, J), _du, _u, funjac.p, cj, t) @@ -123,11 +126,11 @@ function idajac(t::realtype, tmp2::N_Vector, tmp3::N_Vector) jac_prototype = funjac.jac_prototype - - funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u)) + N = ndims(funjac.u) + funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u)) _u = funjac.u - funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du), - length(funjac.du)) + funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du), + size(funjac.du)) _du = funjac.du funjac.jac(jac_prototype, _du, _u, funjac.p, cj, t) diff --git a/src/common_interface/integrator_types.jl b/src/common_interface/integrator_types.jl index 8102b241..523be46c 100644 --- a/src/common_interface/integrator_types.jl +++ b/src/common_interface/integrator_types.jl @@ -28,26 +28,23 @@ end abstract type AbstractSundialsIntegrator{algType} <: DiffEqBase.AbstractODEIntegrator{algType, true, Vector{Float64}, Float64} end -mutable struct CVODEIntegrator{uType, +mutable struct CVODEIntegrator{N, pType, - memType, solType, algType, fType, UFType, JType, oType, - toutType, - sizeType, - tmpType, LStype, Atype, CallbackCacheType} <: AbstractSundialsIntegrator{algType} - u::uType + u::Array{Float64, N} + u_nvec::NVector p::pType t::Float64 tprev::Float64 - mem::memType + mem::Handle{CVODEMem} LS::LStype A::Atype sol::solType @@ -56,12 +53,11 @@ mutable struct CVODEIntegrator{uType, userfun::UFType jac::JType opts::oType - tout::toutType + tout::Vector{Float64} tdir::Float64 - sizeu::sizeType u_modified::Bool - tmp::tmpType - uprev::tmpType + tmp::Array{Float64, N} + uprev::Array{Float64, N} flag::Cint just_hit_tstop::Bool event_last_time::Int @@ -74,7 +70,7 @@ function (integrator::CVODEIntegrator)(t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} out = similar(integrator.u) - integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : out[idxs] end @@ -82,32 +78,29 @@ function (integrator::CVODEIntegrator)(out, t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} - integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : @view out[idxs] end -mutable struct ARKODEIntegrator{uType, +mutable struct ARKODEIntegrator{N, pType, - memType, solType, algType, fType, UFType, JType, oType, - toutType, - sizeType, - tmpType, LStype, Atype, MLStype, Mtype, CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE} - u::uType + u::Array{Float64, N} + u_nvec::NVector p::pType t::Float64 tprev::Float64 - mem::memType + mem::Handle{ARKStepMem} LS::LStype A::Atype MLS::MLStype @@ -118,12 +111,11 @@ mutable struct ARKODEIntegrator{uType, userfun::UFType jac::JType opts::oType - tout::toutType + tout::Vector{Float64} tdir::Float64 - sizeu::sizeType u_modified::Bool - tmp::tmpType - uprev::tmpType + tmp::Array{Float64, N} + uprev::Array{Float64, N} flag::Cint just_hit_tstop::Bool event_last_time::Int @@ -136,7 +128,7 @@ function (integrator::ARKODEIntegrator)(t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} out = similar(integrator.u) - integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : out[idxs] end @@ -144,34 +136,28 @@ function (integrator::ARKODEIntegrator)(out, t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} - integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : @view out[idxs] end -mutable struct IDAIntegrator{uType, - duType, +mutable struct IDAIntegrator{N, pType, - memType, solType, algType, fType, UFType, JType, oType, - toutType, - sizeType, - sizeDType, - tmpType, LStype, Atype, CallbackCacheType, IA} <: AbstractSundialsIntegrator{IDA} - u::uType - du::duType + u::Array{Float64, N} + du::Array{Float64, N} p::pType t::Float64 tprev::Float64 - mem::memType + mem::Handle{IDAMem} LS::LStype A::Atype sol::solType @@ -180,19 +166,19 @@ mutable struct IDAIntegrator{uType, userfun::UFType jac::JType opts::oType - tout::toutType + tout::Vector{Float64} tdir::Float64 - sizeu::sizeType - sizedu::sizeDType u_modified::Bool - tmp::tmpType - uprev::tmpType + tmp::Array{Float64, N} + uprev::Array{Float64, N} flag::Cint just_hit_tstop::Bool event_last_time::Int vector_event_last_time::Int callback_cache::CallbackCacheType last_event_error::Float64 + u_nvec::NVector + du_nvec::NVector initializealg::IA end @@ -200,7 +186,7 @@ function (integrator::IDAIntegrator)(t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} out = similar(integrator.u) - integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : out[idxs] end @@ -208,7 +194,7 @@ function (integrator::IDAIntegrator)(out, t::Number, deriv::Type{Val{T}} = Val{0}; idxs = nothing) where {T} - integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out) + integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), vec(out)) return idxs === nothing ? out : @view out[idxs] end diff --git a/src/common_interface/integrator_utils.jl b/src/common_interface/integrator_utils.jl index 6d530a72..365f02e6 100644 --- a/src/common_interface/integrator_utils.jl +++ b/src/common_interface/integrator_utils.jl @@ -51,12 +51,12 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, curt = pop!(integrator.opts.saveat) tmp = integrator(curt) - save_value!(integrator.sol.u, tmp, uType, integrator.sizeu, + save_value!(integrator.sol.u, tmp, uType, integrator.opts.save_idxs, Val{false}) push!(integrator.sol.t, curt) if integrator.opts.dense tmp = integrator(curt, Val{1}) - save_value!(integrator.sol.interp.du, tmp, uType, integrator.sizeu, + save_value!(integrator.sol.interp.du, tmp, uType, integrator.opts.save_idxs, Val{false}) end end @@ -65,12 +65,12 @@ function DiffEqBase.savevalues!(integrator::AbstractSundialsIntegrator, (integrator.opts.save_everystep && (isempty(integrator.sol.t) || (integrator.t !== integrator.sol.t[end]))) saved = true - save_value!(integrator.sol.u, integrator.u, uType, integrator.sizeu, + save_value!(integrator.sol.u, integrator.u, uType, integrator.opts.save_idxs) push!(integrator.sol.t, integrator.t) if integrator.opts.dense tmp = integrator(integrator.t, Val{1}) - save_value!(integrator.sol.interp.du, tmp, uType, integrator.sizeu, + save_value!(integrator.sol.interp.du, tmp, uType, integrator.opts.save_idxs) end end @@ -81,15 +81,16 @@ end function save_value!(save_array, val, ::Type{T}, - sizeu, save_idxs, + save_idxs, make_copy::Type{Val{bool}} = Val{true}) where {T <: Number, bool} push!(save_array, first(val)) end function save_value!(save_array, val, ::Type{T}, - sizeu, save_idxs, + save_idxs, make_copy::Type{Val{bool}} = Val{true}) where {T <: Vector, bool} + @assert val isa Array save = if save_idxs !== nothing val[save_idxs] else @@ -100,20 +101,20 @@ end function save_value!(save_array, val, ::Type{T}, - sizeu, save_idxs, + save_idxs, make_copy::Type{Val{bool}} = Val{true}) where {T <: AbstractArray, bool } + @assert val isa Array save = if save_idxs !== nothing - reshape(val, sizeu)[save_idxs] + val[save_idxs] else x = bool ? copy(val) : val - reshape(x, sizeu) end push!(save_array, save) end function handle_callback_modifiers!(integrator::CVODEIntegrator) - CVodeReInit(integrator.mem, integrator.t, integrator.u) + CVodeReInit(integrator.mem, integrator.t, integrator.u_nvec) end function handle_callback_modifiers!(integrator::ARKODEIntegrator) @@ -159,11 +160,11 @@ end end @inline function DiffEqBase.get_du(integrator::IDAIntegrator) - reshape(integrator.du, integrator.sizedu) + integrator.du end @inline function DiffEqBase.get_du!(out, integrator::IDAIntegrator) - out .= reshape(integrator.du, integrator.sizedu) + out .= integrator.du end function DiffEqBase.change_t_via_interpolation!(integrator::AbstractSundialsIntegrator, t) @@ -208,7 +209,7 @@ function DiffEqBase.initialize_dae!(integrator::IDAIntegrator, else init_type = IDA_YA_YDP_INIT integrator.flag = IDASetId(integrator.mem, - integrator.sol.prob.differential_vars) + vec(integrator.sol.prob.differential_vars)) end dt = integrator.dt == tstart ? tend : integrator.dt integrator.flag = IDACalcIC(integrator.mem, init_type, dt) diff --git a/src/common_interface/solve.jl b/src/common_interface/solve.jl index df756b0d..1d43dd39 100644 --- a/src/common_interface/solve.jl +++ b/src/common_interface/solve.jl @@ -31,13 +31,12 @@ function DiffEqBase.__solve(prob::Union{ recompile::Type{Val{recompile_flag}} = Val{true}; kwargs...) where {algType <: SundialsNonlinearSolveAlgorithm, recompile_flag, uType, isinplace} - if typeof(prob.u0) <: Number + if prob.u0 isa Number u0 = [prob.u0] else u0 = deepcopy(prob.u0) end - sizeu = size(prob.u0) p = prob.p userdata = alg.userdata linsolve = linear_solver(alg) @@ -45,34 +44,22 @@ function DiffEqBase.__solve(prob::Union{ jac_lower = alg.jac_lower ### Fix the more general function to Sundials allowed style - if typeof(prob.f) <: ODEFunction + if prob.f isa ODEFunction t = Inf - if !isinplace && typeof(prob.u0) <: Number + if !isinplace && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f! = (du, u) -> (du .= prob.f(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} - f! = (du, u) -> prob.f(du, u, p, t) else # Then it's an in-place function on an abstract array - f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t); - du = vec(du); - 0) + f! = (du, u) -> prob.f(du, u, p, t) end - elseif typeof(prob.f) <: NonlinearFunction - if !isinplace && typeof(prob.u0) <: Number + elseif prob.f isa NonlinearFunction + if !isinplace && prob.u0 isa Number f! = (du, u) -> (du .= prob.f(first(u), p); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f! = (du, u) -> (du .= prob.f(u, p); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f! = (du, u) -> (du .= vec(prob.f(reshape(u, sizeu), p)); Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} - f! = (du, u) -> prob.f(du, u, p) else # Then it's an in-place function on an abstract array - f! = (du, u) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p); - du = vec(du); - 0) + f! = (du, u) -> prob.f(du, u, p) end end u = zero(u0) @@ -107,10 +94,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i save_everystep = isempty(saveat), save_idxs = nothing, save_on = true, save_start = save_everystep || isempty(saveat) || - typeof(saveat) <: Number ? true : + saveat isa Number ? true : prob.tspan[1] in saveat, save_end = save_everystep || isempty(saveat) || - typeof(saveat) <: Number ? true : + saveat isa Number ? true : prob.tspan[2] in saveat, dense = save_everystep && isempty(saveat), progress = false, @@ -134,7 +121,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i error("This solver is not able to use mass matrices.") end - if typeof(reltol) <: AbstractArray + if reltol isa AbstractArray error("Sundials only allows scalar reltol.") end @@ -161,36 +148,28 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) - if typeof(prob.u0) <: Number + if prob.u0 isa Number u0 = [prob.u0] else if alias_u0 - u0 = vec(prob.u0) + u0 = prob.u0 else - u0 = vec(copy(prob.u0)) + u0 = copy(prob.u0) end end - sizeu = size(prob.u0) - ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0) <: Number + if !isinplace && prob.u0 isa Number f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} - f! = prob.f else # Then it's an in-place function on an abstract array - f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t); - du = vec(du); - 0) + f! = prob.f end - if typeof(alg) <: CVODE_BDF + if alg isa CVODE_BDF alg_code = CV_BDF - elseif typeof(alg) <: CVODE_Adams + elseif alg isa CVODE_Adams alg_code = CV_ADAMS end @@ -211,8 +190,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i save_start ? ts = [t0] : ts = Float64[] - _u0 = copy(u0) - utmp = NVector(_u0) + out = copy(u0) + uvec = vec(u0) # aliases u0 + utmp = NVector(uvec) # aliases u0 use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && LinearSolver ∈ SPARSE_SOLVERS) || @@ -225,7 +205,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i alg.prec, alg.psetup, u0, - _u0) + out) function getcfunf(::T) where {T} @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) @@ -237,7 +217,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i flag = CVodeSetMinStep(mem, dtmin) flag = CVodeSetMaxStep(mem, dtmax) flag = CVodeSetUserData(mem, userfun) - if typeof(abstol) <: Array + if abstol isa Array flag = CVodeSVtolerances(mem, reltol, abstol) else flag = CVodeSStolerances(mem, reltol, abstol) @@ -255,24 +235,24 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i if Method == :Newton # Only use a linear solver if it's a Newton-based method if LinearSolver in (:Dense, :LapackDense) nojacobian = false - A = SUNDenseMatrix(length(u0), length(u0)) + A = SUNDenseMatrix(length(uvec), length(uvec)) _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0, A) + LS = SUNLinSol_Dense(uvec, A) _LS = LinSolHandle(LS, Dense()) else - LS = SUNLinSol_LapackDense(u0, A) + LS = SUNLinSol_LapackDense(uvec, A) _LS = LinSolHandle(LS, LapackDense()) end elseif LinearSolver in (:Band, :LapackBand) nojacobian = false - A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) + A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0, A) + LS = SUNLinSol_Band(uvec, A) _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0, A) + LS = SUNLinSol_LapackBand(uvec, A) _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :Diagonal @@ -281,43 +261,43 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i _A = nothing _LS = nothing elseif LinearSolver == :GMRES - LS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) elseif LinearSolver == :FGMRES - LS = SUNLinSol_SPFGMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG - LS = SUNLinSol_SPBCGS(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG - LS = SUNLinSol_PCG(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR - LS = SUNLinSol_SPTFQMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nojacobian = false nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) - A = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) - LS = SUNLinSol_KLU(u0, A) + A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) + LS = SUNLinSol_KLU(uvec, A) _A = MatrixHandle(A, SparseMatrix()) _LS = LinSolHandle(LS, KLU()) end if LinearSolver !== :Diagonal flag = CVodeSetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) end - NLS = SUNNonlinSol_Newton(u0) + NLS = SUNNonlinSol_Newton(uvec) else _A = nothing _LS = nothing # TODO: Anderson Acceleration anderson_m = 0 - NLS = SUNNonlinSol_FixedPoint(u0, anderson_m) + NLS = SUNNonlinSol_FixedPoint(uvec, anderson_m) end CVodeSetNonlinearSolver(mem, NLS) @@ -341,7 +321,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i jac = nothing end - if typeof(prob.f.jac_prototype) <: AbstractSciMLOperator + if prob.f.jac_prototype isa AbstractSciMLOperator "here!!!!" function getcfunjtimes(::T) where {T} @cfunction(jactimes, @@ -378,24 +358,24 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i CVodeSetPreconditioner(mem, psetupfun, precfun) end - callbacks_internal === nothing ? tmp = nothing : tmp = similar(u0) - callbacks_internal === nothing ? uprev = nothing : uprev = similar(u0) + tmp = isnothing(callbacks_internal) ? u0 : similar(u0) + uprev = isnothing(callbacks_internal) ? u0 : similar(u0) tout = [tspan[1]] if save_start if save_idxs === nothing ures = Vector{uType}() dures = Vector{uType}() - save_value!(ures, u0, uType, sizeu, save_idxs) + save_value!(ures, u0, uType, save_idxs) if dense - f!(_u0, u0, prob.p, tspan[1]) - save_value!(dures, utmp, uType, sizeu, save_idxs) + f!(out, u0, prob.p, tspan[1]) + save_value!(dures, out, uType, save_idxs) end else ures = [u0[save_idxs]] if dense - f!(_u0, u0, prob.p, tspan[1]) - dures = [_u0[save_idxs]] + f!(out, u0, prob.p, tspan[1]) + dures = [out[save_idxs]] end end else @@ -435,6 +415,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i progress_message, maxiters) integrator = CVODEIntegrator(u0, + utmp, prob.p, t0, t0, @@ -449,7 +430,6 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i opts, tout, tdir, - sizeu, false, tmp, uprev, @@ -503,7 +483,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i warned && DiffEqBase.warn_compat() end - if typeof(reltol) <: AbstractArray + if reltol isa AbstractArray error("Sundials only allows scalar reltol.") end @@ -530,23 +510,22 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) - if typeof(prob.u0) <: Number + if prob.u0 isa Number u0 = [prob.u0] else if alias_u0 - u0 = vec(prob.u0) + u0 = prob.u0 else - u0 = vec(copy(prob.u0)) + u0 = copy(prob.u0) end end - sizeu = size(prob.u0) save_start ? ts = [t0] : ts = Float64[] - u0nv = NVector(u0) - _u0 = copy(u0) - utmp = NVector(_u0) + out = copy(u0) + uvec = vec(u0) + utmp = NVector(uvec) - function arkodemem(; fe = C_NULL, fi = C_NULL, t0 = t0, u0 = u0nv) + function arkodemem(; fe = C_NULL, fi = C_NULL, t0 = t0, u0 = utmp) mem_ptr = ARKStepCreate(fe, fi, t0, u0) (mem_ptr == C_NULL) && error("Failed to allocate ARKODE solver object") mem = Handle(mem_ptr) @@ -559,42 +538,26 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i end ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0) <: Number + if !isinplace && prob.u0 isa Number f! = (du, u, p, t) -> (du .= prob.f(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f! = (du, u, p, t) -> (du .= prob.f(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f! = (du, u, p, t) -> (du .= vec(prob.f(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} - f! = prob.f else # Then it's an in-place function on an abstract array - f! = (du, u, p, t) -> (prob.f(reshape(du, sizeu), reshape(u, sizeu), p, t); - du = vec(du); - Cint(0)) + f! = prob.f end - if typeof(prob.problem_type) <: SplitODEProblem + if prob.problem_type isa SplitODEProblem ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0) <: Number + if !isinplace && prob.u0 isa Number f1! = (du, u, p, t) -> (du .= prob.f.f1(first(u), p, t); Cint(0)) f2! = (du, u, p, t) -> (du .= prob.f.f2(first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f1! = (du, u, p, t) -> (du .= prob.f.f1(u, p, t); Cint(0)) f2! = (du, u, p, t) -> (du .= prob.f.f2(u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f1! = (du, u, p, t) -> (du .= vec(prob.f.f1(reshape(u, sizeu), p, t)); Cint(0)) - f2! = (du, u, p, t) -> (du .= vec(prob.f.f2(reshape(u, sizeu), p, t)); Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} + else # Then it's an in-place function on an abstract array f1! = prob.f.f1 f2! = prob.f.f2 - else # Then it's an in-place function on an abstract array - f1! = (du, u, p, t) -> (prob.f.f1(reshape(du, sizeu), reshape(u, sizeu), p, t); - du = vec(du); - Cint(0)) - f2! = (du, u, p, t) -> (prob.f.f2(reshape(du, sizeu), reshape(u, sizeu), p, t); - du = vec(du); - Cint(0)) end use_jac_prototype = (isa(prob.f.f1.jac_prototype, SparseArrays.SparseMatrixCSC) && @@ -608,7 +571,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i alg.prec, alg.psetup, u0, - _u0, + out, nothing) function getcfunjac(::T) where {T} @@ -632,7 +595,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i alg.prec, alg.psetup, u0, - _u0) + out) if alg.stiffness == Explicit() function getcfun1(::T) where {T} @cfunction(cvodefunjac, Cint, (realtype, N_Vector, N_Vector, Ref{T})) @@ -652,7 +615,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i flag = ARKStepSetMinStep(mem, dtmin) flag = ARKStepSetMaxStep(mem, dtmax) flag = ARKStepSetUserData(mem, userfun) - if typeof(abstol) <: Array + if abstol isa Array flag = ARKStepSVtolerances(mem, reltol, abstol) else flag = ARKStepSStolerances(mem, reltol, abstol) @@ -692,50 +655,50 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i if Method == :Newton && alg.stiffness !== Explicit() # Only use a linear solver if it's a Newton-based method if LinearSolver in (:Dense, :LapackDense) nojacobian = false - A = SUNDenseMatrix(length(u0), length(u0)) + A = SUNDenseMatrix(length(uvec), length(uvec)) _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0, A) + LS = SUNLinSol_Dense(uvec, A) _LS = LinSolHandle(LS, Dense()) else - LS = SUNLinSol_LapackDense(u0, A) + LS = SUNLinSol_LapackDense(uvec, A) _LS = LinSolHandle(LS, LapackDense()) end elseif LinearSolver in (:Band, :LapackBand) nojacobian = false - A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) + A = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0, A) + LS = SUNLinSol_Band(uvec, A) _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0, A) + LS = SUNLinSol_LapackBand(uvec, A) _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :GMRES - LS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = Sundials.LinSolHandle(LS, Sundials.SPGMR()) elseif LinearSolver == :FGMRES - LS = SUNLinSol_SPFGMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPFGMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG - LS = SUNLinSol_SPBCGS(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPBCGS(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG - LS = SUNLinSol_PCG(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_PCG(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR - LS = SUNLinSol_SPTFQMR(u0, alg.prec_side, alg.krylov_dim) + LS = SUNLinSol_SPTFQMR(uvec, alg.prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) - A = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) - LS = SUNLinSol_KLU(u0, A) + A = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) + LS = SUNLinSol_KLU(uvec, A) _A = MatrixHandle(A, SparseMatrix()) _LS = LinSolHandle(LS, KLU()) end @@ -748,10 +711,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i _LS = nothing end - if (typeof(prob.problem_type) <: SplitODEProblem && - typeof(prob.f.f1.jac_prototype) <: AbstractSciMLOperator) || - (!(typeof(prob.problem_type) <: SplitODEProblem) && - typeof(prob.f.jac_prototype) <: AbstractSciMLOperator) && + if (prob.problem_type isa SplitODEProblem && + prob.f.f1.jac_prototype isa AbstractSciMLOperator) || + (!(prob.problem_type isa SplitODEProblem) && + prob.f.jac_prototype isa AbstractSciMLOperator) && alg.stiffness !== Explicit() function getcfunjtimes(::T) where {T} @cfunction(jactimes, @@ -765,50 +728,50 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i if prob.f.mass_matrix != LinearAlgebra.I && alg.stiffness !== Explicit() if MassLinearSolver in (:Dense, :LapackDense) nojacobian = false - M = SUNDenseMatrix(length(u0), length(u0)) + M = SUNDenseMatrix(length(uvec), length(uvec)) _M = MatrixHandle(M, DenseMatrix()) if MassLinearSolver === :Dense - MLS = SUNLinSol_Dense(u0, M) + MLS = SUNLinSol_Dense(uvec, M) _MLS = LinSolHandle(MLS, Dense()) else - MLS = SUNLinSol_LapackDense(u0, M) + MLS = SUNLinSol_LapackDense(uvec, M) _MLS = LinSolHandle(MLS, LapackDense()) end elseif MassLinearSolver in (:Band, :LapackBand) nojacobian = false - M = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) + M = SUNBandMatrix(length(uvec), alg.jac_upper, alg.jac_lower) _M = MatrixHandle(M, BandMatrix()) if MassLinearSolver === :Band - MLS = SUNLinSol_Band(u0, M) + MLS = SUNLinSol_Band(uvec, M) _MLS = LinSolHandle(MLS, Band()) else - MLS = SUNLinSol_LapackBand(u0, M) + MLS = SUNLinSol_LapackBand(uvec, M) _MLS = LinSolHandle(MLS, LapackBand()) end elseif MassLinearSolver == :GMRES - MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) + MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) _M = nothing _MLS = LinSolHandle(MLS, SPGMR()) elseif MassLinearSolver == :FGMRES - MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) + MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) _M = nothing _MLS = LinSolHandle(MLS, SPFGMR()) elseif MassLinearSolver == :BCG - MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) + MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) _M = nothing _MLS = LinSolHandle(MLS, SPBCGS()) elseif MassLinearSolver == :PCG - MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) + MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) _M = nothing _MLS = LinSolHandle(MLS, PCG()) elseif MassLinearSolver == :TFQMR - MLS = SUNLinSol_SPGMR(u0, alg.prec_side, alg.mass_krylov_dim) + MLS = SUNLinSol_SPGMR(uvec, alg.prec_side, alg.mass_krylov_dim) _M = nothing _MLS = LinSolHandle(MLS, PTFQMR()) elseif MassLinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.mass_matrix)) - M = SUNSparseMatrix(length(u0), length(u0), nnz, CSC_MAT) - MLS = SUNLinSol_KLU(u0, M) + M = SUNSparseMatrix(length(uvec), length(uvec), nnz, CSC_MAT) + MLS = SUNLinSol_KLU(uvec, M) _M = MatrixHandle(M, SparseMatrix()) _MLS = LinSolHandle(MLS, KLU()) end @@ -871,24 +834,24 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i ARKStepSetPreconditioner(mem, psetupfun, precfun) end - callbacks_internal === nothing ? tmp = nothing : tmp = similar(u0) - callbacks_internal === nothing ? uprev = nothing : uprev = similar(u0) + tmp = isnothing(callbacks_internal) ? u0 : similar(u0) + uprev = isnothing(callbacks_internal) ? u0 : similar(u0) tout = [tspan[1]] if save_start if save_idxs === nothing ures = Vector{uType}() dures = Vector{uType}() - save_value!(ures, u0, uType, sizeu, save_idxs) + save_value!(ures, u0, uType, save_idxs) if dense - f!(_u0, u0, prob.p, tspan[1]) - save_value!(dures, utmp, uType, sizeu, save_idxs) + f!(out, u0, prob.p, tspan[1]) + save_value!(dures, out, uType, save_idxs) end else ures = [u0[save_idxs]] if dense - f!(_u0, u0, prob.p, tspan[1]) - dures = [_u0[save_idxs]] + f!(out, u0, prob.p, tspan[1]) + dures = [out[save_idxs]] end end else @@ -927,7 +890,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i progress_name, progress_message, maxiters) - integrator = ARKODEIntegrator(utmp, + integrator = ARKODEIntegrator(u0, + utmp, prob.p, t0, t0, @@ -944,7 +908,6 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractODEProblem{uType, tupType, i opts, tout, tdir, - sizeu, false, tmp, uprev, @@ -975,7 +938,7 @@ function tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) tstops_internal = DataStructures.BinaryMaxHeap(tstops_vec) end - if typeof(saveat) <: Number + if saveat isa Number if (tspan[1]:saveat:tspan[end])[end] == tspan[end] saveat_vec = convert(Vector{tType}, collect(tType, (tspan[1] + saveat):saveat:tspan[end])) @@ -1042,11 +1005,11 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu warned && DiffEqBase.warn_compat() end - if typeof(reltol) <: AbstractArray + if reltol isa AbstractArray error("Sundials only allows scalar reltol.") end - if length(prob.u0) <= 0 + if length(prob.u0) == 0 error("Sundials requires at least one state variable.") end @@ -1068,37 +1031,22 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu tstops_internal, saveat_internal = tstop_saveat_disc_handling(tstops, saveat, tdir, tspan, tType) - - if typeof(prob.u0) <: Number + @assert size(prob.u0) == size(prob.du0) + if prob.u0 isa Number u0 = [prob.u0] - else - u0 = vec(copy(prob.u0)) - end - - if typeof(prob.du0) <: Number du0 = [prob.du0] else - du0 = vec(copy(prob.du0)) + u0 = copy(prob.u0) + du0 = copy(prob.du0) end - sizeu = size(prob.u0) - sizedu = size(prob.du0) - ### Fix the more general function to Sundials allowed style - if !isinplace && typeof(prob.u0) <: Number + if !isinplace && prob.u0 isa Number f! = (out, du, u, p, t) -> (out .= prob.f(first(du), first(u), p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: Vector{Float64} + elseif !isinplace f! = (out, du, u, p, t) -> (out .= prob.f(du, u, p, t); Cint(0)) - elseif !isinplace && typeof(prob.u0) <: AbstractArray - f! = (out, du, u, p, t) -> (out .= vec(prob.f(reshape(du, sizedu), - reshape(u, sizeu), p, t)); - Cint(0)) - elseif typeof(prob.u0) <: Vector{Float64} - f! = prob.f else # Then it's an in-place function on an abstract array - f! = (out, du, u, p, t) -> (prob.f(reshape(out, sizeu), reshape(du, sizedu), - reshape(u, sizeu), p, t); - Cint(0)) + f! = prob.f end mem_ptr = IDACreate() @@ -1112,11 +1060,10 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu ts = [t0] - _u0 = copy(u0) - utmp = NVector(_u0) - _du0 = copy(du0) - dutmp = NVector(_du0) - rtest = zeros(length(u0)) + # vec shares memory + utmp = NVector(vec(u0)) + dutmp = NVector(vec(du0)) + rtest = zeros(size(u0)) use_jac_prototype = (isa(prob.f.jac_prototype, SparseArrays.SparseMatrixCSC) && LinearSolver ∈ SPARSE_SOLVERS) @@ -1127,8 +1074,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu use_jac_prototype ? prob.f.jac_prototype : nothing, alg.prec, alg.psetup, - _u0, - _du0, + u0, + du0, rtest) function getcfun(::T) where {T} @@ -1139,7 +1086,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu dt !== nothing && (flag = IDASetInitStep(mem, dt)) flag = IDASetUserData(mem, userfun) flag = IDASetMaxStep(mem, dtmax) - if typeof(abstol) <: Array + if abstol isa Array flag = IDASVtolerances(mem, reltol, abstol) else flag = IDASStolerances(mem, reltol, abstol) @@ -1163,7 +1110,7 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu A = SUNDenseMatrix(length(u0), length(u0)) _A = MatrixHandle(A, DenseMatrix()) if LinearSolver === :Dense - LS = SUNLinSol_Dense(u0, A) + LS = SUNLinSol_Dense(utmp, A) _LS = LinSolHandle(LS, Dense()) else LS = SUNLinSol_LapackDense(u0, A) @@ -1174,42 +1121,42 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu A = SUNBandMatrix(length(u0), alg.jac_upper, alg.jac_lower) _A = MatrixHandle(A, BandMatrix()) if LinearSolver === :Band - LS = SUNLinSol_Band(u0, A) + LS = SUNLinSol_Band(utmp, A) _LS = LinSolHandle(LS, Band()) else - LS = SUNLinSol_LapackBand(u0, A) + LS = SUNLinSol_LapackBand(utmp, A) _LS = LinSolHandle(LS, LapackBand()) end elseif LinearSolver == :GMRES - LS = SUNLinSol_SPGMR(u0, prec_side, alg.krylov_dim) + LS = SUNLinSol_SPGMR(utmp, prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPGMR()) elseif LinearSolver == :FGMRES - LS = SUNLinSol_SPFGMR(u0, prec_side, alg.krylov_dim) + LS = SUNLinSol_SPFGMR(utmp, prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPFGMR()) elseif LinearSolver == :BCG - LS = SUNLinSol_SPBCGS(u0, prec_side, alg.krylov_dim) + LS = SUNLinSol_SPBCGS(utmp, prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, SPBCGS()) elseif LinearSolver == :PCG - LS = SUNLinSol_PCG(u0, prec_side, alg.krylov_dim) + LS = SUNLinSol_PCG(utmp, prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PCG()) elseif LinearSolver == :TFQMR - LS = SUNLinSol_SPTFQMR(u0, prec_side, alg.krylov_dim) + LS = SUNLinSol_SPTFQMR(utmp, prec_side, alg.krylov_dim) _A = nothing _LS = LinSolHandle(LS, PTFQMR()) elseif LinearSolver == :KLU nnz = length(SparseArrays.nonzeros(prob.f.jac_prototype)) A = SUNSparseMatrix(length(u0), length(u0), nnz, Sundials.CSC_MAT) - LS = SUNLinSol_KLU(u0, A) + LS = SUNLinSol_KLU(utmp, A) _A = MatrixHandle(A, SparseMatrix()) _LS = LinSolHandle(LS, KLU()) end flag = IDASetLinearSolver(mem, LS, _A === nothing ? C_NULL : A) - if typeof(prob.f.jac_prototype) <: AbstractSciMLOperator + if prob.f.jac_prototype isa AbstractSciMLOperator function getcfunjtimes(::T) where {T} @cfunction(idajactimes, Cint, @@ -1281,9 +1228,9 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu if save_idxs === nothing ures = Vector{uType}() dures = Vector{uType}() - save_value!(ures, u0, uType, sizeu, save_idxs) + save_value!(ures, u0, uType, save_idxs) if dense - save_value!(dures, du0, uType, sizedu, save_idxs) + save_value!(dures, du0, uType, save_idxs) end else ures = [u0[save_idxs]] @@ -1296,8 +1243,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu dures = Vector{uType}() end - callbacks_internal === nothing ? tmp = nothing : tmp = similar(u0) - callbacks_internal === nothing ? uprev = nothing : uprev = similar(u0) + tmp = isnothing(callbacks_internal) ? u0 : similar(u0) + uprev = isnothing(callbacks_internal) ? u0 : similar(u0) retcode = flag >= 0 ? ReturnCode.Default : ReturnCode.InitialFailure sol = DiffEqBase.build_solution(prob, alg, @@ -1334,8 +1281,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu progress_message, maxiters) - integrator = IDAIntegrator(utmp, - dutmp, + integrator = IDAIntegrator(u0, + du0, prob.p, t0, t0, @@ -1350,8 +1297,6 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu opts, tout, tdir, - sizeu, - sizedu, false, tmp, uprev, @@ -1361,6 +1306,8 @@ function DiffEqBase.__init(prob::DiffEqBase.AbstractDAEProblem{uType, duType, tu 1, callback_cache, 0.0, + utmp, + dutmp, initializealg) DiffEqBase.initialize_dae!(integrator, initializealg) @@ -1379,7 +1326,7 @@ function interpret_sundials_retcode(flag) end function solver_step(integrator::CVODEIntegrator, tstop) - integrator.flag = CVode(integrator.mem, tstop, integrator.u, integrator.tout, + integrator.flag = CVode(integrator.mem, tstop, integrator.u_nvec, integrator.tout, CV_ONE_STEP) if integrator.opts.progress Logging.@logmsg(-1, @@ -1393,14 +1340,14 @@ function solver_step(integrator::CVODEIntegrator, tstop) end end function solver_step(integrator::ARKODEIntegrator, tstop) - integrator.flag = ARKStepEvolve(integrator.mem, tstop, integrator.u, integrator.tout, - ARK_ONE_STEP) + integrator.flag = ARKStepEvolve(integrator.mem, tstop, integrator.u_nvec, + integrator.tout, ARK_ONE_STEP) if integrator.opts.progress Logging.@logmsg(-1, integrator.opts.progress_name, _id=:Sundials, message=integrator.opts.progress_message(integrator.dt, - integrator.u, + integrator.u_nvec, integrator.p, integrator.t), progress=integrator.t / integrator.sol.prob.tspan[2]) @@ -1410,8 +1357,8 @@ function solver_step(integrator::IDAIntegrator, tstop) integrator.flag = IDASolve(integrator.mem, tstop, integrator.tout, - integrator.u, - integrator.du, + integrator.u_nvec, + integrator.du_nvec, IDA_ONE_STEP) if integrator.opts.progress Logging.@logmsg(-1, @@ -1457,7 +1404,7 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = tstop = first(integrator.opts.tstops) set_stop_time(integrator, tstop) integrator.tprev = integrator.t - if !(typeof(integrator.opts.callback.continuous_callbacks) <: Tuple{}) + if !(integrator.opts.callback.continuous_callbacks isa Tuple{}) integrator.uprev .= integrator.u end integrator.userfun.p = integrator.p @@ -1481,12 +1428,12 @@ function DiffEqBase.solve!(integrator::AbstractSundialsIntegrator; early_free = if integrator.opts.save_end && (isempty(integrator.sol.t) || integrator.sol.t[end] != integrator.t) - save_value!(integrator.sol.u, integrator.u, uType, integrator.sizeu, + save_value!(integrator.sol.u, integrator.u, uType, integrator.opts.save_idxs) push!(integrator.sol.t, integrator.t) if integrator.opts.dense integrator(integrator.u, integrator.t, Val{1}) - save_value!(integrator.sol.interp.du, integrator.u, uType, integrator.sizeu, + save_value!(integrator.sol.interp.du, integrator.u, uType, integrator.opts.save_idxs) end end