Skip to content

Commit

Permalink
Merge 7ce445c into 7b69721
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Apr 20, 2019
2 parents 7b69721 + 7ce445c commit 162df91
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 15 deletions.
4 changes: 2 additions & 2 deletions test/rules/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ end
@test dx(z̄) == extern(accumulate(zeros(3, 2), dx, z̄))
@test dy(z̄) == extern(accumulate(zeros(2, 5), dy, z̄))

test_adjoint!(rand(3, 2), dx, z̄, z̄ * y')
test_adjoint!(rand(2, 5), dy, z̄, x' * z̄)
test_accumulation(rand(3, 2), dx, z̄, z̄ * y')
test_accumulation(rand(2, 5), dy, z̄, x' * z̄)
end
@testset "hypot(x, y)" begin
x, y = rand(2)
Expand Down
61 changes: 48 additions & 13 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
@test cr_isapprox(x̄_ad, x̄_fd, rtol, atol)

# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
test_adjoint!(x̄, dx, ȳ, x̄_ad)
test_accumulation(x̄, dx, ȳ, x̄_ad)
test_accumulation(Zero(), dx, ȳ, x̄_ad)
end

function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm=_fdm)
Expand All @@ -56,7 +57,11 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
@test all(map((Δx_ad, Δx_fd)->cr_isapprox(Δx_ad, Δx_fd, rtol, atol), Δxs_ad, Δxs_fd))

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
map((x̄, Δx_rule, Δx_ad)->test_adjoint!(x̄, Δx_rule, ȳ, Δx_ad), x̄s, Δx_rules, Δxs_ad)
map(x̄s, Δx_rules, Δxs_ad) do x̄, Δx_rule, Δx_ad
test_accumulation(x̄, Δx_rule, ȳ, Δx_ad)
test_accumulation(Zero(), Δx_rule, ȳ, Δx_ad)
return nothing
end
end

function cr_isapprox(d_ad, d_fd, rtol, atol)
Expand All @@ -75,21 +80,51 @@ function cr_isapprox(d_ad::Thunk, d_fd, rtol, atol)
return isapprox(extern(d_ad), d_fd; rtol=rtol, atol=atol)
end

function test_adjoint!(x̄, dx, ȳ, partial)
x̄_old = copy(x̄)
x̄_zeros = zero.(x̄)
function test_accumulation(x̄, dx, ȳ, partial)
@test all(extern(ChainRules.add(x̄, partial)) .== extern(x̄) .+ extern(partial))
test_accumulate(x̄, dx, ȳ, partial)
test_accumulate!(x̄, dx, ȳ, partial)
test_store!(x̄, dx, ȳ, partial)
return nothing
end

function test_accumulate(x̄::Zero, dx, ȳ, partial)
@test extern(accumulate(x̄, dx, ȳ)) == extern(partial)
return nothing
end

function test_accumulate(x̄::Number, dx, ȳ, partial)
@test extern(accumulate(x̄, dx, ȳ)) == extern(x̄) + extern(partial)
return nothing
end

@test all(accumulate(Zero(), dx, ȳ) .== accumulate(x̄_zeros, dx, ȳ))
@test all(accumulate(x̄, dx, ȳ) .== (x̄ .+ partial))
function test_accumulate(x̄::AbstractArray, dx, ȳ, partial)
x̄_old = copy(x̄)
@test all(extern(accumulate(x̄, dx, ȳ)) .== (extern(x̄) .+ extern(partial)))
@test== x̄_old
return nothing
end

accumulate!(x̄, dx, ȳ)
@test== (x̄_old .+ partial)
x̄ .= x̄_old
test_accumulate!(x̄::Zero, dx, ȳ, partial) = nothing

function test_accumulate!(x̄::Number, dx, ȳ, partial)
@test accumulate!(x̄, dx, ȳ) == accumulate(x̄, dx, ȳ)
return nothing
end

function test_accumulate!(x̄::AbstractArray, dx, ȳ, partial)
x̄_copy = copy(x̄)
accumulate!(x̄_copy, dx, ȳ)
@test extern(x̄_copy) == (extern(x̄) .+ extern(partial))
return nothing
end

store!(x̄, dx, ȳ)
@test all(x̄ .== partial)
x̄ .= x̄_old
test_store!(x̄::Zero, dx, ȳ, partial) = nothing
test_store!(x̄::Number, dx, ȳ, partial) = nothing

function test_store!(x̄::AbstractArray, dx, ȳ, partial)
x̄_copy = copy(x̄)
store!(x̄_copy, dx, ȳ)
@test all(x̄_copy .== extern(partial))
return nothing
end

0 comments on commit 162df91

Please sign in to comment.