Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ bijector #125

Merged
merged 64 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
a8c32a2
add a stan style corr bijector
yiyuezhuo Jul 30, 2020
ba1ab40
implement a version which is only compatible with ForwardDiff AD backend
yiyuezhuo Jul 31, 2020
a98ce04
add AD support except Tracker
yiyuezhuo Aug 2, 2020
74c902b
sync
yiyuezhuo Aug 6, 2020
9c1742c
add tests for interface
yiyuezhuo Aug 6, 2020
a1f8ec7
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1c9b4ea
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
a52bdce
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
ead7e71
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
32cab00
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1859672
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
1ef50b0
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
5ecaaee
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
bf41efc
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
52e74e4
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
be3dddf
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
476ca0f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
bfd8d4f
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
dd0fade
Update src/compat/zygote.jl
yiyuezhuo Aug 9, 2020
38785f5
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
463acdc
Update src/compat/reversediff.jl
yiyuezhuo Aug 9, 2020
5bdea37
Update src/compat/tracker.jl
yiyuezhuo Aug 9, 2020
131c502
Update src/bijectors/corr.jl
yiyuezhuo Aug 9, 2020
4bbe656
stash
yiyuezhuo Aug 11, 2020
8d2fc60
pass basic tests for all AD
yiyuezhuo Aug 11, 2020
1120763
clean up
yiyuezhuo Aug 11, 2020
eec4de0
add data only AD unit test
yiyuezhuo Aug 11, 2020
d9e0abf
fix type instability
yiyuezhuo Aug 11, 2020
1f24e65
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
09297e9
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
388fee0
Update src/compat/reversediff.jl
yiyuezhuo Aug 11, 2020
72a16a8
Update src/compat/tracker.jl
yiyuezhuo Aug 11, 2020
d2908d5
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
8eb0686
Update src/bijectors/corr.jl
yiyuezhuo Aug 11, 2020
9260c61
remove link_lkj and inv_link_lkj
yiyuezhuo Aug 11, 2020
02a7d7e
switch looping order
yiyuezhuo Aug 11, 2020
d0984a4
rename logabsdetjac_lkj_inv and merge z
yiyuezhuo Aug 11, 2020
10c2aed
fix zeros type instability
yiyuezhuo Aug 11, 2020
fc6af4f
add @inbounds
yiyuezhuo Aug 11, 2020
5ee5826
replace log(1 - tanh(y[...])^2) with numerically stable implementation
yiyuezhuo Aug 13, 2020
917605f
removed logabsdetjac_lkj_inv
yiyuezhuo Aug 13, 2020
32497a1
rename two functions used in corr
yiyuezhuo Aug 14, 2020
c82e57c
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 18, 2020
c13d8a2
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
01250ce
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
de86456
Apply suggestions from code review
yiyuezhuo Aug 21, 2020
5238efe
Update src/bijectors/corr.jl
yiyuezhuo Aug 21, 2020
3a01d90
add broken test
yiyuezhuo Aug 21, 2020
c86bcdf
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 21, 2020
79cbc5e
apply some suggestions
yiyuezhuo Aug 21, 2020
01db0f4
avoid some zero filling
yiyuezhuo Aug 21, 2020
8cfc22d
move dense transform as last as possible, add reversediff dense dense…
yiyuezhuo Aug 21, 2020
9c9846b
furthur remove extra zero filling and f(0) leveraging UpperTriangular…
yiyuezhuo Aug 21, 2020
9746499
forward special speedup, drop copy-paste maintain ability
yiyuezhuo Aug 21, 2020
7fb2675
Apply suggestions from code review
yiyuezhuo Aug 22, 2020
bad759b
optimize forwarddiff and remove custom AD for reversediff
yiyuezhuo Aug 22, 2020
7710187
Merge branch 'Stan_corr_impl_adjoint' of https://github.com/yiyuezhuo…
yiyuezhuo Aug 22, 2020
5228310
fix suggestion mistake and clean up
yiyuezhuo Aug 22, 2020
ea71cd7
small optimizing and clean up
yiyuezhuo Aug 22, 2020
a577b05
add re-derived analytic gradient
yiyuezhuo Aug 23, 2020
50b837d
simplify code (though make code hard to read)
yiyuezhuo Aug 23, 2020
b04252b
add @bounds
yiyuezhuo Aug 23, 2020
6c525a7
Merge branch 'master' into Stan_corr_impl_adjoint
yiyuezhuo Aug 24, 2020
7e46a09
bump version number to 0.8.3 and modify test and improve document
yiyuezhuo Aug 24, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 17 additions & 11 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ scratch.
struct CorrBijector <: Bijector{2} end

function (b::CorrBijector)(x::AbstractMatrix{<:Real})
w = cholesky(x).U + zero(x) # convert to dense matrix
w = cholesky(x).U #+ zero(x) # convert to dense matrix
r = _link_chol_lkj(w)
return r
return r + zero(x) # keep LowerTriangular until here can avoid some computation
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is a quirk of the current implementation, of which it would be nice to be rid.
already states that we actually don't want this? What happens if your remove the + zero(x)?
This hack is really really ugly and unintuitive - e.g., if we would an Array in the end (I'm questioning that) it would be much more natural to call just Array(r) instead of some weird addition.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
Expression: typeof(x) == typeof(y)
Evaluated: Array{Float64,2} == UpperTriangular{Float64,Array{Float64,2}}
Stacktrace:
[1] single_sample_tests(::LKJ{Float64,Int64}) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:67
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:203
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
correlation matrix: Test Failed at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
Expression: typeof(xs) == typeof(ys)
Evaluated: Array{Array{Float64,2},1} == Array{UpperTriangular{Float64,Array{Float64,2}},1}
Stacktrace:
[1] multi_sample_tests(::LKJ{Float64,Int64}, ::Array{Float64,2}, ::Array{Array{Float64,2},1}, ::Int64) at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:82
[2] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:220
[3] top-level scope at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Test/src/Test.jl:1115
[4] top-level scope at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:201
Test Summary: | Pass Fail Total
correlation matrix | 9 2 11
ERROR: LoadError: LoadError: Some tests did not pass: 9 passed, 2 failed, 0 errored, 0 broken.
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/transform.jl:199
in expression starting at /home/yiyuezhuo/.julia/dev/Bijectors/test/runtests.jl:22
ERROR: Package Bijectors errored during testing

Copy link
Member

@devmotion devmotion Aug 21, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know why this test exists at all. My suggestion would be to just remove (or comment out) the test. But I guess @torfjelde might know why the test exists?

end

(b::CorrBijector)(X::AbstractArray{<:AbstractMatrix{<:Real}}) = map(b, X)
Expand Down Expand Up @@ -111,18 +113,22 @@ which is the above implementation.
function _link_chol_lkj(w)
K = LinearAlgebra.checksquare(w)

z = zero(w)

z = similar(w) # z is also UpperTriangular except ReverseDiff.
# Some zero filling can be avoided. Though diagnoal is still needed to be filled with zero.

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
z[j, j] = 0
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
z[i,j] = p
end

y = atanh.(z)
yiyuezhuo marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
30 changes: 20 additions & 10 deletions src/compat/reversediff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,25 +237,35 @@ _inv_link_chol_lkj(y::TrackedMatrix) = track(_inv_link_chol_lkj, y)
return w, pullback_inv_link_chol_lkj
end

_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
# zero(x) is used to convert w to dense matrix, because ReverseDiff doesn't support track
# LowerTriangular matrix (concretely, Base.IndexStyle(::A) != Base.IndexLinear())
# See: http://www.juliadiff.org/ReverseDiff.jl/limits/
_link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w + zero(x))
@grad function _link_chol_lkj(w_tracked)
w = value(w_tracked)

LinearAlgebra.checksquare(w)

K = size(w, 1)
z = zero(w)

@inbounds for j=2:K
z[1, j] = w[1, j]
z = similar(w)

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds for i=1:K
z[i, 1] = 0
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
@inbounds for j=2:K
z[1, j] = w[1, j]
for i=j:K
z[i, j] = 0
end
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
z[i,j] = p
end

y = atanh.(z)
Expand Down
21 changes: 12 additions & 9 deletions src/compat/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -503,18 +503,21 @@ _link_chol_lkj(w::TrackedMatrix) = track(_link_chol_lkj, w)
LinearAlgebra.checksquare(w)

K = size(w, 1)
z = zero(w)

z = similar(w)

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
z[j, j] = 0
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
z[i,j] = p
end

y = atanh.(z)
Expand Down
23 changes: 13 additions & 10 deletions src/compat/zygote.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ end

@adjoint function _inv_link_chol_lkj(y)
LinearAlgebra.checksquare(y)

K = size(y, 1)

z = tanh.(y)
Expand Down Expand Up @@ -250,18 +250,21 @@ end
LinearAlgebra.checksquare(w)

K = size(w, 1)
z = zero(w)

z = similar(w)

# This block can't be integrated with loop below, because w[1,1] != 0.
@inbounds z[1, 1] = 0

@inbounds for j=2:K
z[1, j] = w[1, j]
end

@inbounds for j=3:K, i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p *= 1 / sqrt(1-z[ip, j]^2)
z[j, j] = 0
for i=2:j-1
p = w[i, j]
for ip in 1:(i-1)
p /= sqrt(1-z[ip, j]^2)
end
z[i, j] = p
end
z[i,j] = p
end

y = atanh.(z)
Expand Down