Skip to content

Commit

Permalink
Merge f08eba5 into ec70751
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Dec 24, 2019
2 parents ec70751 + f08eba5 commit 71a521d
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 94 deletions.
15 changes: 14 additions & 1 deletion src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,20 @@ Return `a + *(args...)`. Note that `add_mul(a, b, c) = muladd(b, c, a)`.
function add_mul end
add_mul(a, b) = a + b
add_mul(a, b, c) = muladd(b, c, a)
add_mul(a, b, c::Vararg{Any, N}) where {N} = add_mul(a, b *(c...))
add_mul(a, b, c, d, args::Vararg{Any, N}) where {N} = add_mul(a, b, *(c, d, args...))

"""
sub_mul(a, args...)
Return `a + *(args...)`. Note that `sub_mul(a, b, c) = muladd(b, c, a)`.
"""
function sub_mul end
sub_mul(a, b) = a - b
sub_mul(a, b, c, args::Vararg{Any, N}) where {N} = a - *(b, c, args...)

const AddSubMul = Union{typeof(add_mul), typeof(sub_mul)}
add_sub_op(::typeof(add_mul)) = +
add_sub_op(::typeof(sub_mul)) = -

"""
iszero!(x)
Expand Down
67 changes: 65 additions & 2 deletions src/Test/int.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,30 @@ function int_add_test(::Type{T}) where T

a = t(165)
b = t(255)
@test MA.isequal_canonical(MA.add!(a, b), t(420))
@test MA.isequal_canonical(a, t(420))
expected = t(420)
@test MA.isequal_canonical(MA.add!(a, b), expected)
@test MA.isequal_canonical(a, expected)
end
end
function int_sub_test(::Type{T}) where T
@testset "sub_to! / sub!" begin
@test MA.mutability(T, -, T, T) isa MA.IsMutable

t(n) = convert(T, n)
a = t(5)
b = t(28)
c = t(41)
expected = t(-13)
@test MA.isequal_canonical(MA.sub_to!(a, b, c), expected)
@test MA.isequal_canonical(a, expected)
@test MA.isequal_canonical(MA.sub!(b, c), expected)
@test MA.isequal_canonical(b, expected)

a = t(165)
b = t(255)
expected = t(-90)
@test MA.isequal_canonical(MA.sub!(a, b), expected)
@test MA.isequal_canonical(a, expected)
end
end
function int_mul_test(::Type{T}) where T
Expand Down Expand Up @@ -76,6 +98,47 @@ function int_add_mul_test(::Type{T}) where T
@test MA.isequal_canonical(a, t(420))
end
end
function int_sub_mul_test(::Type{T}) where T
@testset "sub_mul_to! / sub_mul! / sub_mul_buf_to! / sub_mul_buf!" begin
@test MA.mutability(T, MA.sub_mul, T, T) isa MA.IsMutable
@test MA.mutability(T, MA.sub_mul, T, T, T) isa MA.IsMutable
@test MA.mutability(T, MA.sub_mul, T, T, T, T) isa MA.IsMutable

t(n) = convert(T, n)
a = t(5)
b = t(9)
c = t(3)
d = t(20)
buf = t(24)

expected = t(-51)
@test MA.isequal_canonical(MA.sub_mul_to!(a, b, c, d), expected)
@test MA.isequal_canonical(a, expected)
a = t(5)
@test MA.isequal_canonical(MA.sub_mul!(b, c, d), expected)
@test MA.isequal_canonical(b, expected)
b = t(9)

@test MA.isequal_canonical(MA.sub_mul_buf_to!(buf, a, b, c, d), expected)
@test MA.isequal_canonical(a, expected)
@test MA.isequal_canonical(MA.sub_mul_buf!(buf, b, c, d), expected)
@test MA.isequal_canonical(b, expected)

a = t(148)
b = t(16)
c = t(17)
d = t(42)
buf = t(56)
expected = t(-124)
@test MA.isequal_canonical(MA.sub_mul!(a, b, c), expected)
@test MA.isequal_canonical(a, expected)
a = t(148)
@test MA.isequal_canonical(MA.sub_mul_buf_to!(buf, d, a, b, c), expected)
@test MA.isequal_canonical(d, expected)
@test MA.isequal_canonical(MA.sub_mul_buf!(buf, a, b, c), expected)
@test MA.isequal_canonical(a, expected)
end
end

function int_zero_test(::Type{T}) where T
@testset "zero!" begin
Expand Down
31 changes: 19 additions & 12 deletions src/bigfloat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,21 @@ end
# return mutable_operate_to!(output, op, a, b.λ)
#end

# -
promote_operation(::typeof(-), ::Vararg{Type{BigFloat}, N}) where {N} = BigFloat
function mutable_operate_to!(output::BigFloat, ::typeof(-), a::BigFloat, b::BigFloat)
ccall((:mpfr_sub, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, MPFRRoundingMode), output, a, b, Base.MPFR.ROUNDING_MODE[])
return output
end

# *
promote_operation(::typeof(*), ::Vararg{Type{BigFloat}, N}) where {N} = BigFloat
function mutable_operate_to!(output::BigFloat, ::typeof(*), a::BigFloat, b::BigFloat)
ccall((:mpfr_mul, :libmpfr), Int32, (Ref{BigFloat}, Ref{BigFloat}, Ref{BigFloat}, MPFRRoundingMode), output, a, b, Base.MPFR.ROUNDING_MODE[])
return output
end

function mutable_operate_to!(output::BigFloat, op::Union{typeof(*), typeof(+)},
function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(-), typeof(*)},
a::BigFloat, b::BigFloat, c::Vararg{BigFloat, N}) where N
mutable_operate_to!(output, op, a, b)
return mutable_operate!(op, output, c...)
Expand All @@ -45,34 +52,34 @@ function mutable_operate!(op::Function, x::BigFloat, args::Vararg{Any, N}) where
mutable_operate_to!(x, op, x, args...)
end

# add_mul
# add_mul and sub_mul
# Buffer to hold the product
buffer_for(::typeof(add_mul), args::Vararg{Type{BigFloat}, N}) where {N} = BigFloat()
function mutable_operate_to!(output::BigFloat, ::typeof(add_mul), x::BigFloat, y::BigFloat, z::BigFloat, args::Vararg{BigFloat, N}) where N
return mutable_buffered_operate_to!(BigFloat(), output, add_mul, x, y, z, args...)
buffer_for(::AddSubMul, args::Vararg{Type{BigFloat}, N}) where {N} = BigFloat()
function mutable_operate_to!(output::BigFloat, op::AddSubMul, x::BigFloat, y::BigFloat, z::BigFloat, args::Vararg{BigFloat, N}) where N
return mutable_buffered_operate_to!(BigFloat(), output, op, x, y, z, args...)
end

function mutable_buffered_operate_to!(buffer::BigFloat, output::BigFloat, ::typeof(add_mul),
function mutable_buffered_operate_to!(buffer::BigFloat, output::BigFloat, op::AddSubMul,
a::BigFloat, x::BigFloat, y::BigFloat, args::Vararg{BigFloat, N}) where N
mutable_operate_to!(buffer, *, x, y, args...)
return mutable_operate_to!(output, +, a, buffer)
return mutable_operate_to!(output, add_sub_op(op), a, buffer)
end
function mutable_buffered_operate!(buffer::BigFloat, op::typeof(add_mul), x::BigFloat, args::Vararg{Any, N}) where N
function mutable_buffered_operate!(buffer::BigFloat, op::AddSubMul, x::BigFloat, args::Vararg{Any, N}) where N
return mutable_buffered_operate_to!(buffer, x, op, x, args...)
end

scaling_to_bigfloat(x::BigFloat) = x
scaling_to_bigfloat(x::Number) = convert(BigFloat, x)
scaling_to_bigfloat(J::LinearAlgebra.UniformScaling) = scaling_to_bigfloat(J.λ)
function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(*)}, args::Vararg{Scaling, N}) where N
function mutable_operate_to!(output::BigFloat, op::Union{typeof(+), typeof(-), typeof(*)}, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(output, op, scaling_to_bigfloat.(args)...)
end
function mutable_operate_to!(output::BigFloat, op::typeof(add_mul), x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N
function mutable_operate_to!(output::BigFloat, op::AddSubMul, x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(
output, op, scaling_to_bigfloat(x), scaling_to_bigfloat(y),
scaling_to_bigfloat(z), scaling_to_bigfloat.(args)...)
end
# Called for instance if `args` is `(v', v)` for a vector `v`.
function mutable_operate_to!(output::BigFloat, op::typeof(add_mul), x, y, z, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, +, x, *(y, z, args...))
function mutable_operate_to!(output::BigFloat, op::AddSubMul, x, y, z, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, add_sub_op(op), x, *(y, z, args...))
end
30 changes: 18 additions & 12 deletions src/bigint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@ end
# return mutable_operate_to!(output, op, a, b.λ)
#end

# -
promote_operation(::typeof(-), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(-), a::BigInt, b::BigInt)
return Base.GMP.MPZ.sub!(output, a, b)
end

# *
promote_operation(::typeof(*), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(*), a::BigInt, b::BigInt)
return Base.GMP.MPZ.mul!(output, a, b)
end

function mutable_operate_to!(output::BigInt, op::Union{typeof(*), typeof(+)},
function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(-), typeof(*)},
a::BigInt, b::BigInt, c::Vararg{BigInt, N}) where N
mutable_operate_to!(output, op, a, b)
return mutable_operate!(op, output, c...)
Expand All @@ -33,34 +39,34 @@ function mutable_operate!(op::Function, x::BigInt, args::Vararg{Any, N}) where N
mutable_operate_to!(x, op, x, args...)
end

# add_mul
# add_mul and sub_mul
# Buffer to hold the product
buffer_for(::typeof(add_mul), args::Vararg{Type{BigInt}, N}) where {N} = BigInt()
function mutable_operate_to!(output::BigInt, ::typeof(add_mul), x::BigInt, y::BigInt, z::BigInt, args::Vararg{BigInt, N}) where N
return mutable_buffered_operate_to!(BigInt(), output, add_mul, x, y, z, args...)
buffer_for(::AddSubMul, args::Vararg{Type{BigInt}, N}) where {N} = BigInt()
function mutable_operate_to!(output::BigInt, op::AddSubMul, x::BigInt, y::BigInt, z::BigInt, args::Vararg{BigInt, N}) where N
return mutable_buffered_operate_to!(BigInt(), output, op, x, y, z, args...)
end

function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, ::typeof(add_mul),
function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, op::AddSubMul,
a::BigInt, x::BigInt, y::BigInt, args::Vararg{BigInt, N}) where N
mutable_operate_to!(buffer, *, x, y, args...)
return mutable_operate_to!(output, +, a, buffer)
return mutable_operate_to!(output, add_sub_op(op), a, buffer)
end
function mutable_buffered_operate!(buffer::BigInt, op::typeof(add_mul), x::BigInt, args::Vararg{Any, N}) where N
function mutable_buffered_operate!(buffer::BigInt, op::AddSubMul, x::BigInt, args::Vararg{Any, N}) where N
return mutable_buffered_operate_to!(buffer, x, op, x, args...)
end

scaling_to_bigint(x::BigInt) = x
scaling_to_bigint(x::Number) = convert(BigInt, x)
scaling_to_bigint(J::LinearAlgebra.UniformScaling) = scaling_to_bigint(J.λ)
function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(*)}, args::Vararg{Scaling, N}) where N
function mutable_operate_to!(output::BigInt, op::Union{typeof(+), typeof(-), typeof(*)}, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(output, op, scaling_to_bigint.(args)...)
end
function mutable_operate_to!(output::BigInt, op::typeof(add_mul), x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N
function mutable_operate_to!(output::BigInt, op::AddSubMul, x::Scaling, y::Scaling, z::Scaling, args::Vararg{Scaling, N}) where N
return mutable_operate_to!(
output, op, scaling_to_bigint(x), scaling_to_bigint(y),
scaling_to_bigint(z), scaling_to_bigint.(args)...)
end
# Called for instance if `args` is `(v', v)` for a vector `v`.
function mutable_operate_to!(output::BigInt, op::typeof(add_mul), x, y, z, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, +, x, *(y, z, args...))
function mutable_operate_to!(output::BigInt, op::AddSubMul, x, y, z, args::Vararg{Any, N}) where N
return mutable_operate_to!(output, add_sub_op(op), x, *(y, z, args...))
end
14 changes: 7 additions & 7 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ function promote_operation(::typeof(*), ::Type{S}, ::Type{T}, ::Type{U}, args::V
end

# Helpful error for common mistake
function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, A::Type{<:Array}, α::Type{<:Number})
function promote_operation(op::Union{typeof(+), typeof(-), AddSubMul}, A::Type{<:Array}, α::Type{<:Number})
error("Operation `$op` between `$A` and `` is not allowed. You should use broadcast.")
end
function promote_operation(op::Union{typeof(+), typeof(-), typeof(add_mul)}, α::Type{<:Number}, A::Type{<:Array})
function promote_operation(op::Union{typeof(+), typeof(-), AddSubMul}, α::Type{<:Number}, A::Type{<:Array})
error("Operation `$op` between `` and `$A` is not allowed. You should use broadcast.")
end

Expand Down Expand Up @@ -85,7 +85,7 @@ function operate end
# API without altering `x` and `y`. If it is not the case, implement a
# custom `operate` method.
operate(::typeof(-), x) = -x
operate(op::Union{typeof(+), typeof(-), typeof(*), typeof(add_mul)}, x, y, args::Vararg{Any, N}) where {N} = op(x, y, args...)
operate(op::Union{typeof(+), typeof(-), typeof(*), AddSubMul}, x, y, args::Vararg{Any, N}) where {N} = op(x, y, args...)

operate(::Union{typeof(+), typeof(*)}, x) = copy_if_mutable(x)

Expand Down Expand Up @@ -169,8 +169,8 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...
throw(ArgumentError("Cannot call `mutable_operate_to!(::$(typeof(output)), $op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(output))` cannot be modifed to equal the result of the operation. Use `operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified."))
end

function mutable_operate_to_fallback(::IsMutable, output, op::typeof(add_mul), x, y)
return mutable_operate_to!(output, +, x, y)
function mutable_operate_to_fallback(::IsMutable, output, op::AddSubMul, x, y)
return mutable_operate_to!(output, add_sub_op(op), x, y)
end
function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...)
error("`mutable_operate_to!(::$(typeof(output)), $op, ::", join(typeof.(args), ", ::"),
Expand Down Expand Up @@ -201,8 +201,8 @@ function mutable_operate_fallback(::NotMutable, op::Function, args...)
throw(ArgumentError("Cannot call `mutable_operate!($op, ::$(join(typeof.(args), ", ::")))` as objects of type `$(typeof(args[1]))` cannot be modifed to equal the result of the operation. Use `operate!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified."))
end

function mutable_operate_fallback(::IsMutable, op::typeof(add_mul), x, y)
return mutable_operate!(+, x, y)
function mutable_operate_fallback(::IsMutable, op::AddSubMul, x, y)
return mutable_operate!(add_sub_op(op), x, y)
end
function mutable_operate_fallback(::IsMutable, op::Function, args...)
error("`mutable_operate!($op, ::", join(typeof.(args), ", ::"),
Expand Down

0 comments on commit 71a521d

Please sign in to comment.