Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 43 additions & 19 deletions ext/LinearSolveCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ using LinearSolve: LinearSolve, is_cusparse, defaultalg, cudss_loaded, DefaultLi
DefaultAlgorithmChoice, ALREADY_WARNED_CUDSS, LinearCache,
needs_concrete_A,
error_no_cudss_lu, init_cacheval, OperatorAssumptions,
CudaOffloadFactorization, CudaOffloadLUFactorization, CudaOffloadQRFactorization,
CudaOffloadFactorization, CudaOffloadLUFactorization,
CudaOffloadQRFactorization,
CUDAOffload32MixedLUFactorization,
SparspakFactorization, KLUFactorization, UMFPACKFactorization, LinearVerbosity
SparspakFactorization, KLUFactorization, UMFPACKFactorization,
LinearVerbosity
using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterface
using SciMLBase: AbstractSciMLOperator

Expand All @@ -23,11 +25,16 @@ function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSR{Tv, Ti}, b,
if LinearSolve.cudss_loaded(A)
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
else
if !LinearSolve.ALREADY_WARNED_CUDSS[]
@warn("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library. Falling back to Krylov")
LinearSolve.ALREADY_WARNED_CUDSS[] = true
end
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.KrylovJL_GMRES)
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSR. Please load this library.")
end
end

function LinearSolve.defaultalg(A::CUDA.CUSPARSE.CuSparseMatrixCSC{Tv, Ti}, b,
assump::OperatorAssumptions{Bool}) where {Tv, Ti}
if LinearSolve.cudss_loaded(A)
LinearSolve.DefaultLinearSolver(LinearSolve.DefaultAlgorithmChoice.LUFactorization)
else
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
end
end

Expand All @@ -38,6 +45,13 @@ function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSR)
nothing
end

function LinearSolve.error_no_cudss_lu(A::CUDA.CUSPARSE.CuSparseMatrixCSC)
if !LinearSolve.cudss_loaded(A)
error("CUDSS.jl is required for LU Factorizations on CuSparseMatrixCSC. Please load this library.")
end
nothing
end

function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFactorization;
kwargs...)
if cache.isfresh
Expand All @@ -52,14 +66,15 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadLUFact
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
function LinearSolve.init_cacheval(
alg::CudaOffloadLUFactorization, A::AbstractArray, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
assumptions::OperatorAssumptions)
# Check if CUDA is functional before creating CUDA arrays
if !CUDA.functional()
return nothing
end

T = eltype(A)
noUnitT = typeof(zero(T))
luT = LinearAlgebra.lutype(noUnitT)
Expand Down Expand Up @@ -87,7 +102,7 @@ function LinearSolve.init_cacheval(alg::CudaOffloadQRFactorization, A, b, u, Pl,
if !CUDA.functional()
return nothing
end

qr(CUDA.CuArray(A))
end

Expand All @@ -104,35 +119,42 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
SciMLBase.build_linear_solution(alg, y, nothing, cache)
end

function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
function LinearSolve.init_cacheval(
alg::CudaOffloadFactorization, A::AbstractArray, b, u, Pl, Pr,
maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool},
assumptions::OperatorAssumptions)
qr(CUDA.CuArray(A))
end

function LinearSolve.init_cacheval(
::SparspakFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
::KLUFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
nothing
end

function LinearSolve.init_cacheval(
::UMFPACKFactorization, A::CUDA.CUSPARSE.CuSparseMatrixCSR, b, u,
Pl, Pr, maxiters::Int, abstol, reltol, verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
Pl, Pr, maxiters::Int, abstol, reltol,
verbose::Union{LinearVerbosity, Bool}, assumptions::OperatorAssumptions)
nothing
end

# Mixed precision CUDA LU implementation
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
function SciMLBase.solve!(
cache::LinearSolve.LinearCache, alg::CUDAOffload32MixedLUFactorization;
kwargs...)
if cache.isfresh
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
fact, A_gpu_f32,
b_gpu_f32,
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)
# Compute 32-bit type on demand and convert
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
A_f32 = T32.(cache.A)
Expand All @@ -141,12 +163,14 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CUDAOffload32Mixe
cache.cacheval = (fact, A_gpu_f32, b_gpu_f32, u_gpu_f32)
cache.isfresh = false
end
fact, A_gpu_f32, b_gpu_f32, u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)

fact, A_gpu_f32,
b_gpu_f32,
u_gpu_f32 = LinearSolve.@get_cacheval(cache, :CUDAOffload32MixedLUFactorization)

# Compute types on demand for conversions
T32 = eltype(cache.A) <: Complex ? ComplexF32 : Float32
Torig = eltype(cache.u)

# Convert b to Float32, solve, then convert back to original precision
b_f32 = T32.(cache.b)
copyto!(b_gpu_f32, b_f32)
Expand Down
1 change: 1 addition & 0 deletions ext/LinearSolveCUDSSExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@ using LinearSolve: LinearSolve, cudss_loaded
using CUDSS

LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSR) = true
LinearSolve.cudss_loaded(A::CUDSS.CUDA.CUSPARSE.CuSparseMatrixCSC) = true

end
Loading