Skip to content

Commit

Permalink
Merge 457ed82 into 7b69721
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Apr 22, 2019
2 parents 7b69721 + 457ed82 commit 24d8eec
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ include("differentials.jl")
include("rules.jl")
include("rules/base.jl")
include("rules/broadcast.jl")
include("rules/linalg.jl")
include("rules/linalg/dense.jl")
include("rules/linalg/diagonal.jl")
include("rules/linalg/symmetric.jl")
include("rules/blas.jl")
include("rules/nanmath.jl")
include("rules/specialfunctions.jl")
Expand Down
File renamed without changes.
2 changes: 2 additions & 0 deletions src/rules/linalg/diagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
rrule(::typeof(Diagonal), d::AbstractVector) = Diagonal(d), Rule(diag)
rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)
4 changes: 4 additions & 0 deletions src/rules/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
rrule(::typeof(Symmetric), A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ
File renamed without changes.
14 changes: 14 additions & 0 deletions test/rules/linalg/diagonal.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@testset "diagonal" begin
@testset "Diagonal" begin
rng, N = MersenneTwister(123456), 3
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
rrule_test(Diagonal, Diagonal(randn(rng, N)), (randn(rng, N), randn(rng, N)))
end
@testset "diag" begin
rng, N = MersenneTwister(123456), 7
rrule_test(diag, randn(rng, N), (randn(rng, N, N), randn(rng, N, N)))
rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), randn(rng, N, N)))
rrule_test(diag, randn(rng, N), (randn(rng, N, N), Diagonal(randn(rng, N))))
rrule_test(diag, randn(rng, N), (Diagonal(randn(rng, N)), Diagonal(randn(rng, N))))
end
end
6 changes: 6 additions & 0 deletions test/rules/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@testset "symmetric" begin
@testset "Symmetric" begin
rng, N = MersenneTwister(123456), 3
rrule_test(Symmetric, randn(rng, N, N), (randn(rng, N, N), randn(rng, N, N)))
end
end
6 changes: 5 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@ include("test_util.jl")
include("rules.jl")
@testset "rules" begin
include(joinpath("rules", "base.jl"))
@testset "linalg" begin
include(joinpath("rules", "linalg", "dense.jl"))
include(joinpath("rules", "linalg", "diagonal.jl"))
include(joinpath("rules", "linalg", "symmetric.jl"))
end
include(joinpath("rules", "broadcast.jl"))
include(joinpath("rules", "linalg.jl"))
include(joinpath("rules", "blas.jl"))
include(joinpath("rules", "nanmath.jl"))
include(joinpath("rules", "specialfunctions.jl"))
Expand Down

0 comments on commit 24d8eec

Please sign in to comment.