Skip to content

Commit

Permalink
Resolve merge conflict
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Jul 5, 2020
2 parents 11c3b71 + ce39b65 commit c739060
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 1 deletion.
95 changes: 95 additions & 0 deletions src/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,101 @@ function rrule(::typeof(*), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real}
return A * B, times_pullback
end

#####
##### `pinv`
#####

@scalar_rule pinv(x) -^ 2)

function frule(
(_, Δx),
::typeof(pinv),
x::AbstractVector{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
∂y′ = sum(abs2, parent(y)) .* Δx .- 2real(y * Δx) .* parent(y)
∂y = y isa Transpose ? transpose(∂y′) : adjoint(∂y′)
return y, ∂y
end

function frule(
(_, Δx),
::typeof(pinv),
x::LinearAlgebra.AdjOrTransAbsVec{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
∂y = sum(abs2, y) .* vec(Δx') .- 2real(Δx * y) .* y
return y, ∂y
end

# Formula for derivative adapted from Eq 4.12 of
# Golub, Gene H., and Victor Pereyra. "The Differentiation of Pseudo-Inverses and Nonlinear
# Least Squares Problems Whose Variables Separate."
# SIAM Journal on Numerical Analysis 10(2). (1973). 413-432. doi: 10.1137/0710036
function frule((_, ΔA), ::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
Y = pinv(A; kwargs...)
m, n = size(A)
# contract over the largest dimension
if m n
∂Y = -Y * (ΔA * Y)
_add!(∂Y, (ΔA' - Y * (A * ΔA')) * (Y' * Y)) # (I - Y A) ΔA' Y' Y
_add!(∂Y, Y * (Y' * ΔA') * (I - A * Y)) # Y Y' ΔA' (I - A Y)
else
∂Y = -(Y * ΔA) * Y
_add!(∂Y, (I - Y * A) * (ΔA' * Y') * Y) # (I - Y A) ΔA' Y' Y
_add!(∂Y, (Y * Y') * (ΔA' - (ΔA' * A) * Y)) # Y Y' ΔA' (I - A Y)
end
return Y, ∂Y
end

function rrule(
::typeof(pinv),
x::AbstractVector{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
function pinv_pullback(Δy)
∂x = sum(abs2, parent(y)) .* vec(Δy') .- 2real(y * Δy') .* parent(y)
return (NO_FIELDS, ∂x, Zero())
end
return y, pinv_pullback
end

function rrule(
::typeof(pinv),
x::LinearAlgebra.AdjOrTransAbsVec{T},
tol::Real = 0,
) where {T<:Union{Real,Complex}}
y = pinv(x, tol)
function pinv_pullback(Δy)
∂x′ = sum(abs2, y) .* Δy .- 2real(y' * Δy) .* y
∂x = x isa Transpose ? transpose(conj(∂x′)) : adjoint(∂x′)
return (NO_FIELDS, ∂x, Zero())
end
return y, pinv_pullback
end

function rrule(::typeof(pinv), A::AbstractMatrix{T}; kwargs...) where {T}
Y = pinv(A; kwargs...)
function pinv_pullback(ΔY)
m, n = size(A)
# contract over the largest dimension
if m n
∂A = (Y' * -ΔY) * Y'
_add!(∂A, (Y' * Y) * (ΔY' - (ΔY' * Y) * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (I - A * Y) * (ΔY' * Y) * Y') # (I - A Y) ΔY' Y Y'
elseif m > n
∂A = Y' * (-ΔY * Y')
_add!(∂A, Y' * (Y * ΔY') * (I - Y * A)) # Y' Y ΔY' (I - Y A)
_add!(∂A, (ΔY' - A * (Y * ΔY')) * (Y * Y')) # (I - A Y) ΔY' Y Y'
end
return (NO_FIELDS, ∂A)
end
return Y, pinv_pullback
end

#####
##### `/`
#####
Expand Down
2 changes: 1 addition & 1 deletion src/rulesets/LinearAlgebra/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function _eyesubx!(X::AbstractMatrix)
end

# X + Y, overwrites X if possible
function _add!(X::AbstractVecOrMat{T}, Y::AbstractVecOrMat{T}) where T<:Real
function _add!(X::AbstractVecOrMat, Y::AbstractVecOrMat)
@inbounds for i = eachindex(X, Y)
X[i] += Y[i]
end
Expand Down
40 changes: 40 additions & 0 deletions test/rulesets/LinearAlgebra/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,46 @@
frule_test(inv, (B, randn(T, N, N)))
rrule_test(inv, randn(T, N, N), (B, randn(T, N, N)))
end
@testset "pinv" begin
@testset "$T" for T in (Float64, ComplexF64)
test_scalar(pinv, randn(T))
@test frule((Zero(), randn(T)), pinv, zero(T))[2] zero(T)
@test rrule(pinv, zero(T))[2](randn(T))[2] zero(T)
end
@testset "Vector{$T}" for T in (Float64, ComplexF64)
n = 3
x, ẋ, x̄ = randn(T, n), randn(T, n), randn(T, n)
tol, ṫol, t̄ol = 0.0, randn(), randn()
Δy = copyto!(similar(pinv(x)), randn(T, n))
frule_test(pinv, (x, ẋ), (tol, ṫol))
@test frule((Zero(), ẋ), pinv, x)[2] isa typeof(pinv(x))
rrule_test(pinv, Δy, (x, x̄), (tol, t̄ol))
@test rrule(pinv, x)[2](Δy)[2] isa typeof(x)
end
@testset "$F{Vector{$T}}" for T in (Float64, ComplexF64), F in (Transpose, Adjoint)
n = 3
x, ẋ, x̄ = F(randn(T, n)), F(randn(T, n)), F(randn(T, n))
y = pinv(x)
Δy = copyto!(similar(y), randn(T, n))
frule_test(pinv, (x, ẋ))
y_fwd, ∂y_fwd = frule((Zero(), ẋ), pinv, x)
@test y_fwd isa typeof(y)
@test ∂y_fwd isa typeof(y)
rrule_test(pinv, Δy, (x, x̄))
y_rev, back = rrule(pinv, x)
@test y_rev isa typeof(y)
@test back(Δy)[2] isa typeof(x)
end
@testset "Matrix{$T} with size ($m,$n)" for T in (Float64, ComplexF64),
m in 1:3,
n in 1:3

X, Ẋ, X̄ = randn(T, m, n), randn(T, m, n), randn(T, m, n)
ΔY = randn(T, size(pinv(X))...)
frule_test(pinv, (X, Ẋ))
rrule_test(pinv, ΔY, (X, X̄))
end
end
@testset "det(::Matrix{$T})" for T in (Float64, ComplexF64)
N = 3
B = generate_well_conditioned_matrix(T, N)
Expand Down

0 comments on commit c739060

Please sign in to comment.