Skip to content
This repository has been archived by the owner on May 4, 2019. It is now read-only.

Commit

Permalink
Use Base.promote_op() for arithmetic operators
Browse files Browse the repository at this point in the history
This fixes support for cases in which the return type is different
from the inputs, in particular Date. Also makes handling of /
a bit less of a special case.
  • Loading branch information
nalimilan committed Jun 3, 2016
1 parent 2642f6c commit ae21707
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 49 deletions.
45 changes: 22 additions & 23 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,6 @@ using DataArrays, Base.@get!
using Base.Broadcast: bitcache_chunks, bitcache_size, dumpbitcache,
promote_eltype, broadcast_shape, eltype_plus

if isdefined(Base.Broadcast, :type_minus)
using Base.Broadcast: type_minus, type_div, type_pow
const _type_minus = type_minus
const _type_rdiv = type_div
const _type_ldiv = type_div
const _type_pow = type_pow
else
using Base.Broadcast: promote_op
_type_minus(T, S) = promote_op(@functorize(-), T, S)
_type_rdiv(T, S) = promote_op(@functorize(/), T, S)
_type_ldiv(T, S) = promote_op(@functorize(\), T, S)
_type_pow(T, S) = promote_op(@functorize(^), T, S)
end

# Check that all arguments are broadcast compatible with shape
# Differs from Base in that we check for exact matches
function check_broadcast_shape(shape::Dims, As::(@compat Union{AbstractArray,Number})...)
Expand Down Expand Up @@ -304,30 +290,43 @@ end
@da_broadcast_vararg (.*)(As...) = databroadcast(*, As...)
@da_broadcast_binary (.%)(A, B) = databroadcast(%, A, B)
@da_broadcast_vararg (.+)(As...) = broadcast!(+, DataArray(eltype_plus(As...), broadcast_shape(As...)), As...)
@da_broadcast_binary (.-)(A, B) = broadcast!(-, DataArray(_type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B)
@da_broadcast_binary (./)(A, B) = broadcast!(/, DataArray(_type_rdiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
@da_broadcast_binary (.\)(A, B) = broadcast!(\, DataArray(_type_ldiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
@da_broadcast_binary (.-)(A, B) =
broadcast!(-, DataArray(promote_op(@functorize(-), eltype(A), eltype(B)),
broadcast_shape(A,B)), A, B)
@da_broadcast_binary (./)(A, B) =
broadcast!(/, DataArray(promote_op(@functorize(/), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)
@da_broadcast_binary (.\)(A, B) =
broadcast!(\, DataArray(promote_op(@functorize(\), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)
(.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B)
(.^)(A::BitArray, B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B)
(.^)(A::AbstractArray{Bool}, B::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}})) = databroadcast(>=, A, B)
(.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::BitArray) = databroadcast(>=, A, B)
(.^)(A::(@compat Union{DataArray{Bool}, PooledDataArray{Bool}}), B::AbstractArray{Bool}) = databroadcast(>=, A, B)
@da_broadcast_binary (.^)(A, B) = broadcast!(^, DataArray(_type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
@da_broadcast_binary (.^)(A, B) =
broadcast!(^, DataArray(promote_op(@functorize(^), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)

# XXX is a PDA the right return type for these?
Base.broadcast(f::Function, As::PooledDataArray...) = pdabroadcast(f, As...)
(.*)(As::PooledDataArray...) = pdabroadcast(*, As...)
(.%)(A::PooledDataArray, B::PooledDataArray) = pdabroadcast(%, A, B)
(.+)(As::PooledDataArray...) = broadcast!(+, PooledDataArray(eltype_plus(As...), broadcast_shape(As...)), As...)
(.+)(As::PooledDataArray...) =
broadcast!(+, PooledDataArray(eltype_plus(As...), broadcast_shape(As...)), As...)
(.-)(A::PooledDataArray, B::PooledDataArray) =
broadcast!(-, PooledDataArray(_type_minus(eltype(A), eltype(B)), broadcast_shape(A,B)), A, B)
broadcast!(-, PooledDataArray(promote_op(@functorize(-), eltype(A), eltype(B)),
broadcast_shape(A,B)), A, B)
(./)(A::PooledDataArray, B::PooledDataArray) =
broadcast!(/, PooledDataArray(_type_rdiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
broadcast!(/, PooledDataArray(promote_op(@functorize(/), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)
(.\)(A::PooledDataArray, B::PooledDataArray) =
broadcast!(\, PooledDataArray(_type_ldiv(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
broadcast!(\, PooledDataArray(promote_op(@functorize(\), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)
(.^)(A::PooledDataArray{Bool}, B::PooledDataArray{Bool}) = databroadcast(>=, A, B)
(.^)(A::PooledDataArray, B::PooledDataArray) =
broadcast!(^, PooledDataArray(_type_pow(eltype(A), eltype(B)), broadcast_shape(A, B)), A, B)
broadcast!(^, PooledDataArray(promote_op(@functorize(^), eltype(A), eltype(B)),
broadcast_shape(A, B)), A, B)

for (sf, vf) in zip(scalar_comparison_operators, array_comparison_operators)
@eval begin
Expand Down
54 changes: 29 additions & 25 deletions src/operators.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
promote_op{R,S}(f::Any, ::Type{R}, ::Type{S}) =
Base.promote_op(f, R, S)

# Required for /(::Int, ::Int)
if VERSION < v"0.5.0-dev"
promote_op{R<:Integer,S<:Integer}(op, ::Type{R}, ::Type{S}) = typeof(op(one(R), one(S)))
end

const unary_operators = [:+, :-, :!, :*]

const numeric_unary_operators = [:+, :-]
Expand Down Expand Up @@ -223,6 +231,10 @@ macro dataarray_binary_scalar(vectorfunc, scalarfunc, outtype, swappable)
# warnings
Any[
begin
if outtype == :nothing
outtype = :(promote_op(@functorize($scalarfunc),
eltype(a), eltype(b)))
end
fns = Any[
:(function $(vectorfunc)(a::DataArray, b::$t)
data = a.data
Expand Down Expand Up @@ -255,15 +267,16 @@ macro dataarray_binary_scalar(vectorfunc, scalarfunc, outtype, swappable)
end

# Binary operators with two array arguments
macro dataarray_binary_array(vectorfunc, scalarfunc, outtype)
macro dataarray_binary_array(vectorfunc, scalarfunc)
esc(Expr(:block,
# DataArray with other array
Any[
quote
function $(vectorfunc)(a::$atype, b::$btype)
data1 = $(atype == :DataArray || atype == :(DataArray{Bool}) ? :(a.data) : :a)
data2 = $(btype == :DataArray || btype == :(DataArray{Bool}) ? :(b.data) : :b)
res = Array($outtype, promote_shape(size(a), size(b)))
res = Array(promote_op(@functorize($vectorfunc), eltype(a), eltype(b)),
promote_shape(size(a), size(b)))
resna = $narule
@bitenumerate resna i na begin
if !na
Expand All @@ -284,7 +297,9 @@ macro dataarray_binary_array(vectorfunc, scalarfunc, outtype)
Any[
quote
function $(vectorfunc)(a::$atype, b::$btype)
res = similar($(asim ? :a : :b), $outtype, promote_shape(size(a), size(b)))
res = similar($(asim ? :a : :b),
promote_op(@functorize($vectorfunc), eltype(a), eltype(b)),
promote_shape(size(a), size(b)))
for i = 1:length(a)
res[i] = $(scalarfunc)(a[i], b[i])
end
Expand Down Expand Up @@ -607,7 +622,7 @@ end
(.^)(::Irrational{:e}, B::AbstractDataArray) = exp(B)

for f in (:(+), :(.+), :(-), :(.-),
:(*), :(.*), :(.^), :(Base.div),
:(*), :(.*), :(/), :(./), :(.^), :(Base.div),
:(Base.mod), :(Base.fld), :(Base.rem), :(Base.min),
:(Base.max))
@eval begin
Expand Down Expand Up @@ -685,15 +700,15 @@ end

end # if isdefined(Base, :UniformScaling)

for f in (:(.+), :(.-), :(*), :(.*),
for f in (:(.+), :(.-), :(*), :(.*), :(./),
:(.^), :(Base.div), :(Base.mod), :(Base.fld), :(Base.rem))
@eval begin
# Array with NA
@swappable $(f){T,N}(::NAtype, b::AbstractArray{T,N}) =
DataArray(Array(T, size(b)), trues(size(b)))

# DataArray with scalar
@dataarray_binary_scalar $f $f promote_type(eltype(a), eltype(b)) true
@dataarray_binary_scalar $f $f nothing true
end
end

Expand All @@ -708,33 +723,22 @@ end
(^)(::NAtype, ::Integer) = NA
(^)(::NAtype, ::Number) = NA

for (vf, sf) in ((:(+), :(+)),
(:(-), :(-)))
for f in (:(+), :(-))
@eval begin
# Necessary to avoid ambiguity warnings
@swappable ($vf)(A::BitArray, B::AbstractDataArray{Bool}) = ($vf)(Array(A), B)
@swappable ($vf)(A::BitArray, B::DataArray{Bool}) = ($vf)(Array(A), B)

@dataarray_binary_array $vf $sf promote_type(eltype(a), eltype(b))
end
end
@swappable ($f)(A::BitArray, B::AbstractDataArray{Bool}) = ($f)(Array(A), B)
@swappable ($f)(A::BitArray, B::DataArray{Bool}) = ($f)(Array(A), B)

# / and ./ are defined separately since they promote to floating point
for f in (:(/), :(./))
@eval begin
($f)(::NAtype, ::NAtype) = NA
@swappable ($f)(d::NAtype, x::Number) = NA
@dataarray_binary_array $f $f
end
end

# / is defined separately since it is not swappable
(/)(::NAtype, ::NAtype) = NA
@swappable (/)(d::NAtype, x::Number) = NA
(/){T,N}(b::AbstractArray{T,N}, ::NAtype) =
DataArray(Array(T, size(b)), trues(size(b)))
@dataarray_binary_scalar(/, /, eltype(a) <: AbstractFloat || typeof(b) <: AbstractFloat ?
promote_type(eltype(a), typeof(b)) : Float64, false)
@swappable (./){T,N}(::NAtype, b::AbstractArray{T,N}) =
DataArray(Array(T, size(b)), trues(size(b)))
@dataarray_binary_scalar(./, /, eltype(a) <: AbstractFloat || typeof(b) <: AbstractFloat ?
promote_type(eltype(a), typeof(b)) : Float64, true)
@dataarray_binary_scalar(/, /, nothing, false)

for f in biscalar_operators
@eval begin
Expand Down
29 changes: 28 additions & 1 deletion test/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ module TestOperators

# Binary operations on pairs of DataVector's
dv = convert(DataArray, ones(5))
dv[1] = NA
# Dates are an example of type for which - and .- return a different type from its inputs
dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")])
dv[1] = dvd[1] = NA
@test_da_pda dv begin
for f in map(eval, DataArrays.array_arithmetic_operators)
for i in 1:length(dv)
Expand All @@ -186,6 +188,12 @@ module TestOperators
@assert f(bv, bv)[i] == f(bv[i], bv[i])
end
end
for i in 1:length(dvd)
@assert isna((dvd - dvd)[i]) && isna(dvd[i]) ||
(dvd - dvd)[i] == dvd[i] - dvd[i]
@assert isna((dvd .- dvd)[i]) && isna(dvd[i]) ||
(dvd .- dvd)[i] == dvd[i] - dvd[i]
end
end

# + and - with UniformScaling
Expand Down Expand Up @@ -226,27 +234,46 @@ module TestOperators

# Pairwise vector operators on DataVector's
dv = @data([911, 269, 835.0, 448, 772])
# Dates are an example of type for which operations return a different type from their inputs
dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")])
for f in map(eval, DataArrays.pairwise_vector_operators)
@assert isequal(f(dv), f(dv.data))
@assert isequal(f(dvd), f(dvd.data))
end
dv = @data([NA, 269, 835.0, 448, 772])
dvd = @data([NA, Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05")])
for f in map(eval, DataArrays.pairwise_vector_operators)
v = f(dv)
@assert isna(v[1])
@assert isequal(v[2:4], f(dv.data)[2:4])

d = f(dvd)
@assert isna(d[1])
@assert isequal(d[2:3], f(dvd.data)[2:3])
end
dv = @data([911, NA, 835.0, 448, 772])
dvd = @data([Base.Date("2000-01-01"), NA, Base.Date("2010-01-01"), Base.Date("2010-01-05")])
for f in map(eval, DataArrays.pairwise_vector_operators)
v = f(dv)
@assert isna(v[1])
@assert isna(v[2])
@assert isequal(v[3:4], f(dv.data)[3:4])

d = f(dvd)
@assert isna(d[1])
@assert isna(d[2])
@assert isequal(d[3:3], f(dvd.data)[3:3])
end
dv = @data([911, 269, 835.0, 448, NA])
dvd = @data([Base.Date("2000-01-01"), Base.Date("2010-01-01"), Base.Date("2010-01-05"), NA])
for f in map(eval, DataArrays.pairwise_vector_operators)
v = f(dv)
@assert isna(v[4])
@assert isequal(v[1:3], f(dv.data)[1:3])

d = f(dvd)
@assert isna(d[3])
@assert isequal(d[1:2], f(dvd.data)[1:2])
end

# Cumulative vector operators on DataVector's
Expand Down

0 comments on commit ae21707

Please sign in to comment.