Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 116 additions & 38 deletions ext/LinearSolveSparseArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,25 @@ function LinearSolve.init_cacheval(alg::RFLUFactorization,
end

function LinearSolve.handle_sparsematrixcsc_lu(A::AbstractSparseMatrixCSC)
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
@static if Base.USE_GPL_LIBS
lu(SparseMatrixCSC(size(A)..., getcolptr(A), rowvals(A), nonzeros(A)),
check = false)
else
error("Sparse LU factorization requires GPL libraries (UMFPACK). Use `using Sparspak` for a non-GPL alternative or rebuild Julia with USE_GPL_LIBS=1")
end
end

@static if Base.USE_GPL_LIBS
function LinearSolve.defaultalg(
A::Symmetric{<:BLASELTYPES, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CHOLMODFactorization)
end
else
function LinearSolve.defaultalg(
A::Symmetric{<:BLASELTYPES, <:SparseMatrixCSC}, b, ::OperatorAssumptions{Bool})
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.CholeskyFactorization)
end
end # @static if Base.USE_GPL_LIBS

function LinearSolve.defaultalg(A::AbstractSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
Expand All @@ -71,9 +82,13 @@ function LinearSolve.init_cacheval(alg::GenericFactorization,
LinearSolve.do_factorization(alg, newA, b, u)
end

@static if Base.USE_GPL_LIBS

const PREALLOCATED_UMFPACK = SparseArrays.UMFPACK.UmfpackLU(SparseMatrixCSC(0, 0, [1],
Int[], Float64[]))

end # @static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(
alg::LUFactorization, A::AbstractSparseArray{<:Number, <:Integer}, b, u,
Pl, Pr,
Expand All @@ -98,6 +113,8 @@ function LinearSolve.init_cacheval(
nothing
end

@static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(
alg::LUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u,
Pl, Pr,
Expand Down Expand Up @@ -132,6 +149,8 @@ function LinearSolve.init_cacheval(
end
end

end # @static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(
alg::LUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
Expand All @@ -140,14 +159,6 @@ function LinearSolve.init_cacheval(
ArrayInterface.lu_instance(A)
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64, Int}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_UMFPACK
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
Expand All @@ -156,6 +167,16 @@ function LinearSolve.init_cacheval(
nothing
end

@static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray{Float64, Int}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_UMFPACK
end

function LinearSolve.init_cacheval(
alg::UMFPACKFactorization, A::AbstractSparseArray{T, Int64}, b, u,
Pl, Pr,
Expand Down Expand Up @@ -211,8 +232,14 @@ function SciMLBase.solve!(
end
end

const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
Float64[]))
else

function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::UMFPACKFactorization; kwargs...)
error("UMFPACKFactorization requires GPL libraries (UMFPACK). Rebuild Julia with USE_GPL_LIBS=1 or use an alternative algorithm like SparspakFactorization")
end

end # @static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractArray, b, u, Pl,
Expand All @@ -222,14 +249,6 @@ function LinearSolve.init_cacheval(
nothing
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_KLU
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::LinearSolve.GPUArraysCore.AnyGPUArray, b, u,
Pl, Pr,
Expand All @@ -238,6 +257,19 @@ function LinearSolve.init_cacheval(
nothing
end

@static if Base.USE_GPL_LIBS

const PREALLOCATED_KLU = KLU.KLUFactorization(SparseMatrixCSC(0, 0, [1], Int[],
Float64[]))

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int64}, b, u, Pl, Pr,
maxiters::Int, abstol,
reltol,
verbose::Bool, assumptions::OperatorAssumptions)
PREALLOCATED_KLU
end

function LinearSolve.init_cacheval(
alg::KLUFactorization, A::AbstractSparseArray{Float64, Int32}, b, u, Pl, Pr,
maxiters::Int, abstol,
Expand All @@ -247,7 +279,6 @@ function LinearSolve.init_cacheval(
0, 0, [Int32(1)], Int32[], Float64[]))
end

# TODO: guard this against errors
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
A = cache.A
A = convert(AbstractMatrix, A)
Expand Down Expand Up @@ -282,6 +313,24 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization;
end
end

else

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::KLUFactorization; kwargs...)
error("KLUFactorization requires GPL libraries (KLU/SuiteSparse). Rebuild Julia with USE_GPL_LIBS=1 or use an alternative algorithm like SparspakFactorization")
end

end # @static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
A::AbstractArray, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end

@static if Base.USE_GPL_LIBS

const PREALLOCATED_CHOLMOD = cholesky(sparse(reshape([1.0], 1, 1)))

function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
Expand All @@ -302,13 +351,7 @@ function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
cholesky(sparse(reshape([one(T)], 1, 1)))
end

function LinearSolve.init_cacheval(alg::CHOLMODFactorization,
A::AbstractArray, b, u,
Pl, Pr,
maxiters::Int, abstol, reltol,
verbose::Bool, assumptions::OperatorAssumptions)
nothing
end
end # @static if Base.USE_GPL_LIBS

function LinearSolve.init_cacheval(alg::NormalCholeskyFactorization,
A::Union{AbstractSparseArray{T}, LinearSolve.GPUArraysCore.AnyGPUArray,
Expand All @@ -321,39 +364,58 @@ end
# Specialize QR for the non-square case
# Missing ldiv! definitions: https://github.com/JuliaSparse/SparseArrays.jl/issues/242
function LinearSolve._ldiv!(x::Vector,
A::Union{QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::Vector)
A::Union{QR, LinearAlgebra.QRCompactWY}, b::Vector)
x .= A \ b
end

function LinearSolve._ldiv!(x::AbstractVector,
A::Union{QR, LinearAlgebra.QRCompactWY,
SparseArrays.SPQR.QRSparse,
SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
A::Union{QR, LinearAlgebra.QRCompactWY}, b::AbstractVector)
x .= A \ b
end

# Ambiguity removal
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
A::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY},
b::AbstractVector)
(A \ b)
end
function LinearSolve._ldiv!(::SVector,
A::Union{LinearAlgebra.QR, LinearAlgebra.QRCompactWY},
b::SVector)
(A \ b)
end

@static if Base.USE_GPL_LIBS
# SPQR and CHOLMOD Factor support
function LinearSolve._ldiv!(x::Vector,
A::Union{SparseArrays.SPQR.QRSparse, SparseArrays.CHOLMOD.Factor}, b::Vector)
x .= A \ b
end

function LinearSolve._ldiv!(x::AbstractVector,
A::Union{SparseArrays.SPQR.QRSparse, SparseArrays.CHOLMOD.Factor}, b::AbstractVector)
x .= A \ b
end

function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, SparseArrays.SPQR.QRSparse},
b::AbstractVector)
(A \ b)
end
function LinearSolve._ldiv!(::SVector,
A::Union{SparseArrays.CHOLMOD.Factor, LinearAlgebra.QR,
LinearAlgebra.QRCompactWY, SparseArrays.SPQR.QRSparse},
A::Union{SparseArrays.CHOLMOD.Factor, SparseArrays.SPQR.QRSparse},
b::SVector)
(A \ b)
end
end # @static if Base.USE_GPL_LIBS

function LinearSolve.pattern_changed(fact, A::SparseArrays.SparseMatrixCSC)
!(SparseArrays.decrement(SparseArrays.getcolptr(A)) ==
fact.colptr && SparseArrays.decrement(SparseArrays.getrowval(A)) ==
fact.rowval)
end

@static if Base.USE_GPL_LIBS
function LinearSolve.defaultalg(
A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
Expand All @@ -367,6 +429,22 @@ function LinearSolve.defaultalg(
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
end
end
else
function LinearSolve.defaultalg(
A::AbstractSparseMatrixCSC{<:Union{Float64, ComplexF64}, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Ti}
ext = Base.get_extension(LinearSolve, :LinearSolveSparspakExt)
if assump.issq && ext !== nothing
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.SparspakFactorization)
elseif !assump.issq
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.QRFactorization)
elseif ext === nothing
error("SparspakFactorization required for sparse matrix LU without GPL libraries. Do `using Sparspak` to enable this functionality")
else
error("Unreachable reached. Please report this error with a reproducer.")
end
end
end # @static if Base.USE_GPL_LIBS

# SPQR Handling
function LinearSolve.init_cacheval(
Expand Down
18 changes: 15 additions & 3 deletions src/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,9 +396,17 @@ function algchoice_to_alg(alg::Symbol)
elseif alg === :SparspakFactorization
SparspakFactorization(throwerror = false)
elseif alg === :KLUFactorization
KLUFactorization()
@static if Base.USE_GPL_LIBS
KLUFactorization()
else
error("KLUFactorization requires GPL libraries. Rebuild Julia with USE_GPL_LIBS=1 or use a different algorithm")
end
elseif alg === :UMFPACKFactorization
UMFPACKFactorization()
@static if Base.USE_GPL_LIBS
UMFPACKFactorization()
else
error("UMFPACKFactorization requires GPL libraries. Rebuild Julia with USE_GPL_LIBS=1 or use a different algorithm")
end
elseif alg === :KrylovJL_GMRES
KrylovJL_GMRES()
elseif alg === :GenericLUFactorization
Expand All @@ -408,7 +416,11 @@ function algchoice_to_alg(alg::Symbol)
elseif alg === :BunchKaufmanFactorization
BunchKaufmanFactorization()
elseif alg === :CHOLMODFactorization
CHOLMODFactorization()
@static if Base.USE_GPL_LIBS
CHOLMODFactorization()
else
error("CHOLMODFactorization requires GPL libraries. Rebuild Julia with USE_GPL_LIBS=1 or use CholeskyFactorization instead")
end
elseif alg === :CholeskyFactorization
CholeskyFactorization()
elseif alg === :NormalCholeskyFactorization
Expand Down
Loading