Skip to content

Commit

Permalink
Merge 50b837d into 3a2b1e4
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyuezhuo committed Aug 23, 2020
2 parents 3a2b1e4 + 50b837d commit 310285c
Show file tree
Hide file tree
Showing 10 changed files with 379 additions and 3 deletions.
130 changes: 130 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
CorrBijector <: Bijector{2}
A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
Note:(7/30/2020) their "manageable expression" is wrong, used expression is derived from
scratch.
"""
struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U # keep LowerTriangular until here can avoid some computation
r = _link_chol_lkj(w)
return r + zero(x)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
end

return result
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
It's recommended to directly call
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inv(b), (b(X)))
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
logabsdetjac(b, x)
end
end


function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)

w = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]
w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

return w
end

"""
function _link_chol_lkj(w)
Link function for cholesky factor.
An alternative and maybe more efficient implementation was considered:
```
for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
```
But this implementation will not work when w[i-1, j] = 0.
Though it is a zero measure set, unit matrix initialization will not work.
For equivelence, following explanations is given by @torfjelde:
For `(i, j)` in the loop below, we define
z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))
and so
z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))
which is the above implementation.
"""
function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = similar(w) # z is also UpperTriangular.
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
for i in 2:(j - 1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
z[i, j] = atanh(p)
end
z[j, j] = 0
end

return z
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
4 changes: 3 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,5 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


end
102 changes: 102 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,105 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

K = LinearAlgebra.checksquare(y)

w = similar(y)

z_mat = similar(y) # cache for adjoint
tmp_mat = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]

z_mat[i, j] = z
tmp_mat[i, j] = tmp

w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δy = zero(y)

@inbounds for j in 1:K
Δtmp = Δw[j,j]
for i in j:-1:2
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
end
end

return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

K = LinearAlgebra.checksquare(w)

z = similar(w)

@inbounds z[1, 1] = 0

tmp_mat = similar(w) # cache for pullback.

@inbounds for j=2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
tmp_mat[1, j] = tmp
for i in 2:(j - 1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
tmp_mat[i, j] = tmp
z[i, j] = atanh(p)
end
z[j, j] = 0
end

function pullback_link_chol_lkj(Δz)
LinearAlgebra.checksquare(Δz)

Δw = similar(w)

Δw[1,1] = zero(eltype(Δz))

for j=2:K
Δw[j, j] = 0
Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j]
for i in (j-1):-1:2
p = w[i, j] / tmp_mat[i-1, j]
ftmp = sqrt(1 - p^2)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δw[i, j] = Δp / tmp_mat[i-1, j]
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
end

return (Δw,)
end

return z, pullback_link_chol_lkj
end
98 changes: 98 additions & 0 deletions src/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,3 +193,101 @@ end
return replace_diag(log, Y)
end
end

@adjoint function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)

w = similar(y)

z_mat = similar(y) # cache for adjoint
tmp_mat = similar(y)

@inbounds for j in 1:K
w[1, j] = 1
for i in 2:j
z = tanh(y[i-1, j])
tmp = w[i-1, j]

z_mat[i, j] = z
tmp_mat[i, j] = tmp

w[i-1, j] = z * tmp
w[i, j] = tmp * sqrt(1 - z^2)
end
for i in (j+1):K
w[i, j] = 0
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δy = zero(y)

@inbounds for j in 1:K
Δtmp = Δw[j,j]
for i in j:-1:2
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
end
end

return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

@adjoint function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = similar(w)

@inbounds z[1, 1] = 0

tmp_mat = similar(w) # cache for pullback.

@inbounds for j=2:K
z[1, j] = atanh(w[1, j])
tmp = sqrt(1 - w[1, j]^2)
tmp_mat[1, j] = tmp
for i in 2:(j - 1)
p = w[i, j] / tmp
tmp *= sqrt(1 - p^2)
tmp_mat[i, j] = tmp
z[i, j] = atanh(p)
end
z[j, j] = 0
end

function pullback_link_chol_lkj(Δz)
LinearAlgebra.checksquare(Δz)

Δw = similar(w)

Δw[1,1] = zero(eltype(Δz))

for j=2:K
Δw[j, j] = 0
Δtmp = zero(eltype(Δz)) # Δtmp_mat[j-1,j]
for i in (j-1):-1:2
p = w[i, j] / tmp_mat[i-1, j]
ftmp = sqrt(1 - p^2)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δw[i, j] = Δp / tmp_mat[i-1, j]
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
end

return (Δw,)
end

return z, pullback_link_chol_lkj

end

1 change: 1 addition & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ include("bijectors/shift.jl")
include("bijectors/permute.jl")
include("bijectors/simplex.jl")
include("bijectors/pd.jl")
include("bijectors/corr.jl")
include("bijectors/truncated.jl")

# Normalizing flow related
Expand Down
1 change: 1 addition & 0 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d)
bijector(d::PDMatDistribution) = PDBijector()
bijector(d::MatrixBeta) = PDBijector()

bijector(d::LKJ) = CorrBijector()

##############################
# Distributions.jl interface #
Expand Down

0 comments on commit 310285c

Please sign in to comment.