Skip to content

Commit

Permalink
Apply vararg fastmath methods to all types
Browse files Browse the repository at this point in the history
  • Loading branch information
Zentrik committed May 15, 2024
1 parent 5f68f1a commit 769e583
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
36 changes: 27 additions & 9 deletions base/fastmath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ export @fastmath
import Core.Intrinsics: sqrt_llvm_fast, neg_float_fast,
add_float_fast, sub_float_fast, mul_float_fast, div_float_fast,
eq_float_fast, ne_float_fast, lt_float_fast, le_float_fast
import Base: afoldl

const fast_op =
Dict(# basic arithmetic
Expand Down Expand Up @@ -168,11 +169,6 @@ sub_fast(x::T, y::T) where {T<:FloatTypes} = sub_float_fast(x, y)
mul_fast(x::T, y::T) where {T<:FloatTypes} = mul_float_fast(x, y)
div_fast(x::T, y::T) where {T<:FloatTypes} = div_float_fast(x, y)

add_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
add_fast(add_fast(x, y), zs...)
mul_fast(x::T, y::T, zs::T...) where {T<:FloatTypes} =
mul_fast(mul_fast(x, y), zs...)

@fastmath begin
cmp_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(x==y, 0, ifelse(x<y, -1, +1))
log_fast(b::T, x::T) where {T<:FloatTypes} = log_fast(x)/log_fast(b)
Expand Down Expand Up @@ -245,9 +241,6 @@ ComplexTypes = Union{ComplexF32, ComplexF64}
max_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, y, x)
min_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, x, y)
minmax_fast(x::T, y::T) where {T<:FloatTypes} = ifelse(y > x, (x,y), (y,x))

max_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = max_fast(max_fast(x, y), z...)
min_fast(x::T, y::T, z::T...) where {T<:FloatTypes} = min_fast(min_fast(x, y), z...)
end

# fall-back implementations and type promotion
Expand All @@ -260,7 +253,7 @@ for op in (:abs, :abs2, :conj, :inv, :sign)
end
end

for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
for op in (:-, :/, :(==), :!=, :<, :<=, :cmp, :rem, :minmax)
op_fast = fast_op[op]
@eval begin
# fall-back implementation for non-numeric types
Expand All @@ -273,6 +266,31 @@ for op in (:+, :-, :*, :/, :(==), :!=, :<, :<=, :cmp, :rem, :min, :max, :minmax)
end
end

for op in (:+, :*, :min, :max)
op_fast = fast_op[op]
@eval begin
$op_fast(x) = $op(x)
# fall-back implementation for non-numeric types
$op_fast(x, y) = $op(x, y)
# type promotion
$op_fast(x::Number, y::Number) =
$op_fast(promote(x,y)...)
# fall-back implementation that applies after promotion
$op_fast(x::T,y::T) where {T<:Number} = $op(x,y)
# note: these definitions must not cause a dispatch loop when +(a,b) is
# not defined, and must only try to call 2-argument definitions, so
# that defining +(a,b) is sufficient for full functionality.
($op_fast)(a, b, c, xs...) = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
# a further concern is that it's easy for a type like (Int,Int...)
# to match many definitions, so we need to keep the number of
# definitions down to avoid losing type information.
# type promotion
$op_fast(a::Number, b::Number, c::Number, xs::Number...) =
$op_fast(promote(x,y,c,xs...)...)
# fall-back implementation that applies after promotion
$op_fast(a::T, b::T, c::T, xs::T...) where {T<:Number} = (@inline; afoldl($op_fast, ($op_fast)(($op_fast)(a,b),c), xs...))
end
end

# Math functions
exp2_fast(x::Union{Float32,Float64}) = Base.Math.exp2_fast(x)
Expand Down
23 changes: 23 additions & 0 deletions test/fastmath.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

using InteractiveUtils: code_llvm
# fast math

@testset "check fast present in LLVM" begin
for T in (Float16, Float32, Float64, ComplexF32, ComplexF64)
f(x) = @fastmath x + x + x
llvm = sprint(code_llvm, f, (T,))
@test occursin("fmul fast", llvm)

g(x) = @fastmath x * x * x
llvm = sprint(code_llvm, g, (T,))
@test occursin("fmul fast", llvm)
end

for T in (Float16, Float32, Float64)
f(x, y, z) = @fastmath min(x, y, z)
llvm = sprint(code_llvm, f, (T,T,T))
@test occursin("fast", llvm)

g(x, y, z) = @fastmath max(x, y, z)
llvm = sprint(code_llvm, g, (T,T,T))
@test occursin("fcmp fast", llvm)
end
end

@testset "check expansions" begin
@test macroexpand(Main, :(@fastmath 1+2)) == :(Base.FastMath.add_fast(1,2))
@test macroexpand(Main, :(@fastmath +)) == :(Base.FastMath.add_fast)
Expand Down

0 comments on commit 769e583

Please sign in to comment.