Skip to content

Commit

Permalink
Test FastMath
Browse files Browse the repository at this point in the history
  • Loading branch information
oxinabox committed May 7, 2020
1 parent e35ef39 commit 5273b41
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 60 deletions.
63 changes: 3 additions & 60 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
@testset "base" begin
@testset "Trig" begin
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
test_scalar(sin, x)
test_scalar(cos, x)
test_scalar(tan, x)
test_scalar(sec, x)
test_scalar(csc, x)
test_scalar(cot, x)
test_scalar(sinpi, x)
test_scalar(cospi, x)
end
@testset "Hyperbolic" for x = (Float64(π)-0.01, Complex-0.01, π/2))
test_scalar(sinh, x)
test_scalar(cosh, x)
test_scalar(tanh, x)
test_scalar(sech, x)
test_scalar(csch, x)
test_scalar(coth, x)
Expand All @@ -28,9 +22,6 @@
test_scalar(cotd, x)
end
@testset "Inverses" for x = (0.5, Complex(0.5, 0.25))
test_scalar(asin, x)
test_scalar(acos, x)
test_scalar(atan, x)
test_scalar(asec, 1/x)
test_scalar(acsc, 1/x)
test_scalar(acot, 1/x)
Expand All @@ -52,55 +43,25 @@
test_scalar(acscd, 1/x)
test_scalar(acotd, 1/x)
end
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "math" begin
@testset "Angles" begin
for x in (-0.1, 6.4)
test_scalar(deg2rad, x)
test_scalar(rad2deg, x)

test_scalar(inv, x)

test_scalar(exp, x)
test_scalar(exp2, x)
test_scalar(exp10, x)

test_scalar(cbrt, x)

if x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
test_scalar(log10, x)
test_scalar(log1p, x)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4)
test_scalar(real, x)
test_scalar(imag, x)

test_scalar(abs, x)
test_scalar(hypot, x)

test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)
test_scalar(adjoint, x)
end
end


@testset "*(x, y) (scalar)" begin
# This is pretty important so testing it fairly heavily
test_points = (0.0, -2.1, 3.2, 3.7+2.12im, 14.2-7.1im)
Expand Down Expand Up @@ -132,7 +93,7 @@
@test extern(dy) == extern(zeros(2, 5) .+ dy)
end

@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
@testset "binary function ($f)" for f in (mod, \)
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()
Expand Down Expand Up @@ -166,24 +127,6 @@
test_scalar(zero, x)
end

@testset "sign" begin
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
test_scalar(sign, x)
end

@testset "Zero over the point discontinuity" begin
# Can't do finite differencing because we are lying
# following the subgradient convention.

_, pb = rrule(sign, 0.0)
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, ẏ = frule((Zero(), 10.5), sign, 0.0)
@test extern(ẏ) == 0
end
end

@testset "trinary ($f)" for f in (muladd, fma)
x, Δx, x̄ = 10randn(3)
y, Δy, ȳ = randn(3)
Expand Down
107 changes: 107 additions & 0 deletions test/rulesets/Base/fastmath_able.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Add tests to the quote for functions with FastMath varients.
const FASTABLE_AST = quote
@testset "Trig" begin
@testset "Basics" for x = (Float64(π)-0.01, Complex(π, π/2))
test_scalar(sin, x)
test_scalar(cos, x)
test_scalar(tan, x)
end
@testset "Hyperbolic" for x = (Float64(π)-0.01, Complex-0.01, π/2))
test_scalar(sinh, x)
test_scalar(cosh, x)
test_scalar(tanh, x)
end
@testset "Inverses" for x = (0.5, Complex(0.5, 0.25))
test_scalar(asin, x)
test_scalar(acos, x)
test_scalar(atan, x)
end
@testset "Multivariate" begin
@testset "sincos" begin
x, Δx, x̄ = randn(3)
Δz = (randn(), randn())

frule_test(sincos, (x, Δx))
rrule_test(sincos, Δz, (x, x̄))
end
end
end # Trig

@testset "exponents" begin
for x in (-0.1, 6.4)
test_scalar(inv, x)

test_scalar(exp, x)
test_scalar(exp2, x)
test_scalar(exp10, x)
test_scalar(expm1, x)

test_scalar(cbrt, x)

if x >= 0
test_scalar(sqrt, x)
test_scalar(log, x)
test_scalar(log2, x)
test_scalar(log10, x)
test_scalar(log1p, x)
end
end
end

@testset "Unary complex functions" begin
for x in (-4.1, 6.4)
test_scalar(abs, x)

test_scalar(angle, x)
test_scalar(abs2, x)
test_scalar(conj, x)

end
end

@testset "Unary functions" begin
for x in (-4.1, 6.4)
test_scalar(+, x)
test_scalar(-, x)
end
end

@testset "binary function ($f)" for f in (/, +, -, hypot, atan, rem, ^, max, min)
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end



@testset "sign" begin
@testset "at points" for x in (-1.1, -1.1, 0.5, 100)
test_scalar(sign, x)
end

@testset "Zero over the point discontinuity" begin
# Can't do finite differencing because we are lying
# following the subgradient convention.

_, pb = rrule(sign, 0.0)
_, x̄ = pb(10.5)
@test extern(x̄) == 0

_, ẏ = frule((Zero(), 10.5), sign, 0.0)
@test extern(ẏ) == 0
end
end
end

# Now we generate tests for fast and nonfast versions
@eval @testset "fastmath_able Base functions" begin
$FASTABLE_AST
end


@eval @testset "fastmath_able FastMath functions" begin
$(Base.FastMath.make_fastmath(FASTABLE_AST))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ println("Testing ChainRules.jl")
@testset "rulesets" begin
@testset "Base" begin
include(joinpath("rulesets", "Base", "base.jl"))
include(joinpath("rulesets", "Base", "fastmath_able.jl"))
include(joinpath("rulesets", "Base", "array.jl"))
include(joinpath("rulesets", "Base", "mapreduce.jl"))
end
Expand Down

0 comments on commit 5273b41

Please sign in to comment.