Skip to content

Commit

Permalink
Merge pull request #11 from JuliaDiff/aa/more-tests
Browse files Browse the repository at this point in the history
Add some more tests
  • Loading branch information
ararslan committed Apr 12, 2019
2 parents 23a7ab7 + f18ad4c commit 6cfd4b9
Showing 1 changed file with 75 additions and 1 deletion.
76 changes: 75 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,81 @@
# TODO: more tests!

using ChainRules, Test
using ChainRules: One, Zero, rrule, frule, extern, cast, accumulate, accumulate!, store!
using ChainRules: rrule, frule, extern, accumulate, accumulate!, store!, @scalar_rule,
Wirtinger, wirtinger_primal, wirtinger_conjugate, add_wirtinger, mul_wirtinger,
Zero, add_zero, mul_zero, One, add_one, mul_one,
Casted, cast, add_casted, mul_casted
using Base.Broadcast: broadcastable

cool(x) = x + 1
@testset "frule and rrule" begin
@test frule(cool, 1) === nothing
@test rrule(cool, 1) === nothing
ChainRules.@scalar_rule(Main.cool(x), one(x))
frx, fr = frule(cool, 1)
@test frx == 2
@test fr(1) == 1
rrx, rr = rrule(cool, 1)
@test rrx == 2
@test rr(1) == 1
end

@testset "iterating rules" begin
_, rule = frule(+, 1)
i = 0
for r in rule
@test r === rule
i += 1
end
@test i == 1 # rules only iterate once, yielding themselves
end

@testset "Differentials" begin
@testset "Wirtinger" begin
w = Wirtinger(1+1im, 2+2im)
@test wirtinger_primal(w) == 1+1im
@test wirtinger_conjugate(w) == 2+2im
@test add_wirtinger(w, w) == Wirtinger(2+2im, 4+4im)
# TODO: other add_wirtinger methods stack overflow
@test_throws ErrorException mul_wirtinger(w, w)
@test_throws ErrorException extern(w)
for x in w
@test x === w
end
@test broadcastable(w) == w
@test_throws ErrorException conj(w)
end
@testset "Zero" begin
z = Zero()
@test extern(z) === false
@test add_zero(z, z) == z
@test add_zero(z, 1) == 1
@test add_zero(1, z) == 1
@test mul_zero(z, z) == z
@test mul_zero(z, 1) == z
@test mul_zero(1, z) == z
for x in z
@test x === z
end
@test broadcastable(z) isa Ref{Zero}
@test conj(z) == z
end
@testset "One" begin
o = One()
@test extern(o) === true
@test add_one(o, o) == 2
@test add_one(o, 1) == 2
@test add_one(1, o) == 2
@test mul_one(o, o) == o
@test mul_one(o, 1) == 1
@test mul_one(1, o) == 1
for x in o
@test x === o
end
@test broadcastable(o) isa Ref{One}
@test conj(o) == o
end
end

#####
##### `*(x, y)`
Expand Down

0 comments on commit 6cfd4b9

Please sign in to comment.