Skip to content

Commit

Permalink
Merge a072d8c into f7920fa
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 23, 2019
2 parents f7920fa + a072d8c commit 6585c3c
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 47 deletions.
2 changes: 2 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ scaling_convert(T::Type, x::LinearAlgebra.UniformScaling) = convert(T, x.λ)
scaling_convert(T::Type, x) = convert(T, x)
include("bigint.jl")
include("bigfloat.jl")

include("reduce.jl")
include("linear_algebra.jl")
include("sparse_arrays.jl")

Expand Down
30 changes: 4 additions & 26 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,34 +7,12 @@
abstract type AbstractMutable end

function Base.sum(a::AbstractArray{<:AbstractMutable})
return mapreduce(identity, add!, a, init = zero(promote_operation(+, eltype(a), eltype(a))))
return operate(sum, a)
end

LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray) = _dot(lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray, rhs::AbstractArray{<:AbstractMutable}) = _dot(lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray{<:AbstractMutable}) = _dot(lhs, rhs)

function _dot(x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
end
if iszero(lx)
return LinearAlgebra.dot(zero(eltype(x)), zero(eltype(y)))
end

# We need a buffer to hold the intermediate multiplication.

SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y))
mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y))
s = zero(SumType)

for (Ix, Iy) in zip(eachindex(x), eachindex(y))
s = @inbounds buffered_operate!(mul_buffer, add_mul, s, x[Ix], y[Iy])
end

return s
end
LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray) = operate(LinearAlgebra.dot, lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray, rhs::AbstractArray{<:AbstractMutable}) = operate(LinearAlgebra.dot, lhs, rhs)
LinearAlgebra.dot(lhs::AbstractArray{<:AbstractMutable}, rhs::AbstractArray{<:AbstractMutable}) = operate(LinearAlgebra.dot, lhs, rhs)

# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)}))
_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A))
Expand Down
74 changes: 70 additions & 4 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,72 @@ function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, α:
error("Operation `$op` between `` and `$A` is not allowed. You should use broadcast.")
end

"""
operate(op::Function, args...)
Return an object equal to the result of `op(args...)` that can be mutated
through the MultableArithmetics API without affecting the arguments.
By default:
* `operate(+, x)` and `operate(+, x)` redirect to `copy_if_mutable(x)` so a
mutable type `T` can return the same instance from unary operators
`+(x::T) = x` and `*(x::T) = x`.
* `operate(+, args...)` (resp. `operate(-, args...)` and `operate(*, args...)`)
redirect to `+(args...)` (resp. `-(args...)` and `*(args...)`) if `length(args)`
is at least 2 (or the operation is `-`).
Note that when `op` is a `Base` function whose implementation can be improved
for mutable arguments, `operate(op, args...)` may have an implementation in
this package relying on the MutableArithmetics API instead of redirecting to
`op(args...)`. This is the case for instance:
* for `Base.sum`,
* for `LinearAlgebra.dot` and
* for matrix-matrix product and matrix-vector product.
Therefore, for mutable arguments, there may be a performance advantage to call
`operate(op, args...)` instead of `op(args...)`.
## Example
If for a mutable type `T`, the following is defined:
```julia
function Base.:*(a::Bool, x::T)
return a ? x : zero(T)
end
```
then `operate(*, a, x)` will return the instance `x` whose modification will
affect the argument of `operate`. Therefore, the following method need to
be implemented
```julia
function MA.operate(::typeof(*), a::Bool, x::T)
return a ? MA.mutable_copy(x) : zero(T)
end
```
"""
function operate end

# /!\ We assume these two return an object that can be modified through the MA
# API without altering `x` and `y`. If it is not the case, implement a
# custom `operate` method.
operate(::typeof(-), x) = -x
operate(op::Union{typeof(+), typeof(-), typeof(*)}, x, y) where {N} = op(x, y)

# We only give the type to `zero` and `one` to be sure that modifying the
# returned object cannot alter `x`.
operate(::typeof(zero), x) = zero(typeof(x))
operate(::typeof(one), x) = one(typeof(x))

operate(::Union{typeof(+), typeof(*)}, x) = copy_if_mutable(x)
function operate(op::Union{typeof(+), typeof(*)}, x, y, z, args::Vararg{Any, N}) where N
return operate(op, x, operate(op, y, z, args...))
end

operate(::typeof(add_mul), x, y) = operate(+, x, y)
function operate(::typeof(add_mul), x, y, z, args::Vararg{Any, N}) where N
return operate(+, x, operate(*, y, z, args...))
end

# Define Traits

"""
Expand Down Expand Up @@ -196,7 +262,7 @@ function operate_to!(output, op::Function, args::Vararg{Any, N}) where N
end

function operate_to_fallback!(::NotMutable, output, op::Function, args::Vararg{Any, N}) where N
return op(args...)
return operate(op, args...)
end
function operate_to_fallback!(::IsMutable, output, op::Function, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, op, args...)
Expand All @@ -212,7 +278,7 @@ function operate!(op::Function, args::Vararg{Any, N}) where N
end

function operate_fallback!(::NotMutable, op::Function, args::Vararg{Any, N}) where N
return op(args...)
return operate(op, args...)
end
function operate_fallback!(::IsMutable, op::Function, args::Vararg{Any, N}) where N
return mutable_operate!(op, args...)
Expand All @@ -229,7 +295,7 @@ function buffered_operate_to!(buffer, output, op::Function, args::Vararg{Any, N}
end

function buffered_operate_to_fallback!(::NotMutable, buffer, output, op::Function, args::Vararg{Any, N}) where N
return op(args...)
return operate(op, args...)
end
function buffered_operate_to_fallback!(::IsMutable, buffer, output, op::Function, args::Vararg{Any, N}) where N
return mutable_buffered_operate_to!(buffer, output, op, args...)
Expand All @@ -246,7 +312,7 @@ function buffered_operate!(buffer, op::Function, args::Vararg{Any, N}) where N
end

function buffered_operate_fallback!(::NotMutable, buffer, op::Function, args::Vararg{Any, N}) where N
return op(args...)
return operate(op, args...)
end
function buffered_operate_fallback!(::IsMutable, buffer, op::Function, args::Vararg{Any, N}) where N
return mutable_buffered_operate!(buffer, op, args...)
Expand Down
76 changes: 64 additions & 12 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,6 @@ function promote_array_mul(::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{
return Vector{promote_operation(add_mul, S, S, T)}
end

const TransposeOrAdjoint{T, MT} = Union{LinearAlgebra.Transpose{T, MT}, LinearAlgebra.Adjoint{T, MT}}
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Transpose) = LinearAlgebra.transpose(x)
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(x)
function promote_array_mul(::Type{<:TransposeOrAdjoint{S, <:AbstractVector}}, ::Type{<:AbstractVector{T}}) where {S, T}
return promote_operation(add_mul, S, S, T)
end

################################################################################
# We roll our own matmul here (instead of using Julia's generic fallbacks)
# because doing so allows us to accumulate the expressions for the inner loops
Expand Down Expand Up @@ -213,18 +206,18 @@ function mutable_operate_to!(C::AbstractArray, ::typeof(*), A::AbstractArray, B:
return mutable_operate!(add_mul, C, A, B)
end

# `mul` does what `LinearAlgebra/src/matmul.jl` does for abstract
# matrices and vector, i.e., use `matprod` to estimate the resulting element
# type, allocate the resulting array but it redirects to `mul_to!` instead of
# Does what `LinearAlgebra/src/matmul.jl` does for abstract
# matrices and vector, estimate the resulting element type,
# allocate the resulting array but it redirects to `mul_to!` instead of
# `LinearAlgebra.mul!`.
function mul(A::AbstractMatrix{S}, B::AbstractVector{T}) where {T, S}
function operate(::typeof(*), A::AbstractMatrix{S}, B::AbstractVector{T}) where {T, S}
U = promote_operation(add_mul, S, S, T)
# `similar` gives SparseMatrixCSC if `B` is SparseMatrixCSC
#C = similar(B, U, axes(A, 1))
C = Vector{U}(undef, size(A, 1))
return mutable_operate_to!(C, *, A, B)
end
function mul(A::AbstractMatrix{S}, B::AbstractMatrix{T}) where {T, S}
function operate(::typeof(*), A::AbstractMatrix{S}, B::AbstractMatrix{T}) where {T, S}
U = promote_operation(add_mul, S, S, T)
# `similar` gives SparseMatrixCSC if `B` is SparseMatrixCSC
#C = similar(B, U, axes(A, 1), axes(B, 2))
Expand All @@ -236,3 +229,62 @@ end
# Broadcast applies the transpose
#mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A)))
#mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A)))

const TransposeOrAdjoint{T, MT} = Union{LinearAlgebra.Transpose{T, MT}, LinearAlgebra.Adjoint{T, MT}}
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Transpose) = LinearAlgebra.transpose(x)
_mirror_transpose_or_adjoint(x, ::LinearAlgebra.Adjoint) = LinearAlgebra.adjoint(x)
# dot product
function promote_array_mul(::Type{<:TransposeOrAdjoint{S, <:AbstractVector}}, ::Type{<:AbstractVector{T}}) where {S, T}
return promote_operation(add_mul, S, S, T)
end
function operate(::typeof(*), x::LinearAlgebra.Adjoint{<:Any, <:AbstractVector}, y::AbstractVector)
return operate(LinearAlgebra.dot, parent(x), y)
end
function operate(::typeof(*), x::TransposeOrAdjoint{<:Any, <:AbstractVector}, y::AbstractMatrix)
return _mirror_transpose_or_adjoint(
operate(*, _mirror_transpose_or_adjoint(y, x), parent(x)), x)
end

function operate(::typeof(*), x::LinearAlgebra.Transpose{<:Any, <:AbstractVector}, y::AbstractVector)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
end
if iszero(lx)
return zero(promote_operation(add_mul, eltype(x), eltype(y)))
end

# We need a buffer to hold the intermediate multiplication.

SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y))
mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y))
s = zero(SumType)

for (Ix, Iy) in zip(eachindex(x), eachindex(y))
s = @inbounds buffered_operate!(mul_buffer, add_mul, s, x[Ix], y[Iy])
end

return s
end

function operate(::typeof(LinearAlgebra.dot), x::AbstractArray, y::AbstractArray)
lx = length(x)
if lx != length(y)
throw(DimensionMismatch("first array has length $(lx) which does not match the length of the second, $(length(y))."))
end
if iszero(lx)
return LinearAlgebra.dot(zero(eltype(x)), zero(eltype(y)))
end

# We need a buffer to hold the intermediate multiplication.

SumType = promote_operation(add_mul, eltype(x), eltype(x), eltype(y))
mul_buffer = buffer_for(add_mul, SumType, eltype(x), eltype(y))
s = zero(SumType)

for (Ix, Iy) in zip(eachindex(x), eachindex(y))
s = @inbounds buffered_operate!(mul_buffer, add_mul, s, LinearAlgebra.adjoint(x[Ix]), y[Iy])
end

return s
end
3 changes: 3 additions & 0 deletions src/reduce.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
function operate(::typeof(sum), a::AbstractArray)
return mapreduce(identity, add!, a, init = zero(promote_operation(+, eltype(a), eltype(a))))
end
13 changes: 8 additions & 5 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,14 @@ macro rewrite(expr)
end

struct Zero end
# We need to copy `x` as it will be used as might be given by the user and be
# given as first argument of `operate!`.
Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`.
Base.:(+)(x, zero::Zero) = copy_if_mutable(x)
## We need to copy `x` as it will be used as might be given by the user and be
## given as first argument of `operate!`.
#Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
## `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`.
#Base.:(+)(x, zero::Zero) = copy_if_mutable(x)
function operate(::typeof(add_mul), ::Zero, args::Vararg{Any, N}) where {N}
return operate(*, args...)
end
broadcast!(::Union{typeof(add_mul), typeof(+)}, ::Zero, x) = copy_if_mutable(x)
broadcast!(::typeof(add_mul), ::Zero, x, y) = x * y

Expand Down
8 changes: 8 additions & 0 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ Return the product of `a`, `b`, ..., possibly modifying `a`.
"""
mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...)

"""
mul(a, b, ...)
Shortcut for `operate(*, a, b, ...)`, see [`operate`](@ref).
"""
mul(args::Vararg{Any, N}) where {N} = operate(*, args...)


# `Vararg` gives extra allocations on Julia v1.3, see https://travis-ci.com/JuliaOpt/MutableArithmetics.jl/jobs/260666164#L215-L238
function promote_operation(::typeof(add_mul), T::Type, x::Type, y::Type)
return promote_operation(+, T, promote_operation(*, x, y))
Expand Down
37 changes: 37 additions & 0 deletions test/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,32 @@ struct CustomArray{T, N} <: AbstractArray{T, N} end

import LinearAlgebra

function dot_test(x, y)
@test MA.operate(LinearAlgebra.dot, x, y) == LinearAlgebra.dot(x, y)
@test MA.operate(LinearAlgebra.dot, y, x) == LinearAlgebra.dot(y, x)
@test MA.operate(*, x', y) == x' * y
@test MA.operate(*, y', x) == y' * x
@test MA.operate(*, LinearAlgebra.transpose(x), y) == LinearAlgebra.transpose(x) * y
@test MA.operate(*, LinearAlgebra.transpose(y), x) == LinearAlgebra.transpose(y) * x
end

@testset "dot" begin
x = [1im]
y = [1]
A = reshape(x, 1, 1)
B = reshape(y, 1, 1)
dot_test(x, x)
dot_test(y, y)
dot_test(A, A)
dot_test(B, B)
dot_test(x, y)
dot_test(x, A)
dot_test(x, B)
dot_test(y, A)
dot_test(y, B)
dot_test(A, B)
end

@testset "promote_operation" begin
x = [1]
@test MA.promote_operation(*, typeof(x'), typeof(x)) == Int
Expand All @@ -32,6 +58,17 @@ end
B = zeros(2, 2)
err = DimensionMismatch("Cannot sum matrices of size `(1, 1)` and size `(2, 2)`, the size of the two matrices must be equal.")
@test_throws err MA.@rewrite A + B
x = ones(1)
y = ones(2)
err = DimensionMismatch("first array has length 1 which does not match the length of the second, 2.")
@test_throws err MA.operate(*, x', y)
@test_throws err MA.operate(*, LinearAlgebra.transpose(x), y)
err = DimensionMismatch("matrix A has dimensions (2,2), vector B has length 1")
@test_throws err MA.operate(*, x', B)
a = zeros(0)
@test iszero(@inferred MA.operate(LinearAlgebra.dot, a, a))
@test iszero(@inferred MA.operate(*, a', a))
@test iszero(@inferred MA.operate(*, LinearAlgebra.transpose(a), a))
end
end

Expand Down

0 comments on commit 6585c3c

Please sign in to comment.