Skip to content

Commit

Permalink
Promote returned values to a common type in replace()
Browse files Browse the repository at this point in the history
  • Loading branch information
nalimilan committed Oct 16, 2017
1 parent 2060fe5 commit 4cc05fa
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 4 deletions.
20 changes: 16 additions & 4 deletions src/Nulls.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ xor(::Integer, ::Null) = null
Return an iterator wrapping iterable `itr` which replaces [`null`](@ref) values with
`replacement`. When applicable, the size of `itr` is preserved.
If the type of `replacement` differs from the element type of `itr`,
returned values are promoted to a common type.
See also: [`Nulls.skip`](@ref), [`Nulls.fail`](@ref)
Expand All @@ -132,14 +134,23 @@ julia> collect(Nulls.replace([1, null, 2], 0))
0
2
julia> collect(Nulls.replace([1, null, 2], 0.0))
3-element Array{Float64,1}:
1.0
0.0
2.0
julia> collect(Nulls.replace([1 null; 2 null], 0))
2×2 Array{Int64,2}:
1 0
2 0
```
"""
replace(itr, replacement) = EachReplaceNull(itr, replacement)
function replace(itr, replacement)
U = promote_type(eltype(itr), typeof(replacement))
EachReplaceNull(itr, convert(U, replacement))
end
struct EachReplaceNull{T, U}
x::T
replacement::U
Expand All @@ -152,11 +163,12 @@ Base.length(itr::EachReplaceNull) = length(itr.x)
Base.size(itr::EachReplaceNull) = size(itr.x)
Base.start(itr::EachReplaceNull) = start(itr.x)
Base.done(itr::EachReplaceNull, state) = done(itr.x, state)
Base.eltype(itr::EachReplaceNull) =
Union{Nulls.T(eltype(itr.x)), typeof(itr.replacement)}
Base.eltype(itr::EachReplaceNull{T, U}) where {T, U} =
promote_type(Nulls.T(eltype(itr.x)), U)
@inline function Base.next(itr::EachReplaceNull, state)
v, s = next(itr.x, state)
((isnull(v) ? itr.replacement : v)::eltype(itr), s)
el = convert(eltype(itr), isnull(v) ? itr.replacement : v)
(el::eltype(itr), s)
end

"""
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,12 @@ using Base.Test, Nulls
@test size(x) == (4,)
@test collect(x) == collect(1:4)
@test collect(x) isa Vector{Int}
x = Nulls.replace([1, 2, null, 4], 3.0)
@test eltype(x) === Float64
@test length(x) == 4
@test size(x) == (4,)
@test collect(x) == collect(1:4)
@test collect(x) isa Vector{Float64}
x = Nulls.replace([1 2; null 4], 3)
@test eltype(x) === Int
@test length(x) == 4
Expand Down

0 comments on commit 4cc05fa

Please sign in to comment.