Skip to content

Commit

Permalink
Implements direct sum (blkdiag) for sparse matrices.
Browse files Browse the repository at this point in the history
- Provides new blkdiag export
- Provide simple test
- Documentation for blkdiag notes that it is only implemented for sparse
  matrices at the moment.
  • Loading branch information
tkelman authored and jiahao committed Feb 18, 2014
1 parent aded336 commit 26305e1
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 1 deletion.
1 change: 1 addition & 0 deletions base/exports.jl
Expand Up @@ -584,6 +584,7 @@ export
bkfact!, bkfact!,
bkfact, bkfact,
blas_set_num_threads, blas_set_num_threads,
blkdiag,
chol, chol,
cholfact!, cholfact!,
cholfact, cholfact,
Expand Down
2 changes: 1 addition & 1 deletion base/sparse.jl
Expand Up @@ -6,7 +6,7 @@ importall Base
import Base.NonTupleType, Base.float import Base.NonTupleType, Base.float


export SparseMatrixCSC, export SparseMatrixCSC,
dense, diag, diagm, droptol!, dropzeros!, etree, full, blkdiag, dense, diag, diagm, droptol!, dropzeros!, etree, full,
getindex, ishermitian, issparse, issym, istril, istriu, getindex, ishermitian, issparse, issym, istril, istriu,
setindex!, sparse, sparsevec, spdiagm, speye, spones, setindex!, sparse, sparsevec, spdiagm, speye, spones,
sprand, sprandbool, sprandn, spzeros, symperm, trace, tril, tril!, sprand, sprandbool, sprandn, spzeros, symperm, trace, tril, tril!,
Expand Down
31 changes: 31 additions & 0 deletions base/sparse/sparsematrix.jl
Expand Up @@ -1296,6 +1296,37 @@ function hvcat(rows::(Int...), X::SparseMatrixCSC...)
vcat(tmp_rows...) vcat(tmp_rows...)
end end


function blkdiag(X::SparseMatrixCSC...)
num = length(X)
mX = [ size(x, 1) for x in X ]
nX = [ size(x, 2) for x in X ]
m = sum(mX)
n = sum(nX)

Tv = promote_type(map(x->eltype(x.nzval), X)...)
Ti = promote_type(map(x->eltype(x.rowval), X)...)

colptr = Array(Ti, n + 1)
nnzX = [ nfilled(x) for x in X ]
nnz_res = sum(nnzX)
rowval = Array(Ti, nnz_res)
nzval = Array(Tv, nnz_res)

nnz_sofar = 0
nX_sofar = 0
mX_sofar = 0
for i = 1 : num
colptr[(1 : nX[i] + 1) + nX_sofar] = X[i].colptr + nnz_sofar
rowval[(1 : nnzX[i]) + nnz_sofar] = X[i].rowval + mX_sofar
nzval[(1 : nnzX[i]) + nnz_sofar] = X[i].nzval
nnz_sofar += nnzX[i]
nX_sofar += nX[i]
mX_sofar += mX[i]
end

SparseMatrixCSC(m, n, colptr, rowval, nzval)
end

## Structure query functions ## Structure query functions


function issym(A::SparseMatrixCSC) function issym(A::SparseMatrixCSC)
Expand Down
4 changes: 4 additions & 0 deletions doc/stdlib/linalg.rst
Expand Up @@ -330,6 +330,10 @@ Linear algebra functions in Julia are largely implemented by calling functions f


Kronecker tensor product of two vectors or two matrices. Kronecker tensor product of two vectors or two matrices.


.. function:: blkdiag(A...)

Concatenate matrices block-diagonally. Currently only implemented for sparse matrices.

.. function:: linreg(x, y) .. function:: linreg(x, y)


Determine parameters ``[a, b]`` that minimize the squared error between ``y`` and ``a+b*x``. Determine parameters ``[a, b]`` that minimize the squared error between ``y`` and ``a+b*x``.
Expand Down
3 changes: 3 additions & 0 deletions test/sparse.jl
Expand Up @@ -24,6 +24,9 @@ sz34 = spzeros(3, 4)
se77 = speye(7) se77 = speye(7)
@test all([se44 sz42 sz41; sz34 se33] == se77) @test all([se44 sz42 sz41; sz34 se33] == se77)


# check blkdiag concatenation
@test all(blkdiag(se33, se33) == sparse([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], ones(6)))

# check concatenation promotion # check concatenation promotion
sz41_f32 = spzeros(Float32, 4, 1) sz41_f32 = spzeros(Float32, 4, 1)
se33_i32 = speye(Int32, 3, 3) se33_i32 = speye(Int32, 3, 3)
Expand Down

0 comments on commit 26305e1

Please sign in to comment.