Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions src/common_interface/function_types.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
abstract type AbstractFunJac{J2} end
mutable struct FunJac{F, F2, J, P, M, J2, uType, uType2, Prec, PS} <: AbstractFunJac{J2}
mutable struct FunJac{N, F, F2, J, P, M, J2, Prec, PS,
TResid <: Union{Nothing, Array{Float64, N}}} <: AbstractFunJac{J2}
fun::F
fun2::F2
jac::J
Expand All @@ -8,9 +9,9 @@ mutable struct FunJac{F, F2, J, P, M, J2, uType, uType2, Prec, PS} <: AbstractFu
jac_prototype::J2
prec::Prec
psetup::PS
u::uType
du::uType
resid::uType2
u::Array{Float64, N}
du::Array{Float64, N}
resid::TResid
end
function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du)
FunJac(fun, nothing, jac, p, m,
Expand All @@ -25,20 +26,20 @@ function FunJac(fun, jac, p, m, jac_prototype, prec, psetup, u, du, resid)
du, resid)
end

function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac)
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),
length(funjac.du))
function cvodefunjac(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N}
funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u))
funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du),
size(funjac.du))
_du = funjac.du
_u = funjac.u
funjac.fun(_du, _u, funjac.p, t)
return CV_SUCCESS
end

function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac)
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),
length(funjac.du))
function cvodefunjac2(t::Float64, u::N_Vector, du::N_Vector, funjac::FunJac{N}) where {N}
funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u))
funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du),
size(funjac.du))
_du = funjac.du
_u = funjac.u
funjac.fun2(_du, _u, funjac.p, t)
Expand Down Expand Up @@ -79,14 +80,15 @@ function cvodejac(t::realtype,
return CV_SUCCESS
end

function idasolfun(t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector, funjac::FunJac)
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
function idasolfun(t::Float64, u::N_Vector, du::N_Vector, resid::N_Vector,
funjac::FunJac{N}) where {N}
funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u))
_u = funjac.u
funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),
length(funjac.du))
funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du),
size(funjac.du))
_du = funjac.du
funjac.resid = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(resid),
length(funjac.resid))
funjac.resid = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(resid),
size(funjac.resid))
_resid = funjac.resid
funjac.fun(_resid, _du, _u, funjac.p, t)
return IDA_SUCCESS
Expand All @@ -102,10 +104,11 @@ function idajac(t::realtype,
tmp1::N_Vector,
tmp2::N_Vector,
tmp3::N_Vector)
funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
N = ndims(funjac.u)
funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u))
_u = funjac.u
funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),
length(funjac.du))
funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du),
size(funjac.du))
_du = funjac.du

funjac.jac(convert(Matrix, J), _du, _u, funjac.p, cj, t)
Expand All @@ -123,11 +126,11 @@ function idajac(t::realtype,
tmp2::N_Vector,
tmp3::N_Vector)
jac_prototype = funjac.jac_prototype

funjac.u = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(u), length(funjac.u))
N = ndims(funjac.u)
funjac.u = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(u), size(funjac.u))
_u = funjac.u
funjac.du = unsafe_wrap(Vector{Float64}, N_VGetArrayPointer_Serial(du),
length(funjac.du))
funjac.du = unsafe_wrap(Array{Float64, N}, N_VGetArrayPointer_Serial(du),
size(funjac.du))
_du = funjac.du

funjac.jac(jac_prototype, _du, _u, funjac.p, cj, t)
Expand Down
72 changes: 29 additions & 43 deletions src/common_interface/integrator_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,26 +28,23 @@ end
abstract type AbstractSundialsIntegrator{algType} <:
DiffEqBase.AbstractODEIntegrator{algType, true, Vector{Float64}, Float64} end

mutable struct CVODEIntegrator{uType,
mutable struct CVODEIntegrator{N,
pType,
memType,
solType,
algType,
fType,
UFType,
JType,
oType,
toutType,
sizeType,
tmpType,
LStype,
Atype,
CallbackCacheType} <: AbstractSundialsIntegrator{algType}
u::uType
u::Array{Float64, N}
u_nvec::NVector
p::pType
t::Float64
tprev::Float64
mem::memType
mem::Handle{CVODEMem}
LS::LStype
A::Atype
sol::solType
Expand All @@ -56,12 +53,11 @@ mutable struct CVODEIntegrator{uType,
userfun::UFType
jac::JType
opts::oType
tout::toutType
tout::Vector{Float64}
tdir::Float64
sizeu::sizeType
u_modified::Bool
tmp::tmpType
uprev::tmpType
tmp::Array{Float64, N}
uprev::Array{Float64, N}
flag::Cint
just_hit_tstop::Bool
event_last_time::Int
Expand All @@ -74,40 +70,37 @@ function (integrator::CVODEIntegrator)(t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
out = similar(integrator.u)
integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : out[idxs]
end

function (integrator::CVODEIntegrator)(out,
t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag CVodeGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : @view out[idxs]
end

mutable struct ARKODEIntegrator{uType,
mutable struct ARKODEIntegrator{N,
pType,
memType,
solType,
algType,
fType,
UFType,
JType,
oType,
toutType,
sizeType,
tmpType,
LStype,
Atype,
MLStype,
Mtype,
CallbackCacheType} <: AbstractSundialsIntegrator{ARKODE}
u::uType
u::Array{Float64, N}
u_nvec::NVector
p::pType
t::Float64
tprev::Float64
mem::memType
mem::Handle{ARKStepMem}
LS::LStype
A::Atype
MLS::MLStype
Expand All @@ -118,12 +111,11 @@ mutable struct ARKODEIntegrator{uType,
userfun::UFType
jac::JType
opts::oType
tout::toutType
tout::Vector{Float64}
tdir::Float64
sizeu::sizeType
u_modified::Bool
tmp::tmpType
uprev::tmpType
tmp::Array{Float64, N}
uprev::Array{Float64, N}
flag::Cint
just_hit_tstop::Bool
event_last_time::Int
Expand All @@ -136,42 +128,36 @@ function (integrator::ARKODEIntegrator)(t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
out = similar(integrator.u)
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : out[idxs]
end

function (integrator::ARKODEIntegrator)(out,
t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag ARKStepGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : @view out[idxs]
end

mutable struct IDAIntegrator{uType,
duType,
mutable struct IDAIntegrator{N,
pType,
memType,
solType,
algType,
fType,
UFType,
JType,
oType,
toutType,
sizeType,
sizeDType,
tmpType,
LStype,
Atype,
CallbackCacheType,
IA} <: AbstractSundialsIntegrator{IDA}
u::uType
du::duType
u::Array{Float64, N}
du::Array{Float64, N}
p::pType
t::Float64
tprev::Float64
mem::memType
mem::Handle{IDAMem}
LS::LStype
A::Atype
sol::solType
Expand All @@ -180,35 +166,35 @@ mutable struct IDAIntegrator{uType,
userfun::UFType
jac::JType
opts::oType
tout::toutType
tout::Vector{Float64}
tdir::Float64
sizeu::sizeType
sizedu::sizeDType
u_modified::Bool
tmp::tmpType
uprev::tmpType
tmp::Array{Float64, N}
uprev::Array{Float64, N}
flag::Cint
just_hit_tstop::Bool
event_last_time::Int
vector_event_last_time::Int
callback_cache::CallbackCacheType
last_event_error::Float64
u_nvec::NVector
du_nvec::NVector
Comment on lines +180 to +181
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these alias u and du?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make a comment.

initializealg::IA
end

function (integrator::IDAIntegrator)(t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
out = similar(integrator.u)
integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : out[idxs]
end

function (integrator::IDAIntegrator)(out,
t::Number,
deriv::Type{Val{T}} = Val{0};
idxs = nothing) where {T}
integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), out)
integrator.flag = @checkflag IDAGetDky(integrator.mem, t, Cint(T), vec(out))
return idxs === nothing ? out : @view out[idxs]
end

Expand Down
Loading