From 9024edf6704f751d14ee0c21467e4870703bb0c0 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Thu, 29 Nov 2018 10:59:31 +0000 Subject: [PATCH] Preserve block structure in broadcasting (#61) * Preserve block structure * remove currently unused linalg.jl * Support Matrix .+ Vector * fix tests * Arrays and numbers implement BlockArray interface (#63) * try fixing 0.7 tests * make test approx eq * Support StructuredMatrixStyle's --- src/BlockArrays.jl | 12 +++-- src/abstractblockarray.jl | 1 + src/blockarrayinterface.jl | 4 ++ src/blockbroadcast.jl | 58 +++++++++++++++++++++ src/blocksizes.jl | 6 +++ src/deprecate.jl | 2 - test/runtests.jl | 2 + test/test_blockarrayinterface.jl | 17 ++++++ test/test_blockarrays.jl | 2 +- test/test_blockbroadcast.jl | 89 ++++++++++++++++++++++++++++++++ 10 files changed, 186 insertions(+), 7 deletions(-) create mode 100644 src/blockbroadcast.jl create mode 100644 test/test_blockbroadcast.jl diff --git a/src/BlockArrays.jl b/src/BlockArrays.jl index 337b5821..3f73b568 100644 --- a/src/BlockArrays.jl +++ b/src/BlockArrays.jl @@ -18,16 +18,18 @@ import Base: @propagate_inbounds, Array, to_indices, to_index, unsafe_indices, first, last, size, length, unsafe_length, unsafe_convert, getindex, show, - broadcast, eltype, convert, broadcast, + broadcast, eltype, convert, similar, @_inline_meta, _maybetail, tail, @_propagate_inbounds_meta, reindex, RangeIndex, Int, Integer, Number, - +, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate + +, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate, + BroadcastStyle import Base: (:), IteratorSize, iterate, axes1 -import Base.Broadcast: broadcasted, DefaultArrayStyle -import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans +import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted +import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans, + StructuredMatrixStyle include("abstractblockarray.jl") @@ -41,6 +43,8 @@ include("views.jl") include("blockindexrange.jl") include("show.jl") include("blockarrayinterface.jl") +include("blockbroadcast.jl") +# include("linalg.jl") include("deprecate.jl") diff --git a/src/abstractblockarray.jl b/src/abstractblockarray.jl index ff3d03c5..43963fcb 100644 --- a/src/abstractblockarray.jl +++ b/src/abstractblockarray.jl @@ -79,6 +79,7 @@ end Block{N, T}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n) Block{N}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n) +Block() = Block{0,Int}() Block(n::Vararg{T, N}) where {N,T} = Block{N, T}(n) Block{1}(n::Tuple{T}) where {T} = Block{1, T}(n) Block{N}(n::NTuple{N, T}) where {N,T} = Block{N, T}(n) diff --git a/src/blockarrayinterface.jl b/src/blockarrayinterface.jl index 9dc47457..67c269c1 100644 --- a/src/blockarrayinterface.jl +++ b/src/blockarrayinterface.jl @@ -1,4 +1,8 @@ +blocksizes(A) = BlockSizes(vcat.(size(A))...) +getindex(a::Number, ::Block{0}) = a + + blocksizes(A::AbstractTriangular) = blocksizes(parent(A)) diff --git a/src/blockbroadcast.jl b/src/blockbroadcast.jl new file mode 100644 index 00000000..11e282cb --- /dev/null +++ b/src/blockbroadcast.jl @@ -0,0 +1,58 @@ + +# Here we override broadcasting for banded matrices. +# The design is to to exploit the broadcast machinery so that +# banded matrices that conform to the banded matrix interface but are not +# <: AbstractBandedMatrix can get access to fast copyto!, lmul!, rmul!, axpy!, etc. +# using broadcast variants (B .= A, B .= 2.0 .* A, etc.) + + +abstract type AbstractBlockStyle{N} <: AbstractArrayStyle{N} end +struct BlockStyle{N} <: AbstractBlockStyle{N} end +struct PseudoBlockStyle{N} <: AbstractBlockStyle{N} end + + +BlockStyle(::Val{N}) where {N} = BlockStyle{N}() +PseudoBlockStyle(::Val{N}) where {N} = PseudoBlockStyle{N}() +BlockStyle{M}(::Val{N}) where {N,M} = BlockStyle{N}() +PseudoBlockStyle{M}(::Val{N}) where {N,M} = PseudoBlockStyle{N}() +BroadcastStyle(::Type{<:BlockArray{<:Any,N}}) where N = BlockStyle{N}() +BroadcastStyle(::Type{<:PseudoBlockArray{<:Any,N}}) where N = PseudoBlockStyle{N}() +BroadcastStyle(::DefaultArrayStyle{N}, b::AbstractBlockStyle{M}) where {M,N} = typeof(b)(Val(max(M,N))) +BroadcastStyle(a::AbstractBlockStyle{N}, ::DefaultArrayStyle{M}) where {M,N} = typeof(a)(Val(max(M,N))) +BroadcastStyle(::StructuredMatrixStyle, b::AbstractBlockStyle{M}) where {M} = typeof(b)(Val(max(M,2))) +BroadcastStyle(a::AbstractBlockStyle{M}, ::StructuredMatrixStyle) where {M} = typeof(a)(Val(max(M,2))) +BroadcastStyle(::BlockStyle{M}, ::PseudoBlockStyle{N}) where {M,N} = BlockStyle(Val(max(M,N))) +BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(Val(max(M,N))) + + +#### +# Default to standard Array broadcast +#### + + +# following code modified from julia/base/broadcast.jl +broadcast_cumulsizes(::Number) = () +broadcast_cumulsizes(A::AbstractArray) = cumulsizes(blocksizes(A)) +broadcast_cumulsizes(A::Broadcasted) = cumulsizes(blocksizes(A)) + +combine_cumulsizes(A) = A +combine_cumulsizes(A, B, C...) = combine_cumulsizes(_cms(A,B), C...) + +_cms(::Tuple{}, ::Tuple{}) = () +_cms(::Tuple{}, newshape::Tuple) = (newshape[1], _cms((), tail(newshape))...) +_cms(shape::Tuple, ::Tuple{}) = (shape[1], _cms(tail(shape), ())...) +_cms(shape::Tuple, newshape::Tuple) = (sort!(union(shape[1], newshape[1])), _cms(tail(shape), tail(newshape))...) + + +blocksizes(A::Broadcasted{<:AbstractArrayStyle{N}}) where N = + BlockSizes(combine_cumulsizes(broadcast_cumulsizes.(A.args)...)) + + +copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractBlockStyle}) = + copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes)) + +similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} = + BlockArray{T,N}(undef, blocksizes(bc)) + +similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} = + PseudoBlockArray{T,N}(undef, blocksizes(bc)) diff --git a/src/blocksizes.jl b/src/blocksizes.jl index 2c0a10ab..8d0ccb6a 100644 --- a/src/blocksizes.jl +++ b/src/blocksizes.jl @@ -8,8 +8,14 @@ abstract type AbstractBlockSizes{N} end struct BlockSizes{N} <: AbstractBlockSizes{N} cumul_sizes::NTuple{N, Vector{Int}} # Takes a tuple of sizes, accumulates them and create a `BlockSizes` + BlockSizes{N}() where N = new{N}() + BlockSizes{N}(cs::NTuple{N,Vector{Int}}) where N = new{N}(cs) end +BlockSizes() = BlockSizes{0}() + +BlockSizes(cs::NTuple{N,Vector{Int}}) where N = BlockSizes{N}(cs) + function BlockSizes(sizes::Vararg{Vector{Int}, N}) where {N} cumul_sizes = ntuple(k -> _cumul_vec(sizes[k]), Val(N)) return BlockSizes(cumul_sizes) diff --git a/src/deprecate.jl b/src/deprecate.jl index 9ab9ad12..e69de29b 100644 --- a/src/deprecate.jl +++ b/src/deprecate.jl @@ -1,2 +0,0 @@ -@deprecate getindex(block_sizes::BlockSizes, i) cumulsizes(block_sizes, i) -@deprecate getindex(block_sizes::BlockSizes, i, j) cumulsizes(block_sizes, i, j) diff --git a/test/runtests.jl b/test/runtests.jl index 17c31f7a..edf7cc7f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,4 +6,6 @@ include("test_blockviews.jl") include("test_blockrange.jl") include("test_blockarrayinterface.jl") +include("test_blockbroadcast.jl") + include("../docs/make.jl") diff --git a/test/test_blockarrayinterface.jl b/test/test_blockarrayinterface.jl index 34004897..f34f14cf 100644 --- a/test/test_blockarrayinterface.jl +++ b/test/test_blockarrayinterface.jl @@ -1,3 +1,5 @@ +using BlockArrays + struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end @testset "partially implemented block array" begin @@ -25,6 +27,21 @@ struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end end end +@testset "Array block interface" begin + @test blocksizes(1) == BlockArrays.BlockSizes{0}() + @test 1[Block()] == 1 + + A = randn(5) + @test blocksizes(A) == BlockArrays.BlockSizes([5]) + @test A[Block(1)] == A + view(A,Block(1))[1] = 2 + @test A[1] == 2 + @test_throws BoundsError A[Block(2)] + + A = randn(5,5) + @test A[Block(1,1)] == A +end + @testset "Triangular/Symmetric/Hermitian block arrays" begin A = PseudoBlockArray{ComplexF64}(undef, (1:4), (1:4)) A .= randn.() .+ randn.().*im diff --git a/test/test_blockarrays.jl b/test/test_blockarrays.jl index 92d5d2f5..0344240b 100644 --- a/test/test_blockarrays.jl +++ b/test/test_blockarrays.jl @@ -310,7 +310,7 @@ end @test strides(A) == (1, size(A,1)) x = randn(size(A,2)) y = similar(x) - @test BLAS.gemv!('N', 2.0, A, x, 0.0, y) == 2A*x + @test BLAS.gemv!('N', 2.0, A, x, 0.0, y) ≈ 2A*x end @testset "lmul!/rmul!" begin diff --git a/test/test_blockbroadcast.jl b/test/test_blockbroadcast.jl new file mode 100644 index 00000000..f656e210 --- /dev/null +++ b/test/test_blockbroadcast.jl @@ -0,0 +1,89 @@ +using BlockArrays, Test + +@testset "broadcast" begin + @testset "BlockArray" begin + A = BlockArray(randn(6), 1:3) + + @test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{1}() + + @test exp.(A) == exp.(Vector(A)) + @test blocksizes(A) == blocksizes(exp.(A)) + + @test A+A isa BlockArray + @test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A) + @test blocksizes(A .+ 1) == blocksizes(A) + + A = BlockArray(randn(6,6), 1:3,1:3) + + @test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{2}() + + @test exp.(A) == exp.(Matrix(A)) + @test blocksizes(A) == blocksizes(exp.(A)) + + + @test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A) + @test blocksizes(A .+ 1) == blocksizes(A) + end + + @testset "PseudoBlockArray" begin + A = PseudoBlockArray(randn(6), 1:3) + + @test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.PseudoBlockStyle{1}() + + + @test exp.(A) == exp.(Vector(A)) + @test blocksizes(A) == blocksizes(exp.(A)) + + @test A+A isa PseudoBlockArray + @test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A) + @test blocksizes(A .+ 1) == blocksizes(A) + + B = PseudoBlockArray(randn(6,6), 1:3,1:3) + + @test BlockArrays.BroadcastStyle(typeof(B)) == BlockArrays.PseudoBlockStyle{2}() + + @test exp.(B) == exp.(Matrix(B)) + @test blocksizes(B) == blocksizes(exp.(B)) + + @test blocksizes(B + B) == blocksizes(B .+ B) == blocksizes(B) + @test blocksizes(B .+ 1) == blocksizes(B) + @test blocksizes(A .+ 1 .+ B) == blocksizes(B) + @test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B) + end + + @testset "Mixed" begin + A = BlockArray(randn(6), 1:3) + B = PseudoBlockArray(randn(6), 1:3) + + @test A + B isa BlockArray + @test B + A isa BlockArray + + @test blocksizes(A + B) == blocksizes(A) + + C = randn(6) + + @test A + C isa BlockVector{Float64} + @test C + A isa BlockVector{Float64} + @test B + C isa PseudoBlockVector{Float64} + @test C + B isa PseudoBlockVector{Float64} + + @test blocksizes(A+C) == blocksizes(C+A) == blocksizes(A) + @test blocksizes(B+C) == blocksizes(C+B) == blocksizes(B) + + A = BlockArray(randn(6,6), 1:3, 1:3) + D = Diagonal(ones(6)) + @test blocksizes(A + D) == blocksizes(A) + @test blocksizes(B .+ D) == BlockArrays.BlockSizes([1,2,3],[6]) + end + + @testset "Mixed block sizes" begin + A = BlockArray(randn(6), 1:3) + B = BlockArray(randn(6), fill(2,3)) + @test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2]) + + A = BlockArray(randn(6,6), 1:3, 1:3) + B = BlockArray(randn(6,6), fill(2,3), fill(3,2)) + + @test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2], 1:3) + end +end