Skip to content

Commit

Permalink
simplify code (though make code hard to read)
Browse files Browse the repository at this point in the history
  • Loading branch information
yiyuezhuo committed Aug 23, 2020
1 parent a577b05 commit 50b837d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 14 deletions.
13 changes: 6 additions & 7 deletions src/compat/tracker.jl
Expand Up @@ -477,8 +477,8 @@ _inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
@inbounds for j in 1:K
Δtmp = Δw[j,j]
for i in j:-1:2
Δz = Δw[i-1, j] * tmp_mat[i, j] + Δtmp * tmp_mat[i, j] * 1 / sqrt(1 - z_mat[i, j]^2) / 2 * (-2 * z_mat[i, j])
Δy[i-1, j] = Δz * 1 / cosh(y[i-1, j])^2
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
end
end
Expand Down Expand Up @@ -527,15 +527,14 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
for i in (j-1):-1:2
p = w[i, j] / tmp_mat[i-1, j]
ftmp = sqrt(1 - p^2)
d_ftmp_p = 1 / ftmp / 2 * (-2 * p)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] * 1/(1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δw[i, j] = Δp / tmp_mat[i-1, j]
# Δtmp = Δp * d_p_tmp + Δtmp * (ftmp + tmp_mat[i-1, j] * d_ftmp_p * d_p_tmp) # update to "previous" Δtmp
Δtmp = Δp * d_p_tmp + Δtmp * (ftmp) # update to "previous" Δtmp
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Δw[1, j] = Δz[1, j] * 1/(1-w[1,j]^2) + Δtmp / sqrt(1 - w[1,j]^2) / 2 * (-2 * w[1,j])
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
end

return (Δw,)
Expand Down
13 changes: 6 additions & 7 deletions src/compat/zygote.jl
Expand Up @@ -227,8 +227,8 @@ end
@inbounds for j in 1:K
Δtmp = Δw[j,j]
for i in j:-1:2
Δz = Δw[i-1, j] * tmp_mat[i, j] + Δtmp * tmp_mat[i, j] * 1 / sqrt(1 - z_mat[i, j]^2) / 2 * (-2 * z_mat[i, j])
Δy[i-1, j] = Δz * 1 / cosh(y[i-1, j])^2
Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j]
Δy[i-1, j] = Δz / cosh(y[i-1, j])^2
Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2)
end
end
Expand Down Expand Up @@ -274,15 +274,14 @@ end
for i in (j-1):-1:2
p = w[i, j] / tmp_mat[i-1, j]
ftmp = sqrt(1 - p^2)
d_ftmp_p = 1 / ftmp / 2 * (-2 * p)
d_ftmp_p = -p / ftmp
d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2

Δp = Δz[i,j] * 1/(1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify
Δw[i, j] = Δp / tmp_mat[i-1, j]
# Δtmp = Δp * d_p_tmp + Δtmp * (ftmp + tmp_mat[i-1, j] * d_ftmp_p * d_p_tmp) # update to "previous" Δtmp
Δtmp = Δp * d_p_tmp + Δtmp * (ftmp) # update to "previous" Δtmp
Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp
end
Δw[1, j] = Δz[1, j] * 1/(1-w[1,j]^2) + Δtmp / sqrt(1 - w[1,j]^2) / 2 * (-2 * w[1,j])
Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j]
end

return (Δw,)
Expand Down

0 comments on commit 50b837d

Please sign in to comment.