Skip to content

Commit

Permalink
Sparsity-preserving outer products (#24980)
Browse files Browse the repository at this point in the history
* Add indtype and nnz definitions for SparseColumnView

* Handle sparse outer products specially in broadcast

* Add specialized kron for sparse outer products

* Add tests

* Support unitful types

* Address review comments.

* Change is_specialcase_sparse_broadcast -> can_skip_sparsification.
* Lift parent(y) to one function earlier for clarify

* Simply call _copy instead of passing through the broadcast machinery again
  • Loading branch information
jmert authored and andreasnoack committed Jan 7, 2019
1 parent f8f2045 commit dffe119
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
54 changes: 51 additions & 3 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange,
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using LinearAlgebra

Expand Down Expand Up @@ -92,6 +93,9 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)

can_skip_sparsification(f, rest...) = false
can_skip_sparsification(::typeof(*), ::SparseVectorUnion, ::AdjOrTransSparseVectorUnion) = true

# Dispatch on broadcast operations by number of arguments
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
Broadcasted{Style,Axes,F,Tuple{}}
Expand Down Expand Up @@ -810,6 +814,48 @@ end
_finishempty!(C::SparseVector) = C
_finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C)

# special case - vector outer product
_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y)
@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, parent(y))
@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, parent(y))
function _outer(trans::Tf, x, y) where Tf
nx = length(x)
ny = length(y)
rowvalx = nonzeroinds(x)
rowvaly = nonzeroinds(y)
nzvalsx = nonzeros(x)
nzvalsy = nonzeros(y)
nnzx = length(nzvalsx)
nnzy = length(nzvalsy)

nnzC = nnzx * nnzy
Tv = typeof(oneunit(eltype(x)) * oneunit(eltype(y)))
Ti = promote_type(indtype(x), indtype(y))
colptrC = zeros(Ti, ny + 1)
rowvalC = Vector{Ti}(undef, nnzC)
nzvalsC = Vector{Tv}(undef, nnzC)

idx = 0
@inbounds colptrC[1] = 1
@inbounds for jj = 1:nnzy
yval = nzvalsy[jj]
iszero(yval) && continue
col = rowvaly[jj]
yval = trans(yval)

for ii = 1:nnzx
xval = nzvalsx[ii]
iszero(xval) && continue
idx += 1
colptrC[col+1] += 1
rowvalC[idx] = rowvalx[ii]
nzvalsC[idx] = xval * yval
end
end
cumsum!(colptrC, colptrC)

return SparseMatrixCSC(nx, ny, colptrC, rowvalC, nzvalsC)
end

# (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices
function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N}
Expand Down Expand Up @@ -1079,8 +1125,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(

function copy(bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
if is_supported_sparse_broadcast(bcf.args...)
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
if can_skip_sparsification(bcf.f, bcf.args...)
return _copy(bcf.f, bcf.args...)
elseif is_supported_sparse_broadcast(bcf.args...)
return _copy(bcf.f, map(_sparsifystructured, bcf.args)...)
else
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
end
Expand Down
3 changes: 3 additions & 0 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,9 @@ kron(x::SparseVector, A::SparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
kron(A::Union{SparseVector,SparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
kron(A::VecOrMat, B::Union{SparseVector,SparseMatrixCSC}) = kron(sparse(A), B)

# sparse outer product
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B

## det, inv, cond

inv(A::SparseMatrixCSC) = error("The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please convert your matrix to a dense matrix.")
Expand Down
6 changes: 6 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
# union of such a view and a SparseVector so we define an alias for such a union as well
const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false}
const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}}
const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}}

### Basic properties

Expand All @@ -58,6 +59,11 @@ function nonzeroinds(x::SparseColumnView)
return y
end

indtype(x::SparseColumnView) = indtype(parent(x))
function nnz(x::SparseColumnView)
rowidx, colidx = parentindices(x)
return length(nzrange(parent(x), colidx))
end

## similar
#
Expand Down
15 changes: 15 additions & 0 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -656,4 +656,19 @@ using SparseArrays.HigherOrderFns: SparseVecStyle
@test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err))
end

@testset "Sparse outer product, for type $T and vector $op" for
op in (transpose, adjoint),
T in (Float64, ComplexF64)
m, n, p = 100, 250, 0.1
A = sprand(T, m, n, p)
a, b = view(A, :, 1), sprand(T, m, p)
av, bv = Vector(a), Vector(b)
v = @inferred a .* op(b)
w = @inferred b .* op(a)
@test issparse(v)
@test issparse(w)
@test v == av .* op(bv)
@test w == bv .* op(av)
end

end # module
6 changes: 6 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ end
for (m,n) in ((5,10), (13,8), (14,10))
a = sprand(m, 5, 0.4); a_d = Matrix(a)
b = sprand(n, 6, 0.3); b_d = Matrix(b)
v = view(a, :, 1); v_d = Vector(v)
x = sprand(m, 0.4); x_d = Vector(x)
y = sprand(n, 0.3); y_d = Vector(y)
# mat ⊗ mat
Expand All @@ -370,6 +371,11 @@ end
@test Array(kron(x, b)) == kron(x_d, b_d)
@test Array(kron(x_d, b)) == kron(x_d, b_d)
@test Array(kron(x, b_d)) == kron(x_d, b_d)
# vec ⊗ vec'
@test issparse(kron(v, y'))
@test issparse(kron(x, y'))
@test Array(kron(v, y')) == kron(v_d, y_d')
@test Array(kron(x, y')) == kron(x_d, y_d')
# test different types
z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z)
@test Vector(kron(x, z)) == kron(x_d, z_d)
Expand Down

0 comments on commit dffe119

Please sign in to comment.