From a610fb3b351a2b7221bb3a79cc8a7c43d76e4101 Mon Sep 17 00:00:00 2001 From: Alex Arslan Date: Mon, 22 Apr 2019 15:48:55 -0700 Subject: [PATCH] Fix rrules for Symmetric and Diagonal constructors Currently these definitions are extending `rrule(::typeof(T), x)` where `T` is a type. However, `typeof(Diagonal) == UnionAll`, which means this is not defining the method it looks like it might be defining. The only reason this worked when originally implemented was that one of the `rrule` definitions was for `rrule(UnionAll, Matrix)` and the other for `rrule(UnionAll, Vector)`, so dispatch still worked. This replaces these problematic `::typeof(T)`s with `::Type{<:T}`. --- src/rules/linalg/diagonal.jl | 2 +- src/rules/linalg/symmetric.jl | 2 +- test/rules/linalg/diagonal.jl | 5 ++++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/rules/linalg/diagonal.jl b/src/rules/linalg/diagonal.jl index 89bc0316e..cd9a40959 100644 --- a/src/rules/linalg/diagonal.jl +++ b/src/rules/linalg/diagonal.jl @@ -1,2 +1,2 @@ -rrule(::typeof(Diagonal), d::AbstractVector) = Diagonal(d), Rule(diag) +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) diff --git a/src/rules/linalg/symmetric.jl b/src/rules/linalg/symmetric.jl index 07d0e964c..4b1f861d6 100644 --- a/src/rules/linalg/symmetric.jl +++ b/src/rules/linalg/symmetric.jl @@ -1,4 +1,4 @@ -rrule(::typeof(Symmetric), A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) _symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ diff --git a/test/rules/linalg/diagonal.jl b/test/rules/linalg/diagonal.jl index 1320d0155..77de00d09 100644 --- a/test/rules/linalg/diagonal.jl +++ b/test/rules/linalg/diagonal.jl @@ -2,7 +2,10 @@ @testset "Diagonal" begin rng, N = MersenneTwister(123456), 3 rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N))) - rrule_test(Diagonal, Diagonal(randn(rng, N)), (randn(rng, N), randn(rng, N))) + D = Diagonal(randn(rng, N)) + rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N))) + # Concrete type instead of UnionAll + rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N))) end @testset "diag" begin rng, N = MersenneTwister(123456), 7