Skip to content

Commit

Permalink
Merge fe9575b into c88b584
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 22, 2019
2 parents c88b584 + fe9575b commit c0426a4
Show file tree
Hide file tree
Showing 14 changed files with 1,022 additions and 14 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@ os:
- linux
- osx
julia:
- 1.0
- 1.1
- 1.2
- 1.3
- nightly
matrix:
allow_failures:
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
julia = "1.1"

[extras]
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test"]
test = ["Test", "OffsetArrays", "SparseArrays"]
27 changes: 27 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,41 @@ module MutableArithmetics
# slowdown because it compiles something that works for any `N`. See
# https://github.com/JuliaLang/julia/issues/32761 for details.

# `copy(::BigInt)` and `copy(::Array)` does not copy its elements so we need `deepcopy`.
mutable_copy(x) = deepcopy(x)
mutable_copy(A::AbstractArray) = mutable_copy.(A)

"""
add_mul(a, args...)
Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`.
"""
function add_mul end
add_mul(a, b) = a + b
add_mul(a, b, c) = muladd(b, c, a)
add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...))

include("interface.jl")
include("shortcuts.jl")
include("broadcast.jl")

# Test that can be used to test an implementation of the interface
include("Test/Test.jl")

# Implementation of the interface for Base types
import LinearAlgebra
const Scaling = Union{Number, LinearAlgebra.UniformScaling}
mutable_copy(A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(mutable_copy(parent(A)), ifelse(A.uplo == 'U', :U, :L))
# Broadcast applies the transpose
mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A)))
mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A)))
include("bigint.jl")
include("linear_algebra.jl")

isequal_canonical(a, b) = a == b

include("rewrite.jl")

include("dispatch.jl")

end # module
16 changes: 14 additions & 2 deletions src/bigint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,22 @@ promote_operation(::typeof(+), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(+), a::BigInt, b::BigInt)
return Base.GMP.MPZ.add!(output, a, b)
end
#function mutable_operate_to!(output::BigInt, op::typeof(+), a::BigInt, b::LinearAlgebra.UniformScaling)
# return mutable_operate_to!(output, op, a, b.λ)
#end

# *
promote_operation(::typeof(*), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(*), a::BigInt, b::BigInt)
return Base.GMP.MPZ.mul!(output, a, b)
end

function mutable_operate_to!(output::BigInt, op::Union{typeof(*), typeof(+)},
a::BigInt, b::BigInt, c::Vararg{BigInt, N}) where N
mutable_operate_to!(output, op, a, b)
return mutable_operate!(op, output, c...)
end

# add_mul
function mutable_operate_to!(output::BigInt, ::typeof(add_mul), args::Vararg{BigInt, N}) where N
return mutable_buffered_operate_to!(BigInt(), output, add_mul, args...)
Expand All @@ -30,6 +39,9 @@ function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, ::typeof(a
return mutable_operate_to!(output, +, a, buffer)
end

function mutable_operate_to!(output::BigInt, op::Function, a::Integer, b::Integer)
return mutable_operate_to!(output, op, convert(BigInt, a), convert(BigInt, b))
scaling_to_bigint(x::BigInt) = x
scaling_to_bigint(x::Number) = convert(BigInt, x)
scaling_to_bigint(J::LinearAlgebra.UniformScaling) = scaling_to_bigint(J.λ)
function mutable_operate_to!(output::BigInt, op::Function, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(output, op, scaling_to_bigint.(args)...)
end
76 changes: 76 additions & 0 deletions src/broadcast.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Eltype}) where {N, Eltype}
return Array{Eltype, N}
end
function broadcasted_type(::Broadcast.DefaultArrayStyle{N}, ::Type{Bool}) where N
return BitArray{N}
end

# Same as `Base.Broadcast.combine_styles` but with types as argument.
combine_styles() = Broadcast.DefaultArrayStyle{0}()
combine_styles(c::Type) = Broadcast.result_style(Broadcast.BroadcastStyle(c))
combine_styles(c1::Type, c2::Type) = Broadcast.result_style(combine_styles(c1), combine_styles(c2))
@inline combine_styles(c1::Type, c2::Type, cs::Vararg{Type, N}) where N = Broadcast.result_style(combine_styles(c1), combine_styles(c2, cs...))

function promote_broadcast(op::Function, args::Vararg{Any, N}) where N
# FIXME we could use `promote_operation` instead as
# `combine_eltypes` uses `return_type` hence it may return a non-concrete type
# and we do not handle that case.
T = Base.Broadcast.combine_eltypes(op, args)
return broadcasted_type(combine_styles(args...), T)
end

"""
broadcast_mutability(T::Type, ::typeof(op), args::Type...)::MutableTrait
Return `IsMutable` to indicate an object of type `T` can be modified to be
equal to `broadcast(op, args...)`.
"""
function broadcast_mutability(T::Type, op, args::Vararg{Type, N}) where N
if mutability(T) isa IsMutable && promote_broadcast(op, args...) == T
return IsMutable()
else
return NotMutable()
end
end
broadcast_mutability(x, op, args::Vararg{Any, N}) where {N} = broadcast_mutability(typeof(x), op, typeof.(args)...)
broadcast_mutability(::Type) = NotMutable()

"""
mutable_broadcast!(op::Function, args...)
Modify the value of `args[1]` to be equal to the value of `broadcast(op, args...)`. Can
only be called if `mutability(args[1], op, args...)` returns `true`.
"""
function mutable_broadcast! end

function mutable_broadcasted(broadcasted::Broadcast.Broadcasted{S}) where S
function f(args::Vararg{Any, N}) where N
return operate!(broadcasted.f, args...)
end
return Broadcast.Broadcasted{S}(f, broadcasted.args, broadcasted.axes)
end

# If A is `Symmetric`, we cannot do that as we might modify the same entry twice.
# See https://github.com/JuliaOpt/JuMP.jl/issues/2102
function mutable_broadcast!(op::Function, A::Array, args::Vararg{Any, N}) where N
bc = Broadcast.broadcasted(op, A, args...)
instantiated = Broadcast.instantiate(bc)
return copyto!(A, mutable_broadcasted(instantiated))
end


"""
broadcast!(op::Function, args...)
Returns the value of `broadcast(op, args...)`, possibly modifying `args[1]`.
"""
function broadcast!(op::Function, args::Vararg{Any, N}) where N
return broadcast_fallback!(broadcast_mutability(args[1], op, args...), op, args...)
end

function broadcast_fallback!(::NotMutable, op::Function, args::Vararg{Any, N}) where N
return broadcast(op, args...)
end
function broadcast_fallback!(::IsMutable, op::Function, args::Vararg{Any, N}) where N
return mutable_broadcast!(op, args...)
end
9 changes: 9 additions & 0 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
abstract type AbstractMutable end

# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)}))
_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A))
function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
@assert _one_indexed(x) # `LinearAlgebra.diagm` doesn't work for non-one-indexed arrays in general.
ZeroType = promote_operation(zero, eltype(x))
return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x))
end
15 changes: 14 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,18 @@ function promote_operation end
function promote_operation(op::Function, args::Vararg{Type, N}) where N
return typeof(op(zero.(args)...))
end
promote_operation(::typeof(*), ::Type{T}) where {T} = T
function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::Vararg{Type, N}) where {S, T, U, N}
return promote_operation(*, promote_operation(*, S, T), U, args...)
end

# Helpful error for common mistake
function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, A::Type{<:Array}, α::Type{<:Number})
error("Operation `$op` between `$A` and `` is not allowed. You should use broadcast.")
end
function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, α::Type{<:Number}, A::Type{<:Array})
error("Operation `$op` between `` and `$A` is not allowed. You should use broadcast.")
end

# Define Traits
abstract type MutableTrait end
Expand Down Expand Up @@ -40,7 +52,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...
end

function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...)
error("`mutable_operate_to!($op, $(args...))` is not implemented yet.")
error("`mutable_operate_to!($(typeof(output)), $op, ", join(typeof.(args), ", "),
")` is not implemented yet.")
end

"""
Expand Down
60 changes: 58 additions & 2 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
import LinearAlgebra
mutability(::Type{<:Array}) = IsMutable()

mutability(::Type{<:Vector}) = IsMutable()
# Sum

function promote_operation(op::Union{typeof(+), typeof(-)}, ::Type{Array{S, N}}, ::Type{Array{T, N}}) where {S, T, N}
return Array{promote_operation(op, S, T), N}
end
function mutable_operate!(op::Union{typeof(+), typeof(-)}, A::Array{S, N}, B::Array{T, N}) where {S, T, N}
for i in eachindex(A)
A[i] = operate!(op, A[i], B[i])
end
return A
end

# UniformScaling
function promote_operation(op::typeof(+), ::Type{Array{T, 2}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T}
return Array{promote_operation(op, T, S), 2}
end
function promote_operation(op::typeof(+), ::Type{LinearAlgebra.UniformScaling{S}}, ::Type{Array{T, 2}}) where {S, T}
return Array{promote_operation(op, S, T), 2}
end
function mutable_operate!(::typeof(+), A::Matrix, B::LinearAlgebra.UniformScaling)
n = LinearAlgebra.checksquare(A)
for i in 1:n
A[i, i] = operate!(+, A[i, i], B)
end
return A
end
function mutable_operate!(::typeof(add_mul), A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N
return mutable_operate!(+, A, *(B, C, D...))
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::Array{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M}
for i in eachindex(A)
A[i] = operate!(add_mul, A[i], B[i], α...)
end
return A
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
for i in eachindex(A)
A[i] = operate!(add_mul, A[i], α, B[i], β...)
end
return A
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
return mutable_operate!(add_mul, A, α1 * α2, B, β...)
end

# Product

function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N}
return Array{promote_operation(op, T, S), N}
end
function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N}
return Array{promote_operation(op, S, T), N}
end

function promote_operation(::typeof(*), ::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
function promote_operation(::typeof(*), ::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
Expand Down

0 comments on commit c0426a4

Please sign in to comment.