From ce39b6508918de5a0d9fb7764662271c73d79474 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 5 Jul 2020 01:41:09 -0700 Subject: [PATCH] Add rules for pinv (#225) * Add matrix pinv rules * Add vector and scalar pinv rules * Make sure (co)tangents are right type * Avoid unnecessary computation * Add pinv tests * Test return type * Contract over larger dimension * Release type constraint * Don't assume is left or right inverse * Make keyword arg positional * Explicitly call parent * Add pinv for transpose and adjoint vectors * Simplify expression * Add reference * Drop default type from signature * Don't call rules internally * Increment version number --- Project.toml | 2 +- src/rulesets/LinearAlgebra/dense.jl | 95 ++++++++++++++++++++++++++++ src/rulesets/LinearAlgebra/utils.jl | 2 +- test/rulesets/LinearAlgebra/dense.jl | 40 ++++++++++++ 4 files changed, 137 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 7388caafb..57b6aa85f 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.2" +version = "0.7.3" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index cbe27cb05..c949c3b96 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -118,6 +118,101 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real} return A * B, times_pullback end +##### +##### `pinv` +##### + +@scalar_rule pinv(x) -(Ω ^ 2) + +function frule( + (_, Δx), + ::typeof(pinv), + x::AbstractVector{T}, + tol::Real = 0, +) where {T<:Union{Real,Complex}} + y = pinv(x, tol) + ∂y′ = sum(abs2, parent(y)) .* Δx .- 2real(y * Δx) .* parent(y) + ∂y = y isa Transpose ? transpose(∂y′) : adjoint(∂y′) + return y, ∂y +end + +function frule( + (_, Δx), + ::typeof(pinv), + x::LinearAlgebra.AdjOrTransAbsVec{T}, + tol::Real = 0, +) where {T<:Union{Real,Complex}} + y = pinv(x, tol) + ∂y = sum(abs2, y) .* vec(Δx') .- 2real(Δx * y) .* y + return y, ∂y +end + +# Formula for derivative adapted from Eq 4.12 of +# Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses and Nonlinear +# Least Squares Problems Whose Variables Separate." +# SIAM Journal on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036 +function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} + Y = pinv(A; kwargs...) + m, n = size(A) + # contract over the largest dimension + if m ≤ n + ∂Y = -Y * (ΔA * Y) + _add!(∂Y, (ΔA' - Y * (A * ΔA')) * (Y' * Y)) # (I - Y A) ΔA' Y' Y + _add!(∂Y, Y * (Y' * ΔA') * (I - A * Y)) # Y Y' ΔA' (I - A Y) + else + ∂Y = -(Y * ΔA) * Y + _add!(∂Y, (I - Y * A) * (ΔA' * Y') * Y) # (I - Y A) ΔA' Y' Y + _add!(∂Y, (Y * Y') * (ΔA' - (ΔA' * A) * Y)) # Y Y' ΔA' (I - A Y) + end + return Y, ∂Y +end + +function rrule( + ::typeof(pinv), + x::AbstractVector{T}, + tol::Real = 0, +) where {T<:Union{Real,Complex}} + y = pinv(x, tol) + function pinv_pullback(Δy) + ∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y) + return (NO_FIELDS, ∂x, Zero()) + end + return y, pinv_pullback +end + +function rrule( + ::typeof(pinv), + x::LinearAlgebra.AdjOrTransAbsVec{T}, + tol::Real = 0, +) where {T<:Union{Real,Complex}} + y = pinv(x, tol) + function pinv_pullback(Δy) + ∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y + ∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′) + return (NO_FIELDS, ∂x, Zero()) + end + return y, pinv_pullback +end + +function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T} + Y = pinv(A; kwargs...) + function pinv_pullback(ΔY) + m, n = size(A) + # contract over the largest dimension + if m ≤ n + ∂A = (Y' * -ΔY) * Y' + _add!(∂A, (Y' * Y) * (ΔY' - (ΔY' * Y) * A)) # Y' Y ΔY' (I - Y A) + _add!(∂A, (I - A * Y) * (ΔY' * Y) * Y') # (I - A Y) ΔY' Y Y' + elseif m > n + ∂A = Y' * (-ΔY * Y') + _add!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A) + _add!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y' + end + return (NO_FIELDS, ∂A) + end + return Y, pinv_pullback +end + ##### ##### `/` ##### diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index dbba25209..caf661447 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -27,7 +27,7 @@ function _eyesubx!(X::AbstractMatrix) end # X + Y, overwrites X if possible -function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real +function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat) @inbounds for i = eachindex(X, Y) X[i] += Y[i] end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 580333294..3c664fd92 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -48,6 +48,46 @@ frule_test(inv, (B, randn(T, N, N))) rrule_test(inv, randn(T, N, N), (B, randn(T, N, N))) end + @testset "pinv" begin + @testset "$T" for T in (Float64, ComplexF64) + test_scalar(pinv, randn(T)) + @test frule((Zero(), randn(T)), pinv, zero(T))[2] ≈ zero(T) + @test rrule(pinv, zero(T))[2](randn(T))[2] ≈ zero(T) + end + @testset "Vector{$T}" for T in (Float64, ComplexF64) + n = 3 + x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n) + tol, ṫol, t̄ol = 0.0, randn(), randn() + Δy = copyto!(similar(pinv(x)), randn(T, n)) + frule_test(pinv, (x, ẋ), (tol, ṫol)) + @test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x)) + rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol)) + @test rrule(pinv, x)[2](Δy)[2] isa typeof(x) + end + @testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint) + n = 3 + x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n)) + y = pinv(x) + Δy = copyto!(similar(y), randn(T, n)) + frule_test(pinv, (x, ẋ)) + y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x) + @test y_fwd isa typeof(y) + @test ∂y_fwd isa typeof(y) + rrule_test(pinv, Δy, (x, x̄)) + y_rev, back = rrule(pinv, x) + @test y_rev isa typeof(y) + @test back(Δy)[2] isa typeof(x) + end + @testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64), + m in 1:3, + n in 1:3 + + X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n) + ΔY = randn(T, size(pinv(X))...) + frule_test(pinv, (X, Ẋ)) + rrule_test(pinv, ΔY, (X, X̄)) + end + end @testset "det(::Matrix{$T})" for T in (Float64, ComplexF64) N = 3 B = generate_well_conditioned_matrix(T, N)