From ae217077353cc2ae10ea074c13783defc173178d Mon Sep 17 00:00:00 2001 From: Milan Bouchet-Valat Date: Fri, 3 Jun 2016 14:09:13 +0200 Subject: [PATCH] Use Base.promote_op() for arithmetic operators This fixes support for cases in which the return type is different from the inputs, in particular Date. Also makes handling of / a bit less of a special case. --- src/broadcast.jl | 45 +++++++++++++++++++-------------------- src/operators.jl | 54 +++++++++++++++++++++++++---------------------- test/operators.jl | 29 ++++++++++++++++++++++++- 3 files changed, 79 insertions(+), 49 deletions(-) diff --git a/src/broadcast.jl b/src/broadcast.jl index 49cf34f..f4b1c25 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -2,20 +2,6 @@ using DataArrays, Base.@get! using Base.Broadcast: bitcache_chunks, bitcache_size, dumpbitcache, promote_eltype, broadcast_shape, eltype_plus -if isdefined(Base.Broadcast, :type_minus) - using Base.Broadcast: type_minus, type_div, type_pow - const _type_minus = type_minus - const _type_rdiv = type_div - const _type_ldiv = type_div - const _type_pow = type_pow -else - using Base.Broadcast: promote_op - _type_minus(T, S) = promote_op(@functorize(-), T, S) - _type_rdiv(T, S) = promote_op(@functorize(/), T, S) - _type_ldiv(T, S) = promote_op(@functorize(\), T, S) - _type_pow(T, S) = promote_op(@functorize(^), T, S) -end - # Check that all arguments are broadcast compatible with shape # Differs from Base in that we check for exact matches function check_broadcast_shape(shape::Dims, As::(@compat Union{AbstractArray,Number})...) @@ -304,30 +290,43 @@ end @da_broadcast_vararg (.*)(As...) = databroadcast(*, As...) @da_broadcast_binary (.%)(A, B) = databroadcast(%, A, B) @da_broadcast_vararg (.+)(As...) = broadcast!(+, DataArray(eltype_plus(As...), broadcast_shape(As...)), As...) -@da_broadcast_binary (.-)(A, B) = broadcast!(-, DataArray(_type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B) -@da_broadcast_binary (./)(A, B) = broadcast!(/, DataArray(_type_rdiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) -@da_broadcast_binary (.\)(A, B) = broadcast!(\, DataArray(_type_ldiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) +@da_broadcast_binary (.-)(A, B) = + broadcast!(-, DataArray(promote_op(@functorize(-), eltype(A), eltype(B)), + broadcast_shape(A,B)), A, B) +@da_broadcast_binary (./)(A, B) = + broadcast!(/, DataArray(promote_op(@functorize(/), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) +@da_broadcast_binary (.\)(A, B) = + broadcast!(\, DataArray(promote_op(@functorize(\), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) (.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B) (.^)(A::BitArray, B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B) (.^)(A::AbstractArray{Bool}, B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B) (.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::BitArray) = databroadcast(>=, A, B) (.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::AbstractArray{Bool}) = databroadcast(>=, A, B) -@da_broadcast_binary (.^)(A, B) = broadcast!(^, DataArray(_type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) +@da_broadcast_binary (.^)(A, B) = + broadcast!(^, DataArray(promote_op(@functorize(^), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) # XXX is a PDA the right return type for these? Base.broadcast(f::Function, As::PooledDataArray...) = pdabroadcast(f, As...) (.*)(As::PooledDataArray...) = pdabroadcast(*, As...) (.%)(A::PooledDataArray, B::PooledDataArray) = pdabroadcast(%, A, B) -(.+)(As::PooledDataArray...) = broadcast!(+, PooledDataArray(eltype_plus(As...), broadcast_shape(As...)), As...) +(.+)(As::PooledDataArray...) = + broadcast!(+, PooledDataArray(eltype_plus(As...), broadcast_shape(As...)), As...) (.-)(A::PooledDataArray, B::PooledDataArray) = - broadcast!(-, PooledDataArray(_type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B) + broadcast!(-, PooledDataArray(promote_op(@functorize(-), eltype(A), eltype(B)), + broadcast_shape(A,B)), A, B) (./)(A::PooledDataArray, B::PooledDataArray) = - broadcast!(/, PooledDataArray(_type_rdiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) + broadcast!(/, PooledDataArray(promote_op(@functorize(/), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) (.\)(A::PooledDataArray, B::PooledDataArray) = - broadcast!(\, PooledDataArray(_type_ldiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) + broadcast!(\, PooledDataArray(promote_op(@functorize(\), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) (.^)(A::PooledDataArray{Bool}, B::PooledDataArray{Bool}) = databroadcast(>=, A, B) (.^)(A::PooledDataArray, B::PooledDataArray) = - broadcast!(^, PooledDataArray(_type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B) + broadcast!(^, PooledDataArray(promote_op(@functorize(^), eltype(A), eltype(B)), + broadcast_shape(A, B)), A, B) for (sf, vf) in zip(scalar_comparison_operators, array_comparison_operators) @eval begin diff --git a/src/operators.jl b/src/operators.jl index 88e365d..05102d9 100644 --- a/src/operators.jl +++ b/src/operators.jl @@ -1,3 +1,11 @@ +promote_op{R,S}(f::Any, ::Type{R}, ::Type{S}) = + Base.promote_op(f, R, S) + +# Required for /(::Int, ::Int) +if VERSION < v"0.5.0-dev" + promote_op{R<:Integer,S<:Integer}(op, ::Type{R}, ::Type{S}) = typeof(op(one(R), one(S))) +end + const unary_operators = [:+, :-, :!, :*] const numeric_unary_operators = [:+, :-] @@ -223,6 +231,10 @@ macro dataarray_binary_scalar(vectorfunc, scalarfunc, outtype, swappable) # warnings Any[ begin + if outtype == :nothing + outtype = :(promote_op(@functorize($scalarfunc), + eltype(a), eltype(b))) + end fns = Any[ :(function $(vectorfunc)(a::DataArray, b::$t) data = a.data @@ -255,7 +267,7 @@ macro dataarray_binary_scalar(vectorfunc, scalarfunc, outtype, swappable) end # Binary operators with two array arguments -macro dataarray_binary_array(vectorfunc, scalarfunc, outtype) +macro dataarray_binary_array(vectorfunc, scalarfunc) esc(Expr(:block, # DataArray with other array Any[ @@ -263,7 +275,8 @@ macro dataarray_binary_array(vectorfunc, scalarfunc, outtype) function $(vectorfunc)(a::$atype, b::$btype) data1 = $(atype == :DataArray || atype == :(DataArray{Bool}) ? :(a.data) : :a) data2 = $(btype == :DataArray || btype == :(DataArray{Bool}) ? :(b.data) : :b) - res = Array($outtype, promote_shape(size(a), size(b))) + res = Array(promote_op(@functorize($vectorfunc), eltype(a), eltype(b)), + promote_shape(size(a), size(b))) resna = $narule @bitenumerate resna i na begin if !na @@ -284,7 +297,9 @@ macro dataarray_binary_array(vectorfunc, scalarfunc, outtype) Any[ quote function $(vectorfunc)(a::$atype, b::$btype) - res = similar($(asim ? :a : :b), $outtype, promote_shape(size(a), size(b))) + res = similar($(asim ? :a : :b), + promote_op(@functorize($vectorfunc), eltype(a), eltype(b)), + promote_shape(size(a), size(b))) for i = 1:length(a) res[i] = $(scalarfunc)(a[i], b[i]) end @@ -607,7 +622,7 @@ end (.^)(::Irrational{:e}, B::AbstractDataArray) = exp(B) for f in (:(+), :(.+), :(-), :(.-), - :(*), :(.*), :(.^), :(Base.div), + :(*), :(.*), :(/), :(./), :(.^), :(Base.div), :(Base.mod), :(Base.fld), :(Base.rem), :(Base.min), :(Base.max)) @eval begin @@ -685,7 +700,7 @@ end end # if isdefined(Base, :UniformScaling) -for f in (:(.+), :(.-), :(*), :(.*), +for f in (:(.+), :(.-), :(*), :(.*), :(./), :(.^), :(Base.div), :(Base.mod), :(Base.fld), :(Base.rem)) @eval begin # Array with NA @@ -693,7 +708,7 @@ for f in (:(.+), :(.-), :(*), :(.*), DataArray(Array(T, size(b)), trues(size(b))) # DataArray with scalar - @dataarray_binary_scalar $f $f promote_type(eltype(a), eltype(b)) true + @dataarray_binary_scalar $f $f nothing true end end @@ -708,33 +723,22 @@ end (^)(::NAtype, ::Integer) = NA (^)(::NAtype, ::Number) = NA -for (vf, sf) in ((:(+), :(+)), - (:(-), :(-))) +for f in (:(+), :(-)) @eval begin # Necessary to avoid ambiguity warnings - @swappable ($vf)(A::BitArray, B::AbstractDataArray{Bool}) = ($vf)(Array(A), B) - @swappable ($vf)(A::BitArray, B::DataArray{Bool}) = ($vf)(Array(A), B) - - @dataarray_binary_array $vf $sf promote_type(eltype(a), eltype(b)) - end -end + @swappable ($f)(A::BitArray, B::AbstractDataArray{Bool}) = ($f)(Array(A), B) + @swappable ($f)(A::BitArray, B::DataArray{Bool}) = ($f)(Array(A), B) -# / and ./ are defined separately since they promote to floating point -for f in (:(/), :(./)) - @eval begin - ($f)(::NAtype, ::NAtype) = NA - @swappable ($f)(d::NAtype, x::Number) = NA + @dataarray_binary_array $f $f end end +# / is defined separately since it is not swappable +(/)(::NAtype, ::NAtype) = NA +@swappable (/)(d::NAtype, x::Number) = NA (/){T,N}(b::AbstractArray{T,N}, ::NAtype) = DataArray(Array(T, size(b)), trues(size(b))) -@dataarray_binary_scalar(/, /, eltype(a) <: AbstractFloat || typeof(b) <: AbstractFloat ? - promote_type(eltype(a), typeof(b)) : Float64, false) -@swappable (./){T,N}(::NAtype, b::AbstractArray{T,N}) = - DataArray(Array(T, size(b)), trues(size(b))) -@dataarray_binary_scalar(./, /, eltype(a) <: AbstractFloat || typeof(b) <: AbstractFloat ? - promote_type(eltype(a), typeof(b)) : Float64, true) +@dataarray_binary_scalar(/, /, nothing, false) for f in biscalar_operators @eval begin diff --git a/test/operators.jl b/test/operators.jl index 07eeec7..95ef746 100644 --- a/test/operators.jl +++ b/test/operators.jl @@ -173,7 +173,9 @@ module TestOperators # Binary operations on pairs of DataVector's dv = convert(DataArray, ones(5)) - dv[1] = NA + # Dates are an example of type for which - and .- return a different type from its inputs + dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")]) + dv[1] = dvd[1] = NA @test_da_pda dv begin for f in map(eval, DataArrays.array_arithmetic_operators) for i in 1:length(dv) @@ -186,6 +188,12 @@ module TestOperators @assert f(bv, bv)[i] == f(bv[i], bv[i]) end end + for i in 1:length(dvd) + @assert isna((dvd - dvd)[i]) && isna(dvd[i]) || + (dvd - dvd)[i] == dvd[i] - dvd[i] + @assert isna((dvd .- dvd)[i]) && isna(dvd[i]) || + (dvd .- dvd)[i] == dvd[i] - dvd[i] + end end # + and - with UniformScaling @@ -226,27 +234,46 @@ module TestOperators # Pairwise vector operators on DataVector's dv = @data([911, 269, 835.0, 448, 772]) + # Dates are an example of type for which operations return a different type from their inputs + dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")]) for f in map(eval, DataArrays.pairwise_vector_operators) @assert isequal(f(dv), f(dv.data)) + @assert isequal(f(dvd), f(dvd.data)) end dv = @data([NA, 269, 835.0, 448, 772]) + dvd = @data([NA, Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")]) for f in map(eval, DataArrays.pairwise_vector_operators) v = f(dv) @assert isna(v[1]) @assert isequal(v[2:4], f(dv.data)[2:4]) + + d = f(dvd) + @assert isna(d[1]) + @assert isequal(d[2:3], f(dvd.data)[2:3]) end dv = @data([911, NA, 835.0, 448, 772]) + dvd = @data([Base.Date("2000-01-01"), NA, Base.Date("2010-01-01"), Base.Date("2010-01-05")]) for f in map(eval, DataArrays.pairwise_vector_operators) v = f(dv) @assert isna(v[1]) @assert isna(v[2]) @assert isequal(v[3:4], f(dv.data)[3:4]) + + d = f(dvd) + @assert isna(d[1]) + @assert isna(d[2]) + @assert isequal(d[3:3], f(dvd.data)[3:3]) end dv = @data([911, 269, 835.0, 448, NA]) + dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05"), NA]) for f in map(eval, DataArrays.pairwise_vector_operators) v = f(dv) @assert isna(v[4]) @assert isequal(v[1:3], f(dv.data)[1:3]) + + d = f(dvd) + @assert isna(d[3]) + @assert isequal(d[1:2], f(dvd.data)[1:2]) end # Cumulative vector operators on DataVector's