diff --git a/src/rules/linalg/diagonal.jl b/src/rules/linalg/diagonal.jl index 89bc0316e..cd9a40959 100644 --- a/src/rules/linalg/diagonal.jl +++ b/src/rules/linalg/diagonal.jl @@ -1,2 +1,2 @@ -rrule(::typeof(Diagonal), d::AbstractVector) = Diagonal(d), Rule(diag) +rrule(::Type{<:Diagonal}, d::AbstractVector) = Diagonal(d), Rule(diag) rrule(::typeof(diag), A::AbstractMatrix) = diag(A), Rule(Diagonal) diff --git a/src/rules/linalg/symmetric.jl b/src/rules/linalg/symmetric.jl index 07d0e964c..4b1f861d6 100644 --- a/src/rules/linalg/symmetric.jl +++ b/src/rules/linalg/symmetric.jl @@ -1,4 +1,4 @@ -rrule(::typeof(Symmetric), A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) +rrule(::Type{<:Symmetric}, A::AbstractMatrix) = Symmetric(A), Rule(_symmetric_back) _symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ) _symmetric_back(ΔΩ::Union{Diagonal, UpperTriangular}) = ΔΩ diff --git a/test/rules/linalg/diagonal.jl b/test/rules/linalg/diagonal.jl index 1320d0155..77de00d09 100644 --- a/test/rules/linalg/diagonal.jl +++ b/test/rules/linalg/diagonal.jl @@ -2,7 +2,10 @@ @testset "Diagonal" begin rng, N = MersenneTwister(123456), 3 rrule_test(Diagonal, randn(rng, N, N), (randn(rng, N), randn(rng, N))) - rrule_test(Diagonal, Diagonal(randn(rng, N)), (randn(rng, N), randn(rng, N))) + D = Diagonal(randn(rng, N)) + rrule_test(Diagonal, D, (randn(rng, N), randn(rng, N))) + # Concrete type instead of UnionAll + rrule_test(typeof(D), D, (randn(rng, N), randn(rng, N))) end @testset "diag" begin rng, N = MersenneTwister(123456), 7