From 26305e1bfe8fb7d6f96ad21b99e3763c72c6bf08 Mon Sep 17 00:00:00 2001 From: Tony Kelman Date: Sat, 15 Feb 2014 17:35:58 -0800 Subject: [PATCH] Implements direct sum (blkdiag) for sparse matrices. - Provides new blkdiag export - Provide simple test - Documentation for blkdiag notes that it is only implemented for sparse matrices at the moment. --- base/exports.jl | 1 + base/sparse.jl | 2 +- base/sparse/sparsematrix.jl | 31 +++++++++++++++++++++++++++++++ doc/stdlib/linalg.rst | 4 ++++ test/sparse.jl | 3 +++ 5 files changed, 40 insertions(+), 1 deletion(-) diff --git a/base/exports.jl b/base/exports.jl index 1cce33811b14d..a1ac582d6dad9 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -584,6 +584,7 @@ export bkfact!, bkfact, blas_set_num_threads, + blkdiag, chol, cholfact!, cholfact, diff --git a/base/sparse.jl b/base/sparse.jl index a3421e7bd6e05..7a4f6df7fd363 100644 --- a/base/sparse.jl +++ b/base/sparse.jl @@ -6,7 +6,7 @@ importall Base import Base.NonTupleType, Base.float export SparseMatrixCSC, - dense, diag, diagm, droptol!, dropzeros!, etree, full, + blkdiag, dense, diag, diagm, droptol!, dropzeros!, etree, full, getindex, ishermitian, issparse, issym, istril, istriu, setindex!, sparse, sparsevec, spdiagm, speye, spones, sprand, sprandbool, sprandn, spzeros, symperm, trace, tril, tril!, diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index ec1d0bf1ca7fe..669f075c0e939 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1296,6 +1296,37 @@ function hvcat(rows::(Int...), X::SparseMatrixCSC...) vcat(tmp_rows...) end +function blkdiag(X::SparseMatrixCSC...) + num = length(X) + mX = [ size(x, 1) for x in X ] + nX = [ size(x, 2) for x in X ] + m = sum(mX) + n = sum(nX) + + Tv = promote_type(map(x->eltype(x.nzval), X)...) + Ti = promote_type(map(x->eltype(x.rowval), X)...) + + colptr = Array(Ti, n + 1) + nnzX = [ nfilled(x) for x in X ] + nnz_res = sum(nnzX) + rowval = Array(Ti, nnz_res) + nzval = Array(Tv, nnz_res) + + nnz_sofar = 0 + nX_sofar = 0 + mX_sofar = 0 + for i = 1 : num + colptr[(1 : nX[i] + 1) + nX_sofar] = X[i].colptr + nnz_sofar + rowval[(1 : nnzX[i]) + nnz_sofar] = X[i].rowval + mX_sofar + nzval[(1 : nnzX[i]) + nnz_sofar] = X[i].nzval + nnz_sofar += nnzX[i] + nX_sofar += nX[i] + mX_sofar += mX[i] + end + + SparseMatrixCSC(m, n, colptr, rowval, nzval) +end + ## Structure query functions function issym(A::SparseMatrixCSC) diff --git a/doc/stdlib/linalg.rst b/doc/stdlib/linalg.rst index 9deaa95bada9f..849fa20701635 100644 --- a/doc/stdlib/linalg.rst +++ b/doc/stdlib/linalg.rst @@ -330,6 +330,10 @@ Linear algebra functions in Julia are largely implemented by calling functions f Kronecker tensor product of two vectors or two matrices. +.. function:: blkdiag(A...) + + Concatenate matrices block-diagonally. Currently only implemented for sparse matrices. + .. function:: linreg(x, y) Determine parameters ``[a, b]`` that minimize the squared error between ``y`` and ``a+b*x``. diff --git a/test/sparse.jl b/test/sparse.jl index 0bb29741c3782..3d9b80de8cc2c 100644 --- a/test/sparse.jl +++ b/test/sparse.jl @@ -24,6 +24,9 @@ sz34 = spzeros(3, 4) se77 = speye(7) @test all([se44 sz42 sz41; sz34 se33] == se77) +# check blkdiag concatenation +@test all(blkdiag(se33, se33) == sparse([1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6], ones(6))) + # check concatenation promotion sz41_f32 = spzeros(Float32, 4, 1) se33_i32 = speye(Int32, 3, 3)