Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a8c32a2
add a stan style corr bijector
yiyuezhuo Jul 30, 2020
ba1ab40
implement a version which is only compatible with ForwardDiff AD backend
yiyuezhuo Jul 31, 2020
a98ce04
add AD support except Tracker
yiyuezhuo Aug 2, 2020
74c902b
sync
yiyuezhuo Aug 6, 2020
9c1742c
add tests for interface
yiyuezhuo Aug 6, 2020
a1f8ec7
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1c9b4ea
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
a52bdce
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
ead7e71
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
32cab00
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1859672
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1ef50b0
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
5ecaaee
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
bf41efc
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
52e74e4
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
be3dddf
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
476ca0f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
bfd8d4f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
dd0fade
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
38785f5
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
463acdc
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
5bdea37
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
131c502
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
4bbe656
stash
yiyuezhuo Aug 11, 2020
8d2fc60
pass basic tests for all AD
yiyuezhuo Aug 11, 2020
1120763
clean up
yiyuezhuo Aug 11, 2020
eec4de0
add data only AD unit test
yiyuezhuo Aug 11, 2020
d9e0abf
fix type instability
yiyuezhuo Aug 11, 2020
1f24e65
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
09297e9
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
388fee0
Update src/compat/reversediff.jl
yiyuezhuo Aug 11, 2020
72a16a8
Update src/compat/tracker.jl
yiyuezhuo Aug 11, 2020
d2908d5
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
8eb0686
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
9260c61
remove link_lkj and inv_link_lkj
yiyuezhuo Aug 11, 2020
02a7d7e
switch looping order
yiyuezhuo Aug 11, 2020
d0984a4
rename logabsdetjac_lkj_inv and merge z
yiyuezhuo Aug 11, 2020
10c2aed
fix zeros type instability
yiyuezhuo Aug 11, 2020
fc6af4f
add @inbounds
yiyuezhuo Aug 11, 2020
5ee5826
replace log(1 - tanh(y[...])^2) with numerically stable implementation
yiyuezhuo Aug 13, 2020
917605f
removed logabsdetjac_lkj_inv
yiyuezhuo Aug 13, 2020
32497a1
rename two functions used in corr
yiyuezhuo Aug 14, 2020
c82e57c
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 18, 2020
c13d8a2
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
01250ce
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
de86456
Apply suggestions from code review
yiyuezhuo Aug 21, 2020
5238efe
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
3a01d90
add broken test
yiyuezhuo Aug 21, 2020
c86bcdf
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 21, 2020
79cbc5e
apply some suggestions
yiyuezhuo Aug 21, 2020
01db0f4
avoid some zero filling
yiyuezhuo Aug 21, 2020
8cfc22d
move dense transform as last as possible, add reversediff dense dense…
yiyuezhuo Aug 21, 2020
9c9846b
furthur remove extra zero filling and f(0) leveraging UpperTriangular…
yiyuezhuo Aug 21, 2020
9746499
forward special speedup, drop copy-paste maintain ability
yiyuezhuo Aug 21, 2020
7fb2675
Apply suggestions from code review
yiyuezhuo Aug 22, 2020
bad759b
optimize forwarddiff and remove custom AD for reversediff
yiyuezhuo Aug 22, 2020
7710187
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 22, 2020
5228310
fix suggestion mistake and clean up
yiyuezhuo Aug 22, 2020
ea71cd7
small optimizing and clean up
yiyuezhuo Aug 22, 2020
a577b05
add re-derived analytic gradient
yiyuezhuo Aug 23, 2020
50b837d
simplify code (though make code hard to read)
yiyuezhuo Aug 23, 2020
b04252b
add @bounds
yiyuezhuo Aug 23, 2020
6c525a7
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 24, 2020
7e46a09
bump version number to 0.8.3 and modify test and improve document
yiyuezhuo Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
135 changes: 135 additions & 0 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# See stan doc for parametrization method:
# https://mc-stan.org/docs/2_23/reference-manual/correlation-matrix-transform-section.html
# (7/30/2020) their "manageable expression" is wrong...
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

function upper1(AT, A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we make use of UpperTriangular (which uses a view) instead of copying the matrix?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AU = zero(AT)
for i=1:size(A,1), j=(i+1):size(A,2)
AU[i,j] = A[i,j]
end
AU
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

struct CorrBijector <: Bijector{2} end

(b::CorrBijector)(X::AbstractMatrix{<:Real}) = link_lkj(X)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)

(ib::Inverse{<:CorrBijector})(Y::AbstractMatrix{<:Real}) = inv_link_lkj(Y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
(ib::Inverse{<:CorrBijector})(Y::AbstractArray{<:AbstractMatrix{<:Real}}) = map(ib, Y)


logabsdetjac(::Inverse{CorrBijector}, y::AbstractMatrix{<:Real}) = log_abs_det_jac_lkj(y)
function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
if !LinearAlgebra.isposdef(X)
println("!isposdef(X)")
return NaN # prevent Cholesky decomposition to break inference
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

-log_abs_det_jac_lkj(b(X)) # It may be more efficient if we can use un-contraint value to prevent call of b
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
logabsdetjac(b::CorrBijector, X::AbstractArray{<:AbstractMatrix{<:Real}}) = mapvcat(X) do x
logabsdetjac(b, x)
end
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved


function log_abs_det_jac_lkj(y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
# println("log_abs_det_jac_lkj $(typeof(y) == Matrix{Float64})")
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
# it's defined on inverse mapping
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
K = size(y, 1)

z = tanh.(y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
left = 0
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
for i = 1:(K-1), j = (i+1):K
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
left += (K-i-1) * log(1 - z[i, j]^2)
end

right = 0
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
for i = 1:(K-1), j = (i+1):K
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
right += log(cosh(y[i, j])^2)
end

return (0.5 * left - right)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

function inv_link_w_lkj(y)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should just be

Suggested change
function inv_link_w_lkj(y)
function (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real})

and return w' * w instead. All these different functions with (IMO) not really descriptive names should be avoided if possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm guessing the reason why @yiyuezhuo has done this is to be able to define custom adjoints. Not possible to do so if we use (b::Bijector)(x).

So I think leaving the (inv_)link_w_lkj is fine, but redefine (inv_)link_lkj to their corresponding b::Bijector definitions.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Zygote (or rather ChainRules nowadays) it's definitely possible to define adjoints for (::Bijector)(...) (see e.g. the support of Distances in Zygote). Does it not work for other AD backends?

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I defined custom adjoint on link_w_lkj but not on link_lkj, so link_w_lkj is just an "anonymous" code block requiring a custom gradient. A name is required, though not that descriptive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I'm not sure how that relates to my comment? I just wanted to point out that in Zygote and ChainRules you can also define adjoints for functions such as (ib::Inverse{<:CorrBijector})(y::AbstractMatrix{<:Real}) = ... if needed - there is no need to introduce some function inv_link_lkj for this purpose.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woah, really? If so I guess there's no need. Back when we started this overhaul of Bijectors, it was not possible to define adjoints for those kind of signatures. If that's not the case anymore, then I agree with your comments @devmotion

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean, let's say these is a function f which needs AD:

function f(x)
  y = f1(x)
  z = f2(y)
  return z
end

where f1 is AD-friendly and f2 needs a custom adjoint. So I wrote an adjoint for f2 only. Are you suggesting to write an adjoint for f and call pullback for f1 in there to avoid adjoint for f2? It's not clear to me.

# println("inv_link_w_lkj $(typeof(y) == Matrix{Float64})")
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
K = size(y, 1)

z = tanh.(y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
w = similar(z)

w[1,1] = 1
for j in 1:K
w[1, j] = 1
end

for i in 2:K
for j in 1:(i-1)
w[i, j] = 0
end
for j in i:K
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

for i in 1:K
for j in (i+1):K
w[i, j] = w[i, j] * z[i, j]
end
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be good to avoid this loop and the allocation of z if possible (haven't thought through if it is possible).

Copy link
Contributor Author

@yiyuezhuo yiyuezhuo Aug 13, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you suggesting w = w .* z? It's wrong since diagonal elements are not zeros. And z is used in 3 Reverse AD backends.


return w
end

function inv_link_lkj(y)
w = inv_link_w_lkj(y)
return w' * w
end

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
function link_w_lkj(w)
Copy link
Member

@devmotion devmotion Aug 11, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably this should just be

Suggested change
function link_w_lkj(w)
function (b::CorrBijector)(x::AbstractMatrix{<:Real})

with the additional code from link_lkj below.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See above comments.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean because of the adjoint? At least for Zygote that should work in exactly the same way, regardless of how you define the function.

# println("link_w_lkj $(typeof(w) == Matrix{Float64})")
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
K = size(w, 1)

# z = zero(w) # `zero` isn't compatible with ReverseDiff
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
z = similar(w)
for i=1:K, j=1:K
z[i,j] = 0
end

for j=2:K
z[1, j] = w[1, j]
end

#=
# This implementation will not works when w[i-1, j] = 0.
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
# Though it is a zero measure set, unit matrix initialization will not works.
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

for i=2:K, j=(i+1):K
z[i, j] = w[i, j] / w[i-1, j] * z[i-1, j] / sqrt(1 - z[i-1, j]^2)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
=#
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
for i=2:K, j=(i+1):K
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
z[i,j] = p
end

y = atanh.(z)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
return y
end

function link_lkj(x)
# println("link_lkj $(typeof(x) == Matrix{Float64})")
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
w = cholesky(x).U
# w = collect(cholesky(x).U)
# w = convert(typeof(x), cholesky(x).U) # ? test requires it, such quirk
# w = upper(parent(cholesky(x).U))
# return link_w_lkj(w)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
r = link_w_lkj(w)
# return r - lower(parent(r)) # test requires it, such quirk
upper1(x, r)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end
2 changes: 1 addition & 1 deletion src/compat/forwarddiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ function jacobian(
x::AbstractVector{<:Real}
)
return ForwardDiff.jacobian(b, x)
end
end
90 changes: 90 additions & 0 deletions src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,94 @@ lower(A::TrackedMatrix) = track(lower, A)
return lower(Ad), Δ -> (lower(Δ),)
end

#=
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
inv_link_w_lkj(y::TrackedMatrix) = track(inv_link_w_lkj, y)
@grad function inv_link_w_lkj(y_tracked::AbstractMatrix)
y = value(y_tracked)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
for j in 1:K
w[1, j] = 1
end

for i in 2:K
for j in 1:(i-1)
w[i, j] = 0
end
for j in i:K
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

for i in 1:K
for j in (i+1):K
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
Δz = zeros(size(Δw))
Δw1 = zeros(size(Δw))
for i in 1:K, j in (i+1):K
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
for i in 1:K
Δw1[i,i] = Δw[i,i]
end

for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

link_w_lkj(w::TrackedMatrix) = track(link_w_lkj, w)
@grad function link_w_lkj(w_tracked::AbstractMatrix)
w = value(w_tracked)

K = size(w, 1)
z = zero(w)

for j=2:K
z[1, j] = w[1, j]
end

for i=2:K, j=(i+1):K
z[i, j] = w[i, j] / w[i-1, j] * z[i-1, j] / sqrt(1 - z[i-1, j]^2)
end

y = atanh.(z)

return y, Δy -> begin
Δz = Δy .* (1 ./ (1. .- z.^2))
Δw = zeros(size(Δz))
for j=2:K, i=(j-1):-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw[i,j] += Δz[i,j] / w[i-1,j] * z[i-1, j] / tz
Δw[i-1,j] += Δz[i,j] * w[i,j] * z[i-1, j] / tz * (-1 / w[i-1, j]^2)
Δz[i-1,j] += Δz[i,j] * w[i,j] / w[i-1, j] * ((tz - z[i-1,j] * 0.5 / tz * (-2*z[i-1,j])) / tz^2)
end

for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end
end
=#
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

end
147 changes: 147 additions & 0 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ using .Tracker: Tracker,
using Compat: eachcol
using LinearAlgebra

import Base.*

yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
maporbroadcast(f, x::TrackedArray...) = f.(x...)
function maporbroadcast(
f,
Expand Down Expand Up @@ -440,3 +442,148 @@ lower(A::TrackedMatrix) = track(lower, A)
Ad = data(A)
return lower(Ad), Δ -> (lower(Δ),)
end

inv_link_w_lkj(y::TrackedMatrix) = track(inv_link_w_lkj, y)
@grad function inv_link_w_lkj(y_tracked)
y = data(y_tracked)

K = size(y, 1)

z = tanh.(y)
w = similar(z)

w[1,1] = 1
for j in 1:K
w[1, j] = 1
end

for i in 2:K
for j in 1:(i-1)
w[i, j] = 0
end
for j in i:K
w[i, j] = w[i-1, j] * sqrt(1 - z[i-1, j]^2)
end
end

w1 = copy(w) # cache result

for i in 1:K
for j in (i+1):K
w[i, j] = w[i, j] * z[i, j]
end
end

return w, Δw -> begin
Δz = zeros(size(Δw))
Δw1 = zeros(size(Δw))
for i in 1:K, j in (i+1):K
Δw1[i,j] = Δw[i,j] * z[i,j]
Δz[i,j] = Δw[i,j] * w1[i,j]
end
for i in 1:K
Δw1[i,i] = Δw[i,i]
end

for j=2:K, i=j:-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw1[i-1, j] += Δw1[i, j] * tz
Δz[i-1, j] += Δw1[i, j] * w1[i-1, j] * 0.5 / tz * (-2 * z[i-1, j])
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

Δy = Δz .* (1 ./ cosh.(y).^2)
return (Δy,)
end
end

link_w_lkj(w::TrackedMatrix) = track(link_w_lkj, w)
@grad function link_w_lkj(w_tracked)
w = data(w_tracked)

# println("link_w_lkj $(typeof(w) == Matrix{Float64})")
K = size(w, 1)
z = zero(w)

for j=2:K
z[1, j] = w[1, j]
end

#=
for i=2:K, j=(i+1):K
z[i, j] = w[i, j] / w[i-1, j] * z[i-1, j] / sqrt(1 - z[i-1, j]^2)
end
=#
for i=2:K, j=(i+1):K
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
end
z[i,j] = p
end

y = atanh.(z)

return y, Δy -> begin
zt0 = 1 ./ (1 .- z.^2)
zt = sqrt.(zt0)
Δz = Δy .* zt0
Δw = zeros(size(Δy))

for j=2:K, i=(j-1):-1:2
pd = prod(zt[1:i-1,j])
Δw[i,j] += Δz[i,j] * pd
for ip in 1:(i-1)
Δw[ip, j] += Δz[i,j] * w[i,j] * pd / (1-z[ip,j]^2) * z[ip,j]
end
end
for j=2:K
Δw[1, j] += Δz[1, j]
end

(Δw,)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

#=
return y, Δy -> begin
Δz = Δy .* (1 ./ (1. .- z.^2))
Δw = zeros(size(Δz))
for j=2:K, i=(j-1):-1:2
tz = sqrt(1 - z[i-1, j]^2)
Δw[i,j] += Δz[i,j] / w[i-1,j] * z[i-1, j] / tz
Δw[i-1,j] += Δz[i,j] * w[i,j] * z[i-1, j] / tz * (-1 / w[i-1, j]^2)
Δz[i-1,j] += Δz[i,j] * w[i,j] / w[i-1, j] * ((tz - z[i-1,j] * 0.5 / tz * (-2*z[i-1,j])) / tz^2)
end

for j=2:K
Δw[1, j] += Δz[1, j]
end

return (Δw,)
end
=#
end

upper1(AT::TrackedMatrix, A::TrackedMatrix) = track(upper1, AT, A)
@grad function upper1(AT_tracked, A_tracked)
AT = data(AT_tracked)
A = data(A_tracked)
return upper1(AT, A), Δ -> (nothing, upper1(AT, Δ))
end

# Workaround for Tracker ambiguous bug. See: https://github.com/FluxML/Tracker.jl/issues/74
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
# (*)(X::Diagonal, Y::TrackedArray{T,2,A} where A where T) = collect(X) * Y
x::Diagonal * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::Diagonal = track(*, x, y)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved

function LinearAlgebra.isposdef(w_tracked::TrackedMatrix)
# w = data(w_tracked)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
w = w_tracked.data
LinearAlgebra.isposdef(w)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end

function LinearAlgebra.isposdef(w_tracked::Symmetric{<:Any, <:TrackedMatrix})
# w = data(w_tracked)
w = w_tracked.data
LinearAlgebra.isposdef(w)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
end