-
Notifications
You must be signed in to change notification settings - Fork 26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Preserve block structure in broadcasting #61
Conversation
Codecov Report
@@ Coverage Diff @@
## master #61 +/- ##
==========================================
+ Coverage 62.74% 63.69% +0.94%
==========================================
Files 10 11 +1
Lines 459 482 +23
==========================================
+ Hits 288 307 +19
- Misses 171 175 +4
Continue to review full report at Codecov.
|
src/blockbroadcast.jl
Outdated
#### | ||
|
||
|
||
union.(([1,2,3],[4,5,6]), ([1,2,3],[4,5,6])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it a remnant of experiments or something?
Are you planning to optimize the computation at some point? More precisely, what I have in mind is something like the following: using BlockArrays
using BlockArrays: cumulsizes
using FillArrays
function blocks_to_vector(blocks)
R = eltype(blocks)
T = eltype(R)
@assert R <: AbstractVector
ba = BlockArray{T, 1, R}(undef_blocks, size.(blocks, 1))
ba.blocks .= blocks
return ba
end
n = 100
x = blocks_to_vector([Fill(111.0, 4n), Fill(222.0, n)])
y = blocks_to_vector([Fill(333.0, 4n), Fill(444.0, n)])
z = randn(size(x, 1))
function naive(x, y, z)
@. z = x * y * z
return z
end
function blocked(x, y, z)
for (xb, yb, b, e) in zip(x.blocks, y.blocks,
cumulsizes(x, 1)[1:end-1],
cumulsizes(x, 1)[2:end])
zb = @view z[b:e-1]
@. zb = xb * yb * zb
end
return z
end
@assert naive(x, y, copy(z)) == blocked(x, y, copy(z))
using BenchmarkTools
using Statistics
(b1 = @benchmark naive($x, $y, z) setup=(z=randn(size(x, 1)))) |> display
(b2 = @benchmark blocked($x, $y, z) setup=(z=randn(size(x, 1)))) |> display
judge(median(b2), median(b1)) I want Just FYI, the output of the script above is:
|
Yes optimising is next. |
@tkf In dl/fastbroadcast branch I've improved the speed (at least for block vectors/matrices). Unfortunately it's not quite working when the arguments have different block sizes. |
Let me know if you have any more comments, otherwise I'll merge |
test/test_blockarrayinterface.jl
Outdated
|
||
A = randn(5) | ||
@test blocksizes(A) == BlockArrays.BlockSizes([5]) | ||
A[Block(1)] == A |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing @test
?
src/blockbroadcast.jl
Outdated
PseudoBlockStyle{M}(::Val{N}) where {N,M} = PseudoBlockStyle{N}() | ||
BroadcastStyle(::Type{<:BlockArray{<:Any,N}}) where N = BlockStyle{N}() | ||
BroadcastStyle(::Type{<:PseudoBlockArray{<:Any,N}}) where N = PseudoBlockStyle{N}() | ||
BroadcastStyle(::DefaultArrayStyle{N}, b::AbstractBlockStyle{M}) where {M,N} = typeof(b)(_max(Val(M),Val(N))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you planning to "vendor" _max
at some point? Otherwise it doesn't run on Julia master.
I still haven't learned BlockArrays internal so please don't wait for me to do any non-nit-picking review :)
Does "different block sizes" include a case like |
In this PR that works: julia> BlockArray(randn(6), 1:3) .+ ones(6)
6-element BlockArray{Float64,1,Array{Float64,1}}:
1.630468442532147
───────────────────
2.1389068389040777
1.0502081079839254
───────────────────
1.4265672357406085
0.3687548937045446
-0.1642694693789739 In the other branch dl/fastbroadcast the issue is that we want to use blocks for arrays with compatible block sizes as the destination, and regular indexing for other arrays. I think |
The failure on nightly is now an unrelated change in |
This preserves the block structure under broadcasting.
One missing case is
BlockMatrix .+ BlockVector
.This is step 1 in resolving #31