Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes for VectorContinuousCallback #754

Merged
merged 14 commits into from
Jun 6, 2019
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[compat]
julia = "1"
DiffEqBase = ">= 5.8.1"
DiffEqBase = ">= 5.10.0"
DiffEqOperators = ">= 3.2.0"
Parameters = ">= 0.10.0"
ForwardDiff = ">= 0.10.3"
Expand All @@ -31,6 +31,6 @@ NLsolve = ">= 0.14.1"
RecursiveArrayTools = ">= 0.18.6"
DiffEqDiffTools = ">= 0.4.0"
MuladdMacro = ">= 0.2.1"
StaticArrays = ">= 0.10.3"
StaticArrays = "0.10.3"
DataStructures = ">= 0.15.0"
ExponentialUtilities = ">= 1.2.0"
8 changes: 2 additions & 6 deletions src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,15 +334,11 @@ function calc_W!(integrator, cache::OrdinaryDiffEqMutableCache, dtgamma, repeat_
(!repeat_step && W_transform) ? f.invW_t(W, uprev, p, dtgamma, t) : f.invW(W, uprev, p, dtgamma, t) # W == inverse W
is_compos && calc_J!(integrator, cache, true)
elseif DiffEqBase.has_jac(f) && f.jac_prototype !== nothing
# skip calculation of J if step is repeated
new_jac && DiffEqBase.update_coefficients!(W,uprev,p,t)
# skip calculation of W if step is repeated
isnewton || DiffEqBase.update_coefficients!(W,uprev,p,t) # we will call `update_coefficients!` in NLNewton
@label J2W
new_W && (W.transform = W_transform; set_gamma!(W, dtgamma))
W.transform = W_transform; set_gamma!(W, dtgamma)
else # concrete W using jacobian from `calc_J!`
# skip calculation of J if step is repeated
new_jac && calc_J!(integrator, cache, is_compos)
# skip calculation of W if step is repeated
new_W && jacobian2W!(W, mass_matrix, dtgamma, J, W_transform)
end
if isnewton
Expand Down
6 changes: 4 additions & 2 deletions src/integrators/integrator_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,13 +227,15 @@ function handle_callbacks!(integrator)
discrete_modified = false
saved_in_cb = false
if !(typeof(continuous_callbacks)<:Tuple{})
time,upcrossing,event_occurred,idx,counter =
time,upcrossing,event_occurred,event_idx,idx,counter =
DiffEqBase.find_first_continuous_callback(integrator,continuous_callbacks...)
if event_occurred
integrator.event_last_time = idx
continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing)
integrator.vector_event_last_time = event_idx
continuous_modified,saved_in_cb = DiffEqBase.apply_callback!(integrator,continuous_callbacks[idx],time,upcrossing,event_idx)
else
integrator.event_last_time = 0
integrator.vector_event_last_time = 1
end
end
if !integrator.force_stepfail && !(typeof(discrete_callbacks)<:Tuple{})
Expand Down
18 changes: 10 additions & 8 deletions src/integrators/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ integrator.opts.abstol = 1e-9
```
For more info see the linked documentation page.
"""
mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,IIP,uType,tType,pType,eigenType,QT,tdirType,ksEltype,SolType,F,CacheType,O,FSALType,EventErrorType} <: DiffEqBase.AbstractODEIntegrator{algType,IIP,uType,tType}
mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,IIP,uType,tType,pType,eigenType,QT,tdirType,ksEltype,SolType,F,CacheType,O,FSALType,EventErrorType,CallbackCacheType} <: DiffEqBase.AbstractODEIntegrator{algType,IIP,uType,tType}
sol::SolType
u::uType
k::ksEltype
Expand All @@ -101,11 +101,13 @@ mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,IIP,uType,tType,pT
saveiter::Int
saveiter_dense::Int
cache::CacheType
callback_cache::CallbackCacheType
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
kshortsize::Int
force_stepfail::Bool
last_stepfail::Bool
just_hit_tstop::Bool
event_last_time::Int
vector_event_last_time::Int
last_event_error::EventErrorType
accept_step::Bool
isout::Bool
Expand All @@ -117,24 +119,24 @@ mutable struct ODEIntegrator{algType<:OrdinaryDiffEqAlgorithm,IIP,uType,tType,pT
fsallast::FSALType

function ODEIntegrator{algType,IIP,uType,tType,pType,eigenType,tTypeNoUnits,tdirType,ksEltype,SolType,
F,CacheType,O,FSALType,EventErrorType}(
F,CacheType,O,FSALType,EventErrorType,CallbackCacheType}(
sol,u,k,t,dt,f,p,uprev,uprev2,tprev,
alg,dtcache,dtchangeable,dtpropose,tdir,
eigen_est,EEst,qold,q11,erracc,dtacc,success_iter,
iter,saveiter,saveiter_dense,cache,
iter,saveiter,saveiter_dense,cache,callback_cache,
kshortsize,force_stepfail,last_stepfail,just_hit_tstop,
event_last_time,last_event_error,
event_last_time,vector_event_last_time,last_event_error,
accept_step,isout,reeval_fsal,u_modified,opts,destats) where {algType,IIP,uType,tType,pType,eigenType,tTypeNoUnits,tdirType,ksEltype,SolType,
F,CacheType,O,FSALType,EventErrorType}
F,CacheType,O,FSALType,EventErrorType,CallbackCacheType}

new{algType,IIP,uType,tType,pType,eigenType,tTypeNoUnits,tdirType,ksEltype,SolType,
F,CacheType,O,FSALType,EventErrorType}(
F,CacheType,O,FSALType,EventErrorType,CallbackCacheType}(
sol,u,k,t,dt,f,p,uprev,uprev2,tprev,
alg,dtcache,dtchangeable,dtpropose,tdir,
eigen_est,EEst,qold,q11,erracc,dtacc,success_iter,
iter,saveiter,saveiter_dense,cache,
iter,saveiter,saveiter_dense,cache,callback_cache,
kshortsize,force_stepfail,last_stepfail,just_hit_tstop,
event_last_time,last_event_error,
event_last_time,vector_event_last_time,last_event_error,
accept_step,isout,reeval_fsal,u_modified,opts,destats) # Leave off fsalfirst and last
end
end
Expand Down
14 changes: 11 additions & 3 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,13 @@ function DiffEqBase.__init(

callbacks_internal = CallbackSet(callback,prob.callback)

max_len_cb = DiffEqBase.max_vector_callback_length(callbacks_internal)
if max_len_cb isa VectorContinuousCallback
callback_cache = DiffEqBase.CallbackCache(max_len_cb.len,uBottomEltype,uBottomEltype)
else
callback_cache = nothing
end

### Algorithm-specific defaults ###
if save_idxs === nothing
ksEltype = Vector{rateType}
Expand Down Expand Up @@ -302,6 +309,7 @@ function DiffEqBase.__init(
force_stepfail = false
last_stepfail = false
event_last_time = 0
vector_event_last_time = 1
last_event_error = zero(uBottomEltypeNoUnits)
dtchangeable = isdtchangeable(alg)
q11 = tTypeNoUnits(1)
Expand All @@ -313,14 +321,14 @@ function DiffEqBase.__init(
QT,typeof(tdir),typeof(k),SolType,
FType,cacheType,
typeof(opts),fsal_typeof(alg,rate_prototype),
typeof(last_event_error)}(
typeof(last_event_error),typeof(callback_cache)}(
sol,u,k,t,tType(dt),f,p,uprev,uprev2,tprev,
alg,dtcache,dtchangeable,
dtpropose,tdir,eigen_est,EEst,QT(qoldinit),q11,
erracc,dtacc,success_iter,
iter,saveiter,saveiter_dense,cache,
iter,saveiter,saveiter_dense,cache,callback_cache,
kshortsize,force_stepfail,last_stepfail,
just_hit_tstop,event_last_time,last_event_error,
just_hit_tstop,event_last_time,vector_event_last_time,last_event_error,
accept_step,
isout,reeval_fsal,
u_modified,opts,destats)
Expand Down
6 changes: 3 additions & 3 deletions test/algconvergence/linear_nonlinear_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ end
prob = SplitODEProblem(linnonlin_fun_iip,u0,(0.0,1.0))

dts = 1 ./2 .^(8:-1:4) #14->7 good plot
for Alg in [GenericIIF1,GenericIIF2,LawsonEuler,NorsettEuler,ETDRK2,ETDRK3,ETDRK4,HochOst4,ETD2,KenCarp3]
sim = test_convergence(dts,prob,Alg())
@test sim.𝒪est[:l2] ≈ alg_order(Alg()) atol=0.1
for Alg in [GenericIIF1(),GenericIIF2(),LawsonEuler(),NorsettEuler(),ETDRK2(),ETDRK3(),ETDRK4(),HochOst4(),ETD2(),KenCarp3(linsolve=LinSolveGMRES(tol=1e-6))]
sim = test_convergence(dts,prob,Alg)
@test sim.𝒪est[:l2] ≈ alg_order(Alg) atol=0.1
end
sim = test_convergence(dts,prob,ETDRK4(),dense_errors=true)
@test sim.𝒪est[:l2] ≈ 4 atol=0.1
Expand Down
16 changes: 16 additions & 0 deletions test/interface/linear_nonlinear_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
using OrdinaryDiffEq, Test, Random
Random.seed!(123)

using OrdinaryDiffEq, DiffEqOperators, LinearAlgebra
A = 0.01*rand(3, 3)
rn = (du, u, p, t) -> begin
mul!(du, A, u)
end
u0 = rand(3)
prob = ODEProblem(ODEFunction(rn, jac_prototype=JacVecOperator{Float64}(rn, u0; autodiff=false)), u0, (0, 10.))
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false));
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false, linsolve=LinSolveGMRES()));
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false, linsolve=LinSolveGMRES(), smooth_est=false));
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false, linsolve=LinSolveGMRES(Pl=lu(A)), smooth_est=false));
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false, linsolve=LinSolveGMRES(Pr=lu(A)), smooth_est=false));
@test_nowarn sol = solve(prob, TRBDF2(autodiff=false, linsolve=LinSolveGMRES(Pl=lu(A), Pr=lu(A)), smooth_est=false));
2 changes: 1 addition & 1 deletion test/interface/utility_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,6 @@ end

sol1_ip = solve(ODEProblem(fun1_ip,u0,tspan), Alg(); adaptive=false, dt=0.01)
sol2_ip = solve(ODEProblem(fun2_ip,u0,tspan), Alg(linsolve=LinSolveFactorize(lu)); adaptive=false, dt=0.01)
@test sol1_ip(1.0) ≈ sol2_ip(1.0)
@test sol1_ip(1.0) ≈ sol2_ip(1.0) atol=1e-5
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ if group == "All" || group == "Interface"
@time @safetestset "AD Tests" begin include("interface/ad_tests.jl") end
@time @safetestset "No Index Tests" begin include("interface/noindex_tests.jl") end
@time @safetestset "Units Tests" begin include("interface/units_tests.jl") end
@time @safetestset "Linear Nonlinear Solver Tests" begin include("interface/linear_nonlinear_tests.jl") end
end

if group == "All" || group == "Integrators"
Expand Down