Skip to content

Commit

Permalink
Merge pull request #2052 from SciML/early_defaulting
Browse files Browse the repository at this point in the history
Remove early defaulting and fix factorization algs
  • Loading branch information
ChrisRackauckas committed Nov 6, 2023
2 parents 8f5474a + a8c5f45 commit 57d87ba
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 109 deletions.
47 changes: 3 additions & 44 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,50 +234,18 @@ function DiffEqBase.prepare_alg(alg::Union{
OrdinaryDiffEqExponentialAlgorithm{0, AD, FDT}},
u0::AbstractArray{T},
p, prob) where {AD, FDT, T}
if alg isa OrdinaryDiffEqExponentialAlgorithm
linsolve = nothing
elseif alg.linsolve === nothing
if (prob.f isa ODEFunction && prob.f.f isa AbstractSciMLOperator)
linsolve = LinearSolve.defaultalg(prob.f.f, u0)
elseif (prob.f isa SplitFunction &&
prob.f.f1.f isa AbstractSciMLOperator)
linsolve = LinearSolve.defaultalg(prob.f.f1.f, u0)
if (linsolve === nothing) || (linsolve isa LinearSolve.DefaultLinearSolver &&
linsolve.alg !== LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()"
@warn msg
linsolve = KrylovJL_GMRES()
end
elseif (prob isa ODEProblem || prob isa DDEProblem) &&
(prob.f.mass_matrix === nothing ||
(prob.f.mass_matrix !== nothing &&
!(prob.f.jac_prototype isa AbstractSciMLOperator)))
linsolve = LinearSolve.defaultalg(prob.f.jac_prototype, u0)
else
# If mm is a sparse matrix and A is a MatrixOperator, then let linear
# solver choose things later
linsolve = nothing
end
else
linsolve = alg.linsolve
end

# If not using autodiff or norecompile mode or very large bitsize (like a dual number u0 already)
# don't use a large chunksize as it will either error or not be beneficial
if !(alg_autodiff(alg) isa AutoForwardDiff) ||
(isbitstype(T) && sizeof(T) > 24) ||
(prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper)
if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = Val{1}())
else
return remake(alg, chunk_size = Val{1}(), linsolve = linsolve)
end
return remake(alg, chunk_size = Val{1}())
end

L = StaticArrayInterface.known_length(typeof(u0))
if L === nothing # dynamic sized

# If chunksize is zero, pick chunksize right at the start of solve and
# then do function barrier to infer the full solve
x = if prob.f.colorvec === nothing
Expand All @@ -287,19 +255,10 @@ function DiffEqBase.prepare_alg(alg::Union{
end

cs = ForwardDiff.pickchunksize(x)

if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = Val{cs}())
else
return remake(alg, chunk_size = Val{cs}(), linsolve = linsolve)
end
return remake(alg, chunk_size = Val{cs}())
else # statically sized
cs = pick_static_chunksize(Val{L}())
if alg isa OrdinaryDiffEqExponentialAlgorithm
return remake(alg, chunk_size = cs)
else
return remake(alg, chunk_size = cs, linsolve = linsolve)
end
return remake(alg, chunk_size = cs)
end
end

Expand Down
4 changes: 1 addition & 3 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,11 @@ function DiffEqBase.remake(thing::Union{
ST, CJ},
OrdinaryDiffEqImplicitAlgorithm{CS, AD, FDT, ST, CJ
},
DAEAlgorithm{CS, AD, FDT, ST, CJ}};
linsolve, kwargs...) where {CS, AD, FDT, ST, CJ}
DAEAlgorithm{CS, AD, FDT, ST, CJ}}; kwargs...) where {CS, AD, FDT, ST, CJ}
T = SciMLBase.remaker_of(thing)
T(; SciMLBase.struct_as_namedtuple(thing)...,
chunk_size = Val{CS}(), autodiff = Val{AD}(), standardtag = Val{ST}(),
concrete_jac = CJ === nothing ? CJ : Val{CJ}(),
linsolve = linsolve,
kwargs...)
end

Expand Down
82 changes: 20 additions & 62 deletions test/interface/linear_solver_split_ode_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,48 @@ using LinearAlgebra, LinearSolve
import OrdinaryDiffEq.dolinsolve

n = 8
dt = 1 / 16
dt = 1 / 1000
u0 = ones(n)
tspan = (0.0, 1.0)

M1 = 2ones(n) |> Diagonal #|> Array
M2 = 2ones(n) |> Diagonal #|> Array

f1 = M1 |> MatrixOperator
f2 = M2 |> MatrixOperator
f1 = (du,u,p,t) -> du .= M1 * u
f2 = (du,u,p,t) -> du .= M2 * u
prob = SplitODEProblem(f1, f2, u0, tspan)

for algname in (:SBDF2,
:SBDF3,
:KenCarp47)
@testset "$algname" begin
alg0 = @eval $algname()
alg1 = @eval $algname(linsolve = GenericFactorization())
alg1 = @eval $algname(linsolve = LUFactorization())

kwargs = (dt = dt,)

# expected error message
msg = "Split ODE problem do not work with factorization linear solvers. Bug detailed in https://github.com/SciML/OrdinaryDiffEq.jl/pull/1643. Defaulting to linsolve=KrylovJL()"
@test_logs (:warn, msg) solve(prob, alg0; kwargs...)
solve(prob, alg0; kwargs...)
@test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success
@test_broken DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success
@test DiffEqBase.__solve(prob, alg1; kwargs...).retcode == ReturnCode.Success
end
end

#####
# deep dive
#####

alg0 = KenCarp47() # passing case
alg1 = KenCarp47(linsolve = GenericFactorization()) # failing case

## objects
ig0 = SciMLBase.init(prob, alg0; dt = dt)
ig1 = SciMLBase.init(prob, alg1; dt = dt)

nl0 = ig0.cache.nlsolver
nl1 = ig1.cache.nlsolver

lc0 = nl0.cache.linsolve
lc1 = nl1.cache.linsolve

W0 = lc0.A
W1 = lc1.A

# perform first step
OrdinaryDiffEq.loopheader!(ig0)
OrdinaryDiffEq.loopheader!(ig1)

OrdinaryDiffEq.perform_step!(ig0, ig0.cache)
OrdinaryDiffEq.perform_step!(ig1, ig1.cache)

@test !OrdinaryDiffEq.nlsolvefail(nl0)
@test OrdinaryDiffEq.nlsolvefail(nl1)

# check operators
@test W0._concrete_form != W1._concrete_form
@test_broken W0._func_cache == W1._func_cache

# check operator application
b = ones(n)
@test W0 * b == W1 * b
@test mul!(rand(n), W0, b) == mul!(rand(n), W1, b)
#@test W0 \ b == W1 \ b

# check linear solve
lc0.b .= 1.0
lc1.b .= 1.0

solve(lc0)
solve(lc1)
f1 = M1 |> MatrixOperator
f2 = M2 |> MatrixOperator
prob = SplitODEProblem(f1, f2, u0, tspan)

@test_broken lc0.u == lc1.u
for algname in (:SBDF2,
:SBDF3,
:KenCarp47)
@testset "$algname" begin
alg0 = @eval $algname()

# solve contried problem using OrdinaryDiffEq machinery
linres0 = dolinsolve(ig0, lc0; A = W0, b = b, linu = ones(n), reltol = 1e-8)
linres1 = dolinsolve(ig1, lc1; A = W1, b = b, linu = ones(n), reltol = 1e-8)
kwargs = (dt = dt,)

@test_broken linres0 == linres1
solve(prob, alg0; kwargs...)
@test DiffEqBase.__solve(prob, alg0; kwargs...).retcode == ReturnCode.Success
end
end

###
# custom linsolve function
Expand All @@ -101,6 +61,4 @@ end

alg = KenCarp47(linsolve = LinearSolveFunction(linsolve))

@test solve(prob, alg).retcode == ReturnCode.Success

nothing
@test solve(prob, alg).retcode == ReturnCode.Success

0 comments on commit 57d87ba

Please sign in to comment.