Skip to content

Commit

Permalink
Merge 79cbc5e into 3a2b1e4
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyuezhuo committed Aug 21, 2020
2 parents 3a2b1e4 + 79cbc5e commit 1154d2b
Show file tree
Hide file tree
Showing 10 changed files with 503 additions and 3 deletions.
130 changes: 130 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
"""
CorrBijector <: Bijector{2}
A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
Note:(7/30/2020) their "manageable expression" is wrong, used expression is derived from
scratch.
"""
struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U + zero(x) # convert to dense matrix
r = _link_chol_lkj(w)
return r
end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

result = float(zero(eltype(y)))
for j in 2:K, i in 1:(j - 1)
@inbounds abs_y_i_j = abs(y[i, j])
result += (K - i + 1) * (logtwo - (abs_y_i_j + log1pexp(-2 * abs_y_i_j)))
end

return result
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
#=
It may be more efficient if we can use un-contraint value to prevent call of b
It's recommended to directly call
`logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})`
if possible.
=#
return -logabsdetjac(inv(b), (b(X)))
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
logabsdetjac(b, x)
end
end


function _inv_link_chol_lkj(y)
K = LinearAlgebra.checksquare(y)

z = tanh.(y)
w = similar(z)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w
end

"""
function _link_chol_lkj(w)
Link function for cholesky factor.
An alternative and maybe more efficient implementation was considered:
```
for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
```
But this implementation will not work when w[i-1, j] = 0.
Though it is a zero measure set, unit matrix initialization will not work.
For equivelence, following explanations is given by @torfjelde:
For `(i, j)` in the loop below, we define
z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))
and so
z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))
which is the above implementation.
"""
function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)
return y
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
108 changes: 107 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,109 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = value(y_tracked)

LinearAlgebra.checksquare(y)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = value(w_tracked)

LinearAlgebra.checksquare(w)

K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

function pull_link_chol_lkj(Δy)
LinearAlgebra.checksquare(Δy)

zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

return y, pull_link_chol_lkj

end

end
103 changes: 103 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,106 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

LinearAlgebra.checksquare(y)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

@inbounds for j in 1:K
w[1, j] = 1
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

function pullback_inv_link_chol_lkj(Δw)
LinearAlgebra.checksquare(Δw)

Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end

return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

LinearAlgebra.checksquare(w)

K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

function pullback_link_chol_lkj(Δy)
LinearAlgebra.checksquare(Δy)

zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

return y, pullback_link_chol_lkj
end

0 comments on commit 1154d2b

Please sign in to comment.