Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement QR pullback #306

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Implement QR pullback #306

wants to merge 4 commits into from

Conversation

Kolaru
Copy link

@Kolaru Kolaru commented Nov 15, 2020

Implement the pullback for the QR decomposition, following:

Walter and Lehmann, 2018, Algorithmic Differentiation of Linear Algebra Functions with Application in Optimum Experimental Design

Comment on lines 274 to 592
if size(F.R, 2) != size(F.R, 1)
throw(ArgumentError("Pullback for QR decomposition is only supported for m × n matrices with m >= n"))
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because I have not found a reference for that case.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Section 3.2 of https://arxiv.org/pdf/2009.10071.pdf gives an rrule for qr for wide matrices (m<n). I haven't tested it, though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look into adding that.


# Explicitely convert to Matrix since FiniteDifferences seem to
# be broken for LinearAlgebra.QRCompactWYQ (infinite to_vec
# recursion)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can an issue openned for this an linked here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


F, dX_pullback = rrule(qr, X)
for p in [:Q, :R]
Y, dF_pullback = rrule(getproperty, F, p)
Copy link
Member

@oxinabox oxinabox Nov 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does rrule_test not work on this?
If it doesn't an issue needs to be openned on ChainRulesTestUtils.jl

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It does not work on getproperty because of the failure of FiniteDifferences.jl mentionned elsewhere on this PR.

It does not work directly on qr either, because the object returned by qr can not be collected. See JuliaDiff/ChainRulesTestUtils.jl#74

@@ -138,4 +138,40 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test chol_unblocked_rev(B̄, B, true) ≈ chol_blocked_rev(B̄, B, 10, true)
end
end
@testset "qr" begin
@testset "the thing" begin
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just reused the same structure as the test for cholesky. I can give it a more explicit name.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please do.

Comment on lines 287 to 292
C = Composite{T}
∂F = if x === :Q
C(Q=Ȳ,)
elseif x === :R
C(R=Ȳ,)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it worth defining C ?

Suggested change
C = Composite{T}
∂F = if x === :Q
C(Q=Ȳ,)
elseif x === :R
C(R=Ȳ,)
end
∂F = if x === :Q
Composite{T}(Q=Ȳ,)
elseif x === :R
Composite{T}(R=Ȳ,)
end

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did it analog to the svd case. I think I could go even simpler with

∂F = Composite{T}(; x => Ȳ)

Should I port that kind of cleanup to svd and cholesky as well?

@@ -27,3 +27,17 @@ function _eyesubx!(X::AbstractMatrix)
end

_extract_imag(x) = complex(0, imag(x))

# Lower triangle of X - X' overwrite X if possible
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to always overwrite?

if x === :Q
# Return thing Q for consistency
n = size(F.R, 1)
return F.Q[:, 1:n], getproperty_qr_pullback
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this on to do to the primal value?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was because the anything else does not contribute, and the tests used the result of this to infer dimensions. However after pondering on it I figured out it is probably better to let the getproperty pullback do just the obvious and have the qr pullback make sure to only take in account what is relevant and make sure the dimensions match.

return NO_FIELDS, ∂F, DoesNotExist()
end
if x === :Q
# Return thing Q for consistency
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was very confused

Suggested change
# Return thing Q for consistency
# Return thin Q for consistency

@sethaxen
Copy link
Member

I wonder if the QR rule implied by Seeger et al in https://arxiv.org/pdf/1710.08717.pdf is more performant than the one in Walter and Lehmann? (they actually define an LQ rule, but the same approach produces a QR rule). The below reimplementation of this PR's qr_rev using this rule seems to outperform qr_rev in a simple benchmark while yielding the same result.

function qr_rev2(QR_::ChainRules.QR_TYPE, Q̄, R̄)
    Q, R = QR_
    Q = Matrix(Q)
    Q̄ =isa Zero ?: @view Q̄[:, axes(Q, 2)]
    V =*R' - Q'*Q̄
    Ā = (Q̄ + Q * Hermitian(V)) / R'
    returnend

julia> A = randn(4, 4);

julia> F = qr(A);

julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));

julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
  16.374 μs (58 allocations: 7.11 KiB)

julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
  4.323 μs (12 allocations: 2.27 KiB)

julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R)  qr_rev2(F, ΔF.Q, ΔF.R)
true

julia> A = randn(10, 4);

julia> F = qr(A);

julia> ΔF = Composite{typeof(F)}(Q = randn(eltype(F.Q), size(Matrix(F.Q))), R = randn(eltype(F.R), size(F.R)));

julia> @btime ChainRules.qr_rev($F, $(ΔF.Q), $(ΔF.R));
  56.299 μs (134 allocations: 20.86 KiB)

julia> @btime qr_rev2($F, $(ΔF.Q), $(ΔF.R));
  5.112 μs (12 allocations: 3.39 KiB)

julia> ChainRules.qr_rev(F, ΔF.Q, ΔF.R)  qr_rev2(F, ΔF.Q, ΔF.R)
true

Or is the rule you have implemented expected to be more numerically stable?

@Kolaru
Copy link
Author

Kolaru commented Nov 22, 2020

Or is the rule you have implemented expected to be more numerically stable?

I have no idea. The two references we are using do not directly compare each other, and I do not know how to determine this myself.

@sethaxen
Copy link
Member

sethaxen commented Feb 4, 2021

Or is the rule you have implemented expected to be more numerically stable?

I have no idea. The two references we are using do not directly compare each other, and I do not know how to determine this myself.

The article I linked in #306 (comment) (https://arxiv.org/pdf/2009.10071.pdf), which covers wide and tall matrices as well, also uses the simpler rule from the Seeger et al paper. While I didn't implement their rules for wide and tall matrices, I ended up using a similar approach for the LU decomposition of wide and tall matrices in #354. For these reasons, I'm thinking the Seeger approach is preferable.

@Kolaru Kolaru marked this pull request as draft June 29, 2021 16:05
@Kolaru
Copy link
Author

Kolaru commented Jul 22, 2021

I finally had time to come back to this, and it has been kind of a nightmare, because QR decompositions are represented in a weird way that do not play nicely with the tests and the comparison with FiniteDifferences.

After quite a lot of experimentations, I gave up on trying to make everything work with the default type returned by qr. Instead I define a new custom type ExplicitQR that stores the Q and R matrices explicitely. I use this struct for the tests, and compare the end result with the one from the qr method to ensure the latter is correct as well.

I hope this is sufficient. Otherwise I must admit I am out of idea about what should be done to test qr directly.

I am aware of #469 that has a somewhat different approach. I am not currently sure which is better.

Alos I implemented the algorithm suggested by @sethaxen.

So provided my way of testing is okay, this should be ready.

@Kolaru Kolaru marked this pull request as ready for review July 22, 2021 02:13
@oxinabox oxinabox requested a review from sethaxen July 22, 2021 10:28
@sethaxen
Copy link
Member

Thanks, @Kolaru! I'll review #469 first, then this, then I'll recommend how to proceed.

@sethaxen
Copy link
Member

sethaxen commented Oct 8, 2021

What I like about this approach is that it completely sidesteps much of the complexity of the objects returned by the qr methods, as described in #469 (comment).

What I don't like is that the object being returned by the rrule is completely different from the one the user requested, which will cause some code that worked before to suddenly fail. For example, this code is totally valid with the qr return values in the std library:

A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)
y = Q*w + Q*v
@assert size(y) == (10,)

This works because AbstractQ objects are treated as thin or full depending on what they're multiplied by, but if Q is dense, then this no longer works.

@Kolaru
Copy link
Author

Kolaru commented Oct 9, 2021

I tried the following

using ChainRules: rrule, ExplicitQR
using LinearAlgebra

A = randn(10, 5)
Q, _ = qr(A)
v = randn(5)
w = randn(10)

F, F_pullback = rrule(qr, A)
Q, Q_pullback = rrule(getproperty, F, :Q)
y1, y1_pullback = rrule(*, Q, v)
y2, y2_pullback = rrule(*, Q, w)

ȳ1 = rand(10)
_, Q̄1 = y1_pullback(ȳ1)
_, F̄1 = Q_pullback(Q̄1)
_, Ā1 = F_pullback(F̄1) 

ȳ2 = rand(10)
_, Q̄2 = y2_pullback(ȳ2)
_, F̄2 = Q_pullback(Q̄2)
_, Ā2 = F_pullback(F̄2)

and everything seems to be fine (i.e. nothing error, I haven't tested correctness). Is this the correct way to test your point? Whether this PR or #469 is used, adding proper test for quirks of Q seems unavoidable.

Note that the ExplicitQR object I define is only used in the test (and maybe its definition should be moved to the tests for clarity).

As far as I understand, ChainRules should be able to handle the QR objects properly. The problem I had was to create the tangent objects for the tests. The main issues is that we are taking derivative with respect to fields that are not stored in the object directly[*]. After some tries, I just gave up and sidestepped the whole testing issue, hoping this is still enough to ensure correctness.

[*] The most common error I had was:

ArgumentError: Tangent fields do not match primal fields.
Tangent fields: (:Q,). Primal (LinearAlgebra.QRCompactWY{Float64, Matrix{Float64}}) fields: (:factors, :T)

Using a custom struct with explicit fields Q and R naturally solved it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants