Skip to content

Commit

Permalink
Preserve block structure in broadcasting (#61)
Browse files Browse the repository at this point in the history
* Preserve block structure

* remove currently unused linalg.jl

* Support Matrix .+ Vector

* fix tests

* Arrays and numbers implement BlockArray interface (#63)

* try fixing 0.7 tests

* make test approx eq

* Support StructuredMatrixStyle's
  • Loading branch information
dlfivefifty committed Nov 29, 2018
1 parent d92f0b4 commit 9024edf
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 7 deletions.
12 changes: 8 additions & 4 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,18 @@ import Base: @propagate_inbounds, Array, to_indices, to_index,
unsafe_indices, first, last, size, length, unsafe_length,
unsafe_convert,
getindex, show,
broadcast, eltype, convert, broadcast,
broadcast, eltype, convert, similar,
@_inline_meta, _maybetail, tail, @_propagate_inbounds_meta, reindex,
RangeIndex, Int, Integer, Number,
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate,
BroadcastStyle



import Base: (:), IteratorSize, iterate, axes1
import Base.Broadcast: broadcasted, DefaultArrayStyle
import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans
import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted
import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans,
StructuredMatrixStyle


include("abstractblockarray.jl")
Expand All @@ -41,6 +43,8 @@ include("views.jl")
include("blockindexrange.jl")
include("show.jl")
include("blockarrayinterface.jl")
include("blockbroadcast.jl")
# include("linalg.jl")

include("deprecate.jl")

Expand Down
1 change: 1 addition & 0 deletions src/abstractblockarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ end

Block{N, T}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block{N}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block() = Block{0,Int}()
Block(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block{1}(n::Tuple{T}) where {T} = Block{1, T}(n)
Block{N}(n::NTuple{N, T}) where {N,T} = Block{N, T}(n)
Expand Down
4 changes: 4 additions & 0 deletions src/blockarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

blocksizes(A) = BlockSizes(vcat.(size(A))...)
getindex(a::Number, ::Block{0}) = a


blocksizes(A::AbstractTriangular) = blocksizes(parent(A))


Expand Down
58 changes: 58 additions & 0 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@

# Here we override broadcasting for banded matrices.
# The design is to to exploit the broadcast machinery so that
# banded matrices that conform to the banded matrix interface but are not
# <: AbstractBandedMatrix can get access to fast copyto!, lmul!, rmul!, axpy!, etc.
# using broadcast variants (B .= A, B .= 2.0 .* A, etc.)


abstract type AbstractBlockStyle{N} <: AbstractArrayStyle{N} end
struct BlockStyle{N} <: AbstractBlockStyle{N} end
struct PseudoBlockStyle{N} <: AbstractBlockStyle{N} end


BlockStyle(::Val{N}) where {N} = BlockStyle{N}()
PseudoBlockStyle(::Val{N}) where {N} = PseudoBlockStyle{N}()
BlockStyle{M}(::Val{N}) where {N,M} = BlockStyle{N}()
PseudoBlockStyle{M}(::Val{N}) where {N,M} = PseudoBlockStyle{N}()
BroadcastStyle(::Type{<:BlockArray{<:Any,N}}) where N = BlockStyle{N}()
BroadcastStyle(::Type{<:PseudoBlockArray{<:Any,N}}) where N = PseudoBlockStyle{N}()
BroadcastStyle(::DefaultArrayStyle{N}, b::AbstractBlockStyle{M}) where {M,N} = typeof(b)(Val(max(M,N)))
BroadcastStyle(a::AbstractBlockStyle{N}, ::DefaultArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M,N)))
BroadcastStyle(::StructuredMatrixStyle, b::AbstractBlockStyle{M}) where {M} = typeof(b)(Val(max(M,2)))
BroadcastStyle(a::AbstractBlockStyle{M}, ::StructuredMatrixStyle) where {M} = typeof(a)(Val(max(M,2)))
BroadcastStyle(::BlockStyle{M}, ::PseudoBlockStyle{N}) where {M,N} = BlockStyle(Val(max(M,N)))
BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(Val(max(M,N)))


####
# Default to standard Array broadcast
####


# following code modified from julia/base/broadcast.jl
broadcast_cumulsizes(::Number) = ()
broadcast_cumulsizes(A::AbstractArray) = cumulsizes(blocksizes(A))
broadcast_cumulsizes(A::Broadcasted) = cumulsizes(blocksizes(A))

combine_cumulsizes(A) = A
combine_cumulsizes(A, B, C...) = combine_cumulsizes(_cms(A,B), C...)

_cms(::Tuple{}, ::Tuple{}) = ()
_cms(::Tuple{}, newshape::Tuple) = (newshape[1], _cms((), tail(newshape))...)
_cms(shape::Tuple, ::Tuple{}) = (shape[1], _cms(tail(shape), ())...)
_cms(shape::Tuple, newshape::Tuple) = (sort!(union(shape[1], newshape[1])), _cms(tail(shape), tail(newshape))...)


blocksizes(A::Broadcasted{<:AbstractArrayStyle{N}}) where N =
BlockSizes(combine_cumulsizes(broadcast_cumulsizes.(A.args)...))


copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractBlockStyle}) =
copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes))

similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =
BlockArray{T,N}(undef, blocksizes(bc))

similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
PseudoBlockArray{T,N}(undef, blocksizes(bc))
6 changes: 6 additions & 0 deletions src/blocksizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ abstract type AbstractBlockSizes{N} end
struct BlockSizes{N} <: AbstractBlockSizes{N}
cumul_sizes::NTuple{N, Vector{Int}}
# Takes a tuple of sizes, accumulates them and create a `BlockSizes`
BlockSizes{N}() where N = new{N}()
BlockSizes{N}(cs::NTuple{N,Vector{Int}}) where N = new{N}(cs)
end

BlockSizes() = BlockSizes{0}()

BlockSizes(cs::NTuple{N,Vector{Int}}) where N = BlockSizes{N}(cs)

function BlockSizes(sizes::Vararg{Vector{Int}, N}) where {N}
cumul_sizes = ntuple(k -> _cumul_vec(sizes[k]), Val(N))
return BlockSizes(cumul_sizes)
Expand Down
2 changes: 0 additions & 2 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
@deprecate getindex(block_sizes::BlockSizes, i) cumulsizes(block_sizes, i)
@deprecate getindex(block_sizes::BlockSizes, i, j) cumulsizes(block_sizes, i, j)
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ include("test_blockviews.jl")
include("test_blockrange.jl")
include("test_blockarrayinterface.jl")

include("test_blockbroadcast.jl")

include("../docs/make.jl")
17 changes: 17 additions & 0 deletions test/test_blockarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using BlockArrays

struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end

@testset "partially implemented block array" begin
Expand Down Expand Up @@ -25,6 +27,21 @@ struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end
end
end

@testset "Array block interface" begin
@test blocksizes(1) == BlockArrays.BlockSizes{0}()
@test 1[Block()] == 1

A = randn(5)
@test blocksizes(A) == BlockArrays.BlockSizes([5])
@test A[Block(1)] == A
view(A,Block(1))[1] = 2
@test A[1] == 2
@test_throws BoundsError A[Block(2)]

A = randn(5,5)
@test A[Block(1,1)] == A
end

@testset "Triangular/Symmetric/Hermitian block arrays" begin
A = PseudoBlockArray{ComplexF64}(undef, (1:4), (1:4))
A .= randn.() .+ randn.().*im
Expand Down
2 changes: 1 addition & 1 deletion test/test_blockarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ end
@test strides(A) == (1, size(A,1))
x = randn(size(A,2))
y = similar(x)
@test BLAS.gemv!('N', 2.0, A, x, 0.0, y) == 2A*x
@test BLAS.gemv!('N', 2.0, A, x, 0.0, y) 2A*x
end

@testset "lmul!/rmul!" begin
Expand Down
89 changes: 89 additions & 0 deletions test/test_blockbroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
using BlockArrays, Test

@testset "broadcast" begin
@testset "BlockArray" begin
A = BlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{1}()

@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa BlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

A = BlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{2}()

@test exp.(A) == exp.(Matrix(A))
@test blocksizes(A) == blocksizes(exp.(A))


@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)
end

@testset "PseudoBlockArray" begin
A = PseudoBlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.PseudoBlockStyle{1}()


@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa PseudoBlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

B = PseudoBlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(B)) == BlockArrays.PseudoBlockStyle{2}()

@test exp.(B) == exp.(Matrix(B))
@test blocksizes(B) == blocksizes(exp.(B))

@test blocksizes(B + B) == blocksizes(B .+ B) == blocksizes(B)
@test blocksizes(B .+ 1) == blocksizes(B)
@test blocksizes(A .+ 1 .+ B) == blocksizes(B)
@test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B)
end

@testset "Mixed" begin
A = BlockArray(randn(6), 1:3)
B = PseudoBlockArray(randn(6), 1:3)

@test A + B isa BlockArray
@test B + A isa BlockArray

@test blocksizes(A + B) == blocksizes(A)

C = randn(6)

@test A + C isa BlockVector{Float64}
@test C + A isa BlockVector{Float64}
@test B + C isa PseudoBlockVector{Float64}
@test C + B isa PseudoBlockVector{Float64}

@test blocksizes(A+C) == blocksizes(C+A) == blocksizes(A)
@test blocksizes(B+C) == blocksizes(C+B) == blocksizes(B)

A = BlockArray(randn(6,6), 1:3, 1:3)
D = Diagonal(ones(6))
@test blocksizes(A + D) == blocksizes(A)
@test blocksizes(B .+ D) == BlockArrays.BlockSizes([1,2,3],[6])
end

@testset "Mixed block sizes" begin
A = BlockArray(randn(6), 1:3)
B = BlockArray(randn(6), fill(2,3))
@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2])

A = BlockArray(randn(6,6), 1:3, 1:3)
B = BlockArray(randn(6,6), fill(2,3), fill(3,2))

@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2], 1:3)
end
end

0 comments on commit 9024edf

Please sign in to comment.