Skip to content

Commit

Permalink
Zygote AD failure workarounds & test cleanup (#414)
Browse files Browse the repository at this point in the history
Zygote AD failures:

* revert #409 (test_utils workaround for broken Zygote - now working again)

* disable broken Zygote AD test for ChainTransform

Improved tests:

* finer-grained testsets

* add missing test cases to test_AD

* replace test_FiniteDiff with test_AD(..., :FiniteDiff, ...)

* remove code duplication
  • Loading branch information
st-- committed Dec 18, 2021
1 parent 3c49949 commit 2d17212
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 111 deletions.
160 changes: 50 additions & 110 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ function gradient(f, ::Val{:FiniteDiff}, args)
return first(FiniteDifferences.grad(FDM, f, args))
end

function compare_gradient(f, ::Val{:FiniteDiff}, args)
@test_nowarn gradient(f, :FiniteDiff, args)
end

function compare_gradient(f, AD::Symbol, args)
grad_AD = gradient(f, AD, args)
grad_FD = gradient(f, :FiniteDiff, args)
Expand All @@ -88,7 +92,7 @@ testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B))
function test_ADs(
kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3]
)
test_fd = test_FiniteDiff(kernelfunction, args, dims)
test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims)
if !test_fd.anynonpass
for AD in ADs
test_AD(AD, kernelfunction, args, dims)
Expand All @@ -100,7 +104,7 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context())
@inferred f(args...)
@inferred Zygote._pullback(ctx, f, args...)
out, pb = Zygote._pullback(ctx, f, args...)
@test_throws ErrorException @inferred pb(out)
@inferred pb(out)
end

function test_ADs(
Expand All @@ -114,70 +118,6 @@ function test_ADs(
end
end

function test_FiniteDiff(kernelfunction, args=nothing, dims=[3, 3])
# Init arguments :
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)
@testset "FiniteDifferences" begin
if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
@test_nowarn gradient(:FiniteDiff, [d]) do x
kappa(k, exp(first(x)))
end
end
end
## Testing Kernel Functions
x = rand(rng, dims[1])
y = rand(rng, dims[1])
@test_nowarn gradient(:FiniteDiff, x) do x
k(x, y)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
kernelfunction(p)(x, y)
end
end
## Testing Kernel Matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
@test_nowarn gradient(:FiniteDiff, A) do a
testfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff, A) do a
testfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testfunction(k, A, b, dim)
end
if !(args === nothing)
@test_nowarn gradient(:FiniteDiff, args) do p
testfunction(kernelfunction(p), A, B, dim)
end
end

@test_nowarn gradient(:FiniteDiff, A) do a
testdiagfunction(k, a, dim)
end
@test_nowarn gradient(:FiniteDiff, A) do a
testdiagfunction(k, a, B, dim)
end
@test_nowarn gradient(:FiniteDiff, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
@test_nowarn gradient(:FiniteDiff, args) do p
testdiagfunction(kernelfunction(p), A, B, dim)
end
end
end
end
end

function test_FiniteDiff(k::MOKernel, dims=(in=3, out=2, obs=3))
rng = MersenneTwister(42)
@testset "FiniteDifferences" begin
Expand Down Expand Up @@ -224,68 +164,68 @@ end

function test_AD(AD::Symbol, kernelfunction, args=nothing, dims=[3, 3])
@testset "$(AD)" begin
# Test kappa function
k = if args === nothing
kernelfunction()
else
kernelfunction(args)
end
rng = MersenneTwister(42)

if k isa SimpleKernel
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
@testset "kappa function" begin
for d in log.([eps(), rand(rng)])
compare_gradient(AD, [d]) do x
kappa(k, exp(x[1]))
end
end
end
end
# Testing kernel evaluations
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
# Testing kernel matrices
A = rand(rng, dims...)
B = rand(rng, dims...)
for dim in 1:2
compare_gradient(AD, A) do a
testfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testfunction(k, a, B, dim)

@testset "kernel evaluations" begin
x = rand(rng, dims[1])
y = rand(rng, dims[1])
compare_gradient(AD, x) do x
k(x, y)
end
compare_gradient(AD, B) do b
testfunction(k, A, b, dim)
compare_gradient(AD, y) do y
k(x, y)
end
if !(args === nothing)
compare_gradient(AD, args) do p
testfunction(kernelfunction(p), A, dim)
@testset "hyperparameters" begin
compare_gradient(AD, args) do p
kernelfunction(p)(x, y)
end
end
end
end

compare_gradient(AD, A) do a
testdiagfunction(k, a, dim)
end
compare_gradient(AD, A) do a
testdiagfunction(k, a, B, dim)
end
compare_gradient(AD, B) do b
testdiagfunction(k, A, b, dim)
end
if args !== nothing
compare_gradient(AD, args) do p
testdiagfunction(kernelfunction(p), A, dim)
@testset "kernel matrices" begin
A = rand(rng, dims...)
B = rand(rng, dims...)
@testset "$(_testfn)" for _testfn in (testfunction, testdiagfunction)
for dim in 1:2
compare_gradient(AD, A) do a
_testfn(k, a, dim)
end
compare_gradient(AD, A) do a
_testfn(k, a, B, dim)
end
compare_gradient(AD, B) do b
_testfn(k, A, b, dim)
end
if !(args === nothing)
@testset "hyperparameters" begin
compare_gradient(AD, args) do p
_testfn(kernelfunction(p), A, dim)
end
compare_gradient(AD, args) do p
_testfn(kernelfunction(p), A, B, dim)
end
end
end
end
end
end
end # kernel matrices
end
end

Expand Down
4 changes: 3 additions & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
randn(rng, 4);
ADs=[:ForwardDiff, :ReverseDiff], # explicitly pass ADs to exclude :Zygote
)
@test_broken "test_AD of chain transform is currently broken in Zygote, see GitHub issue #263"
end

0 comments on commit 2d17212

Please sign in to comment.