Skip to content

Commit

Permalink
Merge a3bd02e into df431b9
Browse files Browse the repository at this point in the history
  • Loading branch information
mattBrzezinski committed May 1, 2020
2 parents df431b9 + a3bd02e commit 3f52394
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 168 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.5.1"
version = "0.5.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
25 changes: 11 additions & 14 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
@testset "reshape" begin
rng = MersenneTwister(1)
A = randn(rng, 4, 5)
A = rand(4, 5)
B, pullback = rrule(reshape, A, (5, 4))
@test B == reshape(A, (5, 4))
= randn(rng, 4, 5)
= randn(4, 5)

(s̄, Ā, d̄) = pullback(Ȳ)
@test== NO_FIELDS
Expand All @@ -13,7 +12,7 @@
B, pullback = rrule(reshape, A, 5, 4)
@test B == reshape(A, 5, 4)

= randn(rng, 4, 5)
= randn(4, 5)
(s̄, Ā, d̄1, d̄2) = pullback(Ȳ)
@test== NO_FIELDS
@test d̄1 isa DoesNotExist
Expand All @@ -22,13 +21,12 @@
end

@testset "hcat" begin
rng = MersenneTwister(2)
A = randn(rng, 3, 2)
B = randn(rng, 3)
C = randn(rng, 3, 3)
A = randn(3, 2)
B = randn(3)
C = randn(3, 3)
H, pullback = rrule(hcat, A, B, C)
@test H == hcat(A, B, C)
= randn(rng, 3, 6)
= randn(3, 6)
(ds, dA, dB, dC) = pullback(H̄)
@test ds == NO_FIELDS
@test dA view(H̄, :, 1:2)
Expand All @@ -37,13 +35,12 @@ end
end

@testset "vcat" begin
rng = MersenneTwister(3)
A = randn(rng, 2, 4)
B = randn(rng, 1, 4)
C = randn(rng, 3, 4)
A = randn(2, 4)
B = randn(1, 4)
C = randn(3, 4)
V, pullback = rrule(vcat, A, B, C)
@test V == vcat(A, B, C)
= randn(rng, 6, 4)
= randn(6, 4)
(ds, dA, dB, dC) = pullback(V̄)
@test ds == NO_FIELDS
@test dA view(V̄, 1:2, :)
Expand Down
34 changes: 15 additions & 19 deletions test/rulesets/Base/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,34 +133,31 @@
end

@testset "binary function ($f)" for f in (hypot, atan, mod, rem, ^)
rng = MersenneTwister(123456)
x, Δx, x̄ = 10rand(rng, 3)
y, Δy, ȳ = rand(rng, 3)
Δz = rand(rng)
x, Δx, x̄ = 10rand(3)
y, Δy, ȳ = rand(3)
Δz = rand()

frule_test(f, (x, Δx), (y, Δy))
rrule_test(f, Δz, (x, x̄), (y, ȳ))
end

@testset "x^n for x<0" begin
rng = MersenneTwister(123456)
x = -15*rand(rng)
Δx, x̄ = 10rand(rng, 2)
y, Δy, ȳ = rand(rng, 3)
Δz = rand(rng)
x = -15*rand()
Δx, x̄ = 10rand(2)
y, Δy, ȳ = rand(3)
Δz = rand()

frule_test(^, (-x, Δx), (y, Δy))
rrule_test(^, Δz, (-x, x̄), (y, ȳ))
end

@testset "identity" begin
rng = MersenneTwister(1)
rrule_test(identity, randn(rng), (randn(rng), randn(rng)))
rrule_test(identity, randn(rng, 4), (randn(rng, 4), randn(rng, 4)))
rrule_test(identity, randn(), (randn(), randn()))
rrule_test(identity, randn(4), (randn(4), randn(4)))

rrule_test(
identity, Tuple(randn(rng, 3)),
(Composite{Tuple}(randn(rng, 3)...), Composite{Tuple}(randn(rng, 3)...))
identity, Tuple(randn(3)),
(Composite{Tuple}(randn(3)...), Composite{Tuple}(randn(3)...))
)
end

Expand Down Expand Up @@ -188,11 +185,10 @@
end

@testset "trinary ($f)" for f in (muladd, fma)
rng = MersenneTwister(123456)
x, Δx, x̄ = 10randn(rng, 3)
y, Δy, ȳ = randn(rng, 3)
z, Δz, z̄ = randn(rng, 3)
Δk = randn(rng)
x, Δx, x̄ = 10randn(3)
y, Δy, ȳ = randn(3)
z, Δz, z̄ = randn(3)
Δk = randn()

frule_test(f, (x, Δx), (y, Δy), (z, Δz))
rrule_test(f, Δk, (x, x̄), (y, ȳ), (z, z̄))
Expand Down
23 changes: 11 additions & 12 deletions test/rulesets/Base/mapreduce.jl
Original file line number Diff line number Diff line change
@@ -1,26 +1,25 @@
@testset "Maps and Reductions" begin
@testset "sum" begin
@testset "Vector" begin
rng, M = MersenneTwister(123456), 3
frule_test(sum, (randn(rng, M), randn(rng, M)))
rrule_test(sum, randn(rng), (randn(rng, M), randn(rng, M)))
M = 3
frule_test(sum, (randn(M), randn(M)))
rrule_test(sum, randn(), (randn(M), randn(M)))
end
@testset "Matrix" begin
rng, M, N = MersenneTwister(123456), 3, 4
frule_test(sum, (randn(rng, M, N), randn(rng, M, N)))
rrule_test(sum, randn(rng), (randn(rng, M, N), randn(rng, M, N)))
M, N = 3, 4
frule_test(sum, (randn(M, N), randn(M, N)))
rrule_test(sum, randn(), (randn(M, N), randn(M, N)))
end
@testset "Array{T, 3}" begin
rng, M, N, P = MersenneTwister(123456), 3, 7, 11
frule_test(sum, (randn(rng, M, N, P), randn(rng, M, N, P)))
rrule_test(sum, randn(rng), (randn(rng, M, N, P), randn(rng, M, N, P)))
M, N, P = 3, 7, 11
frule_test(sum, (randn(M, N, P), randn(M, N, P)))
rrule_test(sum, randn(), (randn(M, N, P), randn(M, N, P)))
end
@testset "keyword arguments" begin
rng = MersenneTwister(33)
n = 4
X = randn(rng, n, n+1)
X = randn(n, n+1)
y, pullback = rrule(sum, X; dims=2)
= randn(rng, size(y))
= randn(size(y))
_, x̄_ad = pullback(ȳ)
x̄_fd = only(j′vp(central_fdm(5, 1), x->sum(x, dims=2), ȳ, X))
@test x̄_ad x̄_fd atol=1e-9 rtol=1e-9
Expand Down
30 changes: 14 additions & 16 deletions test/rulesets/LinearAlgebra/blas.jl
Original file line number Diff line number Diff line change
@@ -1,39 +1,37 @@
@testset "BLAS" begin
@testset "gemm" begin
rng = MersenneTwister(1)
dims = 3:5
for m in dims, n in dims, p in dims, tA in ('N', 'T'), tB in ('N', 'T')
α = randn(rng)
A = randn(rng, tA === 'N' ? (m, n) : (n, m))
B = randn(rng, tB === 'N' ? (n, p) : (p, n))
α = randn()
A = randn(tA === 'N' ? (m, n) : (n, m))
B = randn(tB === 'N' ? (n, p) : (p, n))
C = gemm(tA, tB, α, A, B)
= randn(rng, size(C)...)
= randn(size(C)...)
rrule_test(
gemm,
ȳ,
(tA, nothing),
(tB, nothing),
(α, randn(rng)),
(A, randn(rng, size(A))),
(B, randn(rng, size(B))),
(α, randn()),
(A, randn(size(A))),
(B, randn(size(B))),
)
end
end
@testset "gemv" begin
rng = MersenneTwister(2)
for n in 3:5, m in 3:5, t in ('N', 'T')
α = randn(rng)
A = randn(rng, m, n)
x = randn(rng, t === 'N' ? n : m)
α = randn()
A = randn(m, n)
x = randn(t === 'N' ? n : m)
y = α * (t === 'N' ? A : A') * x
= randn(rng, size(y)...)
= randn(size(y)...)
rrule_test(
gemv,
ȳ,
(t, nothing),
(α, randn(rng)),
(A, randn(rng, size(A))),
(x, randn(rng, size(x))),
(α, randn()),
(A, randn(size(A))),
(x, randn(size(x))),
)
end
end
Expand Down

0 comments on commit 3f52394

Please sign in to comment.