Skip to content

Commit

Permalink
Fix rrules for Symmetric and Diagonal constructors
Browse files Browse the repository at this point in the history
Currently these definitions are extending `rrule(::typeof(T), x)` where
`T` is a type. However, `typeof(Diagonal) == UnionAll`, which means this
is not defining the method it looks like it might be defining. The only
reason this worked when originally implemented was that one of the
`rrule` definitions was for `rrule(UnionAll, Matrix)` and the other for
`rrule(UnionAll, Vector)`, so dispatch still worked.

This replaces these problematic `::typeof(T)`s with `::Type{<:T}`.
  • Loading branch information
ararslan committed Apr 23, 2019
1 parent f783dff commit a610fb3
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 3 deletions.
2 changes: 1 addition & 1 deletion src/rules/linalg/diagonal.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
rrule(::typeof(Diagonal), d::AbstractVector) = Diagonal(d), Rule(diag)
rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag)
rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal)
2 changes: 1 addition & 1 deletion src/rules/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
rrule(::typeof(Symmetric), A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)
rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back)

_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ
5 changes: 4 additions & 1 deletion test/rules/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
@testset "Diagonal" begin
rng, N = MersenneTwister(123456), 3
rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N)))
rrule_test(Diagonal, Diagonal(randn(rng, N)), (randn(rng, N), randn(rng, N)))
D = Diagonal(randn(rng, N))
rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N)))
# Concrete type instead of UnionAll
rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N)))
end
@testset "diag" begin
rng, N = MersenneTwister(123456), 7
Expand Down

0 comments on commit a610fb3

Please sign in to comment.