Skip to content

Commit

Permalink
Fix mutable_operate! for array with non-mutable elements
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 16, 2019
1 parent 89694dc commit e5e9691
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 68 deletions.
13 changes: 11 additions & 2 deletions src/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,7 @@ function _mul!(output, A, B, α)
return mutable_operate!(add_mul, output, A, B, scaling(α))
end
function _mul!(output, A, B)
mutable_operate!(zero, output)
return mutable_operate!(add_mul, output, A, B)
return mutable_operate_to!(output, *, A, B)
end

function LinearAlgebra.mul!(ret::AbstractMatrix{<:AbstractMutable},
Expand Down Expand Up @@ -266,3 +265,13 @@ function Matrix(A::LinearAlgebra.Hermitian{<:AbstractMutable})
end
return B
end

# Called in `getindex` of `LinearAlgebra.LowerTriangular` and `LinearAlgebra.UpperTriangular`.
Base.zero(x::AbstractMutable) = zero(typeof(x))

# To determine whether the funtion is zero preserving, `LinearAlgebra` calls
# `zero` on the `eltype` of the broadcasted object and then check `_iszero`.
# `_iszero(x)` redirects to `iszero(x)` for numbers and to `x == 0` otherwise.
# `x == 0` returns false for types that implement `iszero` but not `==` such as
# `DummyBigInt` and MOI functions.
LinearAlgebra._iszero(x::AbstractMutable) = iszero(x)
38 changes: 10 additions & 28 deletions src/linear_algebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,7 @@ end
> WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
=#

function _mut_check(C, A, B)
if mutability(eltype(C), add_mul, eltype(C), eltype(A), eltype(B)) isa NotMutable
error("mutable_operate!(add_mul, ::$(typeof(C)), ::$(typeof(A)), ::$(typeof(B))) not implemented as $(eltype(C)) cannot be mutated to the result.")
end
end

function _dim_check(C::AbstractVector, A::AbstractMatrix, B::AbstractVector)
_mut_check(C, A, B)

mB = length(B)
mA, nA = size(A)
if mB != nA
Expand All @@ -159,8 +151,6 @@ function _dim_check(C::AbstractVector, A::AbstractMatrix, B::AbstractVector)
end

function _dim_check(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
_mut_check(C, A, B)

mB, nB = size(B)
mA, nA = size(A)
if mB != nA
Expand All @@ -176,15 +166,15 @@ function _add_mul_array(C::Vector, A::AbstractMatrix, B::AbstractVector)
# We need a buffer to hold the intermediate multiplication.
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))

#@inbounds begin
@inbounds begin
for k = eachindex(B)
aoffs = (k-1) * Astride
b = B[k]
for i = Base.OneTo(size(A, 1))
mutable_buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs + i], b)
C[i] = buffered_operate!(mul_buffer, add_mul, C[i], A[aoffs + i], b)
end
end
#end # @inbounds
end # @inbounds

return C
end
Expand All @@ -193,15 +183,15 @@ end
function _add_mul_array(C::Matrix, A::AbstractMatrix, B::AbstractMatrix)
mul_buffer = buffer_for(add_mul, eltype(C), eltype(A), eltype(B))

#@inbounds begin
@inbounds begin
for i = 1:size(A, 1), j = 1:size(B, 2)
Ctmp = C[i, j]
mutable_operate!(zero, Ctmp)
Ctmp = zero!(C[i, j])
for k = 1:size(A, 2)
mutable_buffered_operate!(mul_buffer, add_mul, Ctmp, A[i, k], B[k, j])
Ctmp = buffered_operate!(mul_buffer, add_mul, Ctmp, A[i, k], B[k, j])
end
C[i, j] = Ctmp
end
#end # @inbounds
end # @inbounds

return C
end
Expand All @@ -214,19 +204,11 @@ end
function mutable_operate!(::typeof(zero), C::Union{Vector, Matrix})
# C may contain undefined values so we cannot call `zero!`
for i in eachindex(C)
#@inbounds C[i] = zero(eltype(C))
C[i] = zero(eltype(C))
@inbounds C[i] = zero(eltype(C))
end
end

function mutable_operate_to!(C::Union{Vector, Matrix}, ::typeof(*), A::AbstractMatrix, B::AbstractVecOrMat)
# If `mutability(S, muladd!, T, U)` is `NotMutable`, we might as well redirect to `LinearAlgebra.mul!(C, A, B)`
# in which case we can do `muladd_buf_impl!(mul_buffer, A[aoffs + i], b, C[i])` here instead of
# `A[aoffs + i] = muladd_buf!(mul_buffer, A[aoffs + i], b, C[i])`
if mutability(eltype(C), add_mul, eltype(C), eltype(A), eltype(B)) isa NotMutable
return LinearAlgebra.mul!(C, A, B)
end

function mutable_operate_to!(C::AbstractArray, ::typeof(*), A::AbstractArray, B::AbstractArray)
mutable_operate!(zero, C)
return mutable_operate!(add_mul, C, A, B)
end
Expand Down
43 changes: 8 additions & 35 deletions src/sparse_arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,7 @@ similar_array_type(::Type{SparseMat{Tv, Ti}}, ::Type{T}) where {T, Tv, Ti} = Spa
function mutable_operate!(::typeof(add_mul), output::SparseMat{T},
A::AbstractMatrix, B::AbstractMatrix) where T
C = Matrix{T}(undef, size(output)...)
mutable_operate!(zero, C)
mutable_operate!(add_mul, C, A, B)
mutable_operate_to!(C, *, A, B)
copyto!(output, C)
return output
end
Expand All @@ -56,11 +55,11 @@ function mutable_operate!(::typeof(add_mul), ret::VecOrMat{T},
for k 1:size(ret, 2)
for col 1:A.n
cur = ret[col, k]
# TODO replace by nzrange
for j A.colptr[col]:(A.colptr[col + 1] - 1)
for j SparseArrays.nzrange(A, col)
A_val = _mirror_transpose_or_adjoint(A_nonzeros[j], adjA)
mutable_operate!(add_mul, cur, A_val, B[A_rowvals[j], k], α...)
cur = operate!(add_mul, cur, A_val, B[A_rowvals[j], k], α...)
end
ret[col, k] = cur
end
end
return ret
Expand All @@ -75,7 +74,7 @@ function mutable_operate!(::typeof(add_mul), ret::VecOrMat{T},
for k 1:size(ret, 2)
αxj = *(B[col,k], α...)
for j SparseArrays.nzrange(A, col)
mutable_operate!(add_mul, ret[A_rowvals[j], k], A_nonzeros[j], αxj)
ret[A_rowvals[j], k] = operate!(add_mul, ret[A_rowvals[j], k], A_nonzeros[j], αxj)
end
end
end
Expand All @@ -91,8 +90,9 @@ function mutable_operate!(::typeof(add_mul), ret::Matrix{T},
for col 1:size(B, 2)
cur = ret[multivec_row, col]
for k SparseArrays.nzrange(B, col)
mutable_operate!(add_mul, cur, A[multivec_row, rowval[k]], B_nonzeros[k], α...)
cur = operate!(add_mul, cur, A[multivec_row, rowval[k]], B_nonzeros[k], α...)
end
ret[multivec_row, col] = cur
end
end
return ret
Expand All @@ -117,7 +117,7 @@ function mutable_operate!(::typeof(add_mul), ret::Matrix{T},
B_val = _mirror_transpose_or_adjoint(B_nonzeros[k], adjB)
αB_val = *(B_val, α...)
for A_row in 1:size(A, 1)
mutable_operate!(add_mul, ret[A_row, B_row], A[A_row, B_col], αB_val)
ret[A_row, B_row] = operate!(add_mul, ret[A_row, B_row], A[A_row, B_col], αB_val)
end
end
return ret
Expand All @@ -136,30 +136,3 @@ function mutable_operate!(::typeof(add_mul), ret::Matrix{T},
# TODO adapt implementation of `SparseArray.spmatmul`
mutable_operate!(add_mul, ret, Matrix{promote_operation(zero, eltype(A))}(A), B, α...)
end

# TODO
#function _densify_with_jump_eltype(x::SparseMat{V}) where {V <: AbstractVariableRef}
# return convert(Matrix{GenericAffExpr{Float64, V}}, x)
#end
#_densify_with_jump_eltype(x::AbstractMatrix) = convert(Matrix, x)
#
## TODO: Implement sparse * sparse code as in base/sparse/linalg.jl (spmatmul).
#function _mul!(ret::AbstractMatrix{<:AbstractMutable},
# A::SparseMat,
# B::SparseMat)
# return mul!(ret, A, _densify_with_jump_eltype(B))
#end
#
## TODO: Implement sparse * sparse code as in base/sparse/linalg.jl (spmatmul).
#function _mul!(ret::AbstractMatrix{<:AbstractMutable},
# A::TransposeOrAdjoint{<:Any, <:SparseMat},
# B::SparseMat)
# return mul!(ret, A, _densify_with_jump_eltype(B))
#end
#
## TODO: Implement sparse * sparse code as in base/sparse/linalg.jl (spmatmul).
#function _mul!(ret::AbstractMatrix{<:AbstractMutable},
# A::SparseMat,
# B::TransposeOrAdjoint{<:Any, <:SparseMat})
# return mul!(ret, _densify_with_jump_eltype(A), B)
#end
10 changes: 7 additions & 3 deletions test/dummy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ end

MA.promote_operation(::typeof(*), ::Type{DummyBigInt}, ::Type{DummyBigInt}) = DummyBigInt
Base.convert(::Type{DummyBigInt}, x::Int) = DummyBigInt(x)
Base.:(==)(x::DummyBigInt, y::DummyBigInt) = x.data == y.data
Base.zero(::Union{DummyBigInt, Type{DummyBigInt}}) = DummyBigInt(zero(BigInt))
Base.one(::Union{DummyBigInt, Type{DummyBigInt}}) = DummyBigInt(one(BigInt))
MA.isequal_canonical(x::DummyBigInt, y::DummyBigInt) = x.data == y.data
Base.iszero(x::DummyBigInt) = iszero(x.data)
# We don't define == to tests that implementation of MA can pass the tests without defining ==.
# This is the case for MOI functions for instance.
# For th same reason, we only define `zero` and `one` for `Type{DummyBigInt}`, not for `DummyBigInt`.
Base.zero(::Type{DummyBigInt}) = DummyBigInt(zero(BigInt))
Base.one(::Type{DummyBigInt}) = DummyBigInt(one(BigInt))
Base.:+(x::DummyBigInt) = DummyBigInt(+x.data)
Base.:+(x::DummyBigInt, y::DummyBigInt) = DummyBigInt(x.data + y.data)
Base.:+(x::DummyBigInt, y::Union{Integer, UniformScaling{<:Integer}}) = DummyBigInt(x.data + y)
Expand Down

0 comments on commit e5e9691

Please sign in to comment.