Skip to content

Simple quick fix for refactor issue #207

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 22 additions & 8 deletions src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ end

Base.@kwdef struct UMFPACKFactorization <: AbstractFactorization
reuse_symbolic::Bool = true
check_pattern::Bool = true # Check factorization re-use
end

function init_cacheval(alg::UMFPACKFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
Expand All @@ -290,7 +291,13 @@ function SciMLBase.solve(cache::LinearCache, alg::UMFPACKFactorization; kwargs..
if cache.isfresh
if cache.cacheval !== nothing && alg.reuse_symbolic
# Caches the symbolic factorization: https://github.com/JuliaLang/julia/pull/33738
fact = lu!(cache.cacheval, A)
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
cache.cacheval.colptr &&
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
fact = lu(A)
else
fact = lu!(cache.cacheval, A)
end
else
fact = lu(A)
end
Expand All @@ -303,6 +310,7 @@ end

Base.@kwdef struct KLUFactorization <: AbstractFactorization
reuse_symbolic::Bool = true
check_pattern::Bool = true
end

function init_cacheval(alg::KLUFactorization, A, b, u, Pl, Pr, maxiters::Int, abstol,
Expand All @@ -316,14 +324,20 @@ function SciMLBase.solve(cache::LinearCache, alg::KLUFactorization; kwargs...)
A = convert(AbstractMatrix, A)
if cache.isfresh
if cache.cacheval !== nothing && alg.reuse_symbolic
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
# This won't recompute if it does.
KLU.klu_analyze!(cache.cacheval)
copyto!(cache.cacheval.nzval, A.nzval)
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
KLU.klu_factor!(cache.cacheval)
if alg.check_pattern && !(SuiteSparse.decrement(SparseArrays.getcolptr(A)) ==
cache.cacheval.colptr &&
SuiteSparse.decrement(SparseArrays.getrowval(A)) == cache.cacheval.rowval)
fact = KLU.klu(A)
else
# If we have a cacheval already, run umfpack_symbolic to ensure the symbolic factorization exists
# This won't recompute if it does.
KLU.klu_analyze!(cache.cacheval)
copyto!(cache.cacheval.nzval, A.nzval)
if cache.cacheval._numeric === C_NULL # We MUST have a numeric factorization for reuse, unlike UMFPACK.
KLU.klu_factor!(cache.cacheval)
end
fact = KLU.klu!(cache.cacheval, A)
end
fact = KLU.klu!(cache.cacheval, A)
else
# New fact each time since the sparsity pattern can change
# and thus it needs to reallocate
Expand Down
37 changes: 20 additions & 17 deletions test/basictests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,49 @@ end
end

@testset "UMFPACK Factorization" begin
A1 = A / 1
A1 = sparse(A / 1)
b1 = rand(n)
x1 = zero(b)
A2 = A / 2
A2 = sparse(A / 2)
b2 = rand(n)
x2 = zero(b)

prob1 = LinearProblem(sparse(A1), b1; u0 = x1)
prob2 = LinearProblem(sparse(A2), b2; u0 = x2)
prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(UMFPACKFactorization(), prob1, prob2)
test_interface(UMFPACKFactorization(reuse_symbolic = false), prob1, prob2)

# Test that refactoring wrong throws.
# Test that refactoring is checked and handled.
cache = SciMLBase.init(prob1, UMFPACKFactorization(); cache_kwargs...) # initialize cache
y = solve(cache)
cache = LinearSolve.set_A(cache, sprand(n, n, 0.8))
@test_throws ArgumentError solve(cache)
cache = LinearSolve.set_A(cache, A2)
@test A2 * solve(cache) ≈ b1
X = sprand(n, n, 0.8)
cache = LinearSolve.set_A(cache, X)
@test X * solve(cache) ≈ b1
end

@testset "KLU Factorization" begin
A1 = A / 1
A1 = sparse(A / 1)
b1 = rand(n)
x1 = zero(b)
A2 = A / 2
A2 = sparse(A / 2)
b2 = rand(n)
x2 = zero(b)

prob1 = LinearProblem(sparse(A1), b1; u0 = x1)
prob2 = LinearProblem(sparse(A2), b2; u0 = x2)
prob1 = LinearProblem(A1, b1; u0 = x1)
prob2 = LinearProblem(A2, b2; u0 = x2)
test_interface(KLUFactorization(), prob1, prob2)
test_interface(KLUFactorization(reuse_symbolic = false), prob1, prob2)

# Test that refactoring wrong throws.
# Test that refactoring wrong is checked and handled.
cache = SciMLBase.init(prob1, KLUFactorization(); cache_kwargs...) # initialize cache
y = solve(cache)
X = copy(A1)
X[8, 8] = 0.0
X[7, 8] = 1.0
cache = LinearSolve.set_A(cache, sparse(X))
@test_throws ArgumentError solve(cache)
cache = LinearSolve.set_A(cache, A2)
@test A2 * solve(cache) ≈ b1
X = sprand(n, n, 0.8)
cache = LinearSolve.set_A(cache, X)
@test X * solve(cache) ≈ b1
end

@testset "FastLAPACK Factorizations" begin
Expand Down
5 changes: 5 additions & 0 deletions test/zeroinittests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ A = Diagonal(ones(4))
b = rand(4)
A = sparse(A)
Anz = deepcopy(A)
C = copy(A)
C[begin, end] = 1e-8
A.nzval .= 0
cache_kwargs = (; verbose = true, abstol = 1e-8, reltol = 1e-8, maxiter = 30)

Expand All @@ -14,6 +16,9 @@ function test_nonzero_init(alg = nothing)
cache = LinearSolve.set_A(cache, Anz)
sol = solve(cache; cache_kwargs...)
@test sol.u == b
cache = LinearSolve.set_A(cache, C)
sol = solve(cache; cache_kwargs...)
@test sol.u ≈ b
end

test_nonzero_init()
Expand Down