Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from Aug 24, 2020
Merged

Conversation

yiyuezhuo
Copy link
Contributor

Fix #108

Implement LKJ bijector (transform), following Stan's spec.

Since #124, I guess the CI will not pass. Though all PR related tests should pass.

Implement #108 example:

Setup:

using Zygote
using ForwardDiff
using ReverseDiff
using Tracker

using Distributions
using DistributionsAD

using Turing

using Bijectors
using LinearAlgebra
using PDMats
using Serialization

sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = Diagonal(sigma) * Omega * Diagonal(sigma)
N = 100
J = 3
# y = rand(MvNormal(zeros(J), Sigma), N)'
y = deserialize("y.ser") # Same value is compared with CmdStan

Model definition:

@model correlation(J, N, y, Zero0) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    _Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
    # Sigma = PDMat(_Sigma) # PDMat requires extra AD support
    Sigma = _Sigma

    for i in 1:N
        y[i,:] ~ MvNormal(Zero0, Sigma) # sampling distribution of the observations
    end
    return Sigma
end

Since truncated(Cauchy(0., 5.), 0., Inf) is a weakly informative prior, which sometimes gives a too large value which breaks sampling, specifying initial value is useful.

model = correlation(J, N, y, zeros(J))
varinfo = Turing.VarInfo(model)
model(varinfo, Turing.SampleFromPrior(), Turing.PriorContext((
    sigma=ones(3),
    Omega=Float64.(collect(I(3))) # unit matrix
)))
init_theta = varinfo[Turing.SampleFromPrior()]
setadbackend(:forwarddiff)
sample(correlation(J, N, y, zeros(J)), HMC(0.01, 5), 100, init_theta = init_theta)

#=
┌ Info: Using passed-in initial variable values
│   init_theta = [1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]
└ @ Turing.Inference /home/yiyuezhuo/.julia/packages/Turing/p8FFd/src/inference/Inference.jl:282
Sampling: 100%|█████████████████████████████████████████| Time: 0:00:00
Chains MCMC chain (100×21×1 Array{Float64,3}):

Iterations        = 1:100
Thinning interval = 1
Chains            = 1
Samples per chain = 100
parameters        = Omega[1,1], Omega[1,2], Omega[1,3], Omega[2,1], Omega[2,2], Omega[2,3], Omega[3,1], Omega[3,2], Omega[3,3], sigma[1], sigma[2], sigma[3]
internals         = acceptance_rate, hamiltonian_energy, hamiltonian_energy_error, is_accept, log_density, lp, n_steps, nom_step_size, step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse       ess      rhat  
      Symbol   Float64   Float64    Float64   Missing   Float64   Float64  
                                                                           
  Omega[1,1]    1.0000    0.0000     0.0000   missing       NaN       NaN  
  Omega[1,2]    0.2779    0.0964     0.0096   missing   10.4459    0.9932  
  Omega[1,3]    0.0572    0.0809     0.0081   missing   14.2873    0.9924  
  Omega[2,1]    0.2779    0.0964     0.0096   missing   10.4459    0.9932  
  Omega[2,2]    1.0000    0.0000     0.0000   missing   87.0968    0.9899  
  Omega[2,3]    0.1385    0.0858     0.0086   missing   13.3596    1.0244  
  Omega[3,1]    0.0572    0.0809     0.0081   missing   14.2873    0.9924  
  Omega[3,2]    0.1385    0.0858     0.0086   missing   13.3596    1.0244  
  Omega[3,3]    1.0000    0.0000     0.0000   missing   75.9334    0.9899  
    sigma[1]    1.0294    0.0934     0.0093   missing   10.0810    0.9911  
    sigma[2]    2.1043    0.1837     0.0184   missing   14.3855    1.0325  
    sigma[3]    2.8150    0.2709     0.0271   missing    9.9845    0.9955  

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%  
      Symbol   Float64   Float64   Float64   Float64   Float64  
                                                                
  Omega[1,1]    1.0000    1.0000    1.0000    1.0000    1.0000  
  Omega[1,2]    0.0563    0.2390    0.2869    0.3380    0.4501  
  Omega[1,3]   -0.0743    0.0020    0.0507    0.1134    0.2157  
  Omega[2,1]    0.0563    0.2390    0.2869    0.3380    0.4501  
  Omega[2,2]    1.0000    1.0000    1.0000    1.0000    1.0000  
  Omega[2,3]   -0.0018    0.0721    0.1273    0.2092    0.2812  
  Omega[3,1]   -0.0743    0.0020    0.0507    0.1134    0.2157  
  Omega[3,2]   -0.0018    0.0721    0.1273    0.2092    0.2812  
  Omega[3,3]    1.0000    1.0000    1.0000    1.0000    1.0000  
    sigma[1]    0.8667    0.9653    1.0302    1.0917    1.1970  
    sigma[2]    1.7321    1.9863    2.0986    2.2369    2.4306  
    sigma[3]    2.3502    2.6356    2.8207    2.9906    3.2872  
=#

While tests passed for ForwardDiff, ReverseDiff, Zygote, Tracker. Above model will not work on Tracker backend.

NUTS sampler is not that robust for this model, because errors such as non-positive-definite will break whole sampling rather than just reject current step like Stan, so following adaption is proposed, according to TuringLang/Turing.jl#702 .

@model correlation(J, N, y, Zero) = begin
    
    if :Omega in fieldnames(typeof(_varinfo.metadata))
        if :vals in fieldnames(typeof(_varinfo.metadata.Omega))
            if :value in fieldnames(typeof(_varinfo.metadata.Omega.vals[1]))
                if det(Bijectors.inv_link_lkj(reshape(map(x->x.value, _varinfo.metadata.Omega.vals), 3, 3)))<=0
                    println("reject 1!")
                    Turing.acclogp!(_varinfo, -Inf)
                    return
                end
            end
        end
    end
    
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    _Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
    if !isposdef(_Sigma)
        println("reject 2!")
        Turing.acclogp!(_varinfo, -Inf)
        return
    end
    Sigma = PDMat(_Sigma)

    for i in 1:N
        y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
    end
    return Sigma
end

Above code will only work for ForwardDiff, ReverseDiff backends.

@yebai yebai requested a review from torfjelde August 7, 2020 21:49
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really great work @yiyuezhuo ! Thank you so much!

I've added quite a few comments, mostly suggested style-changes, but also some questions so I can understand what's going on.

I haven't looked at correctness of impl; I'll do that next 👍

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
# https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
# (7/30/2020) their "manageable expression" is wrong...

function upper1(AT, A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make use of UpperTriangular (which uses a view) instead of copying the matrix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
yiyuezhuo and others added 18 commits August 9, 2020 08:25
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: Tor Erlend Fjelde <tor.github@gmail.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
@yiyuezhuo
Copy link
Contributor Author

Something seems broken 🤔, I will push the remaining modification and add a IJulia gist to show correctness later.

Comment on lines 15 to 17
return r + zero(x) # keep LowerTriangular until here can avoid some computation
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is a quirk of the current implementation, of which it would be nice to be rid.
already states that we actually don't want this? What happens if your remove the + zero(x)?
This hack is really really ugly and unintuitive - e.g., if we would an Array in the end (I'm questioning that) it would be much more natural to call just Array(r) instead of some weird addition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
Expression: typeof(x) == typeof(y)
Evaluated: Array{Float64,2} == UpperTriangular{Float64,Array{Float64,2}}
Stacktrace:
[1] single_sample_tests(::LKJ{Float64,Int64}) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:203
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
Expression: typeof(xs) == typeof(ys)
Evaluated: Array{Array{Float64,2},1} == Array{UpperTriangular{Float64,Array{Float64,2}},1}
Stacktrace:
[1] multi_sample_tests(::LKJ{Float64,Int64}, ::Array{Float64,2}, ::Array{Array{Float64,2},1}, ::Int64) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:220
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
Test Summary: | Pass Fail Total
correlation matrix | 9 2 11
ERROR: LoadError: LoadError: Some tests did not pass: 9 passed, 2 failed, 0 errored, 0 broken.
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:199
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/runtests.jl:22
ERROR: Package Bijectors errored during testing

Copy link
Member

@devmotion devmotion Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this test exists at all. My suggestion would be to just remove (or comment out) the test. But I guess @torfjelde might know why the test exists?

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
@yiyuezhuo
Copy link
Contributor Author

I can see some benefits to asking a dense matrix now. Type based AD such as ReverseDiff and Tracker face the problem of how to place their TrackedArray tag. It's not obvious, as it can be shown by inserting a @show into following code to inspect what type is used:

function (b::CorrBijector)(x::AbstractMatrix{<:Real})    
    w = cholesky(x).U
    @show typeof(w)
    r = _link_chol_lkj(w) 

ForwardDiff:

typeof(w) = UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{typeof(b_f),Float64},Float64,9},Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(b_f),Float64},Float64,9},2}}

Zygote:

typeof(w) = UpperTriangular{Float64,Array{Float64,2}}

ReverseDiff:

typeof(w) = UpperTriangular{ReverseDiff.TrackedReal{Float64,Float64,ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}},ReverseDiff.TrackedArray{Float64,Float64,2,Array{Float64,2},Array{Float64,2}}}

Tracker:

typeof(w) = TrackedArray{,UpperTriangular{Float64,Array{Float64,2}}}

So ReverseDiff and Tracker use different orders.

In fact, ReverseDiff's type can not be matched by ::TrackedMatrix, so a hidden Bug in current code is not detected. The complexity of UpperTriangular and TrackedArray (or Adjoint for transpose) order can be easily removed by converting them to dense matrix, by +zero(x) like code.

@yiyuezhuo
Copy link
Contributor Author

While + zero(x) is ugly hack, is PDBijector way better?

# pd.jl
(b::PDBijector)(X::AbstractMatrix{<:Real}) = pd_link(X)
function pd_link(X)
    Y = lower(parent(cholesky(X; check = true).L))
    return replace_diag(log, Y)
end

lower(A::AbstractMatrix) = convert(typeof(A), LowerTriangular(A))

# zygote.jl
@adjoint function pd_link(X::AbstractMatrix{<:Real})
    return pullback(X) do X
        Y = cholesky(X; check = true).L
        return replace_diag(log, Y)
    end
end

@adjoint function lower(A::AbstractMatrix)
    return lower(A), Δ -> (lower(Δ),)
end

# tracker.jl
lower(A::TrackedMatrix) = track(lower, A)
@grad function lower(A::AbstractMatrix)
    Ad = data(A)
    return lower(Ad), Δ -> (lower(Δ),)
end

# reversediff.jl
lower(A::TrackedMatrix) = track(lower, A)
@grad function lower(A::AbstractMatrix)
    Ad = value(A)
    return lower(Ad), Δ -> (lower(Δ),)
end

src/bijectors/corr.jl Outdated Show resolved Hide resolved
@devmotion
Copy link
Member

Type based AD such as ReverseDiff and Tracker face the problem of how to place their TrackedArray tag. It's not obvious, as it can be shown by inserting a @show into following code to inspect what type is used:

I don't think it should matter to us (apart from dispatching, but do we actually dispatch on these types?) how the AD backends handle the tags.

src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
src/compat/zygote.jl Outdated Show resolved Hide resolved
Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have just two minor comments left, IMO we can merge this if they are addressed and tests pass. Could you also bump the version number to 0.8.3 so that we can make a new release (after we fixed the Distribution bounds)?

Great work and thanks a lot for this implementation (and your patience with all of my comments 😄)!

src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/compat/tracker.jl Outdated Show resolved Hide resolved
@yiyuezhuo
Copy link
Contributor Author

I have just two minor comments left, IMO we can merge this if they are addressed and tests pass. Could you also bump the version number to 0.8.3 so that we can make a new release (after we fixed the Distribution bounds)?

Great work and thanks a lot for this implementation (and your patience with all of my comments 😄)!

@devmotion I appreciate your comment and suggestions very much and have learned many Julia best practices and numerical tricks from them. 🎉

@devmotion devmotion merged commit a3ece91 into TuringLang:master Aug 24, 2020
@devmotion
Copy link
Member

Thanks again @yiyuezhuo!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add LKJ Matrix Distribution
3 participants