Skip to content

Commit

Permalink
bump version number to 0.8.3 and modify test and improve document
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyuezhuo committed Aug 24, 2020
1 parent 6c525a7 commit 7e46a09
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.8.2"
version = "0.8.3"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
58 changes: 56 additions & 2 deletions src/bijectors/corr.jl
Expand Up @@ -4,8 +4,62 @@
A bijector implementation of Stan's parametrization method for Correlation matrix:
https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
Note:(7/30/2020) their "manageable expression" is wrong, used expression is derived from
scratch.
Basically, a unconstrained strictly upper triangular matrix `y` is transformed to
a correlation matrix by following readable but not that efficient form:
```
K = size(y, 1)
z = tanh.(y)
for j=1:K, i=1:K
if i>j
w[i,j] = 0
elseif 1==i==j
w[i,j] = 1
elseif 1<i==j
w[i,j] = prod(sqrt(1 .- z[1:i-1, j].^2))
elseif 1==i<j
w[i,j] = z[i,j]
elseif 1<i<j
w[i,j] = z[i,j] * prod(sqrt(1 .- z[1:i-1, j].^2))
end
end
```
It is easy to see that every column is a unit vector, for example:
```
w3' w3 ==
w[1,3]^2 + w[2,3]^2 + w[3,3]^2 ==
z[1,3]^2 + (z[2,3] * sqrt(1 - z[1,3]^2))^2 + (sqrt(1-z[1,3]^2) * sqrt(1-z[2,3]^2))^2 ==
z[1,3]^2 + z[2,3]^2 * (1-z[1,3]^2) + (1-z[1,3]^2) * (1-z[2,3]^2) ==
z[1,3]^2 + z[2,3]^2 - z[2,3]^2 * z[1,3]^2 + 1 -z[1,3]^2 - z[2,3]^2 + z[1,3]^2 * z[2,3]^2 ==
1
```
And diagonal elements are positive, so `w` is a cholesky factor for a positive matrix.
```
x = w' * w
```
Consider block matrix representation for `x`
```
x = [w1'; w2'; ... wn'] * [w1 w2 ... wn] ==
[w1'w1 w1'w2 ... w1'wn;
w2'w1 w2'w2 ... w2'wn;
...
]
```
The diagonal elements are given by `wk'wk = 1`, thus `x` is a correlation matrix.
Every step is invertible, so this is a bijection(bijector).
Note: The implementation doesn't follow their "manageable expression" directly,
because their equation seems wrong (7/30/2020). Insteadly it follows definition
above the "manageable expression" directly, which is also described in above doc.
"""
struct CorrBijector <: Bijector{2} end

Expand Down
2 changes: 1 addition & 1 deletion src/compat/tracker.jl
Expand Up @@ -530,7 +530,7 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p
Δw[i, j] = Δp / tmp_mat[i-1, j]
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Expand Down
2 changes: 1 addition & 1 deletion src/compat/zygote.jl
Expand Up @@ -277,7 +277,7 @@ end
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p
Δw[i, j] = Δp / tmp_mat[i-1, j]
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Expand Down
14 changes: 10 additions & 4 deletions test/ad/distributions.jl
Expand Up @@ -10,6 +10,12 @@
B = rand(dim, dim)
C = rand(dim, dim)

dim_big = 10

# Some LKJ problems may be hidden when test matrix is too small
A_big = rand(dim_big, dim_big)
B_big = rand(dim_big, dim_big)

# Create a random number
alpha = rand()

Expand Down Expand Up @@ -314,7 +320,7 @@
DistSpec((df, A) -> InverseWishart(df, to_posdef(A)), (3.0, A), B, to_posdef),
DistSpec((df, A) -> TuringWishart(df, to_posdef(A)), (3.0, A), B, to_posdef),
DistSpec((df, A) -> TuringInverseWishart(df, to_posdef(A)), (3.0, A), B, to_posdef),
DistSpec(() -> LKJ(3, 1.), (), A, to_corr),
DistSpec(() -> LKJ(10, 1.), (), A_big, to_corr),

# Vector of matrices x
DistSpec(
Expand Down Expand Up @@ -348,9 +354,9 @@
x -> map(to_posdef, x),
),
DistSpec(
() -> LKJ(3, 1.),
() -> LKJ(10, 1.),
(),
[A, B],
[A_big, B_big],
x -> map(to_corr, x),
)
]
Expand All @@ -376,7 +382,7 @@
B,
to_posdef,
),
DistSpec((eta) -> LKJ(3, eta), (1.), A, to_corr)
DistSpec((eta) -> LKJ(10, eta), (1.), A_big, to_corr)
# AD for parameters of LKJ requires more DistributionsAD supports
]

Expand Down

0 comments on commit 7e46a09

Please sign in to comment.