Skip to content

Commit

Permalink
Add vectorize method for LKJCholesky (#485)
Browse files Browse the repository at this point in the history
* using `LinearAlgebra.Cholesky`

* add `vectorize` for `LKJCholesky`

* add `vectorize` test

* add forgotten `end`

* Update test/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fix typo

* add `reconstruct` methods for LKJ/LKJCholesky inv bijectors

* bump patch

* bump Bijectors compat

* Update src/utils.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* add Bijectors v0.13 compat

* add `inittrans` method for `CholeskyVariate`

* add `LKJ`/`LKJCholesky` tests
Co-authored-by: torfjelde

* include tests

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* make tests more accurate

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update test/lkj.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
harisorgn and github-actions[bot] committed Jun 23, 2023
1 parent 5f74696 commit fff3bd1
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.0"
version = "0.23.1"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
2 changes: 2 additions & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ using Setfield: Setfield
using ZygoteRules: ZygoteRules
using LogDensityProblems: LogDensityProblems

using LinearAlgebra: Cholesky

using DocStringExtensions

using Random: Random
Expand Down
2 changes: 1 addition & 1 deletion src/abstract_varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ end
# NOTE: `reconstruct` is no-op if `val` is already of correct shape.
"""
reconstruct_and_link(dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vi::VarName, dist, val)
reconstruct_and_link(vi::AbstractVarInfo, vn::VarName, dist, val)
Return linked `val` but reconstruct before linking, if necessary.
Expand Down
15 changes: 14 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ vectorize(d, r) = vec(r)
vectorize(d::UnivariateDistribution, r::Real) = [r]
vectorize(d::MultivariateDistribution, r::AbstractVector{<:Real}) = copy(r)
vectorize(d::MatrixDistribution, r::AbstractMatrix{<:Real}) = copy(vec(r))
vectorize(d::Distribution{CholeskyVariate}, r::Cholesky) = copy(vec(r.UL))

# NOTE:
# We cannot use reconstruct{T} because val is always Vector{Real} then T will be Real.
Expand All @@ -235,6 +236,13 @@ reconstruct(f, dist, val) = reconstruct(dist, val)
reconstruct(::UnivariateDistribution, val::Real) = val
reconstruct(::MultivariateDistribution, val::AbstractVector{<:Real}) = copy(val)
reconstruct(::MatrixDistribution, val::AbstractMatrix{<:Real}) = copy(val)
reconstruct(::Inverse{Bijectors.VecCorrBijector}, ::LKJ, val::AbstractVector) = copy(val)
function reconstruct(
::Inverse{Bijectors.VecCholeskyBijector}, ::LKJCholesky, val::AbstractVector
)
return copy(val)
end

# TODO: Implement no-op `reconstruct` for general array variates.

reconstruct(d::Distribution, val::AbstractVector) = reconstruct(size(d), val)
Expand Down Expand Up @@ -294,7 +302,12 @@ function inittrans(rng, dist::MatrixDistribution)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end

function inittrans(rng, dist::Distribution{CholeskyVariate})
# Get the size of the unconstrained vector
b = link_transform(dist)
sz = Bijectors.output_size(b, size(dist))
return Bijectors.invlink(dist, randrealuni(rng, sz...))
end
################################
# Multi-sample initialisations #
################################
Expand Down
51 changes: 51 additions & 0 deletions test/lkj.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
using Bijectors: pd_from_upper

@model lkj_prior_demo() = x ~ LKJ(2, 1)
@model lkj_chol_prior_demo() = x ~ LKJCholesky(2, 1, 'U')

# Same for both distributions
target_mean = vec(Matrix{Float64}(I, 2, 2))

_lkj_atol = 0.05

@testset "Sample from x ~ LKJ(2, 1)" begin
model = lkj_prior_demo()
# `SampleFromPrior` will sample in constrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
@test mean(map(Base.Fix2(getindex, Colon()), samples)) target_mean atol =
_lkj_atol
end
end

@testset "Sample from x ~ LKJCholesky(2, 1, U)" begin
model = lkj_chol_prior_demo()
# `SampleFromPrior` will sample in unconstrained space.
@testset "SampleFromPrior" begin
samples = sample(model, SampleFromPrior(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = Float64.(reshape(s.metadata.vals, (2, 2)))
pd_from_upper(M)
end
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
end

# `SampleFromUniform` will sample in unconstrained space.
@testset "SampleFromUniform" begin
samples = sample(model, SampleFromUniform(), 1_000)
# Build correlation matrix from factor
corr_matrices = map(samples) do s
M = Float64.(reshape(s.metadata.vals, (2, 2)))
pd_from_upper(M)
end
@test vec(mean(corr_matrices)) target_mean atol = _lkj_atol
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ include("test_util.jl")
include("serialization.jl")

include("loglikelihoods.jl")

include("lkj.jl")
end

@testset "compat" begin
Expand Down
6 changes: 6 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@
@test getargs_tilde(:(@. x ~ Normal(μ, $(Expr(:$, :(sqrt(v))))))) === nothing
@test getargs_tilde(:(@~ Normal.(μ, σ))) === nothing
end

@testset "vectorize" begin
dist = LKJCholesky(2, 1)
x = rand(dist)
@test vectorize(dist, x) == vec(x.UL)
end
end

0 comments on commit fff3bd1

Please sign in to comment.