Skip to content

Commit

Permalink
Merge pull request #14 from JuliaOpt/bl/misc
Browse files Browse the repository at this point in the history
Miscellaneous improvements/fixes
  • Loading branch information
blegat committed Nov 16, 2019
2 parents 964d29d + 89df88e commit c88b584
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 32 deletions.
6 changes: 6 additions & 0 deletions src/MutableArithmetics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@

module MutableArithmetics

# Performance note:
# We use `Vararg` instead of splatting `...` as using `where N` forces Julia to
# specialize in the number of arguments `N`. Otherwise, we get allocations and
# slowdown because it compiles something that works for any `N`. See
# https://github.com/JuliaLang/julia/issues/32761 for details.

include("interface.jl")
include("shortcuts.jl")

Expand Down
14 changes: 7 additions & 7 deletions src/bigint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@ promote_operation(::typeof(one), ::Type{BigInt}) = BigInt
mutable_operate!(::typeof(one), x::BigInt) = Base.GMP.MPZ.set_si!(x, 1)

# +
promote_operation(::typeof(+), ::Type{BigInt}...) = BigInt
promote_operation(::typeof(+), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(+), a::BigInt, b::BigInt)
return Base.GMP.MPZ.add!(output, a, b)
end

# *
promote_operation(::typeof(*), ::Type{BigInt}...) = BigInt
promote_operation(::typeof(*), ::Vararg{Type{BigInt}, N}) where {N} = BigInt
function mutable_operate_to!(output::BigInt, ::typeof(*), a::BigInt, b::BigInt)
return Base.GMP.MPZ.mul!(output, a, b)
end

# add_mul
function mutable_operate_to!(output::BigInt, ::typeof(add_mul), args::BigInt...)
function mutable_operate_to!(output::BigInt, ::typeof(add_mul), args::Vararg{BigInt, N}) where N
return mutable_buffered_operate_to!(BigInt(), output, add_mul, args...)
end
# We use `Vararg` instead of splatting `...` as using `where N` forces Julia to
# specialize in the number of arguments `N`. Otherwise, we get allocations and
# slowdown because it compiles something that works for any `N`. See
# https://github.com/JuliaLang/julia/issues/32761 for details.
function mutable_buffered_operate_to!(buffer::BigInt, output::BigInt, ::typeof(add_mul),
a::BigInt, args::Vararg{BigInt, N}) where N
mutable_operate_to!(buffer, *, args...)
return mutable_operate_to!(output, +, a, buffer)
end

function mutable_operate_to!(output::BigInt, op::Function, a::Integer, b::Integer)
return mutable_operate_to!(output, op, convert(BigInt, a), convert(BigInt, b))
end
31 changes: 17 additions & 14 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Returns the type returned to the call `operate(op, args...)` where the types of
the arguments `args` are `ArgsTypes`.
"""
function promote_operation end
function promote_operation(op::Function, args::Vararg{Type, N}) where N
return typeof(op(zero.(args)...))
end

# Define Traits
abstract type MutableTrait end
Expand All @@ -36,7 +39,7 @@ function mutable_operate_to_fallback(::NotMutable, output, op::Function, args...
throw(ArgumentError("Cannot call `mutable_operate_to!($output, $op, $(args...))` as `$output` cannot be modifed to equal the result of the operation. Use `operate!` or `operate_to!` instead which returns the value of the result (possibly modifying the first argument) to write generic code that also works when the type cannot be modified."))
end

function mutable_operate_to_fallback(::IsMutable, op::Function, args...)
function mutable_operate_to_fallback(::IsMutable, output, op::Function, args...)
error("`mutable_operate_to!($op, $(args...))` is not implemented yet.")
end

Expand All @@ -46,8 +49,8 @@ end
Modify the value of `output` to be equal to the value of `op(args...)`. Can
only be called if `mutability(output, op, args...)` returns `true`.
"""
function mutable_operate_to!(output, op::Function, args...)
mutable_operate_fallback(mutability(output, op, args...), output, op, args...)
function mutable_operate_to!(output, op::Function, args::Vararg{Any, N}) where N
mutable_operate_to_fallback(mutability(output, op, args...), output, op, args...)
end

"""
Expand All @@ -56,7 +59,7 @@ end
Modify the value of `args[1]` to be equal to the value of `op(args...)`. Can
only be called if `mutability(args[1], op, args...)` returns `true`.
"""
function mutable_operate!(op::Function, args...)
function mutable_operate!(op::Function, args::Vararg{Any, N}) where N
mutable_operate_to!(args[1], op, args...)
end

Expand All @@ -76,7 +79,7 @@ Modify the value of `args[1]` to be equal to the value of `op(args...)`,
possibly modifying `buffer`. Can only be called if
`mutability(args[1], op, args...)` returns `true`.
"""
function mutable_buffered_operate!(buffer, op::Function, args...)
function mutable_buffered_operate!(buffer, op::Function, args::Vararg{Any, N}) where N
mutable_buffered_operate_to!(buffer, args[1], op, args...)
end

Expand All @@ -101,14 +104,14 @@ end
Returns the value of `op(args...)`, possibly modifying `args[1]`.
"""
function operate!(op::Function, args...)
function operate!(op::Function, args::Vararg{Any, N}) where N
return operate_fallback!(mutability(args[1], op, args...), op, args...)
end

function operate_fallback!(::NotMutable, op::Function, args...)
function operate_fallback!(::NotMutable, op::Function, args::Vararg{Any, N}) where N
return op(args...)
end
function operate_fallback!(::IsMutable, op::Function, args...)
function operate_fallback!(::IsMutable, op::Function, args::Vararg{Any, N}) where N
return mutable_operate!(op, args...)
end

Expand All @@ -117,15 +120,15 @@ end
Returns the value of `op(args...)`, possibly modifying `buffer` and `output`.
"""
function buffered_operate_to!(buffer, output, op::Function, args...)
function buffered_operate_to!(buffer, output, op::Function, args::Vararg{Any, N}) where N
return buffered_operate_to_fallback!(mutability(output, op, args...),
buffer, output, op, args...)
end

function buffered_operate_to_fallback!(::NotMutable, buffer, output, op::Function, args...)
function buffered_operate_to_fallback!(::NotMutable, buffer, output, op::Function, args::Vararg{Any, N}) where N
return op(args...)
end
function buffered_operate_to_fallback!(::IsMutable, buffer, output, op::Function, args...)
function buffered_operate_to_fallback!(::IsMutable, buffer, output, op::Function, args::Vararg{Any, N}) where N
return mutable_buffered_operate_to!(buffer, output, op, args...)
end

Expand All @@ -134,14 +137,14 @@ end
Returns the value of `op(args...)`, possibly modifying `buffer`.
"""
function buffered_operate!(buffer, op::Function, args...)
function buffered_operate!(buffer, op::Function, args::Vararg{Any, N}) where N
return buffered_operate_fallback!(mutability(args[1], op, args...),
buffer, op, args...)
end

function buffered_operate_fallback!(::NotMutable, buffer, op::Function, args...)
function buffered_operate_fallback!(::NotMutable, buffer, op::Function, args::Vararg{Any, N}) where N
return op(args...)
end
function buffered_operate_fallback!(::IsMutable, buffer, op::Function, args...)
function buffered_operate_fallback!(::IsMutable, buffer, op::Function, args::Vararg{Any, N}) where N
return mutable_buffered_operate!(buffer, op, args...)
end
14 changes: 7 additions & 7 deletions src/shortcuts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
Return the sum of `b` and `c`, possibly modifying `a`.
"""
add_to!(output, args...) = operate_to!(output, +, args...)
add_to!(output, args::Vararg{Any, N}) where {N} = operate_to!(output, +, args...)

"""
add!(a, b, ...)
Return the sum of `a`, `b`, ..., possibly modifying `a`.
"""
add!(args...) = operate!(+, args...)
add!(args::Vararg{Any, N}) where {N} = operate!(+, args...)

"""
mul_to!(a, b, c)
Expand All @@ -24,7 +24,7 @@ mul_to!(output, args::Vararg{Any, N}) where {N} = operate_to!(output, *, args...
Return the product of `a`, `b`, ..., possibly modifying `a`.
"""
mul!(args...) = operate!(*, args...)
mul!(args::Vararg{Any, N}) where {N} = operate!(*, args...)

"""
add_mul(a, args...)
Expand All @@ -43,21 +43,21 @@ end
Return `add_mul(args...)`, possibly modifying `output`.
"""
add_mul_to!(output, args...) = operate_to!(output, add_mul, args...)
add_mul_to!(output, args::Vararg{Any, N}) where {N} = operate_to!(output, add_mul, args...)

"""
add_mul!(args...)
Return `add_mul(args...)`, possibly modifying `args[1]`.
"""
add_mul!(args...) = operate!(add_mul, args...)
add_mul!(args::Vararg{Any, N}) where {N} = operate!(add_mul, args...)

"""
add_mul_buf_to!(buffer, output, args...)
Return `add_mul(args...)`, possibly modifying `output` and `buffer`.
"""
function add_mul_buf_to!(buffer, output, args...)
function add_mul_buf_to!(buffer, output, args::Vararg{Any, N}) where {N}
buffered_operate_to!(buffer, output, add_mul, args...)
end

Expand All @@ -66,7 +66,7 @@ end
Return `add_mul(args...)`, possibly modifying `args[1]` and `buffer`.
"""
function add_mul_buf!(buffer, args...)
function add_mul_buf!(buffer, args::Vararg{Any, N}) where {N}
buffered_operate!(buffer, add_mul, args...)
end

Expand Down
8 changes: 8 additions & 0 deletions test/bigint.jl
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
MA.Test.int_test(BigInt)

@testset "Allocation" begin
a = BigInt(2)
b = BigInt(3)
c = BigInt(4)
alloc_test(() -> MA.add!(a, b), 0)
alloc_test(() -> MA.add_to!(c, a, b), 0)
end
8 changes: 8 additions & 0 deletions test/int.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
@testset "promote_operation" begin
@test MA.promote_operation(MA.zero, Int) == Int
@test MA.promote_operation(MA.one, Int) == Int
@test MA.promote_operation(+, Int, Int) == Int
@test MA.promote_operation(-, Int, Int) == Int
@test MA.promote_operation(*, Int, Int) == Int
@test MA.promote_operation(MA.add_mul, Int, Int, Int) == Int
end
@testset "add_to! / add!" begin
@test MA.mutability(Int, MA.add_to!, Int, Int) isa MA.NotMutable
@test MA.mutability(Int, MA.add!, Int) isa MA.NotMutable
Expand Down
4 changes: 0 additions & 4 deletions test/matmul.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,3 @@
function alloc_test(f, n)
f() # compile
@test n == @allocated f()
end
@testset "Matrix multiplication" begin
@testset "matrix-vector product" begin
A = BigInt[1 1 1; 1 1 1; 1 1 1]
Expand Down
5 changes: 5 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ using Test
import MutableArithmetics
const MA = MutableArithmetics

function alloc_test(f, n)
f() # compile
@test n == @allocated f()
end

@testset "Int" begin
include("int.jl")
end
Expand Down

0 comments on commit c88b584

Please sign in to comment.