Skip to content

Commit

Permalink
Implement mutability of Array
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 19, 2019
1 parent 7ee6d28 commit af65f13
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ abstract type AbstractMutable end
# Special-case because the the base version wants to do fill!(::Array{AbstractVariableRef}, zero(GenericAffExpr{Float64,eltype(x)}))
_one_indexed(A) = all(x -> isa(x, Base.OneTo), axes(A))
function LinearAlgebra.diagm(x::AbstractVector{<:AbstractMutable})
@assert _one_indexed(x) # Base.diagm doesn't work for non-one-indexed arrays in general.
@assert _one_indexed(x) # `LinearAlgebra.diagm` doesn't work for non-one-indexed arrays in general.
ZeroType = promote_operation(zero, eltype(x))
return diagm(0 => copyto!(similar(x, ZeroType), x))
return LinearAlgebra.diagm(0 => copyto!(similar(x, ZeroType), x))
end
14 changes: 6 additions & 8 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,10 @@ function promote_operation end
function promote_operation(op::Function, args::Vararg{Type, N}) where N
return typeof(op(zero.(args)...))
end
#promote_operation(::typeof(*), ::Type{T}) where {T} = T
#function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N}
# return Array{promote_operation(op, T, S), N}
#end
#function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N}
# return Array{promote_operation(op, S, T), N}
#end
promote_operation(::typeof(*), ::Type{T}) where {T} = T
function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::Vararg{Type, N}) where {S, T, U, N}
return promote_operation(*, promote_operation(*, S, T), U, args...)
end

# Define Traits
abstract type MutableTrait end
Expand Down Expand Up @@ -47,7 +44,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...
end

function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...)
error("`mutable_operate_to!($op, $(args...))` is not implemented yet.")
error("`mutable_operate_to!($(typeof(output)), $op, ", join(typeof.(args), ", "),
")` is not implemented yet.")
end

"""
Expand Down
58 changes: 57 additions & 1 deletion src/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,62 @@
import LinearAlgebra

mutability(::Type{<:Vector}) = IsMutable()
mutability(::Type{<:Array}) = IsMutable()

# Sum

function promote_operation(op::typeof(+), ::Type{Array{S, N}}, ::Type{Array{T, N}}) where {S, T, N}
return Array{promote_operation(op, S, T), N}
end
function mutable_operate!(::typeof(+), A::Array{S, N}, B::Array{T, N}) where{S, T, N}
for i in eachindex(A)
A[i] = operate!(+, A[i], B[i])
end
return A
end

# UniformScaling
const Scaling = Union{Number, LinearAlgebra.UniformScaling}
function promote_operation(op::typeof(+), ::Type{Array{T, 2}}, ::Type{LinearAlgebra.UniformScaling{S}}) where {S, T}
return Array{promote_operation(op, T, S), 2}
end
function promote_operation(op::typeof(+), ::Type{LinearAlgebra.UniformScaling{S}}, ::Type{Array{T, 2}}) where {S, T}
return Array{promote_operation(op, S, T), 2}
end
function mutable_operate!(::typeof(+), A::Matrix, B::LinearAlgebra.UniformScaling)
n = LinearAlgebra.checksquare(A)
for i in 1:n
A[i, i] = operate!(+, A[i, i], B)
end
return A
end
function mutable_operate!(::typeof(add_mul), A::Matrix, B::Scaling, C::Scaling, D::Vararg{Scaling, N}) where N
return mutable_operate!(+, A, *(B, C, D...))
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, B::Array{T, N}, α::Vararg{Scaling, M}) where {S, T, N, M}
for i in eachindex(A)
A[i] = operate!(add_mul, A[i], B[i], α...)
end
return A
end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
for i in eachindex(A)
A[i] = operate!(add_mul, A[i], α, B[i], β...)
end
return A
end

# Product

function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N}
return Array{promote_operation(op, T, S), N}
end
function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N}
return Array{promote_operation(op, S, T), N}
end

function promote_operation(::typeof(*), ::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
function promote_operation(::typeof(*), ::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
Expand Down
25 changes: 21 additions & 4 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,22 @@ Base.:(+)(x, zero::Zero) = copy(x)

using Base.Meta

# See `JuMP._try_parse_idx_set`
function _try_parse_idx_set(arg::Expr)
# [i=1] and x[i=1] parse as Expr(:vect, Expr(:(=), :i, 1)) and
# Expr(:ref, :x, Expr(:kw, :i, 1)) respectively.
if arg.head === :kw || arg.head === :(=)
@assert length(arg.args) == 2
return true, arg.args[1], arg.args[2]
elseif isexpr(arg, :call) && arg.args[1] === :in
return true, arg.args[2], arg.args[3]
else
return false, nothing, nothing
end
end

function _parse_idx_set(arg::Expr)
parse_done, idxvar, idxset = Containers._try_parse_idx_set(arg)
parse_done, idxvar, idxset = _try_parse_idx_set(arg)
if parse_done
return idxvar, idxset
end
Expand Down Expand Up @@ -61,6 +75,9 @@ function _parse_gen(ex, atleaf)
return loop
end

# See `JuMP._is_sum`
_is_sum(s::Symbol) = (s == :sum) || (s == :∑) || (s == )

function _parse_generator(x::Expr, aff::Symbol, lcoeffs, rcoeffs, newaff=gensym())
@assert isexpr(x,:call)
@assert length(x.args) > 1
Expand Down Expand Up @@ -175,7 +192,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb
x.args[i] = esc(x.args[i])
end
end
callexpr = Expr(:call, :operate!, add_mul, aff,
callexpr = Expr(:call, :(MutableArithmetics.operate!), add_mul, aff,
lcoeffs..., x.args[2:end]..., rcoeffs...)
push!(blk.args, :($newaff = $callexpr))
return newaff, blk
Expand All @@ -187,7 +204,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb
s = gensym()
newaff_, parsed = _rewrite_toplevel(x.args[2], s)
push!(blk.args, :($s = Zero(); $parsed))
push!(blk.args, :($newaff = operate!(add_mul,
push!(blk.args, :($newaff = MutableArithmetics.operate!(add_mul,
$aff, $(Expr(:call, :*, lcoeffs..., newaff_, newaff_,
rcoeffs...)))))
return newaff, blk
Expand Down Expand Up @@ -225,6 +242,6 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, newaff::Symb
" become a syntax error in a future release."
end
# at the lowest level
callexpr = Expr(:call, :operate!, add_mul, aff, lcoeffs..., esc(x), rcoeffs...)
callexpr = Expr(:call, :(MutableArithmetics.operate!), add_mul, aff, lcoeffs..., esc(x), rcoeffs...)
return newaff, :($newaff = $callexpr)
end
2 changes: 2 additions & 0 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ function promote_operation(::typeof(add_mul), T::Type, args::Vararg{Type, N}) wh
return promote_operation(+, T, promote_operation(*, args...))
end

mutable_operate!(::typeof(add_mul), x, y) = mutable_operate!(+, x, y)

"""
add_mul_to!(output, args...)
Expand Down
4 changes: 4 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ function vectorized_test(x, X11, X23, Xd)
v = [4, 5, 6]

@testset "Sum of matrices" begin
@test_rewrite(x + x)
@test_rewrite(x + 2x)
@test_rewrite(x + x * 2)
@test_rewrite(x + 2x * 2)
@test_rewrite(Xd + Yd)
@test_rewrite(Xd + 2Yd)
@test_rewrite(Xd + Yd * 2)
Expand Down

0 comments on commit af65f13

Please sign in to comment.