Skip to content

Commit

Permalink
Merge pull request #1037 from JuliaDiffEq/daeinit
Browse files Browse the repository at this point in the history
Basic DAE Initialization
  • Loading branch information
ChrisRackauckas committed Feb 18, 2020
2 parents e96a97f + 0398974 commit c6c58cf
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 11 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
2 changes: 2 additions & 0 deletions src/OrdinaryDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ module OrdinaryDiffEq

using ExponentialUtilities

using NLsolve
# Required by temporary fix in not in-place methods with 12+ broadcasts
# `MVector` is used by Nordsieck forms
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray
Expand Down Expand Up @@ -61,6 +62,7 @@ module OrdinaryDiffEq
include("misc_utils.jl")
include("algorithms.jl")
include("alg_utils.jl")
include("initialize_dae.jl")

include("nlsolve/type.jl")
include("nlsolve/utils.jl")
Expand Down
24 changes: 14 additions & 10 deletions src/caches/dae_caches.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
@cache mutable struct DImplicitEulerCache{uType,uNoUnitsType,N} <: OrdinaryDiffEqMutableCache
@cache mutable struct DImplicitEulerCache{uType,duType,uNoUnitsType,N} <: OrdinaryDiffEqMutableCache
u::uType
uprev::uType
uprev2::uType
du::duType
atmp::uNoUnitsType
nlsolver::N
end

mutable struct DImplicitEulerConstantCache{N} <: OrdinaryDiffEqConstantCache
mutable struct DImplicitEulerConstantCache{duType,N} <: OrdinaryDiffEqConstantCache
du::duType
nlsolver::N
end

Expand All @@ -16,7 +18,7 @@ function alg_cache(alg::DImplicitEuler,du,u,res_prototype,rate_prototype,uEltype
α = 1
nlsolver = build_nlsolver(alg,u,uprev,p,t,dt,f,res_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,α,Val(false))

DImplicitEulerConstantCache(nlsolver)
DImplicitEulerConstantCache(du,nlsolver)
end

function alg_cache(alg::DImplicitEuler,du,u,res_prototype,rate_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,
Expand All @@ -28,10 +30,11 @@ function alg_cache(alg::DImplicitEuler,du,u,res_prototype,rate_prototype,uEltype

atmp = similar(u,uEltypeNoUnits)

DImplicitEulerCache(u,uprev,uprev2,atmp,nlsolver)
DImplicitEulerCache(u,uprev,uprev2,du,atmp,nlsolver)
end

@cache mutable struct DABDF2ConstantCache{N,dtType,rate_prototype} <: OrdinaryDiffEqConstantCache
@cache mutable struct DABDF2ConstantCache{duType,N,dtType,rate_prototype} <: OrdinaryDiffEqConstantCache
du::duType
nlsolver::N
eulercache::DImplicitEulerConstantCache
dtₙ₋₁::dtType
Expand All @@ -43,18 +46,19 @@ function alg_cache(alg::DABDF2,du,u,res_prototype,rate_prototype,uEltypeNoUnits,
γ, c = 1//1, 1
α = 1//1
nlsolver = build_nlsolver(alg,u,uprev,p,t,dt,f,res_prototype,uEltypeNoUnits,uBottomEltypeNoUnits,tTypeNoUnits,γ,c,α,Val(false))
eulercache = DImplicitEulerConstantCache(nlsolver)
eulercache = DImplicitEulerConstantCache(du,nlsolver)

dtₙ₋₁ = one(dt)
fsalfirstprev = rate_prototype

DABDF2ConstantCache(nlsolver, eulercache, dtₙ₋₁, fsalfirstprev)
DABDF2ConstantCache(du,nlsolver, eulercache, dtₙ₋₁, fsalfirstprev)
end

@cache mutable struct DABDF2Cache{uType,rateType,uNoUnitsType,N,dtType} <: OrdinaryDiffEqMutableCache
@cache mutable struct DABDF2Cache{uType,duType,rateType,uNoUnitsType,N,dtType} <: OrdinaryDiffEqMutableCache
uₙ::uType
uₙ₋₁::uType
uₙ₋₂::uType
du::duType
fsalfirst::rateType
fsalfirstprev::rateType
atmp::uNoUnitsType
Expand All @@ -73,10 +77,10 @@ function alg_cache(alg::DABDF2,du,u,res_prototype,rate_prototype,uEltypeNoUnits,
fsalfirstprev = zero(rate_prototype)
atmp = similar(u,uEltypeNoUnits)

eulercache = DImplicitEulerCache(u,uprev,uprev2,atmp,nlsolver)
eulercache = DImplicitEulerCache(u,uprev,uprev2,du,atmp,nlsolver)

dtₙ₋₁ = one(dt)

DABDF2Cache(u,uprev,uprev2,fsalfirst,fsalfirstprev,atmp,
DABDF2Cache(u,uprev,uprev2,du,fsalfirst,fsalfirstprev,atmp,
nlsolver,eulercache,dtₙ₋₁)
end
93 changes: 93 additions & 0 deletions src/initialize_dae.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
abstract type DAEInitializationAlgorithm end

struct BrownFullBasicInit{T} <: DAEInitializationAlgorithm
abstol::T
end
BrownFullBasicInit() = BrownFullBasicInit(1e-10)

function initialize_dae!(integrator, u, du, differential_vars, alg::BrownFullBasicInit, ::Val{true})
@unpack p, t, f = integrator

tmp = get_tmp_cache(integrator)[1]
f(tmp, du, u, p, t)

if integrator.opts.internalnorm(tmp,t) <= alg.abstol
return
elseif differential_vars === nothing
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions or differential_vars")
end

nlequation = (out, x) -> begin
for i in 1:length(x)
if differential_vars[i]
du[i] = x[i]
else
u[i] = x[i]
end
end
f(out, du, u, p, t)
end

r = nlsolve(nlequation, zero(u))

for i in 1:length(u)
if differential_vars[i]
du[i] = r.zero[i]
else
u[i] = r.zero[i]
end
end

return
end

function initialize_dae!(integrator, _u, _du, differential_vars, alg::BrownFullBasicInit, ::Val{false})
@unpack p, t, f = integrator

if integrator.opts.internalnorm(f(_du, _u, p, t),t) <= alg.abstol
return
elseif differential_vars === nothing
error("differential_vars must be set for DAE initialization to occur. Either set consistent initial conditions or differential_vars")
end

if _u isa Number && _du isa Number
# This doesn't fix static arrays!
u = [_u]
du = [_du]
else
u = _u
du = _du
end

nlequation = (out,x) -> begin
for i in 1:length(x)
if differential_vars[i]
du[i] = x[i]
else
u[i] = x[i]
end
end
out .= f(du, u, p, t)
end

r = nlsolve(nlequation, zero(u))

for i in 1:length(u)
if differential_vars[i]
du[i] = r.zero[i]
else
u[i] = r.zero[i]
end
end

if _u isa Number && _du isa Number
# This doesn't fix static arrays!
integrator.u = first(u)
integrator.cache.du = first(du)
else
integrator.u = u
integrator.cache.du = du
end

return
end
4 changes: 4 additions & 0 deletions src/solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.
initialize_integrator = true,
alias_u0 = false,
alias_du0 = false,
initializealg = BrownFullBasicInit(),
kwargs...) where recompile_flag

if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm
Expand Down Expand Up @@ -382,6 +383,9 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.
isout,reeval_fsal,
u_modified,opts,destats)
if initialize_integrator
if isdae
initialize_dae!(integrator, u, du, prob.differential_vars, initializealg, Val(isinplace(prob)))
end
initialize_callbacks!(integrator, initialize_save)
initialize!(integrator,integrator.cache)
save_start && typeof(alg) <: CompositeAlgorithm && copyat_or_push!(alg_choice,1,integrator.cache.current)
Expand Down
2 changes: 1 addition & 1 deletion test/algconvergence/dae_convergence_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,4 @@ prob_dae_linear_oop = DAEProblem(

sim24 = test_convergence(dts,prob,DABDF2(;autodiff=false))
@test sim24.𝒪est[:final] 2 atol=testTol
end
end
73 changes: 73 additions & 0 deletions test/integrators/dae_initialization_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
using OrdinaryDiffEq, StaticArrays, Test

f = function (out,du,u,p,t)
out[1] = - 0.04u[1] + 1e4*u[2]*u[3] - du[1]
out[2] = + 0.04u[1] - 3e7*u[2]^2 - 1e4*u[2]*u[3] - du[2]
out[3] = u[1] + u[2] + u[3] - 1.0
end

u₀ = [1.0, 0, 0]
du₀ = [0.0, 0.0, 0.0]
tspan = (0.0,100000.0)
differential_vars = [true,true,false]
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
integrator = init(prob, DABDF2())

@test integrator.cache.du[1] -0.04 atol=1e-9
@test integrator.cache.du[2] 0.04 atol=1e-9
@test integrator.u[3] 0.0 atol=1e-9

integrator = init(prob, DImplicitEuler())

@test integrator.cache.du[1] -0.04 atol=1e-9
@test integrator.cache.du[2] 0.04 atol=1e-9
@test integrator.u[3] 0.0 atol=1e-9

# Need to be able to find the consistent solution of this problem, broken right now
# analytical solution:
# u[1](t) -> cos(t)
# u[2](t) -> -sin(t)
# u[3](t) -> 2cos(t)
f = function (out,du,u,p,t)
out[1] = du[1] - u[2]
out[2] = du[2] + u[3] - cos(t)
out[3] = u[1] - cos(t)
end

u₀ = [1.0, 0.0, 0.0]
du₀ = [0.0, 0.0, 0.0]
tspan = (0.0,1.0)
differential_vars = [true, true, false]
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
integrator = init(prob, DABDF2())

@test integrator.cache.du[1] 0.0 atol=1e-9
@test_broken integrator.cache.du[2] -1.0 atol=1e-9
@test_broken integrator.u[3] 2.0 atol=1e-9

f = function (du,u,p,t)
du - u
end

u₀ = 1.0
du₀ = 0.0
tspan = (0.0,1.0)
differential_vars = [true]
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
integrator = init(prob, DABDF2())

@test integrator.cache.du 1.0 atol=1e-9

f = function (du,u,p,t)
du .- u
end

u₀ = SA[1.0, 1.0]
du₀ = SA[0.0, 0.0]
tspan = (0.0,1.0)
differential_vars = [true, true]
prob = DAEProblem(f,du₀,u₀,tspan,differential_vars=differential_vars)
integrator = init(prob, DABDF2())

@test integrator.cache.du[1] 1.0 atol=1e-9
@test integrator.cache.du[2] 1.0 atol=1e-9
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ if !is_APPVEYOR && (GROUP == "All" || GROUP == "Integrators_II")
@time @safetestset "Reverse Directioned Event Tests" begin include("integrators/rev_events_tests.jl") end
@time @safetestset "Differentiation Direction Tests" begin include("integrators/diffdir_tests.jl") end
@time @safetestset "Resize Tests" begin include("integrators/resize_tests.jl") end
@time @safetestset "DAE Initialization Tests" begin include("integrators/dae_initialization_tests.jl") end
end

if !is_APPVEYOR && (GROUP == "All" || GROUP == "Regression")
Expand Down

0 comments on commit c6c58cf

Please sign in to comment.