Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preserve block structure in broadcasting #61

Merged
merged 8 commits into from
Nov 29, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions src/BlockArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@ import Base: @propagate_inbounds, Array, to_indices, to_index,
unsafe_indices, first, last, size, length, unsafe_length,
unsafe_convert,
getindex, show,
broadcast, eltype, convert, broadcast,
broadcast, eltype, convert, similar,
@_inline_meta, _maybetail, tail, @_propagate_inbounds_meta, reindex,
RangeIndex, Int, Integer, Number,
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate
+, -, min, max, *, isless, in, copy, copyto!, axes, @deprecate,
BroadcastStyle



import Base: (:), IteratorSize, iterate, axes1
import Base.Broadcast: broadcasted, DefaultArrayStyle
import Base.Broadcast: broadcasted, DefaultArrayStyle, AbstractArrayStyle, Broadcasted
import LinearAlgebra: lmul!, rmul!, AbstractTriangular, HermOrSym, AdjOrTrans


Expand All @@ -41,6 +42,8 @@ include("views.jl")
include("blockindexrange.jl")
include("show.jl")
include("blockarrayinterface.jl")
include("blockbroadcast.jl")
# include("linalg.jl")

include("deprecate.jl")

Expand Down
48 changes: 48 additions & 0 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@

# Here we override broadcasting for banded matrices.
# The design is to to exploit the broadcast machinery so that
# banded matrices that conform to the banded matrix interface but are not
# <: AbstractBandedMatrix can get access to fast copyto!, lmul!, rmul!, axpy!, etc.
# using broadcast variants (B .= A, B .= 2.0 .* A, etc.)


abstract type AbstractBlockStyle{N} <: AbstractArrayStyle{N} end
struct BlockStyle{N} <: AbstractBlockStyle{N} end
struct PseudoBlockStyle{N} <: AbstractBlockStyle{N} end

BlockStyle{N}(::Val{N}) where N = BlockStyle{N}()
PseudoBlockStyle{N}(::Val{N}) where N = PseudoBlockStyle{N}()
BroadcastStyle(::Type{<:BlockArray{<:Any,N}}) where N = BlockStyle{N}()
BroadcastStyle(::Type{<:PseudoBlockArray{<:Any,N}}) where N = PseudoBlockStyle{N}()
BroadcastStyle(::DefaultArrayStyle{N}, ::AbstractBlockStyle{N}) where N = DefaultArrayStyle{N}()
BroadcastStyle(::AbstractBlockStyle{N}, ::DefaultArrayStyle{N}) where N = DefaultArrayStyle{N}()
BroadcastStyle(::BlockStyle{N}, ::PseudoBlockStyle{N}) where N = BlockStyle{N}()
BroadcastStyle(::PseudoBlockStyle{N}, ::BlockStyle{N}) where N = BlockStyle{N}()



####
# Default to standard Array broadcast
####


union.(([1,2,3],[4,5,6]), ([1,2,3],[4,5,6]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it a remnant of experiments or something?


_broadcast_blocksizes(::Val{N}, ::AbstractArrayStyle{0}, _) where N =
BlockSizes(ntuple(_ -> Int[], N))
_broadcast_blocksizes(::Val{N}, ::AbstractArrayStyle{N}, A) where N =
blocksizes(A)
_broadcast_blocksizes(::Val{N}, A) where N =
_broadcast_blocksizes(Val{N}(), BroadcastStyle(typeof(A)), A)

blocksizes(A::Broadcasted{<:AbstractArrayStyle{N}}) where N =
BlockSizes(sort!.(union.(cumulsizes.(_broadcast_blocksizes.(Val{N}(), A.args))...)))

copyto!(dest::AbstractArray, bc::Broadcasted{<:AbstractBlockStyle}) =
copyto!(dest, Broadcasted{DefaultArrayStyle{2}}(bc.f, bc.args, bc.axes))

similar(bc::Broadcasted{<:AbstractBlockStyle{N}}, ::Type{T}) where {T,N} =
BlockArray{T,N}(undef, blocksizes(bc))

similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
PseudoBlockArray{T,N}(undef, blocksizes(bc))
2 changes: 0 additions & 2 deletions src/deprecate.jl
Original file line number Diff line number Diff line change
@@ -1,2 +0,0 @@
@deprecate getindex(block_sizes::BlockSizes, i) cumulsizes(block_sizes, i)
@deprecate getindex(block_sizes::BlockSizes, i, j) cumulsizes(block_sizes, i, j)
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ include("test_blockviews.jl")
include("test_blockrange.jl")
include("test_blockarrayinterface.jl")

include("test_broadcast.jl")

include("../docs/make.jl")
81 changes: 81 additions & 0 deletions test/test_broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
using BlockArrays, Test

@testset "broadcast" begin
@testset "BlockArray" begin
A = BlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{1}()

@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa BlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

A = BlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.BlockStyle{2}()

@test exp.(A) == exp.(Matrix(A))
@test blocksizes(A) == blocksizes(exp.(A))


@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)
end

@testset "PseudoBlockArray" begin
A = PseudoBlockArray(randn(6), 1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.PseudoBlockStyle{1}()

@test exp.(A) == exp.(Vector(A))
@test blocksizes(A) == blocksizes(exp.(A))

@test A+A isa PseudoBlockArray
@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)

A = PseudoBlockArray(randn(6,6), 1:3,1:3)

@test BlockArrays.BroadcastStyle(typeof(A)) == BlockArrays.PseudoBlockStyle{2}()

@test exp.(A) == exp.(Matrix(A))
@test blocksizes(A) == blocksizes(exp.(A))


@test blocksizes(A + A) == blocksizes(A .+ A) == blocksizes(A)
@test blocksizes(A .+ 1) == blocksizes(A)
end

@testset "Mixed" begin
A = BlockArray(randn(6), 1:3)
B = PseudoBlockArray(randn(6), 1:3)

@test A + B isa BlockArray
@test B + A isa BlockArray

@test blocksizes(A + B) == blocksizes(A)

C = randn(6)

@test A + C isa Vector{Float64}
@test C + A isa Vector{Float64}
@test B + C isa Vector{Float64}
@test C + B isa Vector{Float64}
end


@testset "Mixed block sizes" begin
A = BlockArray(randn(6), 1:3)
B = BlockArray(randn(6), fill(2,3))
@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2])

A = BlockArray(randn(6,6), 1:3, 1:3)
B = BlockArray(randn(6,6), fill(2,3), fill(3,2))

@test blocksizes(A+B) == BlockArrays.BlockSizes([1,1,1,1,2], 1:3)
end

end