Skip to content

Commit

Permalink
Merge 01250ce into 3a2b1e4
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyuezhuo committed Aug 21, 2020
2 parents 3a2b1e4 + 01250ce commit b79b0df
Show file tree
Hide file tree
Showing 10 changed files with 482 additions and 3 deletions.
119 changes: 119 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# See stan doc for parametrization method:
# https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
# (7/30/2020) their "manageable expression" is wrong...

struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U + zero(x) # convert to dense matrix
r = _link_chol_lkj(w)
return r
end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})
w = _inv_link_chol_lkj(y)
return w' * w
end
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


function logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
K = LinearAlgebra.checksquare(y)

left = zero(eltype(y)) # Initial summand may make looping looks weird :(
@inbounds for j=2:K, i=1:(j-1)
# left += (K-i-1) * log(1 - tanh(y[i, j])^2) # lacks numerically stable
left += (K-i-1) * 2 * (log(2) - log(exp(y[i,j]) + exp(-y[i,j])))
end

right = zero(eltype(y))
@inbounds for j=2:K, i=1:(j-1)
right += log(cosh(y[i, j])^2)
end

return left / 2 - right
end
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
return -logabsdetjac(inv(b),(b(X))) # It may be more efficient if we can use un-contraint value to prevent call of b
end
function logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}})
return mapvcat(X) do x
logabsdetjac(b, x)
end
end


function _inv_link_chol_lkj(y)
@assert size(y, 1) == size(y, 2)
K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w
end

function _link_chol_lkj(w)
@assert size(w, 1) == size(w, 2)
K = size(w, 1)

z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

#=
# This implementation will not work when w[i-1, j] = 0.
# Though it is a zero measure set, unit matrix initialization will not work.
for i=2:K, j=(i+1):K
z[i, j] = (w[i, j] / w[i-1, j]) * (z[i-1, j] / sqrt(1 - z[i-1, j]^2))
end
For `(i, j)` in the loop below, we define
z₍ᵢ₋₁, ⱼ₎ = w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² (1 / √(1 - z₍ₖ,ⱼ₎²))
and so
z₍ᵢ,ⱼ₎ = w₍ᵢ,ⱼ₎ * ∏ₖ₌₁ⁱ⁻¹ (1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²))
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (w₍ᵢ₋₁,ⱼ₎ * ∏ₖ₌₁ⁱ⁻² 1 / √(1 - z₍ₖ,ⱼ₎²)) / w₍ᵢ₋₁,ⱼ₎
= (w₍ᵢ,ⱼ₎ * / √(1 - z₍ᵢ₋₁,ⱼ₎²)) * (z₍ᵢ₋₁,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎)
= (w₍ᵢ,ⱼ₎ / w₍ᵢ₋₁,ⱼ₎) * (z₍ᵢ₋₁,ⱼ₎ / √(1 - z₍ᵢ₋₁,ⱼ₎²))
which is the above implementation.
=#
@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)
return y
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
104 changes: 103 additions & 1 deletion src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ using ..Bijectors: Log, SimplexBijector, maphcat, simplex_link_jacobian,
simplex_invlink_jacobian, simplex_logabsdetjac_gradient, ADBijector,
ReverseDiffAD, Inverse
import ..Bijectors: _eps, logabsdetjac, _logabsdetjac_scale, _simplex_bijector,
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower
_simplex_inv_bijector, replace_diag, jacobian, getpd, lower,
_inv_link_chol_lkj, _link_chol_lkj

using Compat: eachcol
using Distributions: LocationScale
Expand Down Expand Up @@ -180,4 +181,105 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end


_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = value(y_tracked)

@assert size(y, 1) == size(y, 2)
K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
@assert size(Δw, 1) == size(Δw, 2)
Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = value(w_tracked)

@assert size(w, 1) == size(w, 2)
K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

return y, Δy -> begin
@assert size(Δy, 1) == size(Δy, 2)
zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

end

end
100 changes: 100 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,103 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

_inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@grad function _inv_link_chol_lkj(y_tracked)
y = data(y_tracked)

@assert size(y, 1) == size(y, 2)
K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
@inbounds for j in 1:K
w[1, j] = 1
end

@inbounds for j in 1:K
for i in j+1:K
w[i, j] = 0
end
for i in 2:j
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

@inbounds for j in 2:K
for i in 1:j-1
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
@assert size(Δw, 1) == size(Δw, 2)
Δz = zero(Δw)
Δw1 = zero(Δw)
@inbounds for j=2:K, i=1:j-1
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
@inbounds for i in 1:K
Δw1[i,i] = Δw[i,i]
end

@inbounds for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
@grad function _link_chol_lkj(w_tracked)
w = data(w_tracked)

@assert size(w, 1) == size(w, 2)
K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

return y, Δy -> begin
@assert size(Δy, 1) == size(Δy, 2)
zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zero(Δy)

@inbounds for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
@inbounds for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end

end

0 comments on commit b79b0df

Please sign in to comment.