Skip to content

Commit

Permalink
Make tests more modular (#27)
Browse files Browse the repository at this point in the history
* Make tests more modular

* Remove special-casing for FDMBackend2

* Convert to tuple if necessary

* Hard-code constant values

* Allow provision of an rng

* Set seed in each test file

* Remove const

* Revert "Hard-code constant values"

This reverts commit dd67276.

* Revert "Set seed in each test file"

This reverts commit 55d91ee.

* Include test_utils just once

* Test also backend with user-specified chunk size
  • Loading branch information
sethaxen committed Jan 22, 2022
1 parent 9b9479c commit 979fa04
Show file tree
Hide file tree
Showing 4 changed files with 692 additions and 672 deletions.
251 changes: 251 additions & 0 deletions test/defaults.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
using AbstractDifferentiation
using Test
using FiniteDifferences, ForwardDiff, Zygote

const FDM = FiniteDifferences

## FiniteDifferences
struct FDMBackend1{A} <: AD.AbstractFiniteDifference
alg::A
end
FDMBackend1() = FDMBackend1(central_fdm(5, 1))
const fdm_backend1 = FDMBackend1()
# Minimal interface
AD.@primitive function jacobian(ab::FDMBackend1, f, xs...)
return FDM.jacobian(ab.alg, f, xs...)
end

struct FDMBackend2{A} <: AD.AbstractFiniteDifference
alg::A
end
FDMBackend2() = FDMBackend2(central_fdm(5, 1))
const fdm_backend2 = FDMBackend2()
AD.@primitive function pushforward_function(ab::FDMBackend2, f, xs...)
return function (vs)
ws = FDM.jvp(ab.alg, f, tuple.(xs, vs)...)
return length(xs) == 1 ? (ws,) : ws
end
end

struct FDMBackend3{A} <: AD.AbstractFiniteDifference
alg::A
end
FDMBackend3() = FDMBackend3(central_fdm(5, 1))
const fdm_backend3 = FDMBackend3()
AD.@primitive function pullback_function(ab::FDMBackend3, f, xs...)
return function (vs)
# Supports only single output
if vs isa AbstractVector
return FDM.j′vp(ab.alg, f, vs, xs...)
else
@assert length(vs) == 1
return FDM.j′vp(ab.alg, f, vs[1], xs...)
end
end
end
##


## ForwardDiff
struct ForwardDiffBackend1 <: AD.AbstractForwardMode end
const forwarddiff_backend1 = ForwardDiffBackend1()
AD.@primitive function jacobian(ab::ForwardDiffBackend1, f, xs)
if xs isa Number
return (ForwardDiff.derivative(f, xs),)
elseif xs isa AbstractArray
out = f(xs)
if out isa Number
return (adjoint(ForwardDiff.gradient(f, xs)),)
else
return (ForwardDiff.jacobian(f, xs),)
end
elseif xs isa Tuple
error(typeof(xs))
else
error(typeof(xs))
end
end
AD.primal_value(::ForwardDiffBackend1, ::Any, f, xs) = ForwardDiff.value.(f(xs...))

struct ForwardDiffBackend2 <: AD.AbstractForwardMode end
const forwarddiff_backend2 = ForwardDiffBackend2()
AD.@primitive function pushforward_function(ab::ForwardDiffBackend2, f, xs...)
# jvp = f'(x)*v, i.e., differentiate f(x + h*v) wrt h at 0
return function (vs)
if xs isa Tuple
@assert length(xs) <= 2
if length(xs) == 1
(ForwardDiff.derivative(h->f(xs[1]+h*vs[1]),0),)
else
ForwardDiff.derivative(h->f(xs[1]+h*vs[1], xs[2]+h*vs[2]),0)
end
else
ForwardDiff.derivative(h->f(xs+h*vs),0)
end
end
end
AD.primal_value(::ForwardDiffBackend2, ::Any, f, xs) = ForwardDiff.value.(f(xs...))
##

## Zygote
struct ZygoteBackend1 <: AD.AbstractReverseMode end
const zygote_backend1 = ZygoteBackend1()
AD.@primitive function pullback_function(ab::ZygoteBackend1, f, xs...)
return function (vs)
# Supports only single output
_, back = Zygote.pullback(f, xs...)
if vs isa AbstractVector
back(vs)
else
@assert length(vs) == 1
back(vs[1])
end
end
end

@testset "defaults" begin
@testset "Utils" begin
test_higher_order_backend(fdm_backend1, fdm_backend2, fdm_backend3, zygote_backend1, forwarddiff_backend2)
end
@testset "FiniteDifferences" begin
@testset "Derivative" begin
test_derivatives(fdm_backend1)
test_derivatives(fdm_backend2)
test_derivatives(fdm_backend3)
end
@testset "Gradient" begin
test_gradients(fdm_backend1)
test_gradients(fdm_backend2)
test_gradients(fdm_backend3)
end
@testset "Jacobian" begin
test_jacobians(fdm_backend1)
test_jacobians(fdm_backend2)
test_jacobians(fdm_backend3)
end
@testset "Hessian" begin
test_hessians(fdm_backend1)
test_hessians(fdm_backend2)
test_hessians(fdm_backend3)
end
@testset "jvp" begin
test_jvp(fdm_backend1)
test_jvp(fdm_backend2; vaugmented=true)
test_jvp(fdm_backend3)
end
@testset "j′vp" begin
test_j′vp(fdm_backend1)
test_j′vp(fdm_backend2)
test_j′vp(fdm_backend3)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(fdm_backend1)
test_lazy_derivatives(fdm_backend2)
test_lazy_derivatives(fdm_backend3)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(fdm_backend1)
test_lazy_gradients(fdm_backend2)
test_lazy_gradients(fdm_backend3)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(fdm_backend1)
test_lazy_jacobians(fdm_backend2; vaugmented=true)
test_lazy_jacobians(fdm_backend3)
end
@testset "Lazy Hessian" begin
test_lazy_hessians(fdm_backend1)
test_lazy_hessians(fdm_backend2)
test_lazy_hessians(fdm_backend3)
end
end
@testset "ForwardDiff" begin
@testset "Derivative" begin
test_derivatives(forwarddiff_backend1; multiple_inputs=false)
test_derivatives(forwarddiff_backend2)
end
@testset "Gradient" begin
test_gradients(forwarddiff_backend1; multiple_inputs=false)
test_gradients(forwarddiff_backend2)
end
@testset "Jacobian" begin
test_jacobians(forwarddiff_backend1; multiple_inputs=false)
test_jacobians(forwarddiff_backend2)
end
@testset "Hessian" begin
test_hessians(forwarddiff_backend1; multiple_inputs=false)
test_hessians(forwarddiff_backend2)
end
@testset "jvp" begin
test_jvp(forwarddiff_backend1; multiple_inputs=false)
test_jvp(forwarddiff_backend2; vaugmented=true)
end
@testset "j′vp" begin
test_j′vp(forwarddiff_backend1; multiple_inputs=false)
test_j′vp(forwarddiff_backend2)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(forwarddiff_backend1; multiple_inputs=false)
test_lazy_derivatives(forwarddiff_backend2)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(forwarddiff_backend1; multiple_inputs=false)
test_lazy_gradients(forwarddiff_backend2)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(forwarddiff_backend1; multiple_inputs=false)
test_lazy_jacobians(forwarddiff_backend2; vaugmented=true)
end
@testset "Lazy Hessian" begin
test_lazy_hessians(forwarddiff_backend1; multiple_inputs=false)
test_lazy_hessians(forwarddiff_backend2)
end
end
@testset "Zygote" begin
@testset "Derivative" begin
test_derivatives(zygote_backend1)
end
@testset "Gradient" begin
test_gradients(zygote_backend1)
end
@testset "Jacobian" begin
test_jacobians(zygote_backend1)
end
@testset "Hessian" begin
# Zygote over Zygote problems
backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1))
test_hessians(backends)
if VERSION >= v"1.3"
backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1))
test_hessians(backends)
end
# fails:
# backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend2))
# test_hessians(backends)
end
@testset "jvp" begin
test_jvp(zygote_backend1)
end
@testset "j′vp" begin
test_j′vp(zygote_backend1)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(zygote_backend1)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(zygote_backend1)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(zygote_backend1)
end
@testset "Lazy Hessian" begin
# Zygote over Zygote problems
backends = AD.HigherOrderBackend((forwarddiff_backend2,zygote_backend1))
test_lazy_hessians(backends)
if VERSION >= v"1.3"
backends = AD.HigherOrderBackend((zygote_backend1,forwarddiff_backend1))
test_lazy_hessians(backends)
end
end
end
end
42 changes: 42 additions & 0 deletions test/forwarddiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
using AbstractDifferentiation
using Test
using ForwardDiff

@testset "ForwardDiffBackend" begin
backends = [
@inferred(AD.ForwardDiffBackend())
@inferred(AD.ForwardDiffBackend(; chunksize=Val{1}()))
]
@testset for backend in backends
@testset "Derivative" begin
test_derivatives(backend)
end
@testset "Gradient" begin
test_gradients(backend)
end
@testset "Jacobian" begin
test_jacobians(backend)
end
@testset "Hessian" begin
test_hessians(backend)
end
@testset "jvp" begin
test_jvp(backend; vaugmented=true)
end
@testset "j′vp" begin
test_j′vp(backend)
end
@testset "Lazy Derivative" begin
test_lazy_derivatives(backend)
end
@testset "Lazy Gradient" begin
test_lazy_gradients(backend)
end
@testset "Lazy Jacobian" begin
test_lazy_jacobians(backend; vaugmented=true)
end
@testset "Lazy Hessian" begin
test_lazy_hessians(backend)
end
end
end
Loading

0 comments on commit 979fa04

Please sign in to comment.