Skip to content

Commit

Permalink
Merge pull request #589 from JuliaDiffEq/myb/fix
Browse files Browse the repository at this point in the history
Fix functional iteration with mass matrix
  • Loading branch information
ChrisRackauckas committed Jan 7, 2019
2 parents 7747e76 + 7b7c89c commit 52292b1
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 37 deletions.
20 changes: 12 additions & 8 deletions src/nlsolve/functional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Equations II, Springer Series in Computational Mathematics. ISBN
978-3-642-05221-7. Section IV.8.
[doi:10.1007/978-3-642-05221-7](https://doi.org/10.1007/978-3-642-05221-7)
"""
function (S::NLFunctional{false})(integrator)
@muladd function (S::NLFunctional{false})(integrator)
nlcache = S.cache
@unpack t,dt,uprev,u,f,p = integrator
@unpack z,tmp,κ,tol,c,γ,min_iter,max_iter = nlcache
Expand All @@ -42,7 +42,8 @@ function (S::NLFunctional{false})(integrator)
if mass_matrix == I
z₊ = dt .* f(u, p, tstep)
else
z₊ = mass_matrix * (dt .* f(u, p, tstep))
mz = mass_matrix * z
z₊ = dt .* f(u, p, tstep) .- mz .+ z
end
ndz = integrator.opts.internalnorm(z₊ .- z)
z = z₊
Expand All @@ -59,7 +60,8 @@ function (S::NLFunctional{false})(integrator)
if mass_matrix == I
z₊ = dt .* f(u, p, tstep)
else
z₊ = mass_matrix * (dt .* f(u, p, tstep))
mz = mass_matrix * z
z₊ = dt .* f(u, p, tstep) .- mz .+ z
end

# check early stopping criterion
Expand All @@ -82,7 +84,7 @@ function (S::NLFunctional{false})(integrator)
z, η, iter, do_functional
end

function (S::NLFunctional{true})(integrator)
@muladd function (S::NLFunctional{true})(integrator)
nlcache = S.cache
@unpack t,dt,uprev,u,f,p = integrator
@unpack z,z₊,b,dz,tmp,κ,tol,k,c,γ,min_iter,max_iter = nlcache
Expand All @@ -104,8 +106,9 @@ function (S::NLFunctional{true})(integrator)
if mass_matrix == I
@. z₊ = dt*k
else
@. ztmp = dt*k
mul!(vec(z₊), mass_matrix, vec(ztmp))
@. z₊ = dt*k + z
mul!(ztmp, mass_matrix, z)
@. z₊ -= ztmp
end
@. dz = z₊ - z
ndz = integrator.opts.internalnorm(dz)
Expand All @@ -124,8 +127,9 @@ function (S::NLFunctional{true})(integrator)
if mass_matrix == I
@. z₊ = dt*k
else
@. ztmp = dt*k
mul!(z₊, mass_matrix, ztmp)
@. z₊ = dt*k + z
mul!(ztmp, mass_matrix, z)
@. z₊ -= ztmp
end
@. dz = z₊ - z
ndzprev = ndz
Expand Down
4 changes: 3 additions & 1 deletion src/nlsolve/type.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ end

NLSolverCache(;κ=nothing, tol=nothing, min_iter=1, max_iter=10) =
NLSolverCache(κ, tol, min_iter, max_iter, 0, true,
(nothing for i in 1:10)...)
ntuple(i->nothing, 4)...,
κ === nothing ? κ : zero(κ),
ntuple(i->nothing, 5)...)

# Default `iip` to `true`, but the whole type will be reinitialized in `alg_cache`
function NLFunctional(;kwargs...)
Expand Down
68 changes: 40 additions & 28 deletions test/mass_matrix_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,38 @@ using OrdinaryDiffEq, Test, LinearAlgebra, Statistics

# TODO: clean up
@testset "Mass Matrix Accuracy Tests" begin
mm_A = [-2.0 1 4
4 -2 1
2 1 3]
mm_b = mm_A*ones(3)
# iip
function mm_f_iip(du,u,p,t)
mul!(du,mm_A,u)
tmp = t*mm_b
du .+= tmp
end
mm_g_iip(du,u,p,t) = @. du = u + t
# oop
function mm_f_oop(u,p,t)
du = u./t
mm_f_iip(du,u,p,t)
du
end
function mm_g_oop(u,p,t)
du = u./t
mm_g_iip(du,u,p,t)
du
end
mm_analytic(u0,p,t) = @. 2u0*exp(t) - t - 1
for iip in (false, true)
function make_prob(mm_A; iip)
mm_b = mm_A*ones(3)
# iip
function mm_f_iip(du,u,p,t)
mul!(du,mm_A,u)
tmp = t*mm_b
du .+= tmp
end
mm_g_iip(du,u,p,t) = @. du = u + t
# oop
function mm_f_oop(u,p,t)
du = u./t
mm_f_iip(du,u,p,t)
du
end
function mm_g_oop(u,p,t)
du = u./t
mm_g_iip(du,u,p,t)
du
end
mm_analytic(u0,p,t) = @. 2u0*exp(t) - t - 1
f = ((mm_f_oop, mm_g_oop), (mm_f_iip, mm_g_iip))[iip+1]
prob = ODEProblem(ODEFunction(f[1],analytic=mm_analytic,mass_matrix=mm_A),ones(3),
(0.0,1.0))
prob2 = ODEProblem(ODEFunction(f[2],analytic=mm_analytic),ones(3),(0.0,1.0))
return prob, prob2
end
for iip in (false, true)
mm_A = [-2.0 1 4
4 -2 1
2 1 3]
prob, prob2 = make_prob(mm_A; iip=iip)

######################################### Test each method for exactness

Expand Down Expand Up @@ -72,17 +76,25 @@ using OrdinaryDiffEq, Test, LinearAlgebra, Statistics

sol = solve(prob, ImplicitEuler(),dt=1/10,adaptive=false)
sol2 = solve(prob2,ImplicitEuler(),dt=1/10,adaptive=false)
sol3 = solve(prob2,ImplicitEuler(nlsolve=NLFunctional(min_iter=9)),dt=1/10,adaptive=false)

@test norm(sol .- sol2) 0 atol=1e-7
@test norm(sol .- sol3) 0 atol=1e-7

sol = solve(prob, ImplicitMidpoint(extrapolant = :constant),dt=1/10)
sol2 = solve(prob2,ImplicitMidpoint(extrapolant = :constant),dt=1/10)
sol3 = solve(prob2,ImplicitMidpoint(extrapolant = :constant,nlsolve=NLFunctional(min_iter=7)),dt=1/10)

@test norm(sol .- sol2) 0 atol=1e-7
@test norm(sol .- sol3) 0 atol=1e-7

# Functional iteration
prob, prob2 = make_prob(Matrix{Float64}(1.01I, 3, 3); iip=iip)
sol = solve(prob,ImplicitEuler(
nlsolve=NLFunctional=2000.,tol=1e-7,min_iter=10,max_iter=100)),dt=1/10,adaptive=false)
sol2 = solve(prob2,ImplicitEuler(),dt=1/10,adaptive=false)
@test norm(sol .- sol2) 0 atol=1e-7

sol = solve(prob, ImplicitMidpoint(extrapolant = :constant,
nlsolve=NLFunctional=2000.,tol=1e-7,min_iter=10,max_iter=100)),dt=1/10,adaptive=false)
sol2 = solve(prob2,ImplicitMidpoint(extrapolant = :constant),dt=1/10)
@test norm(sol .- sol2) 0 atol=1e-7
end
end

Expand Down

0 comments on commit 52292b1

Please sign in to comment.