Skip to content

Commit

Permalink
Add rules for pinv (#225)
Browse files Browse the repository at this point in the history
* Add matrix pinv rules

* Add vector and scalar pinv rules

* Make sure (co)tangents are right type

* Avoid unnecessary computation

* Add pinv tests

* Test return type

* Contract over larger dimension

* Release type constraint

* Don't assume is left or right inverse

* Make keyword arg positional

* Explicitly call parent

* Add pinv for transpose and adjoint vectors

* Simplify expression

* Add reference

* Drop default type from signature

* Don't call rules internally

* Increment version number
  • Loading branch information
sethaxen committed Jul 5, 2020
1 parent 224e553 commit ce39b65
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.2"
version = "0.7.3"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
95 changes: 95 additions & 0 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,101 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}
return A * B, times_pullback
end

#####
##### `pinv`
#####

@scalar_rule pinv(x) -^ 2)

function frule(
(_, Δx),
::typeof(pinv),
x::AbstractVector{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
∂y′ = sum(abs2, parent(y)) .* Δx .- 2real(y * Δx) .* parent(y)
∂y = y isa Transpose ? transpose(∂y′) : adjoint(∂y′)
return y, ∂y
end

function frule(
(_, Δx),
::typeof(pinv),
x::LinearAlgebra.AdjOrTransAbsVec{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
∂y = sum(abs2, y) .* vec(Δx') .- 2real(Δx * y) .* y
return y, ∂y
end

# Formula for derivative adapted from Eq 4.12 of
# Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses and Nonlinear
# Least Squares Problems Whose Variables Separate."
# SIAM Journal on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
Y = pinv(A; kwargs...)
m, n = size(A)
# contract over the largest dimension
if m n
∂Y = -Y * (ΔA * Y)
_add!(∂Y, (ΔA' - Y * (A * ΔA')) * (Y' * Y)) # (I - Y A) ΔA' Y' Y
_add!(∂Y, Y * (Y' * ΔA') * (I - A * Y)) # Y Y' ΔA' (I - A Y)
else
∂Y = -(Y * ΔA) * Y
_add!(∂Y, (I - Y * A) * (ΔA' * Y') * Y) # (I - Y A) ΔA' Y' Y
_add!(∂Y, (Y * Y') * (ΔA' - (ΔA' * A) * Y)) # Y Y' ΔA' (I - A Y)
end
return Y, ∂Y
end

function rrule(
::typeof(pinv),
x::AbstractVector{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
function pinv_pullback(Δy)
∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y)
return (NO_FIELDS, ∂x, Zero())
end
return y, pinv_pullback
end

function rrule(
::typeof(pinv),
x::LinearAlgebra.AdjOrTransAbsVec{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
function pinv_pullback(Δy)
∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y
∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′)
return (NO_FIELDS, ∂x, Zero())
end
return y, pinv_pullback
end

function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
Y = pinv(A; kwargs...)
function pinv_pullback(ΔY)
m, n = size(A)
# contract over the largest dimension
if m n
∂A = (Y' * -ΔY) * Y'
_add!(∂A, (Y' * Y) * (ΔY' - (ΔY' * Y) * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (I - A * Y) * (ΔY' * Y) * Y') # (I - A Y) ΔY' Y Y'
elseif m > n
∂A = Y' * (-ΔY * Y')
_add!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y'
end
return (NO_FIELDS, ∂A)
end
return Y, pinv_pullback
end

#####
##### `/`
#####
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function _eyesubx!(X::AbstractMatrix)
end

# X + Y, overwrites X if possible
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
Expand Down
40 changes: 40 additions & 0 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,46 @@
frule_test(inv, (B, randn(T, N, N)))
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
end
@testset "pinv" begin
@testset "$T" for T in (Float64, ComplexF64)
test_scalar(pinv, randn(T))
@test frule((Zero(), randn(T)), pinv, zero(T))[2] zero(T)
@test rrule(pinv, zero(T))[2](randn(T))[2] zero(T)
end
@testset "Vector{$T}" for T in (Float64, ComplexF64)
n = 3
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
tol, ṫol, t̄ol = 0.0, randn(), randn()
Δy = copyto!(similar(pinv(x)), randn(T, n))
frule_test(pinv, (x, ẋ), (tol, ṫol))
@test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol))
@test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
end
@testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
n = 3
x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n))
y = pinv(x)
Δy = copyto!(similar(y), randn(T, n))
frule_test(pinv, (x, ẋ))
y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
@test y_fwd isa typeof(y)
@test ∂y_fwd isa typeof(y)
rrule_test(pinv, Δy, (x, x̄))
y_rev, back = rrule(pinv, x)
@test y_rev isa typeof(y)
@test back(Δy)[2] isa typeof(x)
end
@testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64),
m in 1:3,
n in 1:3

X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n)
ΔY = randn(T, size(pinv(X))...)
frule_test(pinv, (X, Ẋ))
rrule_test(pinv, ΔY, (X, X̄))
end
end
@testset "det(::Matrix{$T})" for T in (Float64, ComplexF64)
N = 3
B = generate_well_conditioned_matrix(T, N)
Expand Down

2 comments on commit ce39b65

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/17468

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.3 -m "<description of version>" ce39b6508918de5a0d9fb7764662271c73d79474
git push origin v0.7.3

Please sign in to comment.