Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 17 additions & 22 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,15 +58,12 @@ Keep this in mind when testing discontinuous rules for functions like [ReLU](htt
```jldoctest ex; output = false
using ChainRulesTestUtils

test_frule(two2three, 3.33, -7.77)
test_frule(two2three, 3.33, -7.77);

# output
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.1 | 1 1
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.2 | 1 1
Test Summary: | Pass Total
Tuple{Float64,Float64,Float64}.3 | 1 1
Test Passed
Test Summary: | Pass Total
test_frule: two2three at (3.33, -7.77) | 5 5
Test.DefaultTestSet("test_frule: two2three at (3.33, -7.77)", Any[Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.1", Any[], 1, false), Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.2", Any[], 1, false), Test.DefaultTestSet("Tuple{Float64,Float64,Float64}.3", Any[], 1, false)], 2, false)
```

### Testing the `rrule`
Expand All @@ -75,11 +72,12 @@ Test Passed
The call will test the `rrule` for function `f` at the point `x`, and similarly to `frule` some rules should be tested at multiple points in the domain.

```jldoctest ex; output = false
test_rrule(two2three, 3.33, -7.77)
test_rrule(two2three, 3.33, -7.77);

# output
Test Summary: |
Don't thunk only non_zero argument | No tests
Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)
Test Summary: | Pass Total
test_rrule: two2three at (3.33, -7.77) | 6 6
Test.DefaultTestSet("test_rrule: two2three at (3.33, -7.77)", Any[Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)], 6, false)
```

## Scalar example
Expand All @@ -104,18 +102,15 @@ with the `frule` and `rrule` defined with the help of `@scalar_rule` macro
`test_scalar` function is provided to test both the `frule` and the `rrule` with a single
call.
```jldoctest ex; output = false
test_scalar(relu, 0.5)
test_scalar(relu, -0.5)
test_scalar(relu, 0.5);
test_scalar(relu, -0.5);

# output
Test Summary: | Pass Total
relu at 0.5, with tangent 1.0 | 3 3
Test Summary: | Pass Total
relu at 0.5, with cotangent 1.0 | 4 4
Test Summary: | Pass Total
relu at -0.5, with tangent 1.0 | 3 3
Test Summary: | Pass Total
relu at -0.5, with cotangent 1.0 | 4 4
Test Summary: | Pass Total
test_scalar: relu at 0.5 | 7 7
Test Summary: | Pass Total
test_scalar: relu at -0.5 | 7 7
Test.DefaultTestSet("test_scalar: relu at -0.5", Any[Test.DefaultTestSet("with tangent 1.0", Any[Test.DefaultTestSet("test_frule: relu at (ChainRulesTestUtils.PrimalAndTangent{Float64,Float64}(-0.5, 1.0),)", Any[], 3, false)], 0, false), Test.DefaultTestSet("with cotangent 1.0", Any[Test.DefaultTestSet("test_rrule: relu at (ChainRulesTestUtils.PrimalAndTangent{Float64,Float64}(-0.5, 1.0),)", Any[Test.DefaultTestSet("Don't thunk only non_zero argument", Any[], 0, false)], 4, false)], 0, false)], 0, false)
```

## Specifying Tangents
Expand Down
204 changes: 105 additions & 99 deletions src/testers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,51 +17,53 @@ function test_scalar(f, z; rtol=1e-9, atol=1e-9, fdm=_fdm, fkwargs=NamedTuple(),
rule_test_kwargs = (; rtol=rtol, atol=atol, fdm=fdm, fkwargs=fkwargs, check_inferred=check_inferred, kwargs...)
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)

# test jacobian using forward mode
Δx = one(z)
@testset "$f at $z, with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
test_frule(f, z ⊢ Δx; rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
_, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...)
check_equal(real_tangent, embedded_tangent; isapprox_kwargs...)
@testset "test_scalar: $f at $z" begin
_ensure_not_running_on_functor(f, "test_scalar")
# z = x + im * y
# Ω = u(x, y) + im * v(x, y)
Ω = f(z; fkwargs...)

# test jacobian using forward mode
Δx = one(z)
@testset "with tangent $Δx" begin
# check ∂u_∂x and (if Ω is complex) ∂v_∂x via forward mode
test_frule(f, z ⊢ Δx; rule_test_kwargs...)
if z isa Complex
# check that same tangent is produced for tangent 1.0 and 1.0 + 0.0im
_, real_tangent = frule((Zero(), real(Δx)), f, z; fkwargs...)
_, embedded_tangent = frule((Zero(), Δx), f, z; fkwargs...)
check_equal(real_tangent, embedded_tangent; isapprox_kwargs...)
end
end
end
if z isa Complex
Δy = one(z) * im
@testset "$f at $z, with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
test_frule(f, z ⊢ Δy; rule_test_kwargs...)
if z isa Complex
Δy = one(z) * im
@testset "with tangent $Δy" begin
# check ∂u_∂y and (if Ω is complex) ∂v_∂y via forward mode
test_frule(f, z ⊢ Δy; rule_test_kwargs...)
end
end
end

# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "$f at $z, with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δu, rule_test_kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
_, back = rrule(f, z)
_, real_cotangent = back(real(Δu))
_, embedded_cotangent = back(Δu)
check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...)
# test jacobian transpose using reverse mode
Δu = one(Ω)
@testset "with cotangent $Δu" begin
# check ∂u_∂x and (if z is complex) ∂u_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δu, rule_test_kwargs...)
if Ω isa Complex
# check that same cotangent is produced for cotangent 1.0 and 1.0 + 0.0im
_, back = rrule(f, z)
_, real_cotangent = back(real(Δu))
_, embedded_cotangent = back(Δu)
check_equal(real_cotangent, embedded_cotangent; isapprox_kwargs...)
end
end
end
if Ω isa Complex
Δv = one(Ω) * im
@testset "$f at $z, with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δv, rule_test_kwargs...)
if Ω isa Complex
Δv = one(Ω) * im
@testset "with cotangent $Δv" begin
# check ∂v_∂x and (if z is complex) ∂v_∂y via reverse mode
test_rrule(f, z ⊢ Δx; output_tangent=Δv, rule_test_kwargs...)
end
end
end
end # top-level testset
end


Expand Down Expand Up @@ -96,28 +98,30 @@ function test_frule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "test_frule")
@testset "test_frule: $f at $inputs" begin
_ensure_not_running_on_functor(f, "test_frule")

xẋs = auto_primal_and_tangent.(inputs)
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
check_equal(Ω_ad, Ω; isapprox_kwargs...)

ẋs_is_ignored = ẋs .== nothing
# Correctness testing via finite differencing.
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)

acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
_check_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
xẋs = auto_primal_and_tangent.(inputs)
xs = primal.(xẋs)
ẋs = tangent.(xẋs)
if check_inferred && _is_inferrable(f, deepcopy(xs)...; deepcopy(fkwargs)...)
_test_inferred(frule, (NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
end
res = frule((NO_FIELDS, deepcopy(ẋs)...), f, deepcopy(xs)...; deepcopy(fkwargs)...)
res === nothing && throw(MethodError(frule, typeof((f, xs...))))
res isa Tuple || error("The frule should return (y, ∂y), not $res.")
Ω_ad, dΩ_ad = res
Ω = f(deepcopy(xs)...; deepcopy(fkwargs)...)
check_equal(Ω_ad, Ω; isapprox_kwargs...)

ẋs_is_ignored = ẋs .== nothing
# Correctness testing via finite differencing.
dΩ_fd = _make_jvp_call(fdm, (xs...) -> f(deepcopy(xs)...; deepcopy(fkwargs)...), Ω, xs, ẋs, ẋs_is_ignored)
check_equal(dΩ_ad, dΩ_fd; isapprox_kwargs...)

acc = output_tangent isa Auto ? rand_tangent(Ω) : output_tangent
_check_add!!_behaviour(acc, dΩ_ad; rtol=rtol, atol=atol, kwargs...)
end # top-level testset
end


Expand Down Expand Up @@ -152,47 +156,49 @@ function test_rrule(
# To simplify some of the calls we make later lets group the kwargs for reuse
isapprox_kwargs = (; rtol=rtol, atol=atol, kwargs...)

_ensure_not_running_on_functor(f, "test_rrule")
@testset "test_rrule: $f at $inputs" begin
_ensure_not_running_on_functor(f, "test_rrule")

# Check correctness of evaluation.
xx̄s = auto_primal_and_tangent.(inputs)
xs = primal.(xx̄s)
accumulated_x̄ = tangent.(xx̄s)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
_test_inferred(rrule, f, xs...; fkwargs...)
end
res = rrule(f, xs...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
y_ad, pullback = res
y = f(xs...; fkwargs...)
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

check_inferred && _test_inferred(pullback, ȳ)
∂s = pullback(ȳ)
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS # No internal fields

# Correctness testing via finite differencing.
x̄s_is_dne = accumulated_x̄ .== nothing
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
if accumulated_x̄ === nothing # then we marked this argument as not differentiable
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
else
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)

# The main test of the actual deriviative being correct:
check_equal(x̄_ad, x̄_fd; isapprox_kwargs...)
_check_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
# Check correctness of evaluation.
xx̄s = auto_primal_and_tangent.(inputs)
xs = primal.(xx̄s)
accumulated_x̄ = tangent.(xx̄s)
if check_inferred && _is_inferrable(f, xs...; fkwargs...)
_test_inferred(rrule, f, xs...; fkwargs...)
end
res = rrule(f, xs...; fkwargs...)
res === nothing && throw(MethodError(rrule, typeof((f, xs...))))
y_ad, pullback = res
y = f(xs...; fkwargs...)
check_equal(y_ad, y; isapprox_kwargs...) # make sure primal is correct

ȳ = output_tangent isa Auto ? rand_tangent(y) : output_tangent

check_inferred && _test_inferred(pullback, ȳ)
∂s = pullback(ȳ)
∂s isa Tuple || error("The pullback must return (∂self, ∂args...), not $∂s.")
∂self = ∂s[1]
x̄s_ad = ∂s[2:end]
@test ∂self === NO_FIELDS # No internal fields

# Correctness testing via finite differencing.
x̄s_is_dne = accumulated_x̄ .== nothing
x̄s_fd = _make_j′vp_call(fdm, (xs...) -> f(xs...; fkwargs...), ȳ, xs, x̄s_is_dne)
for (accumulated_x̄, x̄_ad, x̄_fd) in zip(accumulated_x̄, x̄s_ad, x̄s_fd)
if accumulated_x̄ === nothing # then we marked this argument as not differentiable
@assert x̄_fd === nothing # this is how `_make_j′vp_call` works
@test x̄_ad isa DoesNotExist # we said it wasn't differentiable.
else
x̄_ad isa AbstractThunk && check_inferred && _test_inferred(unthunk, x̄_ad)

# The main test of the actual deriviative being correct:
check_equal(x̄_ad, x̄_fd; isapprox_kwargs...)
_check_add!!_behaviour(accumulated_x̄, x̄_ad; isapprox_kwargs...)
end
end
end

check_thunking_is_appropriate(x̄s_ad)
check_thunking_is_appropriate(x̄s_ad)
end # top-level testset
end

function check_thunking_is_appropriate(x̄s)
Expand Down
5 changes: 4 additions & 1 deletion test/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ end

# we defined these functions at top of file to throw errors unless we pass `err=false`
@test_throws ErrorException futestkws(randn())
@test_throws ErrorException test_scalar(futestkws, randn())
@test errors(
()->test_scalar(futestkws, randn()),
"futestkws_err",
)
@test_throws ErrorException frule((nothing, randn()), futestkws, randn())
@test_throws ErrorException rrule(futestkws, randn())

Expand Down
Loading