Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve block structure in broadcasting #61

Merged
merged 8 commits into from
Nov 29, 2018
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ import Base: @propagate_inbounds, Array, to_indices, to_index,
unsafe_indices, first, last, size, length, unsafe_length,
unsafe_convert,
getindex, show,
broadcast, eltype, convert, broadcast,
broadcast, eltype, convert, similar,
@_inline_meta, _maybetail, tail, @_propagate_inbounds_meta, reindex,
RangeIndex, Int, Integer, Number,
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate,
BroadcastStyle



import Base: (:), IteratorSize, iterate, axes1
import Base.Broadcast: broadcasted, DefaultArrayStyle
import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted, _max
import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans


Expand All @@ -41,6 +42,8 @@ include("views.jl")
include("blockindexrange.jl")
include("show.jl")
include("blockarrayinterface.jl")
include("blockbroadcast.jl")
# include("linalg.jl")

include("deprecate.jl")

Expand Down
1 change: 1 addition & 0 deletions src/abstractblockarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ end

Block{N, T}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block{N}(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block() = Block{0,Int}()
Block(n::Vararg{T, N}) where {N,T} = Block{N, T}(n)
Block{1}(n::Tuple{T}) where {T} = Block{1, T}(n)
Block{N}(n::NTuple{N, T}) where {N,T} = Block{N, T}(n)
Expand Down
4 changes: 4 additions & 0 deletions src/blockarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@

blocksizes(A) = BlockSizes(vcat.(size(A))...)
getindex(a::Number, ::Block{0}) = a


blocksizes(A::AbstractTriangular) = blocksizes(parent(A))


Expand Down
56 changes: 56 additions & 0 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

# Here we override broadcasting for banded matrices.
# The design is to to exploit the broadcast machinery so that
# banded matrices that conform to the banded matrix interface but are not
# <: AbstractBandedMatrix can get access to fast copyto!, lmul!, rmul!, axpy!, etc.
# using broadcast variants (B .= A, B .= 2.0 .* A, etc.)


abstract type AbstractBlockStyle{N} <: AbstractArrayStyle{N} end
struct BlockStyle{N} <: AbstractBlockStyle{N} end
struct PseudoBlockStyle{N} <: AbstractBlockStyle{N} end


BlockStyle(::Val{N}) where {N} = BlockStyle{N}()
PseudoBlockStyle(::Val{N}) where {N} = PseudoBlockStyle{N}()
BlockStyle{M}(::Val{N}) where {N,M} = BlockStyle{N}()
PseudoBlockStyle{M}(::Val{N}) where {N,M} = PseudoBlockStyle{N}()
BroadcastStyle(::Type{<:BlockArray{<:Any,N}}) where N = BlockStyle{N}()
BroadcastStyle(::Type{<:PseudoBlockArray{<:Any,N}}) where N = PseudoBlockStyle{N}()
BroadcastStyle(::DefaultArrayStyle{N}, b::AbstractBlockStyle{M}) where {M,N} = typeof(b)(_max(Val(M),Val(N)))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you planning to "vendor" _max at some point? Otherwise it doesn't run on Julia master.

BroadcastStyle(a::AbstractBlockStyle{N}, ::DefaultArrayStyle{M}) where {M,N} = typeof(a)(_max(Val(M),Val(N)))
BroadcastStyle(::BlockStyle{M}, ::PseudoBlockStyle{N}) where {M,N} = BlockStyle(_max(Val(M),Val(N)))
BroadcastStyle(::PseudoBlockStyle{M}, ::BlockStyle{N}) where {M,N} = BlockStyle(_max(Val(M),Val(N)))


####
# Default to standard Array broadcast
####


# following code modified from julia/base/broadcast.jl
broadcast_cumulsizes(::Number) = ()
broadcast_cumulsizes(A::AbstractArray) = cumulsizes(blocksizes(A))
broadcast_cumulsizes(A::Broadcasted) = cumulsizes(blocksizes(A))

combine_cumulsizes(A) = A
combine_cumulsizes(A, B, C...) = combine_cumulsizes(_cms(A,B), C...)

_cms(::Tuple{}, ::Tuple{}) = ()
_cms(::Tuple{}, newshape::Tuple) = (newshape[1], _cms((), tail(newshape))...)
_cms(shape::Tuple, ::Tuple{}) = (shape[1], _cms(tail(shape), ())...)
_cms(shape::Tuple, newshape::Tuple) = (sort!(union(shape[1], newshape[1])), _cms(tail(shape), tail(newshape))...)


blocksizes(A::Broadcasted{<:AbstractArrayStyle{N}}) where N =
BlockSizes(combine_cumulsizes(broadcast_cumulsizes.(A.args)...))


copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractBlockStyle}) =
copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes))

similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =
BlockArray{T,N}(undef, blocksizes(bc))

similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
PseudoBlockArray{T,N}(undef, blocksizes(bc))
6 changes: 6 additions & 0 deletions src/blocksizes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,14 @@ abstract type AbstractBlockSizes{N} end
struct BlockSizes{N} <: AbstractBlockSizes{N}
cumul_sizes::NTuple{N, Vector{Int}}
# Takes a tuple of sizes, accumulates them and create a `BlockSizes`
BlockSizes{N}() where N = new{N}()
BlockSizes{N}(cs::NTuple{N,Vector{Int}}) where N = new{N}(cs)
end

BlockSizes() = BlockSizes{0}()

BlockSizes(cs::NTuple{N,Vector{Int}}) where N = BlockSizes{N}(cs)

function BlockSizes(sizes::Vararg{Vector{Int}, N}) where {N}
cumul_sizes = ntuple(k -> _cumul_vec(sizes[k]), Val(N))
return BlockSizes(cumul_sizes)
Expand Down
2 changes: 0 additions & 2 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
@deprecate getindex(block_sizes::BlockSizes, i) cumulsizes(block_sizes, i)
@deprecate getindex(block_sizes::BlockSizes, i, j) cumulsizes(block_sizes, i, j)
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ include("test_blockviews.jl")
include("test_blockrange.jl")
include("test_blockarrayinterface.jl")

include("test_broadcast.jl")

include("../docs/make.jl")
17 changes: 17 additions & 0 deletions test/test_blockarrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using BlockArrays

struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end

@testset "partially implemented block array" begin
Expand Down Expand Up @@ -25,6 +27,21 @@ struct PartiallyImplementedBlockVector <: AbstractBlockArray{Float64,1} end
end
end

@testset "Array block interface" begin
@test blocksizes(1) == BlockArrays.BlockSizes{0}()
@test 1[Block()] == 1

A = randn(5)
@test blocksizes(A) == BlockArrays.BlockSizes([5])
A[Block(1)] == A
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing @test?

view(A,Block(1))[1] = 2
@test A[1] == 2
@test_throws BoundsError A[Block(2)]

A = randn(5,5)
@test A[Block(1,1)] == A
end

@testset "Triangular/Symmetric/Hermitian block arrays" begin
A = PseudoBlockArray{ComplexF64}(undef, (1:4), (1:4))
A .= randn.() .+ randn.().*im
Expand Down
2 changes: 1 addition & 1 deletion test/test_blockarrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ end
@test strides(A) == (1, size(A,1))
x = randn(size(A,2))
y = similar(x)
@test BLAS.gemv!('N', 2.0, A, x, 0.0, y) == 2A*x
@test BLAS.gemv!('N', 2.0, A, x, 0.0, y) 2A*x
end

@testset "lmul!/rmul!" begin
Expand Down
87 changes: 87 additions & 0 deletions test/test_broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using BlockArrays, Test

@testset "broadcast" begin
@testset "BlockArray" begin
A = BlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{1}()

@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa BlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

A = BlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{2}()

@test exp.(A) == exp.(Matrix(A))
@test blocksizes(A) == blocksizes(exp.(A))


@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)
end

@testset "PseudoBlockArray" begin
A = PseudoBlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.PseudoBlockStyle{1}()


@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa PseudoBlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

B = PseudoBlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(B)) == BlockArrays.PseudoBlockStyle{2}()

@test exp.(B) == exp.(Matrix(B))
@test blocksizes(B) == blocksizes(exp.(B))

@test blocksizes(B + B) == blocksizes(B .+ B) == blocksizes(B)
@test blocksizes(B .+ 1) == blocksizes(B)
@test blocksizes(A .+ 1 .+ B) == blocksizes(B)
@test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B)
end

@testset "Mixed" begin
A = BlockArray(randn(6), 1:3)
B = PseudoBlockArray(randn(6), 1:3)

@test A + B isa BlockArray
@test B + A isa BlockArray

@test blocksizes(A + B) == blocksizes(A)

C = randn(6)

A .+ C


@test A + C isa BlockVector{Float64}
@test C + A isa BlockVector{Float64}
@test B + C isa PseudoBlockVector{Float64}
@test C + B isa PseudoBlockVector{Float64}

blocksizes(A+C) == blocksizes(C+A) == blocksizes(A)
blocksizes(B+C) == blocksizes(C+B) == blocksizes(B)
end

@testset "Mixed block sizes" begin
A = BlockArray(randn(6), 1:3)
B = BlockArray(randn(6), fill(2,3))
@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2])

A = BlockArray(randn(6,6), 1:3, 1:3)
B = BlockArray(randn(6,6), fill(2,3), fill(3,2))

@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2], 1:3)
end
end