Skip to content

Commit

Permalink
Do not call complete when reduced
Browse files Browse the repository at this point in the history
Previously (<= v0.2.1), Transducers.jl was using

    val isa Reduced && return reduced(complete(rf, unreduced(val)))

everywhere.  Or more precisely an equivalent macro:

    @​return_if_reduced complete(rf, val)

This actually was a bad idea because `complete` can call `next` hence
the reducing function at the "bottom."  This violates the purpose of
`Reduced`; the reducing function must not be called after returning a
`Reduced`.  So, the correct way to do this is simply

    val isa Reduced && return val

(I was initially doing this but it was changed to the v0.2.1 form
apparently to "Improve PartitionBy" 12dd581.)

Fixing this requires some related changes:

* Previously, the private state of the transducers are `unwrap`'ed
  during the `complete` phase.  However, now that `complete` will not
  be called always _by the transducible processes_ (e.g., `foldl`),
  the `unwrap`ping has to be done during the `next` call chain where
  the `Reduced` is created.

  This is treated by changing `wrap(rf, state, iresult::Reduced)`.

* Now that the transducible processes are not responsible for calling
  `complete`, aborting transducers must initiate `complete` for inner
  reducing functions (e.g., in `Take(4) |> TakeLast(2)`, transducer
  `Take(4)` must call `complete` for `TakeLast(2)` to flush the buffer
  and invoke the downstream reducing functions).

  This change is introduced to `Take`, `TakeWhile`, and `Inject`.
  • Loading branch information
tkf committed Jun 23, 2019
1 parent 03ce32d commit 74f8961
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 42 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,7 +1,7 @@
name = "Transducers"
uuid = "28d57a85-8fef-5791-bfe6-a80928e7c999"
authors = ["Takafumi Arakaki <aka.tkf@gmail.com>"]
version = "0.2.2-DEV"
version = "0.3.0-DEV"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
2 changes: 1 addition & 1 deletion examples/reducibles.jl
Expand Up @@ -23,7 +23,7 @@ function Transducers.__foldl__(rf, val, vov::VecOfVec)
for vector in vov.vectors
for x in vector
val = next(rf, val, x)
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
end
return complete(rf, val)
Expand Down
51 changes: 31 additions & 20 deletions src/core.jl
Expand Up @@ -104,36 +104,37 @@ unreduced(x::Reduced) = x.value
unreduced(x) = x

"""
@return_if_reduced complete(rf, val)
@return_if_reduced val
It transforms the given expression to:
```julia
val isa Reduced && return reduced(complete(rf, unreduced(val)))
val isa Reduced && return val
```
That is to say, if `val` is `Reduced`, unpack it, call `complete`,
re-pack into `Reduced`, and then finally return it.
# Examples
```jldoctest:
julia> using Transducers: @return_if_reduced
julia> @macroexpand @return_if_reduced complete(rf, val)
:(val isa Transducers.Reduced && return (Transducers.reduced)(complete(rf, (Transducers.unreduced)(val))))
julia> @macroexpand @return_if_reduced val
:(val isa Transducers.Reduced && return val)
```
"""
macro return_if_reduced(ex)
if !(ex.head == :call && length(ex.args) == 3)
error(
"`@return_if_reduced` only accepts an expression of the form",
" `complete(rf, val)`.",
" Given:\n",
ex,
)
if ex isa Expr && ex.head == :call && length(ex.args) == 3
val = esc(ex)
return quote
if $(esc(ex.args[1])) === complete
error("""
Calling `@return_if_reduced complete(rf, val)` is now an error.
Please use `@return_if_reduced val`.
""")
end
$val isa Reduced && return $val
end
end
complete, rf, val = esc.(ex.args)
:($val isa Reduced && return reduced($complete($rf, unreduced($val))))
val = esc(ex)
:($val isa Reduced && return $val)
end

abstract type Transducer end
Expand Down Expand Up @@ -427,6 +428,13 @@ to the outer reducing function.
This is intended to be used only in [`start`](@ref). Inside
[`next`](@ref), use [`wrapping`](@ref).
!!! note "Implementation detail"
If `iresult` is a [`Reduced`](@ref), `wrap` actually _un_wraps all
internal state `iresult` recursively. However, this is an
implementation detail that should not matter when writing
transducers.
Consider a reducing step constructed as
rf = Reduction(xf₁ |> xf₂ |> xf₃, f, intype)
Expand Down Expand Up @@ -455,14 +463,17 @@ result₀ = wrap(rf, state₁, result₁)
The inner most step function receives the original `result` as the
first argument while transducible processes such as [`mapfoldl`](@ref)
only sees the outer-most "tree" `result₀` during the reduction. The
whole tree is [`unwrap`](@ref)ed during the [`complete`](@ref) phase.
only sees the outer-most "tree" `result₀` during the reduction.
See [`wrapping`](@ref), [`unwrap`](@ref), and [`start`](@ref).
"""
wrap(rf::T, state, iresult) where {T} = PrivateState(rf, state, iresult)
wrap(rf, state, iresult::Reduced) =
Reduced(PrivateState(rf, state, unreduced(iresult)))
wrap(rf, state, iresult::Reduced) = unwrap_all(iresult) :: Reduced
#
# Note: `unwrap_all` is required since any transducer in arbitrary
# location of the `Reduction` chain can create a `Reduced`.
#
# But `unwrap_all`ing in `wrap` sounds counter intuitive. Maybe rename?

"""
wrapping(f, rf, result)
Expand Down
18 changes: 9 additions & 9 deletions src/library.jl
Expand Up @@ -390,7 +390,7 @@ next(rf::R_{Take}, result, input) =
n -= 1
end
if n <= 0
iresult = reduced(iresult)
iresult = reduced(complete(inner(rf), iresult))
end
return n, iresult
end
Expand Down Expand Up @@ -447,16 +447,16 @@ function complete(rf::R_{TakeLast}, result)
if c <= 0 # buffer is not full (or c is just wrapping)
for i in 1:(c + length(buffer))
iresult = next(inner(rf), iresult, @inbounds buffer[i])
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
else
for i in c+1:length(buffer)
iresult = next(inner(rf), iresult, @inbounds buffer[i])
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
for i in 1:c
iresult = next(inner(rf), iresult, @inbounds buffer[i])
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
end
return complete(inner(rf), iresult)
Expand Down Expand Up @@ -490,7 +490,7 @@ next(rf::R_{TakeWhile}, result, input) =
if xform(rf).pred(input)
next(inner(rf), result, input)
else
reduced(result)
reduced(complete(inner(rf), result))
end

# https://clojure.github.io/clojure/clojure.core-api.html#clojure.core/take-nth
Expand Down Expand Up @@ -830,7 +830,7 @@ function complete(rf::R_{Partition}, result)
iinput = @view iinput[s + 1:end]
iinput :: DenseSubVector
iresult = next(inner(rf), iresult, iinput)
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
return complete(inner(rf), iresult)
end
Expand Down Expand Up @@ -893,7 +893,7 @@ function complete(rf::R_{PartitionBy}, ps)
(iinput, _), iresult = unwrap(rf, ps)
if !isempty(iinput)
iresult = next(inner(rf), iresult, iinput)
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
return complete(inner(rf), iresult)
end
Expand Down Expand Up @@ -1215,7 +1215,7 @@ function complete(rf::R_{ScanEmit}, result)
u, iresult = unwrap(rf, result)
if xform(rf).onlast !== nothing
iresult = next(inner(rf), iresult, xform(rf).onlast(u))
@return_if_reduced complete(inner(rf), iresult)
@return_if_reduced iresult
end
return complete(inner(rf), iresult)
end
Expand Down Expand Up @@ -1742,7 +1742,7 @@ start(rf::R_{Inject}, result) =
wrap(rf, iterate(xform(rf).iterator), start(inner(rf), result))
next(rf::R_{Inject}, result, input) =
wrapping(rf, result) do istate, iresult
istate === nothing && return istate, reduced(iresult)
istate === nothing && return istate, reduced(complete(inner(rf), iresult))
y, s = istate
iresult2 = next(inner(rf), iresult, (input, y))
return iterate(xform(rf).iterator, s), iresult2
Expand Down
18 changes: 9 additions & 9 deletions src/processes.jl
Expand Up @@ -17,7 +17,7 @@ For a simple iterable type `MyType`, a valid implementation is:
function __foldl__(rf, val, itr::MyType)
for x in itr
val = next(rf, val, x)
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
return complete(rf, val)
end
Expand All @@ -43,11 +43,11 @@ function __foldl__(rf, init, coll)
# optimization to cover a good amount of cases anyway.
x, state = ret
val = next(rf, init, x)
@return_if_reduced complete(rf, val)
@return_if_reduced val
while (ret = iterate(coll, state)) !== nothing
x, state = ret
val = next(rf, val, x)
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
return complete(rf, val)
end
Expand All @@ -57,11 +57,11 @@ end
isempty(arr) && return complete(rf, init)
idxs = eachindex(arr)
val = next(rf, init, @inbounds arr[idxs[firstindex(idxs)]])
@return_if_reduced complete(rf, val)
@return_if_reduced val
@simd_if rf for k in firstindex(idxs) + 1:lastindex(idxs)
i = @inbounds idxs[k]
val = next(rf, val, @inbounds arr[i])
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
return complete(rf, val)
end
Expand All @@ -77,10 +77,10 @@ end
isempty(zs) && return complete(rf, init)
idxs = eachindex(zs.is...)
val = next(rf, init, _getvalues(firstindex(idxs), zs.is...))
@return_if_reduced complete(rf, val)
@return_if_reduced val
@simd_if rf for i in firstindex(idxs) + 1:lastindex(idxs)
val = next(rf, val, _getvalues(i, zs.is...))
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
return complete(rf, val)
end
Expand Down Expand Up @@ -112,7 +112,7 @@ end
# inner-most non-tuple iterators should use @simd_if.
@simd_if rf for input in iterator
val_ = next(rf, val, (input, outer...))
@return_if_reduced complete(rf, val_)
@return_if_reduced val_
val = val_
end
return val
Expand All @@ -121,7 +121,7 @@ end
function __simple_foldl__(rf, val, itr)
for x in itr
val = next(rf, val, x)
@return_if_reduced complete(rf, val)
@return_if_reduced val
end
return complete(rf, val)
end
Expand Down
4 changes: 2 additions & 2 deletions src/show.jl
Expand Up @@ -10,11 +10,11 @@ function __foldl__(rf, val, xff::TransducerFolder)
xf = _normalize(xff.xform)
while xf isa Composition
val = next(rf, val, xf.outer)
@return_if_reduced complete(rf, val)
@return_if_reduced val
xf = xf.inner
end
val = next(rf, val, xf)
@return_if_reduced complete(rf, val)
@return_if_reduced val
return complete(rf, val)
end

Expand Down
43 changes: 43 additions & 0 deletions test/test_library.jl
Expand Up @@ -87,6 +87,39 @@ end
@test eltype(eduction(
ScanEmit(tuple, Initializer(_ -> rand(Int))), Int[])) === Int
end

@testset "Do not call `complete` when reduced" begin
xs = 1:8
xf = ScanEmit(Initializer(_ -> []), identity) do u, x
push!(u, x)
if x % 3 == 0
return u, []
else
return nothing, u
end
end |> NotA(Nothing)

@testset "foreach" begin
called_with = []
@test foreach(xf, xs) do chunk
push!(called_with, copy(chunk))
5 chunk && reduced(true)
end == true
@test called_with == [1:3, 4:6]
end

@testset "foreach" begin
called_with = []
history = []
@test foldl(xf, xs; init=false) do state, chunk
push!(history, state)
push!(called_with, copy(chunk))
5 chunk && reduced(true)
end == true
@test called_with == [1:3, 4:6]
@test history == [false, false]
end
end
end

@testset "TeeZip" begin
Expand Down Expand Up @@ -198,6 +231,12 @@ end
@testset for xs in iterator_variants(1:5)
@test collect(TakeWhile(x -> x < 3), xs) == 1:2
end
@testset "Combination with stateful transducers" begin
@testset for xs in iterator_variants(1:5)
@test collect(TakeWhile(x -> x 4) |> TakeLast(2), xs) == 3:4
@test collect(TakeLast(4) |> TakeWhile(x -> x 3), xs) == 2:3
end
end
end

@testset "TakeNth" begin
Expand Down Expand Up @@ -419,6 +458,10 @@ end
@testset for xs in iterator_variants(1:3)
@test collect(Inject(xs) |> Take(2), xs) == collect(zip(1:2, 1:2))
@test collect(Take(2) |> Inject(xs), xs) == collect(zip(1:2, 1:2))
@test collect(Inject(xs) |> TakeLast(2), xs) == collect(zip(2:3, 2:3))
@test collect(TakeLast(2) |> Inject(xs), xs) == collect(zip(2:3, 1:2))
@test collect(Inject(1:1) |> TakeLast(2), xs) == collect(zip(1:1, 1:1))
@test collect(TakeLast(2) |> Inject(1:1), xs) == collect(zip(2:2, 1:1))
end
end
end
Expand Down

0 comments on commit 74f8961

Please sign in to comment.