Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparsity-preserving outer products #24980

Merged
merged 7 commits into from
Jan 7, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 51 additions & 3 deletions stdlib/SparseArrays/src/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ import Base: map, map!, broadcast, copy, copyto!

using Base: front, tail, to_shape
using ..SparseArrays: SparseVector, SparseMatrixCSC, AbstractSparseVector,
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange
AbstractSparseMatrix, AbstractSparseArray, indtype, nnz, nzrange,
SparseVectorUnion, AdjOrTransSparseVectorUnion, nonzeroinds, nonzeros
using Base.Broadcast: BroadcastStyle, Broadcasted, flatten
using LinearAlgebra

Expand Down Expand Up @@ -92,6 +93,9 @@ is_supported_sparse_broadcast(t::Union{Transpose, Adjoint}, rest...) = is_suppor
is_supported_sparse_broadcast(x, rest...) = axes(x) === () && is_supported_sparse_broadcast(rest...)
is_supported_sparse_broadcast(x::Ref, rest...) = is_supported_sparse_broadcast(rest...)

can_skip_sparsification(f, rest...) = false
can_skip_sparsification(::typeof(*), ::SparseVectorUnion, ::AdjOrTransSparseVectorUnion) = true

# Dispatch on broadcast operations by number of arguments
const Broadcasted0{Style<:Union{Nothing,BroadcastStyle},Axes,F} =
Broadcasted{Style,Axes,F,Tuple{}}
Expand Down Expand Up @@ -810,6 +814,48 @@ end
_finishempty!(C::SparseVector) = C
_finishempty!(C::SparseMatrixCSC) = (fill!(C.colptr, 1); C)

# special case - vector outer product
_copy(f::typeof(*), x::SparseVectorUnion, y::AdjOrTransSparseVectorUnion) = _outer(x, y)
@inline _outer(x::SparseVectorUnion, y::Adjoint) = return _outer(conj, x, parent(y))
@inline _outer(x::SparseVectorUnion, y::Transpose) = return _outer(identity, x, parent(y))
function _outer(trans::Tf, x, y) where Tf
nx = length(x)
ny = length(y)
rowvalx = nonzeroinds(x)
rowvaly = nonzeroinds(y)
nzvalsx = nonzeros(x)
nzvalsy = nonzeros(y)
nnzx = length(nzvalsx)
nnzy = length(nzvalsy)

nnzC = nnzx * nnzy
Tv = typeof(oneunit(eltype(x)) * oneunit(eltype(y)))
Ti = promote_type(indtype(x), indtype(y))
colptrC = zeros(Ti, ny + 1)
rowvalC = Vector{Ti}(undef, nnzC)
nzvalsC = Vector{Tv}(undef, nnzC)

idx = 0
@inbounds colptrC[1] = 1
@inbounds for jj = 1:nnzy
yval = nzvalsy[jj]
iszero(yval) && continue
col = rowvaly[jj]
yval = trans(yval)

for ii = 1:nnzx
xval = nzvalsx[ii]
iszero(xval) && continue
idx += 1
colptrC[col+1] += 1
rowvalC[idx] = rowvalx[ii]
nzvalsC[idx] = xval * yval
end
end
cumsum!(colptrC, colptrC)

return SparseMatrixCSC(nx, ny, colptrC, rowvalC, nzvalsC)
end

# (9) _broadcast_zeropres!/_broadcast_notzeropres! for more than two (input) sparse vectors/matrices
function _broadcast_zeropres!(f::Tf, C::SparseVecOrMat, As::Vararg{SparseVecOrMat,N}) where {Tf,N}
Expand Down Expand Up @@ -1079,8 +1125,10 @@ broadcast(f::Tf, A::SparseMatrixCSC, ::Type{T}) where {Tf,T} = broadcast(x -> f(

function copy(bc::Broadcasted{PromoteToSparse})
bcf = flatten(bc)
if is_supported_sparse_broadcast(bcf.args...)
broadcast(bcf.f, map(_sparsifystructured, bcf.args)...)
if can_skip_sparsification(bcf.f, bcf.args...)
return _copy(bcf.f, bcf.args...)
elseif is_supported_sparse_broadcast(bcf.args...)
return _copy(bcf.f, map(_sparsifystructured, bcf.args)...)
else
return copy(convert(Broadcasted{Broadcast.DefaultArrayStyle{length(axes(bc))}}, bc))
end
Expand Down
3 changes: 3 additions & 0 deletions stdlib/SparseArrays/src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,9 @@ kron(x::SparseVector, A::SparseMatrixCSC) = kron(SparseMatrixCSC(x), A)
kron(A::Union{SparseVector,SparseMatrixCSC}, B::VecOrMat) = kron(A, sparse(B))
kron(A::VecOrMat, B::Union{SparseVector,SparseMatrixCSC}) = kron(sparse(A), B)

# sparse outer product
kron(A::SparseVectorUnion, B::AdjOrTransSparseVectorUnion) = A .* B

## det, inv, cond

inv(A::SparseMatrixCSC) = error("The inverse of a sparse matrix can often be dense and can cause the computer to run out of memory. If you are sure you have enough memory, please convert your matrix to a dense matrix.")
Expand Down
6 changes: 6 additions & 0 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ SparseVector(n::Integer, nzind::Vector{Ti}, nzval::Vector{Tv}) where {Tv,Ti} =
# union of such a view and a SparseVector so we define an alias for such a union as well
const SparseColumnView{T} = SubArray{T,1,<:SparseMatrixCSC,Tuple{Base.Slice{Base.OneTo{Int}},Int},false}
const SparseVectorUnion{T} = Union{SparseVector{T}, SparseColumnView{T}}
const AdjOrTransSparseVectorUnion{T} = LinearAlgebra.AdjOrTrans{T, <:SparseVectorUnion{T}}

### Basic properties

Expand All @@ -58,6 +59,11 @@ function nonzeroinds(x::SparseColumnView)
return y
end

indtype(x::SparseColumnView) = indtype(parent(x))
function nnz(x::SparseColumnView)
rowidx, colidx = parentindices(x)
return length(nzrange(parent(x), colidx))
end

## similar
#
Expand Down
15 changes: 15 additions & 0 deletions stdlib/SparseArrays/test/higherorderfns.jl
Original file line number Diff line number Diff line change
Expand Up @@ -656,4 +656,19 @@ using SparseArrays.HigherOrderFns: SparseVecStyle
@test occursin("no method matching _copy(::typeof(rand))", sprint(showerror, err))
end

@testset "Sparse outer product, for type $T and vector $op" for
op in (transpose, adjoint),
T in (Float64, ComplexF64)
m, n, p = 100, 250, 0.1
A = sprand(T, m, n, p)
a, b = view(A, :, 1), sprand(T, m, p)
av, bv = Vector(a), Vector(b)
v = @inferred a .* op(b)
w = @inferred b .* op(a)
@test issparse(v)
@test issparse(w)
@test v == av .* op(bv)
@test w == bv .* op(av)
end

end # module
6 changes: 6 additions & 0 deletions stdlib/SparseArrays/test/sparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ end
for (m,n) in ((5,10), (13,8), (14,10))
a = sprand(m, 5, 0.4); a_d = Matrix(a)
b = sprand(n, 6, 0.3); b_d = Matrix(b)
v = view(a, :, 1); v_d = Vector(v)
x = sprand(m, 0.4); x_d = Vector(x)
y = sprand(n, 0.3); y_d = Vector(y)
# mat ⊗ mat
Expand All @@ -370,6 +371,11 @@ end
@test Array(kron(x, b)) == kron(x_d, b_d)
@test Array(kron(x_d, b)) == kron(x_d, b_d)
@test Array(kron(x, b_d)) == kron(x_d, b_d)
# vec ⊗ vec'
@test issparse(kron(v, y'))
@test issparse(kron(x, y'))
@test Array(kron(v, y')) == kron(v_d, y_d')
@test Array(kron(x, y')) == kron(x_d, y_d')
# test different types
z = convert(SparseVector{Float16, Int8}, y); z_d = Vector(z)
@test Vector(kron(x, z)) == kron(x_d, z_d)
Expand Down