diff --git a/src/host/linalg.jl b/src/host/linalg.jl index 42952e57b..0cc40bf18 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -1,6 +1,6 @@ # integration with LinearAlgebra stdlib -using LinearAlgebra: MulAddMul, wrap +using LinearAlgebra: MulAddMul, wrap, diagm ## transpose and adjoint @@ -62,6 +62,42 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w copyto!(A, Transpose(Array(parent(B)))) end +## diagm + +LinearAlgebra.diagm(kv::Pair{<:Integer,<:AbstractGPUVector}...) = _gpu_diagm(nothing, kv...) +LinearAlgebra.diagm(m::Integer, n::Integer, kv::Pair{<:Integer,<:AbstractGPUVector}...) = _gpu_diagm((Int(m),Int(n)), kv...) +LinearAlgebra.diagm(v::AbstractGPUVector) = LinearAlgebra.diagm(0 => v) +LinearAlgebra.diagm(m::Integer, n::Integer, v::AbstractGPUVector) = LinearAlgebra.diagm(m, n, 0 => v) + +function _gpu_diagm(size, kv::Pair{<:Integer,<:AbstractGPUVector}...) + A = LinearAlgebra.diagm_container(size, kv...) + for p in kv + inds = LinearAlgebra.diagind(A, p.first) + copyto!(view(A, inds), p.second) + end + return A +end + +function LinearAlgebra.diagm_container(size, kv::Pair{<:Integer,<:AbstractGPUVector}...) + T = promote_type(map(x -> eltype(x.second), kv)...) + U = promote_type(T, typeof(zero(T))) + A = similar(kv[1].second, U, LinearAlgebra.diagm_size(size, kv...)...) + fill!(A, zero(U)) + return A +end + +function LinearAlgebra.diagm_size(size::Nothing, kv::Pair{<:Integer,<:AbstractGPUVector}...) + mnmax = mapreduce(x -> length(x.second) + abs(Int(x.first)), max, kv; init=0) + return mnmax, mnmax +end +function LinearAlgebra.diagm_size(size::Tuple{Int,Int}, kv::Pair{<:Integer,<:AbstractGPUVector}...) + mmax = mapreduce(x -> length(x.second) - min(0,Int(x.first)), max, kv; init=0) + nmax = mapreduce(x -> length(x.second) + max(0,Int(x.first)), max, kv; init=0) + m, n = size + (m ≥ mmax && n ≥ nmax) || throw(DimensionMismatch(lazy"invalid size=$size")) + return m, n +end + ## trace function LinearAlgebra.tr(A::AnyGPUMatrix) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index 613ceda2d..804912d82 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -321,6 +321,24 @@ end end + @testset "diagm" begin + @testset "$elty" for elty in (Float32, ComplexF32) + m = 128 + A = AT(rand(elty, m)) + B = AT(rand(elty, m - 1)) + diagA = diagm(A) + diagB = diagm(1 => B) + @test eltype(diagA) == elty + @test eltype(diagB) == elty + @test size(diagA) == (m, m) + @test size(diagB) == (m, m) + diagind_A = diagind(diagA, 0) + diagind_B = diagind(diagB, 1) + @test collect(diagA[diagind_A]) == collect(A) + @test collect(diagB[diagind_B]) == collect(B) + end + end + @testset "mul! + UniformScaling" begin @testset "$elty" for elty in (Float32, ComplexF32) if !(elty in eltypes)