Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
args...;
alias_A = default_alias_A(alg, prob.A, prob.b),
alias_b = default_alias_b(alg, prob.A, prob.b),
abstol = default_tol(eltype(prob.b)),
reltol = default_tol(eltype(prob.b)),
abstol = default_tol(real(eltype(prob.b))),
reltol = default_tol(real(eltype(prob.b))),
maxiters::Int = length(prob.b),
verbose::Bool = false,
Pl = IdentityOperator(size(prob.A)[1]),
Expand Down Expand Up @@ -151,8 +151,8 @@ function SciMLBase.init(prob::LinearProblem, alg::SciMLLinearSolveAlgorithm,
end

# Guard against type mismatch for user-specified reltol/abstol
reltol = eltype(prob.b)(reltol)
abstol = eltype(prob.b)(abstol)
reltol = real(eltype(prob.b))(reltol)
abstol = real(eltype(prob.b))(abstol)

cacheval = init_cacheval(alg, A, b, u0, Pl, Pr, maxiters, abstol, reltol, verbose,
assumptions)
Expand Down
2 changes: 1 addition & 1 deletion src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function _ldiv!(x::Vector, A::Factorization, b::Vector)
ldiv!(A, x)
end

#RF Bad fallback: will fail if `A` is just a stand-in
# RF Bad fallback: will fail if `A` is just a stand-in
# This should instead just create the factorization type.
function init_cacheval(alg::AbstractFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
reltol, verbose::Bool, assumptions::OperatorAssumptions)
Expand Down
76 changes: 49 additions & 27 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,37 +9,52 @@ const Dual64 = ForwardDiff.Dual{Nothing, Float64, 1}
n = 8
A = Matrix(I, n, n)
b = ones(n)
# Real-valued systems
A1 = A / 1;
b1 = rand(n);
x1 = zero(b);
# A2 is similar to A1; created to test cache reuse
A2 = A / 2;
b2 = rand(n);
x2 = zero(b);
# Complex systems + mismatched types with eltype(tol)
A3 = A1 .|> ComplexF32
b3 = b1 .|> ComplexF32
x3 = x1 .|> ComplexF32
# A4 is similar to A3; created to test cache reuse
A4 = A2 .|> ComplexF32
b4 = b2 .|> ComplexF32
x4 = x2 .|> ComplexF32

prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
prob3 = LinearProblem(A3, b3; u0 = x3)
prob4 = LinearProblem(A4, b4; u0 = x4)

cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)

function test_interface(alg, prob1, prob2)
A1 = prob1.A
b1 = prob1.b
x1 = prob1.u0
A2 = prob2.A
b2 = prob2.b
x2 = prob2.u0
A1, b1 = prob1.A, prob1.b
A2, b2 = prob2.A, prob2.b

sol = solve(prob1, alg; cache_kwargs...)
@test A1 * sol.u ≈ b1

sol = solve(prob2, alg; cache_kwargs...)
@test A2 * sol.u ≈ b2

# Test cache resue: base mechanism
cache = SciMLBase.init(prob1, alg; cache_kwargs...) # initialize cache
sol = solve!(cache)
@test A1 * sol.u ≈ b1

# Test cache resue: only A changes
cache.A = deepcopy(A2)
sol = solve!(cache; cache_kwargs...)
@test A2 * sol.u ≈ b1

cache.A = A2
# Test cache resue: both A and b change
cache.A = deepcopy(A2)
cache.b = b2
sol = solve!(cache; cache_kwargs...)
@test A2 * sol.u ≈ b2
Expand All @@ -50,6 +65,7 @@ end
@testset "LinearSolve" begin
@testset "Default Linear Solver" begin
test_interface(nothing, prob1, prob2)
test_interface(nothing, prob3, prob4)

A1 = prob1.A * prob1.A'
b1 = prob1.b
Expand Down Expand Up @@ -202,25 +218,24 @@ end
end
end

test_algs = if VERSION >= v"1.9" && LinearSolve.usemkl
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
RFLUFactorization(),
MKLLUFactorization(),
LinearSolve.defaultalg(prob1.A, prob1.b))
else
(LUFactorization(),
QRFactorization(),
SVDFactorization(),
RFLUFactorization(),
LinearSolve.defaultalg(prob1.A, prob1.b))

test_algs = [
LUFactorization(),
QRFactorization(),
SVDFactorization(),
RFLUFactorization(),
LinearSolve.defaultalg(prob1.A, prob1.b),
]

if VERSION >= v"1.9" && LinearSolve.usemkl
push!(test_algs, MKLLUFactorization())
end

@testset "Concrete Factorizations" begin
for alg in test_algs
@testset "$alg" begin
test_interface(alg, prob1, prob2)
VERSION >= v"1.9" && (alg isa MKLLUFactorization || test_interface(alg, prob3, prob4))
end
end
if LinearSolve.appleaccelerate_isavailable()
Expand All @@ -232,15 +247,16 @@ end
for fact_alg in (lu, lu!,
qr, qr!,
cholesky,
#cholesky!,
#ldlt, ldlt!,
# cholesky!,
# ldlt, ldlt!,
bunchkaufman, bunchkaufman!,
lq, lq!,
svd, svd!,
LinearAlgebra.factorize)
@testset "fact_alg = $fact_alg" begin
alg = GenericFactorization(fact_alg = fact_alg)
test_interface(alg, prob1, prob2)
test_interface(alg, prob3, prob4)
end
end
end
Expand All @@ -251,13 +267,17 @@ end

@testset "KrylovJL" begin
kwargs = (; gmres_restart = 5)
for alg in (("Default", KrylovJL(kwargs...)),
algorithms = (
("Default", KrylovJL(kwargs...)),
("CG", KrylovJL_CG(kwargs...)),
("GMRES", KrylovJL_GMRES(kwargs...)),
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
("MINRES", KrylovJL_MINRES(kwargs...)))
@testset "$(alg[1])" begin
test_interface(alg[2], prob1, prob2)
# ("BICGSTAB",KrylovJL_BICGSTAB(kwargs...)),
("MINRES", KrylovJL_MINRES(kwargs...))
)
for (name, algorithm) in algorithms
@testset "$name" begin
test_interface(algorithm, prob1, prob2)
test_interface(algorithm, prob3, prob4)
end
end
end
Expand All @@ -274,6 +294,7 @@ end
)
@testset "$(alg[1])" begin
test_interface(alg[2], prob1, prob2)
test_interface(alg[2], prob3, prob4)
end
end
end
Expand All @@ -287,6 +308,7 @@ end
("GMRES", KrylovKitJL_GMRES(kwargs...)))
@testset "$(alg[1])" begin
test_interface(alg[2], prob1, prob2)
test_interface(alg[2], prob3, prob4)
end
@test alg[2] isa KrylovKitJL
end
Expand Down