diff --git a/Project.toml b/Project.toml index 8cdff5c..72afd49 100644 --- a/Project.toml +++ b/Project.toml @@ -18,25 +18,28 @@ Tensors = "48a634ad-e948-5137-8d70-aa71f2a747f4" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Exodus = "f57ae99e-f805-4780-bdca-96e224be1e5a" [extensions] FiniteElementContainersAdaptExt = "Adapt" +FiniteElementContainersAMDGPUExt = ["Adapt", "AMDGPU"] FiniteElementContainersCUDAExt = ["Adapt", "CUDA"] FiniteElementContainersExodusExt = "Exodus" [compat] AcceleratedKernels = "0.3" Adapt = "3, 4" +AMDGPU = "1" Aqua = "0.8" Atomix = "1" CUDA = "5" DocStringExtensions = "0.9" Exodus = "0.13" JET = "0.9" -Krylov = "0.9" KernelAbstractions = "0.9" +Krylov = "0.9" LinearAlgebra = "1" Parameters = "0.12" Reexport = "1" diff --git a/ext/FiniteElementContainersAMDGPUExt.jl b/ext/FiniteElementContainersAMDGPUExt.jl new file mode 100644 index 0000000..91f3d30 --- /dev/null +++ b/ext/FiniteElementContainersAMDGPUExt.jl @@ -0,0 +1,26 @@ +module FiniteElementContainersAMDGPUExt + +using Adapt +using AMDGPU +using FiniteElementContainers +using KernelAbstractions + +FiniteElementContainers.gpu(x) = adapt_structure(ROCArray, x) + +function AMDGPU.rocSPARSE.ROCSparseMatrixCSC(asm::SparseMatrixAssembler) + # TODO Not sure what the AMD Backend is called in KernelAbstractions + # I couldn't quite figure it out. This assert statement below though + # would be good for error checking and device consistency. + # @assert typeof(get_backend(asm)) <: CUDABackend "Assembler is not on a CUDA device" + @assert length(asm.pattern.cscnzval) > 0 "Need to assemble the assembler once with SparseArrays.sparse!(assembler)" + @assert all(x -> x != zero(eltype(asm.pattern.cscnzval)), asm.pattern.cscnzval) "Need to assemble the assembler once with SparseArrays.sparse!(assembler)" + n_dofs = FiniteElementContainers.num_unknowns(asm.dof) + return AMDGPU.rocSPARSE.ROCSparseMatrixCSC( + asm.pattern.csccolptr, + asm.pattern.cscrowval, + asm.pattern.cscnzval, + (n_dofs, n_dofs) + ) +end + +end # module