diff --git a/Project.toml b/Project.toml index 31c4bf0b..8557c2f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.20.0" +version = "7.20.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -24,7 +24,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" ArrayInterfaceBandedMatricesExt = "BandedMatrices" ArrayInterfaceBlockBandedMatricesExt = "BlockBandedMatrices" ArrayInterfaceCUDAExt = "CUDA" -ArrayInterfaceCUDSSExt = "CUDSS" +ArrayInterfaceCUDSSExt = ["CUDSS", "CUDA"] ArrayInterfaceChainRulesCoreExt = "ChainRulesCore" ArrayInterfaceChainRulesExt = "ChainRules" ArrayInterfaceGPUArraysCoreExt = "GPUArraysCore" @@ -39,7 +39,7 @@ Adapt = "4" BandedMatrices = "1" BlockBandedMatrices = "0.13" CUDA = "5" -CUDSS = "0.2, 0.3, 0.4" +CUDSS = "0.5, 0.6" ChainRules = "1" ChainRulesCore = "1" ChainRulesTestUtils = "1" diff --git a/ext/ArrayInterfaceCUDSSExt.jl b/ext/ArrayInterfaceCUDSSExt.jl index 01fb2395..e6a4908d 100644 --- a/ext/ArrayInterfaceCUDSSExt.jl +++ b/ext/ArrayInterfaceCUDSSExt.jl @@ -2,14 +2,18 @@ module ArrayInterfaceCUDSSExt using ArrayInterface using CUDSS +using CUDA function ArrayInterface.lu_instance(A::CUDSS.CuSparseMatrixCSR) ArrayInterface.LinearAlgebra.checksquare(A) - fact = CudssSolver(A, "G", 'F') T = eltype(A) - n = size(A,1) - x = CudssMatrix(T, n) - b = CudssMatrix(T, n) + n = size(A, 1) + + # Use standard CUDA types (CuVector) instead of deprecated CudssMatrix + x = CUDA.CuVector{T}(undef, n) + b = CUDA.CuVector{T}(undef, n) + + fact = CudssSolver(A, "G", 'F') cudss("analysis", fact, x, b) fact end