Skip to content

Commit

Permalink
Merge 79d4393 into f6cdbcb
Browse files Browse the repository at this point in the history
  • Loading branch information
nickrobinson251 committed Jan 22, 2020
2 parents f6cdbcb + 79d4393 commit 120ed6b
Show file tree
Hide file tree
Showing 6 changed files with 6 additions and 77 deletions.
3 changes: 1 addition & 2 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ end
where `y = foo(args; kwargs...)`, and `pushforward` is a function to propagate the derivative information forwards at that point (more later).



The `rrule` for some function `foo`, which takes the positional argument `args` and keyword argument `kwargs`, is written:

```julia
Expand Down Expand Up @@ -234,7 +233,7 @@ The most important `AbstractDifferential`s when getting started are the ones abo

#### Other `AbstractDifferential`s: don't worry about them right now
- `Casted`: it implements broadcasting mechanics. See [#10](https://github.com/JuliaDiff/ChainRulesCore.jl/issues/10)
- `InplaceableThunk`: it is like a Thunk but it can do `store!` and `accumulate!` in-place.
- `InplaceableThunk`: it is like a `Thunk` but it can do in-place `add!`.

-------------------------------

Expand Down
4 changes: 0 additions & 4 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@ using Reexport
# to the normal rule of only overload via `ChainRulesCore.rrule`.
import ChainRulesCore: rrule, frule

# Deal with name clashes, by defining in this module which one we mean.
const accumulate = ChainRulesCore.accumulate
const accumulate! = ChainRulesCore.accumulate!

using LinearAlgebra
using LinearAlgebra.BLAS
using Requires
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,8 @@

@test ds === NO_FIELDS

@test extern(dx) == extern(accumulate(zeros(3, 2), dx))
@test extern(dy) == extern(accumulate(zeros(2, 5), dy))

test_accumulation(rand(3, 2), dx)
test_accumulation(rand(2, 5), dy)
@test extern(dx) == extern(zeros(3, 2) .+ dx)
@test extern(dy) == extern(zeros(2, 5) .+ dy)
end

@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
Expand Down
7 changes: 2 additions & 5 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@

x̄, ȳ = rand(), rand()
∂x = pullback(ȳ)[3]
@test isequal(
extern(ChainRules.accumulate(x̄, ∂x)),
.+.* cos.(x)
)
@test isequal(extern(x̄ .+ ∂x), x̄ .+.* cos.(x))

x̄, ȳ = Zero(), rand(3, 3)
∂x = pullback(ȳ)[3]
@test extern(extern(accumulate(x̄, ∂x))) ==.* cos.(x)
@test extern(extern(.+ ∂x)) ==.* cos.(x)
end
@testset "frule" begin
x = rand(3, 3)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using Statistics
using Test

# For testing purposes we use a lot of
using ChainRulesCore: extern, accumulate, accumulate!, store!, @scalar_rule,
using ChainRulesCore: extern, @scalar_rule,
Zero, One, DoesNotExist, Thunk, AbstractDifferential

Random.seed!(1) # Set seed that all testsets should reset to.
Expand Down
60 changes: 0 additions & 60 deletions test/test_util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,6 @@ function rrule_test(f, ȳ, (x, x̄)::Tuple{Any, Any}; rtol=1e-9, atol=1e-9, fdm
# Correctness testing via finite differencing.
x̄_fd = j′vp(fdm, f, ȳ, x)
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)

# Assuming x̄_ad to be correct, check that other ChainRules mechanisms are correct.
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end

function _make_fdm_call(fdm, f, ȳ, xs, ignores)
Expand Down Expand Up @@ -177,13 +173,6 @@ function rrule_test(f, ȳ, xx̄s::Tuple{Any, Any}...; rtol=1e-9, atol=1e-9, fdm
@test isapprox(x̄_ad, x̄_fd; rtol=rtol, atol=atol, kwargs...)
end
end

# Assuming the above to be correct, check that other ChainRules mechanisms are correct.
for (x̄, x̄_ad) in zip(x̄s, x̄s_ad)
=== nothing && continue
test_accumulation(x̄, x̄_ad)
test_accumulation(Zero(), x̄_ad)
end
end

function Base.isapprox(d_ad::DoesNotExist, d_fd; kwargs...)
Expand All @@ -193,52 +182,3 @@ end
function Base.isapprox(d_ad::AbstractDifferential, d_fd; kwargs...)
return isapprox(extern(d_ad), d_fd; kwargs...)
end

function test_accumulation(x̄, ∂x)
@test all(extern(x̄ + ∂x) .≈ extern(x̄) .+ extern(∂x))
test_accumulate(x̄, ∂x)
test_accumulate!(x̄, ∂x)
test_store!(x̄, ∂x)
end

function test_accumulate(x̄::Zero, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(∂x)
end

function test_accumulate(x̄::Number, ∂x)
@test extern(accumulate(x̄, ∂x)) extern(x̄) + extern(∂x)
end

function test_accumulate(x̄::AbstractArray, ∂x)
x̄_old = copy(x̄)
@test all(extern(accumulate(x̄, ∂x)) .≈ (extern(x̄) .+ extern(∂x)))
@test== x̄_old # make sure didn't mutate x̄
end

test_accumulate!(x̄::Zero, ∂x) = nothing

function test_accumulate!(x̄::Number, ∂x)
# This case won't have been inplace as `Number` is immutable
@test accumulate!(x̄, ∂x) accumulate(x̄, ∂x)
end

function test_accumulate!(x̄::AbstractArray, ∂x)
x̄_copy = copy(x̄)

accumulate!(x̄_copy, ∂x) # this should have actually been in-place
@test extern(x̄_copy) (extern(x̄) .+ extern(∂x))
end

test_store!(x̄::Zero, ∂x) = nothing
test_store!(x̄::Number, ∂x) = nothing

function test_store!(x̄::AbstractArray, ∂x)
x̄_store = copy(x̄)
store!(x̄_store, ∂x)
@test x̄_store extern(∂x)

# store! is the same as `accumulate!` to a zero array
x̄_acc = false.*
accumulate!(x̄_acc, ∂x)
@test x̄_acc x̄_store
end

0 comments on commit 120ed6b

Please sign in to comment.