Permalink
Browse files

Implements direct sum (blkdiag) for sparse matrices.

- Provides new blkdiag export
- Provide simple test
- Documentation for blkdiag notes that it is only implemented for sparse
  matrices at the moment.
  • Loading branch information...
tkelman authored and jiahao committed Feb 16, 2014
1 parent aded336 commit 26305e1bfe8fb7d6f96ad21b99e3763c72c6bf08
Showing with 40 additions and 1 deletion.
  1. +1 −0 base/exports.jl
  2. +1 −1 base/sparse.jl
  3. +31 −0 base/sparse/sparsematrix.jl
  4. +4 −0 doc/stdlib/linalg.rst
  5. +3 −0 test/sparse.jl
View
@@ -584,6 +584,7 @@ export
bkfact!,
bkfact,
blas_set_num_threads,
+ blkdiag,
chol,
cholfact!,
cholfact,
View
@@ -6,7 +6,7 @@ importall Base
import Base.NonTupleType, Base.float
export SparseMatrixCSC,
- dense, diag, diagm, droptol!, dropzeros!, etree, full,
+ blkdiag, dense, diag, diagm, droptol!, dropzeros!, etree, full,
getindex, ishermitian, issparse, issym, istril, istriu,
setindex!, sparse, sparsevec, spdiagm, speye, spones,
sprand, sprandbool, sprandn, spzeros, symperm, trace, tril, tril!,
@@ -1296,6 +1296,37 @@ function hvcat(rows::(Int...), X::SparseMatrixCSC...)
vcat(tmp_rows...)
end
+function blkdiag(X::SparseMatrixCSC...)
+ num = length(X)
+ mX = [ size(x, 1) for x in X ]
+ nX = [ size(x, 2) for x in X ]
+ m = sum(mX)
+ n = sum(nX)
+
+ Tv = promote_type(map(x->eltype(x.nzval), X)...)
+ Ti = promote_type(map(x->eltype(x.rowval), X)...)
+
+ colptr = Array(Ti, n + 1)
+ nnzX = [ nfilled(x) for x in X ]
+ nnz_res = sum(nnzX)
+ rowval = Array(Ti, nnz_res)
+ nzval = Array(Tv, nnz_res)
+
+ nnz_sofar = 0
+ nX_sofar = 0
+ mX_sofar = 0
+ for i = 1 : num
+ colptr[(1 : nX[i] + 1) + nX_sofar] = X[i].colptr + nnz_sofar
+ rowval[(1 : nnzX[i]) + nnz_sofar] = X[i].rowval + mX_sofar
+ nzval[(1 : nnzX[i]) + nnz_sofar] = X[i].nzval
+ nnz_sofar += nnzX[i]
+ nX_sofar += nX[i]
+ mX_sofar += mX[i]
+ end
+
+ SparseMatrixCSC(m, n, colptr, rowval, nzval)
+end
+
## Structure query functions
function issym(A::SparseMatrixCSC)
View
@@ -330,6 +330,10 @@ Linear algebra functions in Julia are largely implemented by calling functions f
Kronecker tensor product of two vectors or two matrices.
+.. function:: blkdiag(A...)
+
+ Concatenate matrices block-diagonally. Currently only implemented for sparse matrices.
+
.. function:: linreg(x, y)
Determine parameters ``[a, b]`` that minimize the squared error between ``y`` and ``a+b*x``.
View
@@ -24,6 +24,9 @@ sz34 = spzeros(3, 4)
se77 = speye(7)
@test all([se44 sz42 sz41; sz34 se33] == se77)
+# check blkdiag concatenation
+@test all(blkdiag(se33, se33) == sparse([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], ones(6)))
+
# check concatenation promotion
sz41_f32 = spzeros(Float32, 4, 1)
se33_i32 = speye(Int32, 3, 3)

0 comments on commit 26305e1

Please sign in to comment.