Skip to content

Commit

Permalink
Add basic support for SparseArrays
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Nov 23, 2019
1 parent fe9575b commit c3dd3fe
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 27 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
Expand Down
14 changes: 6 additions & 8 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ module MutableArithmetics
# slowdown because it compiles something that works for any `N`. See
# https://github.com/JuliaLang/julia/issues/32761 for details.

# `copy(::BigInt)` and `copy(::Array)` does not copy its elements so we need `deepcopy`.
mutable_copy(x) = deepcopy(x)
mutable_copy(A::AbstractArray) = mutable_copy.(A)

"""
add_mul(a, args...)
Expand All @@ -36,10 +32,10 @@ include("Test/Test.jl")
# Implementation of the interface for Base types
import LinearAlgebra
const Scaling = Union{Number, LinearAlgebra.UniformScaling}
mutable_copy(A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(mutable_copy(parent(A)), ifelse(A.uplo == 'U', :U, :L))
# Broadcast applies the transpose
mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A)))
mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A)))
#mutable_copy(A::LinearAlgebra.Symmetric) = LinearAlgebra.Symmetric(mutable_copy(parent(A)), ifelse(A.uplo == 'U', :U, :L))
## Broadcast applies the transpose
#mutable_copy(A::LinearAlgebra.Transpose) = LinearAlgebra.Transpose(mutable_copy(parent(A)))
#mutable_copy(A::LinearAlgebra.Adjoint) = LinearAlgebra.Adjoint(mutable_copy(parent(A)))
include("bigint.jl")
include("linear_algebra.jl")

Expand All @@ -49,4 +45,6 @@ include("rewrite.jl")

include("dispatch.jl")

include("sparse_arrays.jl")

end # module
1 change: 1 addition & 0 deletions src/bigint.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mutability(::Type{BigInt}) = IsMutable()
mutable_copy(x::BigInt) = deepcopy(x)

# zero
promote_operation(::typeof(zero), ::Type{BigInt}) = BigInt
Expand Down
8 changes: 8 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ end
mutability(x, op, args::Vararg{Any, N}) where {N} = mutability(typeof(x), op, typeof.(args)...)
mutability(::Type) = NotMutable()

# `copy(::BigInt)` and `copy(::Array)` does not copy its elements so we need `deepcopy`.
function mutable_copy end
mutable_copy(x) = deepcopy(x)
mutable_copy(A::AbstractArray) = mutable_copy.(A)
copy_if_mutable_fallback(::NotMutable, x) = x
copy_if_mutable_fallback(::IsMutable, x) = mutable_copy(x)
copy_if_mutable(x) = copy_if_mutable_fallback(mutability(typeof(x)), x)

function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...)
throw(ArgumentError("Cannot call `mutable_operate_to!($output, $op, $(args...))` as `$output` cannot be modifed to equal the result of the operation. Use `operate!` or `operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified."))
end
Expand Down
22 changes: 16 additions & 6 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
mutability(::Type{<:Array}) = IsMutable()
mutable_copy(A::Array) = copy_if_mutable.(A)

# Sum

Expand Down Expand Up @@ -44,20 +45,29 @@ end
function mutable_operate!(::typeof(add_mul), A::Array{S, N}, α1::Scaling, α2::Scaling, B::Array{T, N}, β::Vararg{Scaling, M}) where {S, T, N, M}
return mutable_operate!(add_mul, A, α1 * α2, B, β...)
end
# Fallback, we may be able to be more efficient in more cases by adding more specialized methods
function mutable_operate!(::typeof(add_mul), A::Array, x, y, args::Vararg{Any, N}) where N
return mutable_operate!(+, A, *(x, y, args...))
end

# Product

function promote_operation(op::typeof(*), ::Type{Array{T, N}}, ::Type{S}) where {S, T, N}
return Array{promote_operation(op, T, S), N}
similar_array_type(::Type{Array{T, N}}, ::Type{S}) where {S, T, N} = Array{S, N}
function promote_operation(op::typeof(*), A::Type{<:AbstractArray{T}}, ::Type{S}) where {S, T}
return similar_array_type(A, promote_operation(op, T, S))
end
function promote_operation(op::typeof(*), ::Type{S}, ::Type{Array{T, N}}) where {S, T, N}
return Array{promote_operation(op, S, T), N}
function promote_operation(op::typeof(*), ::Type{S}, A::Type{<:AbstractArray{T}}) where {S, T}
return similar_array_type(A, promote_operation(op, S, T))
end
# `{S}` and `{T}` are used to avoid ambiguity with above methods.
function promote_operation(op::typeof(*), A::Type{<:AbstractArray{S}}, B::Type{<:AbstractArray{T}}) where {S, T}
return promote_array_mul(A, B)
end

function promote_operation(::typeof(*), ::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T}
function promote_array_mul(::Type{Matrix{S}}, ::Type{Vector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
function promote_operation(::typeof(*), ::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T}
function promote_array_mul(::Type{<:AbstractMatrix{S}}, ::Type{<:AbstractVector{T}}) where {S, T}
return Vector{Base.promote_op(LinearAlgebra.matprod, S, T)}
end
function mutable_operate_to!(C::Vector, ::typeof(*), A::AbstractMatrix, B::AbstractVector)
Expand Down
26 changes: 13 additions & 13 deletions src/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@

export @rewrite
macro rewrite(expr)
return rewrite(expr)
return rewrite_and_return(expr)
end

struct Zero end
# We need to copy `x` as it will be used as might be given by the user and be
# given as first argument of `operate!`.
Base.:(+)(zero::Zero, x) = mutable_copy(x)
Base.:(+)(zero::Zero, x) = copy_if_mutable(x)
# `add_mul(zero, ...)` redirects to `muladd(..., zero)` which calls `... + zero`.
Base.:(+)(x, zero::Zero) = mutable_copy(x)
Base.:(+)(x, zero::Zero) = copy_if_mutable(x)

using Base.Meta

Expand Down Expand Up @@ -100,17 +100,17 @@ end

_is_complex_expr(ex) = isa(ex, Expr) && !isexpr(ex, :ref)

function rewrite_and_return(x)
variable, code = rewrite(x)
return :($code; $variable)
end
function rewrite(x)
variable = gensym()
new_variable, code = _rewrite_toplevel(x, variable)
return quote
$variable = MutableArithmetics.Zero()
$code
$new_variable
end
new_variable, code = rewrite_to(x, variable)
return new_variable, :($variable = MutableArithmetics.Zero(); $code)
end

_rewrite_toplevel(x, variable::Symbol) = _rewrite(x, variable, [], [])
rewrite_to(x, variable::Symbol) = _rewrite(x, variable, [], [])

function _is_comparison(ex::Expr)
if isexpr(ex, :comparison)
Expand Down Expand Up @@ -192,7 +192,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
for i in 2:length(x.args)
if _is_complex_expr(x.args[i])
s = gensym()
new_var_, parsed = _rewrite_toplevel(x.args[i], s)
new_var_, parsed = rewrite_to(x.args[i], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
x.args[i] = new_var_
else
Expand All @@ -209,7 +209,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
if x.args[3] == 2
blk = Expr(:block)
s = gensym()
new_var_, parsed = _rewrite_toplevel(x.args[2], s)
new_var_, parsed = rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs..., new_var_, new_var_,
Expand All @@ -222,7 +222,7 @@ function _rewrite(x, aff::Symbol, lcoeffs::Vector, rcoeffs::Vector, new_var::Sym
else
blk = Expr(:block)
s = gensym()
new_var_, parsed = _rewrite_toplevel(x.args[2], s)
new_var_, parsed = rewrite_to(x.args[2], s)
push!(blk.args, :($s = MutableArithmetics.Zero(); $parsed))
push!(blk.args, :($new_var = MutableArithmetics.add_mul!(
$aff, $(Expr(:call, :*, lcoeffs...,
Expand Down
3 changes: 3 additions & 0 deletions src/sparse_arrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import SparseArrays
similar_array_type(::Type{SparseArrays.SparseVector{Tv, Ti}}, ::Type{T}) where {T, Tv, Ti} = SparseArrays.SparseVector{T, Ti}
similar_array_type(::Type{SparseArrays.SparseMatrixCSC{Tv, Ti}}, ::Type{T}) where {T, Tv, Ti} = SparseArrays.SparseMatrixCSC{T, Ti}
15 changes: 15 additions & 0 deletions test/rewrite.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,21 @@ function vectorized_test(x, X11, X23, Xd)
@test MA.isequal_canonical(A*X, B*X)
@test MA.isequal_canonical(A*X', B*X')
@test MA.isequal_canonical(A'*X, B'*X)

A = [1 2 3
0 4 5
6 0 7]
B = sparse(A)

@test_rewrite reshape(x, (1, 3)) * A * x .- 1
@test_rewrite x'*A*x .- 1
@test_rewrite x'*B*x .- 1
for (A1, A2) in [(A, A), (A, B), (B, A), (B, B)]
@test_rewrite (x'A1)' + 2A2*x
@test_rewrite (x'A1)' + 2A2*x .- 1
@test_rewrite (x'A1)' + 2A2*x .- [3:-1:1;]
@test_rewrite (x'A1)' + 2A2*x - [3:-1:1;]
end
end

function broadcast_test(x)
Expand Down

0 comments on commit c3dd3fe

Please sign in to comment.