Skip to content

Commit

Permalink
Merge pull request #112 from JuliaDiffEq/finalizers
Browse files Browse the repository at this point in the history
move high level APIs to Handle
  • Loading branch information
ChrisRackauckas committed Mar 27, 2017
2 parents 1f052b1 + a70d989 commit cf0634d
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 205 deletions.
251 changes: 120 additions & 131 deletions src/common.jl
Expand Up @@ -59,79 +59,73 @@ function solve{uType,tType,isinplace,F,Method,LinearSolver}(
method_code = CV_FUNCTIONAL
end

mem = CVodeCreate(alg_code, method_code)

if mem == C_NULL
error("Failed to allocate CVODE solver object")
end
mem_ptr = CVodeCreate(alg_code, method_code)
(mem_ptr == C_NULL) && error("Failed to allocate CVODE solver object")
mem = Handle(mem_ptr)

ures = Vector{Vector{Float64}}()
ts = [t0]

try
userfun = UserFunctionAndData(f!, userdata)
u0nv = NVector(u0)
flag = @checkflag CVodeInit(mem,
cfunction(cvodefun, Cint,
(realtype, N_Vector,
N_Vector, Ref{typeof(userfun)})),
t0, convert(N_Vector, u0nv))
flag = @checkflag CVodeSetUserData(mem, userfun)
flag = @checkflag CVodeSStolerances(mem, reltol, abstol)
flag = @checkflag CVodeSetMaxNumSteps(mem, maxiter)
if Method == :Newton # Only use a linear solver if it's a Newton-based method
if LinearSolver == :Dense
flag = @checkflag CVDense(mem, length(u0))
elseif LinearSolver == :Banded
flag = @checkflag CVBand(mem,length(u0),alg.jac_upper,alg.jac_lower)
elseif LinearSolver == :Diagonal
flag = @checkflag CVDiag(mem)
elseif LinearSolver == :GMRES
flag = @checkflag CVSpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :BCG
flag = @checkflag CVSpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :TFQMR
flag = @checkflag CVSptfqmr(mem,PREC_NONE,alg.krylov_dim)
end
userfun = UserFunctionAndData(f!, userdata)
u0nv = NVector(u0)
flag = @checkflag CVodeInit(mem,
cfunction(cvodefun, Cint,
(realtype, N_Vector,
N_Vector, Ref{typeof(userfun)})),
t0, convert(N_Vector, u0nv))
flag = @checkflag CVodeSetUserData(mem, userfun)
flag = @checkflag CVodeSStolerances(mem, reltol, abstol)
flag = @checkflag CVodeSetMaxNumSteps(mem, maxiter)
if Method == :Newton # Only use a linear solver if it's a Newton-based method
if LinearSolver == :Dense
flag = @checkflag CVDense(mem, length(u0))
elseif LinearSolver == :Banded
flag = @checkflag CVBand(mem,length(u0),alg.jac_upper,alg.jac_lower)
elseif LinearSolver == :Diagonal
flag = @checkflag CVDiag(mem)
elseif LinearSolver == :GMRES
flag = @checkflag CVSpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :BCG
flag = @checkflag CVSpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :TFQMR
flag = @checkflag CVSptfqmr(mem,PREC_NONE,alg.krylov_dim)
end
end

push!(ures, copy(u0))
utmp = NVector(copy(u0))
tout = [tspan[1]]

# The Inner Loops : Style depends on save_timeseries
if save_timeseries
for k in 1:length(save_ts)
looped = false
while tdir*tout[end] < tdir*save_ts[k]
looped = true
flag = @checkflag CVode(mem,
save_ts[k], utmp, tout, CV_ONE_STEP)
push!(ures,copy(utmp))
push!(ts, tout...)
end
if looped
# Fix the end
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), ures[end])
ts[end] = save_ts[k]
else # Just push another value
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), utmp)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end
else # save_timeseries == false, so use CV_NORMAL style
for k in 1:length(save_ts)
push!(ures, copy(u0))
utmp = NVector(copy(u0))
tout = [tspan[1]]

# The Inner Loops : Style depends on save_timeseries
if save_timeseries
for k in 1:length(save_ts)
looped = false
while tdir*tout[end] < tdir*save_ts[k]
looped = true
flag = @checkflag CVode(mem,
save_ts[k], utmp, tout, CV_NORMAL)
save_ts[k], utmp, tout, CV_ONE_STEP)
push!(ures,copy(utmp))
push!(ts, tout...)
end
if looped
# Fix the end
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), ures[end])
ts[end] = save_ts[k]
else # Just push another value
flag = @checkflag CVodeGetDky(
mem, save_ts[k], Cint(0), utmp)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end
finally
CVodeFree(Ref{CVODEMemPtr}(mem))
else # save_timeseries == false, so use CV_NORMAL style
for k in 1:length(save_ts)
flag = @checkflag CVode(mem,
save_ts[k], utmp, tout, CV_NORMAL)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end

### Finishing Routine
Expand Down Expand Up @@ -206,89 +200,84 @@ function solve{uType,duType,tType,isinplace,F,LinearSolver}(
u = vec(u); du=vec(du); 0)
end

mem = IDACreate()
if mem == C_NULL
error("Failed to allocate IDA solver object")
end
mem_ptr = IDACreate()
(mem_ptr == C_NULL) && error("Failed to allocate IDA solver object")
mem = Handle(mem_ptr)

ures = Vector{Vector{Float64}}()
ts = [t0]

try
userfun = UserFunctionAndData(f!, userdata)
u0nv = NVector(u0)
flag = @checkflag IDAInit(mem, cfunction(idasolfun,
Cint, (realtype, N_Vector, N_Vector,
N_Vector, Ref{typeof(userfun)})),
t0, convert(N_Vector, u0),
convert(N_Vector, du0))
flag = @checkflag IDASetUserData(mem, userfun)
flag = @checkflag IDASStolerances(mem, reltol, abstol)
flag = @checkflag IDASetMaxNumSteps(mem, maxiter)
if LinearSolver == :Dense
flag = @checkflag IDADense(mem, length(u0))
elseif LinearSolver == :Band
flag = @checkflag IDABand(mem,length(u0),alg.jac_upper,alg.jac_lower)
elseif LinearSolver == :Diagonal
flag = @checkflag IDADiag(mem)
elseif LinearSolver == :GMRES
flag = @checkflag IDASpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :BCG
flag = @checkflag IDASpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :TFQMR
flag = @checkflag IDASptfqmr(mem,PREC_NONE,alg.krylov_dim)
end
userfun = UserFunctionAndData(f!, userdata)
u0nv = NVector(u0)
flag = @checkflag IDAInit(mem, cfunction(idasolfun,
Cint, (realtype, N_Vector, N_Vector,
N_Vector, Ref{typeof(userfun)})),
t0, convert(N_Vector, u0),
convert(N_Vector, du0))
flag = @checkflag IDASetUserData(mem, userfun)
flag = @checkflag IDASStolerances(mem, reltol, abstol)
flag = @checkflag IDASetMaxNumSteps(mem, maxiter)
if LinearSolver == :Dense
flag = @checkflag IDADense(mem, length(u0))
elseif LinearSolver == :Band
flag = @checkflag IDABand(mem,length(u0),alg.jac_upper,alg.jac_lower)
elseif LinearSolver == :Diagonal
flag = @checkflag IDADiag(mem)
elseif LinearSolver == :GMRES
flag = @checkflag IDASpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :BCG
flag = @checkflag IDASpgmr(mem,PREC_NONE,alg.krylov_dim)
elseif LinearSolver == :TFQMR
flag = @checkflag IDASptfqmr(mem,PREC_NONE,alg.krylov_dim)
end


push!(ures, copy(u0))
utmp = NVector(copy(u0))
dutmp = NVector(copy(u0))
tout = [tspan[1]]
push!(ures, copy(u0))
utmp = NVector(copy(u0))
dutmp = NVector(copy(u0))
tout = [tspan[1]]

rtest = zeros(length(u0))
f!(t0, u0, du0, rtest)
if any(abs.(rtest) .>= reltol)
if diffstates === nothing
error("Must supply diffstates argument to use IDA initial value solver.")
end
flag = @checkflag IDASetId(mem, collect(Float64, diffstates))
flag = @checkflag IDACalcIC(mem, IDA_YA_YDP_INIT, save_ts[2])
rtest = zeros(length(u0))
f!(t0, u0, du0, rtest)
if any(abs.(rtest) .>= reltol)
if diffstates === nothing
error("Must supply diffstates argument to use IDA initial value solver.")
end
flag = @checkflag IDASetId(mem, collect(Float64, diffstates))
flag = @checkflag IDACalcIC(mem, IDA_YA_YDP_INIT, save_ts[2])
end

# The Inner Loops : Style depends on save_timeseries
if save_timeseries
for k in 1:length(save_ts)
looped = false
while tdir*tout[end] < tdir*save_ts[k]
looped = true
flag = @checkflag IDASolve(mem,
save_ts[k], tout, utmp, dutmp, IDA_ONE_STEP)

push!(ures,copy(utmp))
push!(ts, tout...)
end
if looped
# Fix the end
flag = @checkflag IDAGetDky(
mem, save_ts[k], Cint(0), ures[end])
ts[end] = save_ts[k]
else # Just push another value
flag = @checkflag IDAGetDky(
mem, save_ts[k], Cint(0), utmp)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end
else # save_timeseries == false, so use IDA_NORMAL style
for k in 1:length(save_ts)
# The Inner Loops : Style depends on save_timeseries
if save_timeseries
for k in 1:length(save_ts)
looped = false
while tdir*tout[end] < tdir*save_ts[k]
looped = true
flag = @checkflag IDASolve(mem,
save_ts[k], tout, utmp, dutmp, IDA_NORMAL)
save_ts[k], tout, utmp, dutmp, IDA_ONE_STEP)

push!(ures,copy(utmp))
push!(ts, tout...)
end
if looped
# Fix the end
flag = @checkflag IDAGetDky(
mem, save_ts[k], Cint(0), ures[end])
ts[end] = save_ts[k]
else # Just push another value
flag = @checkflag IDAGetDky(
mem, save_ts[k], Cint(0), utmp)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end
finally
IDAFree(Ref{IDAMemPtr}(mem))
else # save_timeseries == false, so use IDA_NORMAL style
for k in 1:length(save_ts)
flag = @checkflag IDASolve(mem,
save_ts[k], tout, utmp, dutmp, IDA_NORMAL)
push!(ures,copy(utmp))
push!(ts, save_ts[k]...)
end
end

### Finishing Routine
Expand Down
3 changes: 2 additions & 1 deletion src/handle.jl
Expand Up @@ -33,6 +33,7 @@ immutable Handle{T <: AbstractSundialsObject}
ptr_ref::Ref{Ptr{T}} # pointer to a pointer

@compat function (::Type{Handle}){T}(ptr::Ptr{T})
(ptr == C_NULL) && throw(ArgumentError("Null pointer passed to Handle()"))
h = new{T}(Ref{Ptr{T}}(ptr))
finalizer(h.ptr_ref, release_handle)
return h
Expand All @@ -43,7 +44,7 @@ Base.convert{T}(::Type{Ptr{T}}, h::Handle{T}) = h.ptr_ref[]
Base.convert{T}(::Type{Ptr{Ptr{T}}}, h::Handle{T}) = convert(Ptr{Ptr{T}}, h.ptr_ref[])

release_handle{T}(ptr_ref::Ref{Ptr{T}}) = throw(MethodError("Freeing objects of type $T not supported"))
release_handle(ptr_ref::Ref{Ptr{KINMem}}) = KINSOLFree(ptr_ref)
release_handle(ptr_ref::Ref{Ptr{KINMem}}) = KINFree(ptr_ref)
release_handle(ptr_ref::Ref{Ptr{CVODEMem}}) = CVodeFree(ptr_ref)
release_handle(ptr_ref::Ref{Ptr{IDAMem}}) = IDAFree(ptr_ref)

Expand Down

0 comments on commit cf0634d

Please sign in to comment.