Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for qr decomposition pullback #469

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

rkube
Copy link

@rkube rkube commented Jul 12, 2021

Added a rrule for the qr deomposition. @sethaxen

@sethaxen sethaxen self-requested a review July 12, 2021 19:38
@Kolaru Kolaru mentioned this pull request Jul 22, 2021
@rkube
Copy link
Author

rkube commented Aug 25, 2021

I've ported this pullback to CUDA.jl: https://gist.github.com/rkube/b17ef683409d76a3f01bcc590b85de6e
Where would be a good place for that code?

@oxinabox
Copy link
Member

oxinabox commented Sep 21, 2021

pokes @sethaxen
(I can't really review this)

Comment on lines +288 to +290
∂T = d === :R ? Ȳ : nothing

∂F = Tangent{LinearAlgebra.QRCompactWY}(; factors=∂factors, T=∂T)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

∂factors isn't defined

# R. Schreiber and C. van Loan, Sci. Stat. Comput. 10, 53-57 (1989).
# Instead of backpropagating Q̄ and R̄ through (factors)bar and T̄, we re-use factors to carry Q̄ and T to carry R̄
# in the Tangent object.
∂T = d === :R ? Ȳ : nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not uses nothing to represent not used.
We use ZeroTangent for not used, (and NoTangent for not having a meaningful tangent space)

@sethaxen
Copy link
Member

pokes @sethaxen
(I can't really review this)

Yeah, sorry, I took the deep dive studying the various QR parameterizations a few weeks back in prep for reviewing this but haven't had the chance to yet. Sorry for the delay, @rkube!

@sethaxen sethaxen closed this Sep 28, 2021
@sethaxen sethaxen reopened this Sep 28, 2021
@sethaxen
Copy link
Member

sethaxen commented Oct 8, 2021

So this is a really tricky set of rules to define, perhaps trickier than any of the other rules we have in ChainRules currently. Here are just a few complications:

  • The signatures for qr are all changing with Julia v1.7 (below I use the 1.7 signatures)
  • qr can produce 4 different types in the standard library, summarized below:
# returns QRCompactWY via LAPACK.geqrt!
qr(A::StridedMatrix{<:BlasFloat}, pivot = NoPivot(); kwargs...)
qr!(A::StridedMatrix{<:BlasFloat}, ::NoPivot; kwargs...)

# returns QR via qrfactUnblocked!
qr(A::AbstractMatrix, pivot = NoPivot())
qr!(A::AbstractMatrix, ::NoPivot)

# returns QRPivoted via qrfactPivotedUnblocked!
qr(A::AbstractMatrix, ::ColumnNorm)
qr!(A::AbstractMatrix, ::ColumnNorm)

# returns SuiteSparse.SPQR.QRSparse
qr(A::SparseMatrixCSC, pivot = NoPivot())
  • None of the QR objects generate the Q matrix. Instead, they represent it in a compact form, where factors contains Householder reflectors in the strict lower trapezoid, and R in the upper trapezoid. Computing rules in terms of these compact elements is challenging, roughly as challenging as implementing the qr functions themselves.
  • Calling .Q on one of these factorizations produces a AbstractQ <: AbstractMatrix object that basically has all of the same fields. The AbstractQ objects are AbstractMatrixes, which means they by default hit all of our AbstractMatrix rules and therefore will end up with AbstractMatrix cotangents unless we write custom rrules for every function one might call on a QR object.
  • The AbstractQ objects are weird. For an nxk matrix A, size(qr(A).Q) == (n, n). However, Q also allowed to be multiplied by matrices with size (k, m). So consider code like the following:
A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,)

This is completely allowed, but note that the cotangent of Q will be ∂Q = ∂y * w' + ∂y * v'. This adds two matrices of size (10, 10) and (10, 5), respectively. This addition will be handled by the AD engine and will error, so it's necessary then to use ProjectTo to padd the (10, 5) matrix with zeros to make it (10, 10), but this is very wasteful when dealing with very tall matrices where one may never use its (10, 10) version.

I don't think we can just address a subset of these complications one-at-a-time. Once we start adding rules, which will override AD systems' default behavior of differentiating through the qr! fallback (for operator-overloading ADs), then we will need to have more rules to make sure all of our rules compose nicely. I need to think more if there's a way that this can be handled without a tremendous amount of complication.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants