diff --git a/Project.toml b/Project.toml index 2594e1ba9..627ad6fbb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.6.3" +version = "0.6.4" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 5daa046f4..84caab6ea 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -7,22 +7,22 @@ using LinearAlgebra.BLAS: gemv, gemv!, gemm!, trsm!, axpy!, ger! function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) - function svd_pullback(Ȳ::Composite{<:SVD}) - ∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)) + function svd_pullback(Ȳ::Composite) + ∂X = @thunk(svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)) return (NO_FIELDS, ∂X) end return F, svd_pullback end function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD - function getproperty_svd_pullback(Ȳ) + function getproperty_svd_pullback(Ȳ) C = Composite{T} ∂F = if x === :U - C(U=Ȳ,) + C(U=Ȳ,) elseif x === :S - C(S=Ȳ,) + C(S=Ȳ,) elseif x === :V - C(V=Ȳ,) + C(V=Ȳ,) elseif x === :Vt # TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106 throw(ArgumentError("Vt is unsupported; use V and transpose the result")) @@ -32,8 +32,8 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD return getproperty(F, x), getproperty_svd_pullback end -# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` -function svd_rev(USV::SVD, Ū, s̄, V̄) +# When not `Zero`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` +function svd_rev(USV::SVD, Ū, s̄, V̄) # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default U = USV.U s = USV.S @@ -49,7 +49,7 @@ function svd_rev(USV::SVD, Ū, s̄, V̄) # place functions here are significantly faster than their out-of-place, naively # implemented counterparts, and allocate no additional memory. Ut = U' - FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) + FUᵀŪ = _mulsubtrans!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) FVᵀV̄ = _mulsubtrans!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ @@ -58,11 +58,11 @@ function svd_rev(USV::SVD, Ū, s̄, V̄) S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) # TODO: consider using MuladdMacro here - Ā = _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt - Ā = _add!(Ā, U * S̄ * Vt) - Ā = _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ)) + Ā = _add!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt + Ā = _add!(Ā, U * S̄ * Vt) + Ā = _add!(Ā, U * _add!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ)) - return Ā + return Ā end ##### @@ -71,11 +71,11 @@ end function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) F = cholesky(X) - function cholesky_pullback(Ȳ::Composite{<:Cholesky}) + function cholesky_pullback(Ȳ::Composite) ∂X = if F.uplo === 'U' - @thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true)) + @thunk(chol_blocked_rev(Ȳ.U, F.U, 25, true)) else - @thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false)) + @thunk(chol_blocked_rev(Ȳ.L, F.L, 25, false)) end return (NO_FIELDS, ∂X) end @@ -83,19 +83,19 @@ function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real}) end function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky - function getproperty_cholesky_pullback(Ȳ) + function getproperty_cholesky_pullback(Ȳ) C = Composite{T} ∂F = @thunk if x === :U if F.uplo === 'U' - C(U=UpperTriangular(Ȳ),) + C(U=UpperTriangular(Ȳ),) else - C(L=LowerTriangular(Ȳ'),) + C(L=LowerTriangular(Ȳ'),) end elseif x === :L if F.uplo === 'L' - C(L=LowerTriangular(Ȳ),) + C(L=LowerTriangular(Ȳ),) else - C(U=UpperTriangular(Ȳ'),) + C(U=UpperTriangular(Ȳ'),) end end return NO_FIELDS, ∂F, DoesNotExist() @@ -159,14 +159,14 @@ function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool) end """ - chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) + chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool) Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner. If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle -of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the -upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output -`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and -`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. +of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the +upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output +`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and +`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`. """ function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real n = checksquare(Σ̄)