Skip to content

Commit

Permalink
Faster and more general reductions for sparse matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
simonster committed Mar 16, 2015
1 parent 9cd496e commit ff3723d
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 73 deletions.
296 changes: 233 additions & 63 deletions base/sparse/sparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -714,73 +714,191 @@ end # macro
(.<)(A::SparseMatrixCSC, B::Number) = (.<)(full(A), B)
(.<)(A::Number, B::SparseMatrixCSC) = (.<)(A, full(B))

# Reductions

# TODO: Should the results of sparse reductions be sparse?
function reducedim{Tv,Ti}(f::Function, A::SparseMatrixCSC{Tv,Ti}, region, v0)
if region == 1 || region == (1,)

S = Array(Tv, 1, A.n)
@inbounds for i = 1 : A.n
Si = v0
ccount = 0
for j = A.colptr[i] : A.colptr[i+1]-1
Si = f(Si, A.nzval[j])
ccount += 1
end
if ccount != A.m; Si = f(Si, zero(Tv)); end
S[i] = Si
## Reductions

# In general, output of sparse matrix reductions will not be sparse,
# and computing reductions along columns into SparseMatrixCSC is
# non-trivial, so use Arrays for output
Base.reducedim_initarray{R}(A::SparseMatrixCSC, region, v0, ::Type{R}) =
fill!(Array(R,Base.reduced_dims(A,region)), v0)
Base.reducedim_initarray0{R}(A::SparseMatrixCSC, region, v0, ::Type{R}) =
fill!(Array(R,Base.reduced_dims0(A,region)), v0)

# General mapreduce
function _mapreducezeros(f, op, T::Type, nzeros::Int, v0)
nzeros == 0 && return v0

# Reduce over first zero
zeroval = f(zero(T))
v = op(v0, zeroval)
isequal(v, v0) && return v

# Reduce over remaining zeros
for i = 2:nzeros
lastv = v
v = op(v, zeroval)
# Bail out early if we reach a fixed point
isequal(v, lastv) && break
end

v
end

function Base._mapreduce{T}(f, op, A::SparseMatrixCSC{T})
z = nnz(A)
n = length(A)
if z == 0
if n == 0
Base.mr_empty(f, op, T)
else
_mapreducezeros(f, op, T, n-z-1, f(zero(T)))
end
return S

elseif region == 2 || region == (2,)
else
_mapreducezeros(f, op, T, n-z, Base._mapreduce(f, op, A.nzval))
end
end

S = fill(v0, A.m, 1)
rcounts = zeros(Ti, A.m)
@inbounds for i = 1 : A.n, j = A.colptr[i] : A.colptr[i+1]-1
row = A.rowval[j]
S[row] = f(S[row], A.nzval[j])
rcounts[row] += 1
end
for i = 1:A.m
if rcounts[i] != A.n; S[i] = f(S[i], zero(Tv)); end
end
return S
# Specialized mapreduce for AddFun/MulFun
_mapreducezeros(f, ::Base.AddFun, T::Type, nzeros::Int, v0) =
nzeros == 0 ? v0 : f(zero(T))*nzeros + v0
_mapreducezeros(f, ::Base.MulFun, T::Type, nzeros::Int, v0) =
nzeros == 0 ? v0 : f(zero(T))^nzeros * v0

elseif region == (1,2)
function Base._mapreduce{T}(f, op::Base.MulFun, A::SparseMatrixCSC{T})
nzeros = length(A)-nnz(A)
if nzeros == 0
# No zeros, so don't compute f(0) since it might throw
Base._mapreduce(f, op, A.nzval)
else
v = f(zero(T))^(nzeros)
# Bail out early if initial reduction value is zero
v == zero(T) ? v : v*Base._mapreduce(f, op, A.nzval)
end
end

S = v0
@inbounds for i = 1 : A.n, j = A.colptr[i] : A.colptr[i+1]-1
S = f(S, A.nzval[j])
# General mapreducedim
function _mapreducerows!{T}(f, op, R::AbstractArray, A::SparseMatrixCSC{T})
colptr = A.colptr
rowval = A.rowval
nzval = A.nzval
m, n = size(A)
@inbounds for col = 1:n
r = R[1, col]
@simd for j = colptr[col]:colptr[col+1]-1
r = op(r, f(nzval[j]))
end
if nnz(A) != A.m*A.n; S = f(S, zero(Tv)); end

return fill(S, 1, 1)

else
throw(ArgumentError("invalid value for region; must be 1, 2, or (1,2)"))
R[1, col] = _mapreducezeros(f, op, T, m-(colptr[col+1]-colptr[col]), r)
end
R
end

function maximum{T}(A::SparseMatrixCSC{T})
isempty(A) && throw(ArgumentError("argument must not be empty"))
(nnz(A) == 0) && (return zero(T))
m = maximum(A.nzval)
nnz(A)!=length(A) ? max(m,zero(T)) : m
function _mapreducecols!{Tv,Ti}(f, op, R::AbstractArray, A::SparseMatrixCSC{Tv,Ti})
colptr = A.colptr
rowval = A.rowval
nzval = A.nzval
m, n = size(A)
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = colptr[col]:colptr[col+1]-1
row = rowval[j]
R[row, 1] = op(R[row, 1], f(nzval[j]))
rownz[row] -= 1
end
end
@inbounds for i = 1:m
R[i, 1] = _mapreducezeros(f, op, Tv, rownz[i], R[i, 1])
end
R
end

maximum{T}(A::SparseMatrixCSC{T}, region) =
isempty(A) ? similar(A, reduced_dims0(A,region)) : reducedim(Base.scalarmax,A,region,typemin(T))
function Base._mapreducedim!{T}(f, op, R::AbstractArray, A::SparseMatrixCSC{T})
lsiz = Base.check_reducedims(R,A)
isempty(A) && return R

function minimum{T}(A::SparseMatrixCSC{T})
isempty(A) && throw(ArgumentError("argument must not be empty"))
(nnz(A) == 0) && (return zero(T))
m = minimum(A.nzval)
nnz(A)!=length(A) ? min(m,zero(T)) : m
if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = mapreduce(f, op, A)
elseif size(R, 1) == 1
# Reduction along rows
_mapreducerows!(f, op, R, A)
elseif size(R, 2) == 1
# Reduction along columns
_mapreducecols!(f, op, R, A)
else
# Reduction along a dimension > 2
# Compute op(R, f(A))
m, n = size(A)
nzval = A.nzval
if length(nzval) == m*n
# No zeros, so don't compute f(0) since it might throw
for col = 1:n
@simd for row = 1:size(A, 1)
@inbounds R[row, col] = op(R[row, col], f(nzval[(col-1)*m+row]))
end
end
else
colptr = A.colptr
rowval = A.rowval
zeroval = f(zero(T))
@inbounds for col = 1:n
lastrow = 0
for j = colptr[col]:colptr[col+1]-1
row = rowval[j]
@simd for i = lastrow+1:row-1 # Zeros before this nonzero
R[i, col] = op(R[i, col], zeroval)
end
R[row, col] = op(R[row, col], f(nzval[j]))
lastrow = row
end
@simd for i = lastrow+1:m # Zeros at end
R[i, col] = op(R[i, col], zeroval)
end
end
end
end
R
end

minimum{T}(A::SparseMatrixCSC{T}, region) =
isempty(A) ? similar(A, reduced_dims0(A,region)) : reducedim(Base.scalarmin,A,region,typemax(T))
# Specialized mapreducedim for AddFun cols to avoid allocating a
# temporary array when f(0) == 0
function _mapreducecols!{Tv,Ti}(f, op::Base.AddFun, R::AbstractArray, A::SparseMatrixCSC{Tv,Ti})
nzval = A.nzval
m, n = size(A)
if length(nzval) == m*n
# No zeros, so don't compute f(0) since it might throw
for col = 1:n
@simd for row = 1:size(A, 1)
@inbounds R[row, 1] = op(R[row, 1], f(nzval[(col-1)*m+row]))
end
end
else
colptr = A.colptr
rowval = A.rowval
zeroval = f(zero(Tv))
if isequal(zeroval, zero(Tv))
# Case where f(0) == 0
@inbounds for col = 1:size(A, 2)
@simd for j = colptr[col]:colptr[col+1]-1
R[rowval[j], 1] += f(nzval[j])
end
end
else
# Case where f(0) != 0
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:size(A, 2)
@simd for j = colptr[col]:colptr[col+1]-1
row = rowval[j]
R[row, 1] += f(nzval[j])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*zeroval
end
end
end
R
end

# findmax/min and indmax/min methods
function _findz{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}, rows=1:A.m, cols=1:A.n)
Expand Down Expand Up @@ -881,15 +999,6 @@ findmax{Tv,Ti}(A::SparseMatrixCSC{Tv,Ti}) = (r=findmax(A,(1,2)); (r[1][1], r[2][
indmin(A::SparseMatrixCSC) = findmin(A)[2]
indmax(A::SparseMatrixCSC) = findmax(A)[2]


sum{T}(A::SparseMatrixCSC{T}) = sum(A.nzval)
sum{T}(A::SparseMatrixCSC{T}, region) = reducedim(+,A,region,zero(T))

prod{T}(A::SparseMatrixCSC{T}) = nnz(A)!=length(A) ? zero(T) : prod(A.nzval)
prod{T}(A::SparseMatrixCSC{T}, region) = reducedim(*,A,region,one(T))

mean(A::SparseMatrixCSC, region::Integer) = sum(A, region) / size(A, region)

#all(A::SparseMatrixCSC{Bool}, region) = reducedim(all,A,region,true)
#any(A::SparseMatrixCSC{Bool}, region) = reducedim(any,A,region,false)
#sum(A::SparseMatrixCSC{Bool}, region) = reducedim(+,A,region,0,Int)
Expand Down Expand Up @@ -2376,3 +2485,64 @@ function hash{T}(A::SparseMatrixCSC{T}, h::UInt)
h = hashrun(lastnz, runlength, h) # Hash previous run
hashrun(0, length(A)-lastidx, h) # Hash zeros at end
end

## Statistics

# This is the function that does the reduction underlying var/std
function Base.centralize_sumabs2!{S,Tv,Ti}(R::AbstractArray{S}, A::SparseMatrixCSC{Tv,Ti}, means::AbstractArray)
lsiz = Base.check_reducedims(R,A)
size(means) == size(R) || error("size of means must match size of R")
isempty(R) || fill!(R, zero(S))
isempty(A) && return R

colptr = A.colptr
rowval = A.rowval
nzval = A.nzval
m = size(A, 1)
n = size(A, 2)

if size(R, 1) == size(R, 2) == 1
# Reduction along both columns and rows
R[1, 1] = Base.centralize_sumabs2(A, means[1])
elseif size(R, 1) == 1
# Reduction along rows
@inbounds for col = 1:n
mu = means[col]
r = convert(S, (m-colptr[col+1]+colptr[col])*abs2(mu))
@simd for j = colptr[col]:colptr[col+1]-1
r += abs2(nzval[j] - mu)
end
R[1, col] = r
end
elseif size(R, 2) == 1
# Reduction along columns
rownz = fill(convert(Ti, n), m)
@inbounds for col = 1:n
@simd for j = colptr[col]:colptr[col+1]-1
row = rowval[j]
R[row, 1] += abs2(nzval[j] - means[row])
rownz[row] -= 1
end
end
for i = 1:m
R[i, 1] += rownz[i]*abs2(means[i])
end
else
# Reduction along a dimension > 2
@inbounds for col = 1:n
lastrow = 0
@simd for j = colptr[col]:colptr[col+1]-1
row = rowval[j]
for i = lastrow+1:row-1
R[i, col] = abs2(means[i, col])
end
R[row, col] = abs2(nzval[j] - means[row, col])
lastrow = row
end
for i = lastrow+1:m
R[i, col] = abs2(means[i, col])
end
end
end
return R
end
4 changes: 3 additions & 1 deletion base/statistics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ immutable CentralizedAbs2Fun{T<:Number} <: Func{1}
m::T
end
call(f::CentralizedAbs2Fun, x) = abs2(x - f.m)
centralize_sumabs2(A::AbstractArray, m::Number) =
mapreduce(CentralizedAbs2Fun(m), AddFun(), A)
centralize_sumabs2(A::AbstractArray, m::Number, ifirst::Int, ilast::Int) =
mapreduce_impl(CentralizedAbs2Fun(m), AddFun(), A, ifirst, ilast)

Expand Down Expand Up @@ -137,7 +139,7 @@ function varm{T}(A::AbstractArray{T}, m::Number; corrected::Bool=true)
n = length(A)
n == 0 && return convert(momenttype(T), NaN)
n == 1 && return convert(momenttype(T), abs2(A[1] - m)/(1 - Int(corrected)))
return centralize_sumabs2(A, m, 1, n) / (n - Int(corrected))
return centralize_sumabs2(A, m) / (n - Int(corrected))
end

function varm!{S}(R::AbstractArray{S}, A::AbstractArray, m::AbstractArray; corrected::Bool=true)
Expand Down
6 changes: 3 additions & 3 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,13 @@ const × = cross
include("broadcast.jl")
importall .Broadcast

# statistics
include("statistics.jl")

# sparse matrices and sparse linear algebra
include("sparse.jl")
importall .SparseMatrix

# statistics
include("statistics.jl")

# signal processing
include("fftw.jl")
include("dsp.jl")
Expand Down
Loading

0 comments on commit ff3723d

Please sign in to comment.