From 8259818f7f11332b17ae51c8d5e1e9a42168e7ba Mon Sep 17 00:00:00 2001 From: agisga Date: Wed, 27 May 2015 17:33:13 -0500 Subject: [PATCH] Arrays as input to #block_diagonal --- lib/nmatrix/shortcuts.rb | 40 +++++++++++++++++++++++++--------------- spec/shortcuts_spec.rb | 30 +++++++++++++++++------------- 2 files changed, 42 insertions(+), 28 deletions(-) diff --git a/lib/nmatrix/shortcuts.rb b/lib/nmatrix/shortcuts.rb index 9c384df7..9af205c7 100644 --- a/lib/nmatrix/shortcuts.rb +++ b/lib/nmatrix/shortcuts.rb @@ -282,35 +282,45 @@ def diagonal(entries, opts={}) # can receive any number of arguments. Optionally, the last entry of +params+ is # a hash of options from NMatrix#initialize. All other entries of +params+ are # the blocks of the desired block-diagonal matrix, which are supplied - # as square 2D NMatrix objects. + # as square 2D NMatrix objects, or alternatively as arrays of arrays + # (with dimensions corresponding to square matrices). # * *Returns* # - NMatrix of block-diagonal form filled with specified matrices # as the blocks along the diagonal. # # * *Example* # - # a = NMatrix.new([2,2],[1,2,3,4]) - # b = NMatrix.new([1,1],[123],dtype: :int32) - # c = NMatrix.new([3,3],[1,2,3,1,2,3,1,2,3], dtype: :float64) - # m = NMatrix.block_diagonal(a,b,c,dtype: :int64, stype: :yale) + # a = NMatrix.new([2,2], [1,2,3,4]) + # b = NMatrix.new([1,1], [123], dtype: :float64) + # c = Array.new(2) { [[10,10], [10,10]] } + # d = Array[[1,2,3], [4,5,6], [7,8,9]] + # m = NMatrix.block_diagonal(a, b, *c, d, dtype: :int64, stype: :yale) # => # [ - # [1, 2, 0, 0, 0, 0] - # [3, 4, 0, 0, 0, 0] - # [0, 0, 123, 0, 0, 0] - # [0, 0, 0, 1, 2, 3] - # [0, 0, 0, 1, 2, 3] - # [0, 0, 0, 1, 2, 3] + # [1, 2, 0, 0, 0, 0, 0, 0, 0, 0] + # [3, 4, 0, 0, 0, 0, 0, 0, 0, 0] + # [0, 0, 123, 0, 0, 0, 0, 0, 0, 0] + # [0, 0, 0, 10, 10, 0, 0, 0, 0, 0] + # [0, 0, 0, 10, 10, 0, 0, 0, 0, 0] + # [0, 0, 0, 0, 0, 10, 10, 0, 0, 0] + # [0, 0, 0, 0, 0, 10, 10, 0, 0, 0] + # [0, 0, 0, 0, 0, 0, 0, 1, 2, 3] + # [0, 0, 0, 0, 0, 0, 0, 4, 5, 6] + # [0, 0, 0, 0, 0, 0, 0, 7, 8, 9] # ] - # + # def block_diagonal(*params) options = params.last.is_a?(Hash) ? params.pop : {} + params.each_index do |i| + params[i] = params[i].to_nm if params[i].is_a?(Array) # Convert Array to NMatrix + end + block_sizes = [] #holds the size of each matrix block params.each do |b| - raise ArgumentError, "Only NMatrix objects allowed" unless b.is_a?(NMatrix) - raise ArgumentError, "Only 2D matrices allowed" unless b.shape.size == 2 - raise ArgumentError, "Only square matrices allowed" unless b.shape[0] == b.shape[1] + raise ArgumentError, "Only NMatrix or Array objects allowed" unless b.is_a?(NMatrix) + raise ArgumentError, "Only 2D matrices or 2D arrays allowed" unless b.shape.size == 2 + raise ArgumentError, "Only square-shaped blocks allowed" unless b.shape[0] == b.shape[1] block_sizes << b.shape[0] end diff --git a/spec/shortcuts_spec.rb b/spec/shortcuts_spec.rb index 677864fc..07f187a3 100644 --- a/spec/shortcuts_spec.rb +++ b/spec/shortcuts_spec.rb @@ -69,19 +69,23 @@ [:dense, :yale, :list].each do |stype| context "#block_diagonal #{dtype} #{stype}" do it "block_diagonal() creates a block-diagonal NMatrix" do - a = NMatrix.new([2,2],[1,2, - 3,4]) - b = NMatrix.new([1,1],[123.0]) - c = NMatrix.new([3,3],[1,2,3, - 1,2,3, - 1,2,3]) - m = NMatrix.block_diagonal(a,b,c, dtype: dtype, stype: stype) - expect(m).to eq(NMatrix.new([6,6], [1, 2, 0, 0, 0, 0, - 3, 4, 0, 0, 0, 0, - 0, 0, 123, 0, 0, 0, - 0, 0, 0, 1, 2, 3, - 0, 0, 0, 1, 2, 3, - 0, 0, 0, 1, 2, 3], dtype: dtype, stype: stype)) + a = NMatrix.new([2,2], [1,2, + 3,4]) + b = NMatrix.new([1,1], [123.0]) + c = NMatrix.new([3,3], [1,2,3, + 1,2,3, + 1,2,3]) + d = Array[ [1,1,1], [2,2,2], [3,3,3] ] + m = NMatrix.block_diagonal(a, b, c, d, dtype: dtype, stype: stype) + expect(m).to eq(NMatrix.new([9,9], [1, 2, 0, 0, 0, 0, 0, 0, 0, + 3, 4, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 123, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 1, 2, 3, 0, 0, 0, + 0, 0, 0, 1, 2, 3, 0, 0, 0, + 0, 0, 0, 1, 2, 3, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 2, 2, 2, + 0, 0, 0, 0, 0, 0, 3, 3, 3], dtype: dtype, stype: stype)) end end end