Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a8c32a2
add a stan style corr bijector
yiyuezhuo Jul 30, 2020
ba1ab40
implement a version which is only compatible with ForwardDiff AD backend
yiyuezhuo Jul 31, 2020
a98ce04
add AD support except Tracker
yiyuezhuo Aug 2, 2020
74c902b
sync
yiyuezhuo Aug 6, 2020
9c1742c
add tests for interface
yiyuezhuo Aug 6, 2020
a1f8ec7
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1c9b4ea
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
a52bdce
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
ead7e71
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
32cab00
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1859672
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1ef50b0
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
5ecaaee
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
bf41efc
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
52e74e4
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
be3dddf
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
476ca0f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
bfd8d4f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
dd0fade
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
38785f5
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
463acdc
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
5bdea37
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
131c502
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
4bbe656
stash
yiyuezhuo Aug 11, 2020
8d2fc60
pass basic tests for all AD
yiyuezhuo Aug 11, 2020
1120763
clean up
yiyuezhuo Aug 11, 2020
eec4de0
add data only AD unit test
yiyuezhuo Aug 11, 2020
d9e0abf
fix type instability
yiyuezhuo Aug 11, 2020
1f24e65
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
09297e9
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
388fee0
Update src/compat/reversediff.jl
yiyuezhuo Aug 11, 2020
72a16a8
Update src/compat/tracker.jl
yiyuezhuo Aug 11, 2020
d2908d5
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
8eb0686
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
9260c61
remove link_lkj and inv_link_lkj
yiyuezhuo Aug 11, 2020
02a7d7e
switch looping order
yiyuezhuo Aug 11, 2020
d0984a4
rename logabsdetjac_lkj_inv and merge z
yiyuezhuo Aug 11, 2020
10c2aed
fix zeros type instability
yiyuezhuo Aug 11, 2020
fc6af4f
add @inbounds
yiyuezhuo Aug 11, 2020
5ee5826
replace log(1 - tanh(y[...])^2) with numerically stable implementation
yiyuezhuo Aug 13, 2020
917605f
removed logabsdetjac_lkj_inv
yiyuezhuo Aug 13, 2020
32497a1
rename two functions used in corr
yiyuezhuo Aug 14, 2020
c82e57c
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 18, 2020
c13d8a2
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
01250ce
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
de86456
Apply suggestions from code review
yiyuezhuo Aug 21, 2020
5238efe
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
3a01d90
add broken test
yiyuezhuo Aug 21, 2020
c86bcdf
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 21, 2020
79cbc5e
apply some suggestions
yiyuezhuo Aug 21, 2020
01db0f4
avoid some zero filling
yiyuezhuo Aug 21, 2020
8cfc22d
move dense transform as last as possible, add reversediff dense dense…
yiyuezhuo Aug 21, 2020
9c9846b
furthur remove extra zero filling and f(0) leveraging UpperTriangular…
yiyuezhuo Aug 21, 2020
9746499
forward special speedup, drop copy-paste maintain ability
yiyuezhuo Aug 21, 2020
7fb2675
Apply suggestions from code review
yiyuezhuo Aug 22, 2020
bad759b
optimize forwarddiff and remove custom AD for reversediff
yiyuezhuo Aug 22, 2020
7710187
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 22, 2020
5228310
fix suggestion mistake and clean up
yiyuezhuo Aug 22, 2020
ea71cd7
small optimizing and clean up
yiyuezhuo Aug 22, 2020
a577b05
add re-derived analytic gradient
yiyuezhuo Aug 23, 2020
50b837d
simplify code (though make code hard to read)
yiyuezhuo Aug 23, 2020
b04252b
add @bounds
yiyuezhuo Aug 23, 2020
6c525a7
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 24, 2020
7e46a09
bump version number to 0.8.3 and modify test and improve document
yiyuezhuo Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.2"
version = "0.8.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
184 changes: 184 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
"""
CorrBijector <: Bijector{2}

A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html

Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
a correlation matrix by following readable but not that efficient form:

```
K = size(y, 1)
z = tanh.(y)

for j=1:K, i=1:K
if i>j
w[i,j] = 0
elseif 1==i==j
w[i,j] = 1
elseif 1<i==j
w[i,j] = prod(sqrt(1 .- z[1:i-1, j].^2))
elseif 1==i<j
w[i,j] = z[i,j]
elseif 1<i<j
w[i,j] = z[i,j] * prod(sqrt(1 .- z[1:i-1, j].^2))
end
end
```

It is easy to see that every column is a unit vector, for example:

```
w3' w3 ==
w[1,3]^2 + w[2,3]^2 + w[3,3]^2 ==
z[1,3]^2 + (z[2,3] * sqrt(1 - z[1,3]^2))^2 + (sqrt(1-z[1,3]^2) * sqrt(1-z[2,3]^2))^2 ==
z[1,3]^2 + z[2,3]^2 * (1-z[1,3]^2) + (1-z[1,3]^2) * (1-z[2,3]^2) ==
z[1,3]^2 + z[2,3]^2 - z[2,3]^2 * z[1,3]^2 + 1 -z[1,3]^2 - z[2,3]^2 + z[1,3]^2 * z[2,3]^2 ==
1
```

And diagonal elements are positive, so `w` is a cholesky factor for a positive matrix.

```
x = w' * w
```

Consider block matrix representation for `x`

```
x = [w1'; w2'; ... wn'] * [w1 w2 ... wn] ==
[w1'w1 w1'w2 ... w1'wn;
w2'w1 w2'w2 ... w2'wn;
...
]
```

The diagonal elements are given by `wk'wk = 1`, thus `x` is a correlation matrix.

Every step is invertible, so this is a bijection(bijector).

Note: The implementation doesn't follow their "manageable expression" directly,
because their equation seems wrong (7/30/2020). Insteadly it follows definition
above the "manageable expression" directly, which is also described in above doc.
"""
struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U # keep LowerTriangular until here can avoid some computation
r = _link_chol_lkj(w)
return r + zero(x)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
end

return result
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
It's recommended to directly call
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inv(b), (b(X)))
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
logabsdetjac(b, x)
end
end


function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)

w = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]
w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

return w
end

"""
function _link_chol_lkj(w)

Link function for cholesky factor.

An alternative and maybe more efficient implementation was considered:

```
for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
```

But this implementation will not work when w[i-1, j] = 0.
Though it is a zero measure set, unit matrix initialization will not work.

For equivelence, following explanations is given by @torfjelde:

For `(i, j)` in the loop below, we define

z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))

and so

z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))

which is the above implementation.
"""
function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = similar(w) # z is also UpperTriangular.
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
for i in 2:(j - 1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[i, j] = atanh(p)
end
z[j, j] = 0
end

return z
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
4 changes: 3 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,5 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


end
102 changes: 102 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,105 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

K = LinearAlgebra.checksquare(y)

w = similar(y)

z_mat = similar(y) # cache for adjoint
tmp_mat = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]

z_mat[i, j] = z
tmp_mat[i, j] = tmp

w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δy = zero(y)

@inbounds for j in 1:K
Δtmp = Δw[j,j]
for i in j:-1:2
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
end
end

return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

K = LinearAlgebra.checksquare(w)

z = similar(w)

@inbounds z[1, 1] = 0

tmp_mat = similar(w) # cache for pullback.

@inbounds for j=2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
tmp_mat[1, j] = tmp
for i in 2:(j - 1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
tmp_mat[i, j] = tmp
z[i, j] = atanh(p)
end
z[j, j] = 0
end

function pullback_link_chol_lkj(Δz)
LinearAlgebra.checksquare(Δz)

Δw = similar(w)

@inbounds Δw[1,1] = zero(eltype(Δz))

@inbounds for j=2:K
Δw[j, j] = 0
Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j]
for i in (j-1):-1:2
p = w[i, j] / tmp_mat[i-1, j]
ftmp = sqrt(1 - p^2)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p
Δw[i, j] = Δp / tmp_mat[i-1, j]
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
end

return (Δw,)
end

return z, pullback_link_chol_lkj
end