From 31f18fe27215880a08742ac452f0344b61180163 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Thu, 16 Jan 2020 02:10:33 -0800 Subject: [PATCH] Support reduce with Iterators.product (#156) * Support reduce with Iterators.product * Test reduce with Iterators.product only in Julia >= 1.3 --- examples/tutorial_parallel.jl | 7 +++++++ src/reduce.jl | 13 +++++++++++++ test/test_parallel_reduce.jl | 12 ++++++++++++ 3 files changed, 32 insertions(+) diff --git a/examples/tutorial_parallel.jl b/examples/tutorial_parallel.jl index 26468abef6..c66c9b724d 100644 --- a/examples/tutorial_parallel.jl +++ b/examples/tutorial_parallel.jl @@ -38,6 +38,13 @@ dreduce(+, Map(sin), xs) reduce(+, eduction(sin(x) for x in xs if abs(x) < 1); basesize = 500_000) +#- + +if VERSION >= v"1.3" #src +@test 36 == #src +reduce(+, eduction(x * y for x in 1:3, y in 1:3)) +end #src + # You can omit `eduction` when using Transducers.jl-specific functions # like [`tcollect`](@ref)/[`dcollect`](@ref): diff --git a/src/reduce.jl b/src/reduce.jl index cd800df452..1044ba5c49 100644 --- a/src/reduce.jl +++ b/src/reduce.jl @@ -82,6 +82,19 @@ function halve(arr::AbstractArray) return (left, right) end +function halve(product::Iterators.ProductIterator) + i = findfirst(x -> length(x) > 1, product.iterators) + if i === nothing + error( + "Unreachable reached. A bug in `issmall`?", + " length(product) = ", + length(product), + ) + end + left, right = halve(product.iterators[i]) + return (@set(product.iterators[i] = left), @set(product.iterators[i] = right)) +end + struct TaskContext listening::Vector{Threads.Atomic{Bool}} cancellables::Vector{Threads.Atomic{Bool}} diff --git a/test/test_parallel_reduce.jl b/test/test_parallel_reduce.jl index ab9887f487..72664ef69d 100644 --- a/test/test_parallel_reduce.jl +++ b/test/test_parallel_reduce.jl @@ -79,6 +79,18 @@ end ) == StructVector(a = 1:3) end +@testset "product" begin + if VERSION >= v"1.3" + @test reduce(+, MapSplat(*), Iterators.product(1:3, 1:3); basesize = 1) == 36 + @test reduce(+, eduction(x * y for x in 1:3, y in 1:3); basesize = 1) == 36 + end + + @test_throws( + ErrorException("Unreachable reached. A bug in `issmall`? length(product) = 0"), + Transducers.halve(Iterators.product((), ())) + ) +end + @testset "withprogress" begin xf = Map() do x x