Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic DAE Initialization #1037

Merged
merged 9 commits into from Feb 18, 2020
1 change: 1 addition & 0 deletions Project.toml
Expand Up @@ -15,6 +15,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221"
NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down
2 changes: 2 additions & 0 deletions src/OrdinaryDiffEq.jl
Expand Up @@ -28,6 +28,7 @@ module OrdinaryDiffEq

using ExponentialUtilities

using NLsolve
# Required by temporary fix in not in-place methods with 12+ broadcasts
# `MVector` is used by Nordsieck forms
import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray
Expand Down Expand Up @@ -61,6 +62,7 @@ module OrdinaryDiffEq
include("misc_utils.jl")
include("algorithms.jl")
include("alg_utils.jl")
include("initialize_dae.jl")

include("nlsolve/type.jl")
include("nlsolve/utils.jl")
Expand Down
54 changes: 54 additions & 0 deletions src/initialize_dae.jl
@@ -0,0 +1,54 @@
abstract type DAEInitializationAlgorithm end

struct BrownFullBasicInit <: DAEInitializationAlgorithm end

function initialize_dae!(integrator, u, du, differential_vars, alg::BrownFullBasicInit, ::Val{true})
@unpack p, t, f = integrator

nlequation = (out, x) -> begin
for i in 1:length(x)
if differential_vars[i]
du[i] = x[i]
else
u[i] = x[i]
end
end
f(out, du, u, p, t)
end

r = nlsolve(nlequation, zero(u))

for i in 1:length(u)
if differential_vars[i]
du[i] = r.zero[i]
else
u[i] = r.zero[i]
end
end
end

function initialize_dae!(integrator, u, du, differential_vars, alg::BrownFullBasicInit, ::Val{true})
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
@unpack p, t, f = integrator

nlequation = (x) -> begin
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
for i in 1:length(x)
if differential_vars[i]
du[i] = x[i]
else
u[i] = x[i]
end
end
f(du, u, p, t)
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end

r = nlsolve(nlequation, zero(u))

for i in 1:length(u)
if differential_vars[i]
du[i] = r.zero[i]
else
u[i] = r.zero[i]
end
end
end

5 changes: 5 additions & 0 deletions src/solve.jl
Expand Up @@ -61,6 +61,7 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.
initialize_integrator = true,
alias_u0 = false,
alias_du0 = false,
initializealg = BrownFullBasicInit(),
kwargs...) where recompile_flag

if prob isa DiffEqBase.AbstractDAEProblem && alg isa OrdinaryDiffEqAlgorithm
Expand Down Expand Up @@ -382,6 +383,10 @@ function DiffEqBase.__init(prob::Union{DiffEqBase.AbstractODEProblem,DiffEqBase.
isout,reeval_fsal,
u_modified,opts,destats)
if initialize_integrator
if isdae
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

evaluate to see if initialization is required first? Or do that in initialize?

initialize_dae!(integrator, u, du, prob.differential_vars, initializealg, Val(isinplace(prob)))
@show du
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved
end
initialize_callbacks!(integrator, initialize_save)
initialize!(integrator,integrator.cache)
save_start && typeof(alg) <: CompositeAlgorithm && copyat_or_push!(alg_choice,1,integrator.cache.current)
Expand Down