Skip to content

Commit

Permalink
Bugfix in VarInfo. (#516)
Browse files Browse the repository at this point in the history
* Bugfix in `VarInfo`.

* Update Project.toml

* Update src/varinfo.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Update src/varinfo.jl

Co-authored-by: David Widmann <devmotion@users.noreply.github.com>

* Added tests.

* Apply suggestions from code review

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* Fix for tests in #516 (#517)

* fixed tests for linking of dirichlet with different dimensionality

* added usage of same logp in TestUtils.setup_varinfos

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: David Widmann <devmotion@users.noreply.github.com>
Co-authored-by: Tor Erlend Fjelde <tor.erlend95@gmail.com>
  • Loading branch information
4 people committed Aug 9, 2023
1 parent 4a986df commit 7ef5da7
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 17 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.12"
version = "0.23.13"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
3 changes: 2 additions & 1 deletion src/test_utils.jl
Expand Up @@ -51,9 +51,10 @@ function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())

lp = getlogp(vi_typed)
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
# Set them all to the same values.
update_values!!(vi, example_values, varnames)
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end
end

Expand Down
2 changes: 1 addition & 1 deletion src/varinfo.jl
Expand Up @@ -366,7 +366,7 @@ setall!(vi::TypedVarInfo, val) = _setall!(vi.metadata, val)
for f in names
length = :(sum(length, metadata.$f.ranges))
finish = :($start + $length - 1)
push!(expr.args, :(metadata.$f.vals .= val[($start):($finish)]))
push!(expr.args, :(copyto!(metadata.$f.vals, 1, val, $start, $length)))
start = :($start + $length)
end
return expr
Expand Down
33 changes: 19 additions & 14 deletions test/linking.jl
Expand Up @@ -91,21 +91,26 @@ end
end
end

# Related: https://github.com/TuringLang/DynamicPPL.jl/issues/504
@testset "dirichlet" begin
@model demo_dirichlet() = x ~ Dirichlet(2, 1.0)
model = demo_dirichlet()
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
@test length(vi[:]) == 2
@test iszero(getlogp(vi))
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == 1
@test !iszero(getlogp(vi_linked)) # should now include the log-absdet-jacobian correction
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == 2
@test iszero(getlogp(vi_invlinked))
@model demo_dirichlet(d::Int) = x ~ Dirichlet(d, 1.0)
@testset "d=$d" for d in [2, 3, 5]
model = demo_dirichlet(d)
vis = DynamicPPL.TestUtils.setup_varinfos(model, rand(model), (@varname(x),))
@testset "$(short_varinfo_name(vi))" for vi in vis
lp = logpdf(Dirichlet(d, 1.0), vi[:])
@test length(vi[:]) == d
@test getlogp(vi) lp
# Linked.
vi_linked = DynamicPPL.link!!(deepcopy(vi), model)
@test length(vi_linked[:]) == d - 1
# Should now include the log-absdet-jacobian correction.
@test !(getlogp(vi_linked) lp)
# Invlinked.
vi_invlinked = DynamicPPL.invlink!!(deepcopy(vi_linked), model)
@test length(vi_invlinked[:]) == d
@test getlogp(vi_invlinked) lp
end
end
end
end

2 comments on commit 7ef5da7

@torfjelde
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/89311

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.23.13 -m "<description of version>" 7ef5da709564802fc2ccae182e0f77d1b7af5958
git push origin v0.23.13

Please sign in to comment.