Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a8c32a2
add a stan style corr bijector
yiyuezhuo Jul 30, 2020
ba1ab40
implement a version which is only compatible with ForwardDiff AD backend
yiyuezhuo Jul 31, 2020
a98ce04
add AD support except Tracker
yiyuezhuo Aug 2, 2020
74c902b
sync
yiyuezhuo Aug 6, 2020
9c1742c
add tests for interface
yiyuezhuo Aug 6, 2020
a1f8ec7
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1c9b4ea
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
a52bdce
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
ead7e71
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
32cab00
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1859672
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1ef50b0
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
5ecaaee
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
bf41efc
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
52e74e4
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
be3dddf
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
476ca0f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
bfd8d4f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
dd0fade
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
38785f5
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
463acdc
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
5bdea37
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
131c502
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
4bbe656
stash
yiyuezhuo Aug 11, 2020
8d2fc60
pass basic tests for all AD
yiyuezhuo Aug 11, 2020
1120763
clean up
yiyuezhuo Aug 11, 2020
eec4de0
add data only AD unit test
yiyuezhuo Aug 11, 2020
d9e0abf
fix type instability
yiyuezhuo Aug 11, 2020
1f24e65
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
09297e9
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
388fee0
Update src/compat/reversediff.jl
yiyuezhuo Aug 11, 2020
72a16a8
Update src/compat/tracker.jl
yiyuezhuo Aug 11, 2020
d2908d5
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
8eb0686
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
9260c61
remove link_lkj and inv_link_lkj
yiyuezhuo Aug 11, 2020
02a7d7e
switch looping order
yiyuezhuo Aug 11, 2020
d0984a4
rename logabsdetjac_lkj_inv and merge z
yiyuezhuo Aug 11, 2020
10c2aed
fix zeros type instability
yiyuezhuo Aug 11, 2020
fc6af4f
add @inbounds
yiyuezhuo Aug 11, 2020
5ee5826
replace log(1 - tanh(y[...])^2) with numerically stable implementation
yiyuezhuo Aug 13, 2020
917605f
removed logabsdetjac_lkj_inv
yiyuezhuo Aug 13, 2020
32497a1
rename two functions used in corr
yiyuezhuo Aug 14, 2020
c82e57c
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 18, 2020
c13d8a2
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
01250ce
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
de86456
Apply suggestions from code review
yiyuezhuo Aug 21, 2020
5238efe
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
3a01d90
add broken test
yiyuezhuo Aug 21, 2020
c86bcdf
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 21, 2020
79cbc5e
apply some suggestions
yiyuezhuo Aug 21, 2020
01db0f4
avoid some zero filling
yiyuezhuo Aug 21, 2020
8cfc22d
move dense transform as last as possible, add reversediff dense dense…
yiyuezhuo Aug 21, 2020
9c9846b
furthur remove extra zero filling and f(0) leveraging UpperTriangular…
yiyuezhuo Aug 21, 2020
9746499
forward special speedup, drop copy-paste maintain ability
yiyuezhuo Aug 21, 2020
7fb2675
Apply suggestions from code review
yiyuezhuo Aug 22, 2020
bad759b
optimize forwarddiff and remove custom AD for reversediff
yiyuezhuo Aug 22, 2020
7710187
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 22, 2020
5228310
fix suggestion mistake and clean up
yiyuezhuo Aug 22, 2020
ea71cd7
small optimizing and clean up
yiyuezhuo Aug 22, 2020
a577b05
add re-derived analytic gradient
yiyuezhuo Aug 23, 2020
50b837d
simplify code (though make code hard to read)
yiyuezhuo Aug 23, 2020
b04252b
add @bounds
yiyuezhuo Aug 23, 2020
6c525a7
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 24, 2020
7e46a09
bump version number to 0.8.3 and modify test and improve document
yiyuezhuo Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
131 changes: 131 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
CorrBijector <: Bijector{2}

A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html

Note:(7/30/2020) their "manageable expression" is wrong, used expression is derived from
scratch.
devmotion marked this conversation as resolved.
Show resolved Hide resolved
"""
struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U
r = _link_chol_lkj(w)
return r + zero(x) # keep LowerTriangular until here can avoid some computation
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is a quirk of the current implementation, of which it would be nice to be rid.
already states that we actually don't want this? What happens if your remove the + zero(x)?
This hack is really really ugly and unintuitive - e.g., if we would an Array in the end (I'm questioning that) it would be much more natural to call just Array(r) instead of some weird addition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
Expression: typeof(x) == typeof(y)
Evaluated: Array{Float64,2} == UpperTriangular{Float64,Array{Float64,2}}
Stacktrace:
[1] single_sample_tests(::LKJ{Float64,Int64}) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:203
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
Expression: typeof(xs) == typeof(ys)
Evaluated: Array{Array{Float64,2},1} == Array{UpperTriangular{Float64,Array{Float64,2}},1}
Stacktrace:
[1] multi_sample_tests(::LKJ{Float64,Int64}, ::Array{Float64,2}, ::Array{Array{Float64,2},1}, ::Int64) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:220
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
Test Summary: | Pass Fail Total
correlation matrix | 9 2 11
ERROR: LoadError: LoadError: Some tests did not pass: 9 passed, 2 failed, 0 errored, 0 broken.
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:199
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/runtests.jl:22
ERROR: Package Bijectors errored during testing

Copy link
Member

@devmotion devmotion Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this test exists at all. My suggestion would be to just remove (or comment out) the test. But I guess @torfjelde might know why the test exists?

end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
end

return result
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
It's recommended to directly call
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inv(b), (b(X)))
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
logabsdetjac(b, x)
end
end


function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)

w = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
z = tanh(y[i-1, j])
w[i, j] = w[i-1, j] * sqrt(1 - z^2)
w[i-1, j] *= z
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

return w
end

"""
function _link_chol_lkj(w)

Link function for cholesky factor.

An alternative and maybe more efficient implementation was considered:

```
for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
```

But this implementation will not work when w[i-1, j] = 0.
Though it is a zero measure set, unit matrix initialization will not work.

For equivelence, following explanations is given by @torfjelde:

For `(i, j)` in the loop below, we define

z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))

and so

z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))

which is the above implementation.
"""
function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = similar(w) # z is also UpperTriangular except ReverseDiff.
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = w[1, j]
z[j, j] = 0
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

y = atanh.(z)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
return y
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
118 changes: 117 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,119 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestions above also apply to the custom adjoints.

y = value(y_tracked)

LinearAlgebra.checksquare(y)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

# zero(x) is used to convert w to dense matrix, because ReverseDiff doesn't support track
# LowerTriangular matrix (concretely, Base.IndexStyle(::A) != Base.IndexLinear())
# See: http://www.juliadiff.org/ReverseDiff.jl/limits/
_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w + zero(x))
@grad function _link_chol_lkj(w_tracked)
w = value(w_tracked)

LinearAlgebra.checksquare(w)

K = size(w, 1)
z = similar(w)

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds for i=1:K
z[i, 1] = 0
end

@inbounds for j=2:K
z[1, j] = w[1, j]
for i=j:K
z[i, j] = 0
end
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
end

y = atanh.(z)

function pull_link_chol_lkj(Δy)
LinearAlgebra.checksquare(Δy)

zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

return y, pull_link_chol_lkj

end

end
106 changes: 106 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,109 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

LinearAlgebra.checksquare(y)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

LinearAlgebra.checksquare(w)

K = size(w, 1)
z = similar(w)

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = w[1, j]
z[j, j] = 0
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

y = atanh.(z)

function pullback_link_chol_lkj(Δy)
LinearAlgebra.checksquare(Δy)

zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

return y, pullback_link_chol_lkj
end