diff --git a/src/host/construction.jl b/src/host/construction.jl index a8da22d2..454ff3d9 100644 --- a/src/host/construction.jl +++ b/src/host/construction.jl @@ -59,6 +59,8 @@ function _one(unit::T, x::AbstractGPUMatrix) where {T} m,n = size(x) m==n || throw(DimensionMismatch("multiplicative identity defined only for square matrices")) I = similar(x, T) + isempty(I) && return I + fill!(I, zero(T)) kernel = identity_kernel(get_backend(I)) kernel(I, m, unit; ndrange=m) diff --git a/test/testsuite/construction.jl b/test/testsuite/construction.jl index 0ed33933..50ae02e2 100644 --- a/test/testsuite/construction.jl +++ b/test/testsuite/construction.jl @@ -130,6 +130,10 @@ @test A isa AT{T,2} @test Array(A) == one(rand(T, 2, 2)) + A = one(AT(rand(T, 0, 0))) + @test A isa AT{T,2} + @test Array(A) == one(rand(T, 0, 0)) + A = oneunit(AT(rand(T, 2, 2))) @test A isa AT{T,2} @test Array(A) == oneunit(rand(T, 2, 2)) @@ -193,7 +197,7 @@ x1 = AT{Float32, 2}(I, (0, 3)) @test Array(x1) ≈ x - + copyto!(x1, I) @test Array(x1) ≈ x end