Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a8c32a2
add a stan style corr bijector
yiyuezhuo Jul 30, 2020
ba1ab40
implement a version which is only compatible with ForwardDiff AD backend
yiyuezhuo Jul 31, 2020
a98ce04
add AD support except Tracker
yiyuezhuo Aug 2, 2020
74c902b
sync
yiyuezhuo Aug 6, 2020
9c1742c
add tests for interface
yiyuezhuo Aug 6, 2020
a1f8ec7
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1c9b4ea
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
a52bdce
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
ead7e71
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
32cab00
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1859672
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1ef50b0
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
5ecaaee
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
bf41efc
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
52e74e4
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
be3dddf
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
476ca0f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
bfd8d4f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
dd0fade
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
38785f5
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
463acdc
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
5bdea37
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
131c502
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
4bbe656
stash
yiyuezhuo Aug 11, 2020
8d2fc60
pass basic tests for all AD
yiyuezhuo Aug 11, 2020
1120763
clean up
yiyuezhuo Aug 11, 2020
eec4de0
add data only AD unit test
yiyuezhuo Aug 11, 2020
d9e0abf
fix type instability
yiyuezhuo Aug 11, 2020
1f24e65
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
09297e9
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
388fee0
Update src/compat/reversediff.jl
yiyuezhuo Aug 11, 2020
72a16a8
Update src/compat/tracker.jl
yiyuezhuo Aug 11, 2020
d2908d5
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
8eb0686
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
9260c61
remove link_lkj and inv_link_lkj
yiyuezhuo Aug 11, 2020
02a7d7e
switch looping order
yiyuezhuo Aug 11, 2020
d0984a4
rename logabsdetjac_lkj_inv and merge z
yiyuezhuo Aug 11, 2020
10c2aed
fix zeros type instability
yiyuezhuo Aug 11, 2020
fc6af4f
add @inbounds
yiyuezhuo Aug 11, 2020
5ee5826
replace log(1 - tanh(y[...])^2) with numerically stable implementation
yiyuezhuo Aug 13, 2020
917605f
removed logabsdetjac_lkj_inv
yiyuezhuo Aug 13, 2020
32497a1
rename two functions used in corr
yiyuezhuo Aug 14, 2020
c82e57c
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 18, 2020
c13d8a2
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
01250ce
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
de86456
Apply suggestions from code review
yiyuezhuo Aug 21, 2020
5238efe
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
3a01d90
add broken test
yiyuezhuo Aug 21, 2020
c86bcdf
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 21, 2020
79cbc5e
apply some suggestions
yiyuezhuo Aug 21, 2020
01db0f4
avoid some zero filling
yiyuezhuo Aug 21, 2020
8cfc22d
move dense transform as last as possible, add reversediff dense dense…
yiyuezhuo Aug 21, 2020
9c9846b
furthur remove extra zero filling and f(0) leveraging UpperTriangular…
yiyuezhuo Aug 21, 2020
9746499
forward special speedup, drop copy-paste maintain ability
yiyuezhuo Aug 21, 2020
7fb2675
Apply suggestions from code review
yiyuezhuo Aug 22, 2020
bad759b
optimize forwarddiff and remove custom AD for reversediff
yiyuezhuo Aug 22, 2020
7710187
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 22, 2020
5228310
fix suggestion mistake and clean up
yiyuezhuo Aug 22, 2020
ea71cd7
small optimizing and clean up
yiyuezhuo Aug 22, 2020
a577b05
add re-derived analytic gradient
yiyuezhuo Aug 23, 2020
50b837d
simplify code (though make code hard to read)
yiyuezhuo Aug 23, 2020
b04252b
add @bounds
yiyuezhuo Aug 23, 2020
6c525a7
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 24, 2020
7e46a09
bump version number to 0.8.3 and modify test and improve document
yiyuezhuo Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
118 changes: 118 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# See stan doc for parametrization method:
# https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
# (7/30/2020) their "manageable expression" is wrong...
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U + zero(x) # convert to dense matrix
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need a dense matrix? Seems a bit wasteful. I would have thought that Array(cholesky(x).U) or convert(Array, cholesky(x).U) would work as well.

r = _link_chol_lkj(w)
return r
end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
@assert size(y, 1) == size(y, 2)
K = size(y, 1)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

left = zero(eltype(y)) # Initial summand may make looping looks weird :(
@inbounds for j=2:K, i=1:(j-1)
# left += (K-i-1) * log(1 - tanh(y[i, j])^2) # lacks numerically stable
left += (K-i-1) * 2 * (log(2) - log(exp(y[i,j]) + exp(-y[i,j])))
end

right = zero(eltype(y))
@inbounds for j=2:K, i=1:(j-1)
right += log(cosh(y[i, j])^2)
end

return left / 2 - right
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
return -logabsdetjac(inv(b),(b(X))) # It may be more efficient if we can use un-contraint value to prevent call of b
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) = mapvcat(X) do x
logabsdetjac(b, x)
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved


function _inv_link_chol_lkj(y)
@assert size(y, 1) == size(y, 2)
K = size(y, 1)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

z = tanh.(y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The array z and the loops below and above are not needed, it seems you could just write

Suggested change
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
w[1, j] = 1
for i in 2:j
u = w[i-1, j]
z = tanh(y[i-1, j])
w[i-1, j] = u * z
w[i, j] = u * sqrt(1 - z^2)
end
for i in (j + 1):K
w[i, j] = 0
end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW can't we just return a triangular matrix so we don't have to fill the other elements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not needed by ForwardDiff, but it's not free for other 3 reversediff backends since it can be used to avoid an extra tanh call.

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test to ensure input and output have same type:

@test typeof(xs) == typeof(ys)

So a zero filling is necessary, a triangular matrix just postpone the filling to another point.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not needed by ForwardDiff, but it's not free for other 3 reversediff backends since it can be used to avoid an extra tanh call.

But it's the standard method here that is used for ForwardDiff, for reverse-mode you defined custom adjoints anyways. And even in the reverse-mode case it probably makes sense to move everything into one loop and fill both z and w there to avoid all the unnecessary computations. Then you can even make z a LowerTriangular which would probably speed up the computations in the backward pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a test to ensure input and output have same type:

That seems weird but shouldn't be touched in this PR. @torfjelde Is that a general design choice?

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried again and rediscovery a forgotten problem. ReverseDiff doesn't support array with Base.IndexStyle(::A) != Base.IndexLinear() (see limit) but IndexStyle(LowerTriangular(randn(3,3))) == IndexCartesian(). Passing a LowerTriangular will throw AssertionError: IndexStyle(value) === IndexLinear() error.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean you could still do it differently for ReverseDiff and just not wrap it in a LowerTriangular, I guess? My main point here was that there is no need to define the default method in a Zygote/ReverseDiff/Tracker compatible way since these AD backends will use the custom adjoints anyway. For the default method we should use the most efficient method. And even in the custom adjoints we can use the same optimized implementation logic and just keep z (as LowerTriangular or not, depending on the backend).

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I converted it to dense only in ReverseDiff. Anyway I still need to convert it to dense in last step to make that test happy.

end

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
return w
end

function _link_chol_lkj(w)
@assert size(w, 1) == size(w, 2)
K = size(w, 1)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

#=
# This implementation will not work when w[i-1, j] = 0.
# Though it is a zero measure set, unit matrix initialization will not work.

for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
For `(i, j)` in the loop below, we define

z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))

and so

z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))

which is the above implementation.
=#
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
return y
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
104 changes: 103 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,105 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The suggestions above also apply to the custom adjoints.

y = value(y_tracked)

@assert size(y, 1) == size(y, 2)
K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
@assert size(Δw, 1) == size(Δw, 2)
Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = value(w_tracked)

@assert size(w, 1) == size(w, 2)
K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

return y, Δy -> begin
@assert size(Δy, 1) == size(Δy, 2)
zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

end

end
100 changes: 100 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,103 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

@assert size(y, 1) == size(y, 2)
K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
@assert size(Δw, 1) == size(Δw, 2)
Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

@assert size(w, 1) == size(w, 2)
K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

return y, Δy -> begin
@assert size(Δy, 1) == size(Δy, 2)
zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

end