diff --git a/ext/LinearSolveCUDAExt.jl b/ext/LinearSolveCUDAExt.jl index 58e2c3444..e446094e1 100644 --- a/ext/LinearSolveCUDAExt.jl +++ b/ext/LinearSolveCUDAExt.jl @@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache, needs_concrete_A, error_no_cudss_lu, init_cacheval, OperatorAssumptions, - CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization, + CudaOffloadFactorization, CudaOffloadLUFactorization, + CudaOffloadQRFactorization, CUDAOffload32MixedLUFactorization, - SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity + SparspakFactorization, KLUFactorization, UMFPACKFactorization, + LinearVerbosity using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface using SciMLBase: AbstractSciMLOperator @@ -23,11 +25,16 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b, if LinearSolve.cudss_loaded(A) LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) else - if !LinearSolve.ALREADY_WARNED_CUDSS[] - @warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov") - LinearSolve.ALREADY_WARNED_CUDSS[] = true - end - LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES) + error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.") + end +end + +function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b, + assump::OperatorAssumptions{Bool}) where {Tv, Ti} + if LinearSolve.cudss_loaded(A) + LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization) + else + error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.") end end @@ -38,6 +45,13 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR) nothing end +function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC) + if !LinearSolve.cudss_loaded(A) + error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.") + end + nothing +end + function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization; kwargs...) if cache.isfresh @@ -52,14 +66,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact SciMLBase.build_linear_solution(alg, y, nothing, cache) end -function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr, +function LinearSolve.init_cacheval( + alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) # Check if CUDA is functional before creating CUDA arrays if !CUDA.functional() return nothing end - + T = eltype(A) noUnitT = typeof(zero(T)) luT = LinearAlgebra.lutype(noUnitT) @@ -87,7 +102,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl, if !CUDA.functional() return nothing end - + qr(CUDA.CuArray(A)) end @@ -104,7 +119,8 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor SciMLBase.build_linear_solution(alg, y, nothing, cache) end -function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr, +function LinearSolve.init_cacheval( + alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) qr(CUDA.CuArray(A)) @@ -112,27 +128,33 @@ end function LinearSolve.init_cacheval( ::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) + Pl, Pr, maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) nothing end function LinearSolve.init_cacheval( ::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) + Pl, Pr, maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) nothing end function LinearSolve.init_cacheval( ::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u, - Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) + Pl, Pr, maxiters::Int, abstol, reltol, + verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions) nothing end # Mixed precision CUDA LU implementation -function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization; +function SciMLBase.solve!( + cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization; kwargs...) if cache.isfresh - fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + fact, A_gpu_f32, + b_gpu_f32, + u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) # Compute 32-bit type on demand and convert T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 A_f32 = T32.(cache.A) @@ -141,12 +163,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32) cache.isfresh = false end - fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) - + fact, A_gpu_f32, + b_gpu_f32, + u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization) + # Compute types on demand for conversions T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32 Torig = eltype(cache.u) - + # Convert b to Float32, solve, then convert back to original precision b_f32 = T32.(cache.b) copyto!(b_gpu_f32, b_f32) diff --git a/ext/LinearSolveCUDSSExt.jl b/ext/LinearSolveCUDSSExt.jl index 506ada99a..bbcf635f0 100644 --- a/ext/LinearSolveCUDSSExt.jl +++ b/ext/LinearSolveCUDSSExt.jl @@ -4,5 +4,6 @@ using LinearSolve: LinearSolve, cudss_loaded using CUDSS LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true +LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true end