Skip to content

Commit

Permalink
dense output converging
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Jul 6, 2017
1 parent 0698191 commit 57bf853
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 76 deletions.
1 change: 1 addition & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ module OrdinaryDiffEq
include("interp_func.jl")
include("dense/generic_dense.jl")
include("dense/interpolants.jl")
include("dense/rosenbrock_interpolants.jl")
include("dense/stiff_addsteps.jl")
include("dense/low_order_rk_addsteps.jl")
include("dense/verner_addsteps.jl")
Expand Down
61 changes: 0 additions & 61 deletions src/dense/interpolants.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,67 +56,6 @@ end
end
end

"""
From MATLAB ODE Suite by Shampine
"""
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{0}})
d = cache.d
c1 = Θ*(1-Θ)/(1-2d)
c2 = Θ*-2d)/(1-2d)
#@. y₀ + dt*(c1*k[1] + c2*k[2])
y₀ + dt*(c1*k[1] + c2*k[2])
end

# First Derivative of the dense output
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{1}})
d = cache.d
c1diff = (1-2*Θ)/(1-2*d)
c2diff = (2*Θ-2*d)/(1-2*d)
#@. c1diff*k[1] + c2diff*k[2]
c1diff*k[1] + c2diff*k[2]
end

"""
From MATLAB ODE Suite by Shampine
"""
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{0}})
d = cache.tab.d
c1 = Θ*(1-Θ)/(1-2d)
c2 = Θ*-2d)/(1-2d)
if out == nothing
return y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
elseif idxs == nothing
#@. out = y₀ + dt*(c1*k[1] + c2*k[2])
@inbounds for i in eachindex(out)
out[i] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
end
else
#@views @. out = y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
end
end
end

@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{1}})
d = cache.tab.d
c1diff = (1-2*Θ)/(1-2*d)
c2diff = (2*Θ-2*d)/(1-2*d)
if out == nothing
return c1diff*k[1][idxs] + c2diff*k[2][idxs]
elseif idxs == nothing
#@. out = c1diff*k[1] + c2diff*k[2]
@inbounds for i in eachindex(out)
out[i] = c1diff*k[1][i] + c2diff*k[2][i]
end
else
#@views @. out = c1diff*k[1][idxs] + c2diff*k[2][idxs]
@inbounds for (j,i) in enumerate(idxs)
out[j] = c1diff*k[1][i] + c2diff*k[2][i]
end
end
end

"""
Second order strong stability preserving (SSP) interpolant.
Expand Down
64 changes: 64 additions & 0 deletions src/dense/rosenbrock_interpolants.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
From MATLAB ODE Suite by Shampine
"""
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{0}})
d = cache.d
c1 = Θ*(1-Θ)/(1-2d)
c2 = Θ*-2d)/(1-2d)
#@. y₀ + dt*(c1*k[1] + c2*k[2])
y₀ + dt*(c1*k[1] + c2*k[2])
end

# First Derivative of the dense output
@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23ConstantCache,Rosenbrock32ConstantCache},idxs,T::Type{Val{1}})
d = cache.d
c1diff = (1-2*Θ)/(1-2*d)
c2diff = (2*Θ-2*d)/(1-2*d)
#@. c1diff*k[1] + c2diff*k[2]
c1diff*k[1] + c2diff*k[2]
end

"""
From MATLAB ODE Suite by Shampine
"""
@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{0}})
d = cache.tab.d
c1 = Θ*(1-Θ)/(1-2d)
c2 = Θ*-2d)/(1-2d)
if out == nothing
return y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
elseif idxs == nothing
#@. out = y₀ + dt*(c1*k[1] + c2*k[2])
@inbounds for i in eachindex(out)
out[i] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
end
else
#@views @. out = y₀[idxs] + dt*(c1*k[1][idxs] + c2*k[2][idxs])
@inbounds for (j,i) in enumerate(idxs)
out[j] = y₀[i] + dt*(c1*k[1][i] + c2*k[2][i])
end
end
end

@inline function ode_interpolant!(out,Θ,dt,y₀,y₁,k,cache::Union{Rosenbrock23Cache,Rosenbrock32Cache},idxs,T::Type{Val{1}})
d = cache.tab.d
c1diff = (1-2*Θ)/(1-2*d)
c2diff = (2*Θ-2*d)/(1-2*d)
if out == nothing
return c1diff*k[1][idxs] + c2diff*k[2][idxs]
elseif idxs == nothing
#@. out = c1diff*k[1] + c2diff*k[2]
@inbounds for i in eachindex(out)
out[i] = c1diff*k[1][i] + c2diff*k[2][i]
end
else
#@views @. out = c1diff*k[1][idxs] + c2diff*k[2][idxs]
@inbounds for (j,i) in enumerate(idxs)
out[j] = c1diff*k[1][i] + c2diff*k[2][i]
end
end
end

@inline function ode_interpolant(Θ,dt,y₀,y₁,k,cache::Rodas4ConstantCache,idxs,T::Type{Val{0}})
y₀*(1-Θ)+Θ*(y₁+(1-Θ)*(k[1] + Θ*k[2]))
end
28 changes: 16 additions & 12 deletions src/integrators/rosenbrock_integrators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -881,8 +881,8 @@ end
#### Rodas4 type method

@inline function initialize!(integrator,cache::Rodas4ConstantCache,f=integrator.f)
integrator.kshortsize = 2
k = eltype(integrator.sol.k)(2)
integrator.kshortsize = 3
k = eltype(integrator.sol.k)(3)
integrator.k = k
integrator.fsalfirst = f(integrator.t,integrator.uprev)
end
Expand Down Expand Up @@ -954,25 +954,29 @@ end
integrator.EEst = integrator.opts.internalnorm(atmp)
end

integrator.k[1] = integrator.fsalfirst
integrator.k[2] = du
if integrator.opts.calck
@unpack h21,h22,h23,h24,h25,h31,h32,h33,h34,h35 = cache.tab
integrator.k[1] = h21*k1 + h22*k2 + h23*k3 + h24*k4 + h25*k5
integrator.k[2] = h31*k1 + h32*k2 + h33*k3 + h34*k4 + h35*k5
end

integrator.fsallast = du
@pack integrator = t,dt,u,k
end

@inline function initialize!(integrator,cache::Rodas4Cache,f=integrator.f)
integrator.kshortsize = 2
@unpack fsalfirst,fsallast = cache
integrator.kshortsize = 3
@unpack fsalfirst,fsallast,k1,k2,k3,k4 = cache
integrator.fsalfirst = fsalfirst
integrator.fsallast = fsallast
integrator.k = [fsalfirst,fsallast]
integrator.k = [k1,k2,k3,k4]
f(integrator.t,integrator.uprev,integrator.fsalfirst)
end

@inline function perform_step!(integrator,cache::Rodas4Cache,f=integrator.f)
@unpack t,dt,uprev,u,k = integrator
uidx = eachindex(integrator.uprev)
@unpack du,du1,du2,vectmp,vectmp2,vectmp3,vectmp4,vectmp5,vectmp6,fsalfirst,fsallast,dT,J,W,uf,tf,linsolve_tmp,linsolve_tmp_vec,jac_config = cache
@unpack du,du1,du2,k1,k2,k3,k4,vectmp,vectmp2,vectmp3,vectmp4,vectmp5,vectmp6,fsalfirst,fsallast,dT,J,W,uf,tf,linsolve_tmp,linsolve_tmp_vec,jac_config = cache
jidx = eachindex(J)
@unpack a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,C21,C31,C32,C41,C42,C43,C51,C52,C53,C54,C61,C62,C63,C64,C65,gamma,c2,c3,c4,d1,d2,d3,d4 = cache.tab
mass_matrix = integrator.sol.prob.mass_matrix
Expand Down Expand Up @@ -1022,7 +1026,7 @@ end
integrator.alg.linsolve(vectmp,W,linsolve_tmp_vec,true)
end

k1 = reshape(vectmp,sizeu...)
recursivecopy!(k1,reshape(vectmp,size(u)...))

@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i]+a21*k1[i]
Expand All @@ -1040,7 +1044,7 @@ end
integrator.alg.linsolve(vectmp2,W,linsolve_tmp_vec)
end

k2 = reshape(vectmp2,sizeu...)
recursivecopy!(k2,reshape(vectmp2,size(u)...))

@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i] + a31*k1[i] + a32*k2[i]
Expand All @@ -1058,7 +1062,7 @@ end
integrator.alg.linsolve(vectmp3,W,linsolve_tmp_vec)
end

k3 = reshape(vectmp3,sizeu...)
recursivecopy!(k3,reshape(vectmp3,size(u)...))

@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i] + a41*k1[i] + a42*k2[i] + a43*k3[i]
Expand All @@ -1076,7 +1080,7 @@ end
integrator.alg.linsolve(vectmp4,W,linsolve_tmp_vec)
end

k4 = reshape(vectmp4,sizeu...)
recursivecopy!(k4,reshape(vectmp4,size(u)...))

@tight_loop_macros for i in uidx
@inbounds u[i] = uprev[i] + a51*k1[i] + a52*k2[i] + a53*k3[i] + a54*k4[i]
Expand Down
9 changes: 6 additions & 3 deletions test/ode/ode_rosenbrock_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -288,20 +288,23 @@ sol = solve(prob,Ros4LStab(linsolve=LinSolveFactorize(qrfact!)))

prob = prob_ode_linear

sim = test_convergence(dts,prob,Rodas4())
sim = test_convergence(dts,prob,Rodas4(),dense_errors=true)
@test abs(sim.𝒪est[:final]-4) < testTol
@test abs(sim.𝒪est[:L2]-4) < testTol

sol = solve(prob,Rodas4())
@test length(sol) < 20

sim = test_convergence(dts,prob,Rodas42())
sim = test_convergence(dts,prob,Rodas42(),dense_errors=true)
@test abs(sim.𝒪est[:final]-4) < testTol
@test abs(sim.𝒪est[:L2]-4) < testTol

sol = solve(prob,Rodas42())
@test length(sol) < 20

sim = test_convergence(dts,prob,Rodas4P())
sim = test_convergence(dts,prob,Rodas4P(),dense_errors=true)
@test abs(sim.𝒪est[:final]-4) < testTol
@test abs(sim.𝒪est[:L2]-4) < testTol

sol = solve(prob,Rodas4P())
@test length(sol) < 20
Expand Down

0 comments on commit 57bf853

Please sign in to comment.