diff --git a/stdlib/SparseArrays/src/higherorderfns.jl b/stdlib/SparseArrays/src/higherorderfns.jl index 63d318af434f9..66ab8b3f60e52 100644 --- a/stdlib/SparseArrays/src/higherorderfns.jl +++ b/stdlib/SparseArrays/src/higherorderfns.jl @@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto! using Base: front, tail, to_shape using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector, - AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange + AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange, + SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros using Base.Broadcast: BroadcastStyle, Broadcasted, flatten using LinearAlgebra @@ -92,6 +93,9 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...) is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...) +can_skip_sparsification(f, rest...) = false +can_skip_sparsification(::typeof(*), ::SparseVectorUnion, ::AdjOrTransSparseVectorUnion) = true + # Dispatch on broadcast operations by number of arguments const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} = Broadcasted{Style,Axes,F,Tuple{}} @@ -810,6 +814,48 @@ end _finishempty!(C::SparseVector) = C _finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C) +# special case - vector outer product +_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y) +@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, parent(y)) +@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, parent(y)) +function _outer(trans::Tf, x, y) where Tf + nx = length(x) + ny = length(y) + rowvalx = nonzeroinds(x) + rowvaly = nonzeroinds(y) + nzvalsx = nonzeros(x) + nzvalsy = nonzeros(y) + nnzx = length(nzvalsx) + nnzy = length(nzvalsy) + + nnzC = nnzx * nnzy + Tv = typeof(oneunit(eltype(x)) * oneunit(eltype(y))) + Ti = promote_type(indtype(x), indtype(y)) + colptrC = zeros(Ti, ny + 1) + rowvalC = Vector{Ti}(undef, nnzC) + nzvalsC = Vector{Tv}(undef, nnzC) + + idx = 0 + @inbounds colptrC[1] = 1 + @inbounds for jj = 1:nnzy + yval = nzvalsy[jj] + iszero(yval) && continue + col = rowvaly[jj] + yval = trans(yval) + + for ii = 1:nnzx + xval = nzvalsx[ii] + iszero(xval) && continue + idx += 1 + colptrC[col+1] += 1 + rowvalC[idx] = rowvalx[ii] + nzvalsC[idx] = xval * yval + end + end + cumsum!(colptrC, colptrC) + + return SparseMatrixCSC(nx, ny, colptrC, rowvalC, nzvalsC) +end # (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N} @@ -1079,8 +1125,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f( function copy(bc::Broadcasted{PromoteToSparse}) bcf = flatten(bc) - if is_supported_sparse_broadcast(bcf.args...) - broadcast(bcf.f, map(_sparsifystructured, bcf.args)...) + if can_skip_sparsification(bcf.f, bcf.args...) + return _copy(bcf.f, bcf.args...) + elseif is_supported_sparse_broadcast(bcf.args...) + return _copy(bcf.f, map(_sparsifystructured, bcf.args)...) else return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc)) end diff --git a/stdlib/SparseArrays/src/linalg.jl b/stdlib/SparseArrays/src/linalg.jl index 0c7fea10b32ac..c0d352bafcaa7 100644 --- a/stdlib/SparseArrays/src/linalg.jl +++ b/stdlib/SparseArrays/src/linalg.jl @@ -1198,6 +1198,9 @@ kron(x::SparseVector, A::SparseMatrixCSC) = kron(SparseMatrixCSC(x), A) kron(A::Union{SparseVector,SparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B)) kron(A::VecOrMat, B::Union{SparseVector,SparseMatrixCSC}) = kron(sparse(A), B) +# sparse outer product +kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B + ## det, inv, cond inv(A::SparseMatrixCSC) = error("The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please convert your matrix to a dense matrix.") diff --git a/stdlib/SparseArrays/src/sparsevector.jl b/stdlib/SparseArrays/src/sparsevector.jl index ea46ebd213c2d..305e93ee1a0bb 100644 --- a/stdlib/SparseArrays/src/sparsevector.jl +++ b/stdlib/SparseArrays/src/sparsevector.jl @@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} = # union of such a view and a SparseVector so we define an alias for such a union as well const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false} const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}} +const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}} ### Basic properties @@ -58,6 +59,11 @@ function nonzeroinds(x::SparseColumnView) return y end +indtype(x::SparseColumnView) = indtype(parent(x)) +function nnz(x::SparseColumnView) + rowidx, colidx = parentindices(x) + return length(nzrange(parent(x), colidx)) +end ## similar # diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index 5b1423582f0d5..5d353f431cb38 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -656,4 +656,19 @@ using SparseArrays.HigherOrderFns: SparseVecStyle @test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err)) end +@testset "Sparse outer product, for type $T and vector $op" for + op in (transpose, adjoint), + T in (Float64, ComplexF64) + m, n, p = 100, 250, 0.1 + A = sprand(T, m, n, p) + a, b = view(A, :, 1), sprand(T, m, p) + av, bv = Vector(a), Vector(b) + v = @inferred a .* op(b) + w = @inferred b .* op(a) + @test issparse(v) + @test issparse(w) + @test v == av .* op(bv) + @test w == bv .* op(av) +end + end # module diff --git a/stdlib/SparseArrays/test/sparse.jl b/stdlib/SparseArrays/test/sparse.jl index dfb9aa9a284fa..1c8bf76e166e5 100644 --- a/stdlib/SparseArrays/test/sparse.jl +++ b/stdlib/SparseArrays/test/sparse.jl @@ -352,6 +352,7 @@ end for (m,n) in ((5,10), (13,8), (14,10)) a = sprand(m, 5, 0.4); a_d = Matrix(a) b = sprand(n, 6, 0.3); b_d = Matrix(b) + v = view(a, :, 1); v_d = Vector(v) x = sprand(m, 0.4); x_d = Vector(x) y = sprand(n, 0.3); y_d = Vector(y) # mat ⊗ mat @@ -370,6 +371,11 @@ end @test Array(kron(x, b)) == kron(x_d, b_d) @test Array(kron(x_d, b)) == kron(x_d, b_d) @test Array(kron(x, b_d)) == kron(x_d, b_d) + # vec ⊗ vec' + @test issparse(kron(v, y')) + @test issparse(kron(x, y')) + @test Array(kron(v, y')) == kron(v_d, y_d') + @test Array(kron(x, y')) == kron(x_d, y_d') # test different types z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z) @test Vector(kron(x, z)) == kron(x_d, z_d)