diff --git a/src/compat/tracker.jl b/src/compat/tracker.jl index 6742ed16..8dab331a 100644 --- a/src/compat/tracker.jl +++ b/src/compat/tracker.jl @@ -477,8 +477,8 @@ _inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y) @inbounds for j in 1:K Δtmp = Δw[j,j] for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] + Δtmp * tmp_mat[i, j] * 1 / sqrt(1 - z_mat[i, j]^2) / 2 * (-2 * z_mat[i, j]) - Δy[i-1, j] = Δz * 1 / cosh(y[i-1, j])^2 + Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] + Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) end end @@ -527,15 +527,14 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w) for i in (j-1):-1:2 p = w[i, j] / tmp_mat[i-1, j] ftmp = sqrt(1 - p^2) - d_ftmp_p = 1 / ftmp / 2 * (-2 * p) + d_ftmp_p = -p / ftmp d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 - Δp = Δz[i,j] * 1/(1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify + Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify Δw[i, j] = Δp / tmp_mat[i-1, j] - # Δtmp = Δp * d_p_tmp + Δtmp * (ftmp + tmp_mat[i-1, j] * d_ftmp_p * d_p_tmp) # update to "previous" Δtmp - Δtmp = Δp * d_p_tmp + Δtmp * (ftmp) # update to "previous" Δtmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp end - Δw[1, j] = Δz[1, j] * 1/(1-w[1,j]^2) + Δtmp / sqrt(1 - w[1,j]^2) / 2 * (-2 * w[1,j]) + Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] end return (Δw,) diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl index ed9d6a3e..48ff3076 100644 --- a/src/compat/zygote.jl +++ b/src/compat/zygote.jl @@ -227,8 +227,8 @@ end @inbounds for j in 1:K Δtmp = Δw[j,j] for i in j:-1:2 - Δz = Δw[i-1, j] * tmp_mat[i, j] + Δtmp * tmp_mat[i, j] * 1 / sqrt(1 - z_mat[i, j]^2) / 2 * (-2 * z_mat[i, j]) - Δy[i-1, j] = Δz * 1 / cosh(y[i-1, j])^2 + Δz = Δw[i-1, j] * tmp_mat[i, j] - Δtmp * tmp_mat[i, j] / sqrt(1 - z_mat[i, j]^2) * z_mat[i, j] + Δy[i-1, j] = Δz / cosh(y[i-1, j])^2 Δtmp = Δw[i-1, j] * z_mat[i, j] + Δtmp * sqrt(1 - z_mat[i, j]^2) end end @@ -274,15 +274,14 @@ end for i in (j-1):-1:2 p = w[i, j] / tmp_mat[i-1, j] ftmp = sqrt(1 - p^2) - d_ftmp_p = 1 / ftmp / 2 * (-2 * p) + d_ftmp_p = -p / ftmp d_p_tmp = -w[i,j] / tmp_mat[i-1, j]^2 - Δp = Δz[i,j] * 1/(1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify + Δp = Δz[i,j] / (1-p^2) + Δtmp * tmp_mat[i-1, j] * d_ftmp_p # TODO: simplify Δw[i, j] = Δp / tmp_mat[i-1, j] - # Δtmp = Δp * d_p_tmp + Δtmp * (ftmp + tmp_mat[i-1, j] * d_ftmp_p * d_p_tmp) # update to "previous" Δtmp - Δtmp = Δp * d_p_tmp + Δtmp * (ftmp) # update to "previous" Δtmp + Δtmp = Δp * d_p_tmp + Δtmp * ftmp # update to "previous" Δtmp end - Δw[1, j] = Δz[1, j] * 1/(1-w[1,j]^2) + Δtmp / sqrt(1 - w[1,j]^2) / 2 * (-2 * w[1,j]) + Δw[1, j] = Δz[1, j] / (1-w[1,j]^2) - Δtmp / sqrt(1 - w[1,j]^2) * w[1,j] end return (Δw,)