Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use AbstractVector in LKJ and LKJCholesky bijectors #253

Merged
merged 75 commits into from
Jun 6, 2023

Conversation

harisorgn
Copy link
Member

@harisorgn harisorgn commented Apr 6, 2023

Expands on #246 .

Use ::AbstractVector in VecCorrBijector operations, so we won't need to transform to ::AbstractMatrix and back.

Add bijector for LKJCholesky. I believe this was missing and in practice it is the more efficient alternative when working with correlation matrices (avoids Cholesky decompositions on every call).
In LKJCholesky there is control over the returned factor ('U' -> UpperTriangular or 'L' -> LowerTriangular). I was wondering whether we want to respect the factor choice and always return the same triangular factor. If yes, we can use VecTriuBijector and VecTrilBijector to retain information about the original factor in LKJCholesky and return it. If no, we can always work with one type, e.g. UpperTriangular.

TO DO :

  • Add ChainRulesCore.rrules for all link functions that work on ::AbstractVector, defined in this PR. I have only added one rule for the forward link function, but ChainRulesTestUtils.test_rrule complains about type instability and value mismatch. When comparing the values returned by the pullback inside the closure of rrule against the one defined for Zygote I'm getting the same output though. I will have more of a look next week.
  • Document how I ended up with _logabsdetjac_inv_chol, so it can be verified. This was based on the Stan manual pages for correlation matrices and Cholesky factors of correlation matrices.
    EDIT 2: I have not documented the formula derivation but added a test for it that passes.
  • Remove this dispatch
    function _link_chol_lkj(W::LowerTriangular)

    and use transpose(W::UpperTriangular)
    .

Related to the second point, right above, in general it would be nice if we could test these analytical formulas for logabsdetjac derived by hand. I played around with it a bit, but couldn't come up with something.
EDIT : This can be done using AD. I see there is something already implemented along these lines in test/transform.jl, just needs some tweaking.

cc @torfjelde if you want to have a look already

to `Matrix` row index
to `_logabsdetjac_inv_corr`
[WIP] LKJ and LKJCholesky bijectors
@harisorgn
Copy link
Member Author

Have another look and If you think all issues were addressed, we can merge.

Actually let's wait, as the AD tests I've just added on the roundtrip transformation fail, will have a look.

@harisorgn
Copy link
Member Author

harisorgn commented Apr 19, 2023

Looked more into test_ad using the roundtrip inverse-and-then-forward transformation :

Test is failing on cholesky(Matrix{ForwardDiff.Dual...}) . I see there is no frule defined for cholesky, not sure if something ForwardDiff specific exists elsewhere.
EDIT: ForwardDiff is not officially using ChainRules to define rules, but also could not find a forward rule for cholesky in ForwardDiffChainRules.

Unsure about the Tracker error. Is the plan to keep supporting Tracker in general?

Zygote is returning nothing gradients, that's its "hard zero" IIUC. Not sure if it has to do with my usage of getproperty on Cholesky, UpperTriangular and LowerTriangular types.

ReverseDiff is passing all tests after fixing the rule for pd_from_upper.

All these tests are for AD through transform which IIUC are not relevant for Turing.jl usage. Since this PR is adding on the previous one that aims to get LKJ priors working for Turing, it might be worth merging and tackling the remaining AD issues in a future PR?

@torfjelde
Copy link
Member

Regarding the test-failures, it's a bit strange.

This seems related: JuliaDiff/ForwardDiff.jl#606.

But I thought this was fixed because they pulled 0.10.33 after the discussion there, and deferred the breaking changes to 0.11. The tests are running on 0.10.35 so I don't get why we're seeing this 😕

Think this needs a bit further inspection.

And we do actually need to AD through the transform in some places, e.g. ADVI.

(Btw, I'm not done with my review, will continue later)

@harisorgn
Copy link
Member Author

harisorgn commented Apr 25, 2023

Good points @torfjelde , thanks! Here's more about the AD issues :

Tracker

From discussions elsewhere (Slack) I understand that we agree to drop support for this.

ForwardDiff

This seems related: JuliaDiff/ForwardDiff.jl#606.

It might actually not be. It seems like a numerical issue when comparing values in ishermitian.

I found two samples from the same LKJ where one passes and one fails. MWE :

using Bijectors, DistributionsAD, LinearAlgebra
using Bijectors: VecCorrBijector
using ForwardDiff
using ForwardDiff: Dual

b = VecCorrBijector('C') # bijector(LKJ(5,1))
binv = inverse(b)

f = x -> sum(b(binv(x)))

# x_f ~ LKJ(5,1)
x_f = [
    1.0  0.38808945715615550398  0.55251148082365042491   0.06333711952583508109  -0.51630779311225594164
    0.38808945715615550398  1.0  0.31760367441586356829   0.34585990227668395036   0.06051504059466897290
    0.55251148082365042491  0.31760367441586356829   1.0   0.17416714618194936715  -0.02825518349677474950
    0.06333711952583508109  0.34585990227668395036   0.17416714618194936715   1.0  -0.07513830680477201485
    -0.51630779311225594164  0.06051504059466897290  -0.02825518349677474950   -0.07513830680477201485   1.0
]
df_f = ForwardDiff.gradient(f, b(x_f)) # Errors, ishermitian returns false


# x_s ~ LKJ(5,1)
x_s = [
    1.0  -0.01569213125090618277 -0.79039374741027101923  -0.03400980954333766848   0.54371128016847525277
    -0.01569213125090618277   1.0 -0.19877390203937703173   -0.37124942960738860354  -0.39209191569764001439
    -0.79039374741027101923  -0.19877390203937703173   1.0   0.03430683023840974677  -0.62744676631878926187
    -0.03400980954333766848  -0.37124942960738860354   0.03430683023840974677   1.0   0.50841756191547016197
    0.54371128016847525277  -0.39209191569764001439    -0.62744676631878926187  0.50841756191547016197   1.0
]
df_s = ForwardDiff.gradient(f, b(x_s)) # Runs, ishermitian returns true

# Let's see where x_f fails
function ish(A::AbstractMatrix)
    # Just a copy of ishermitian with a `@show`
    indsm, indsn = axes(A)
    if indsm != indsn
        return false
    end
    for i = indsn, j = i:last(indsn)
        if A[i,j] != adjoint(A[j,i])
            @show abs(A[i,j] - adjoint(A[j,i]))
            return false
        end
   end
    return true
end

y_f = b(x_f)
ish(binv(Dual.(y_f))) # Returns false, shows abs(A[i, j] - adjoint(A[j, i])) = Dual{Nothing}(2.0816681711721685e-17)

# Without using `Dual`s though, all is good
ish(binv(y_f)) # Returns true

So ishermitian fails because of a very small difference between a single pair of adjoint elements. This is consistent across other samples from LKJ(5,1). Shall we remove the ishermitian check altogether? Not sure how safe that is, but by trying out this uniform LKJ(5,1) over correlation matrices, all I get is tiny errors with Duals like in the example.

EDIT: Tried using cholesky(x; check = false) but the gradients for these problematic samples are way off (1e-1), even if the matrices are not hermitian by very little (1e-17).

Zygote

This indeed has to do with getproperty(::Cholesky, :UL). In ChainRules there is an rrule defined for getproperty(::Cholesky, ::Symbol) that only accounts for the cases of :U and :L. So we have:

using Bijectors, DistributionsAD, LinearAlgebra
using Zygote

dist = LKJ(5, 1)
x = rand(dist)

g = x -> sum(cholesky(x).U)
dg = Zygote.gradient(g, x) # Returns correct gradient

h = x -> sum(cholesky(x).UL)
dh = Zygote.gradient(h, x) # Returns (nothing, )

So Zygote can be fixed by changing

cholesky_factor(X::Cholesky) = X.UL

to X.U , take the potential extra allocation (if uplo === :L) and always work with UpperTriangular downstream. Using PDMats.chol_upper as suggested here results in the same issue by accessing getproperty(::Cholesky, :factors).

Any thoughts on how to handle the ForwardDiff and Zygote cases? I think the Zygote changes are more straightforward unless I'm missing something.

@harisorgn
Copy link
Member Author

I think the Zygote changes are more straightforward unless I'm missing something.

It is for the case of LKJ (changing X.UL to X.U works) but not for LKJCholesky. In the latter case, we construct a Cholesky during the inverse transform, as this is the support of the distribution. I'm guessing the Cholesky constructor needs an rrule for Zygote to work.

@harisorgn
Copy link
Member Author

Zygote is fixed. It was more straightforward than writing new rrules, just passing a X::Matrix instead of X::UpperTriangular or X::LowerTriangular to Cholesky and avoid doing X.data.

ForwardDiff passed twice on the latest commit, but I changed nothing to fix it. Probably has to do with the stochastic nature of the numerical error, like the example above.

@harisorgn
Copy link
Member Author

I restarted the Inference tests multiple times and the ForwardDiff test passes (only fails are from Tracker not being broken). I can't recreate this locally, I still get some fails and passes like the example above, and have matched package versions, so I'm confused 😅

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Almost there! But I think the cholesky-version should just be its own struct so we avoid the type-stabilities.

Otherwise it's looking pretty dank!

And I'll have a look at the ForwardDiff issue.


# Fields
- mode :`Symbol`. Controls the inverse tranformation :
- if `mode === :C` returns a correlation matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? I'm personally happy to just support U or L.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is, make the cholesky version into a separate type, e.g. VecCholCorrBijector. This will avoid the type-instabilities + moves the conditional handling you have in some functions into multiple dispatch instead.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@harisorgn Any updates on this?:)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was completely off last week. Agree with splitting/specialising the structures, will implement it this week!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, no worries! Sweet!

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
Comment on lines 318 to 321
return UpperTriangular(X)' * UpperTriangular(X), Δ -> begin
Xu = UpperTriangular(X)
return ChainRulesCore.NoTangent(), UpperTriangular(Xu * Δ + Xu * Δ')
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, maybe add a rrule test? That would have caught the missing unthunk.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thing is I was testing the rrules locally as I was adding them and this was passing. Probably unthunking would be needed if it's part of multiple function calls that get differentiated? I am adding it anyway.

src/compat/zygote.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
test/bijectors/utils.jl Outdated Show resolved Hide resolved
@harisorgn
Copy link
Member Author

harisorgn commented Jun 1, 2023

@torfjelde , I implemented your suggestions, thanks for the feedback again : )

I couldn't locally reproduce the DomainError that comes up in the AD test.

Also disregard my previous confusion about reproducing the ForwardDiff numerical error. I was restarting an interface test that wasn't hitting it, hence it was passing. When the right interface test of CI was run, test failed as it fails locally (see comments above).

So there are still these two errors, plus the stack related one that is addressed in another PR here.

(Apologies for the format, only have phone access for now)

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great stuff @harisorgn :) Really close now!

I had a super-quick look, and made some very minor comments + changes. Once those are addressed, I think we should be good go!

Again, awesome work; I imagine this isn't the most fun PR to work on, so appreciate you seeing this through ❤️

src/bijectors/corr.jl Show resolved Hide resolved
@@ -182,7 +188,23 @@ end

upperinds = [LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]]
J = ForwardDiff.jacobian(x->link(dist, x), x)
J = J[upperinds, upperinds]
J = J[:, upperinds]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What was this for again? Sorry, we might have discussed this before.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we have : ) . It's because the output of dist is an AbstractVector now, so the indices of upper triagular elements don't apply anymore. In this test, x is 3x3 matrix, link(dist, x) is length 3 vector, the Jacobian is then a 3x9 matrix, and we are keeping all output elements (as they are all relevant now) and only the upperinds of the input elements.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aaah yeah now I remember:) We didn't discuss this but I had a think through it myself the last time I looked at it 😅 Just had a vague memory of at some point being befuddled about it and then figuring it out.

test/transform.jl Show resolved Hide resolved
@harisorgn harisorgn merged commit a3c7f57 into TuringLang:torfjelde/vec-corr Jun 6, 2023
3 of 21 checks passed
@harisorgn
Copy link
Member Author

@torfjelde accidental merge, sorry, was setting up git in a new machine 😅 . Please revert it and I'll implement the last changes.

@torfjelde
Copy link
Member

Is it maybe easier if you just take over the other PR?:) #246

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants