Skip to content

Commit

Permalink
Fixes for SimpleVarInfo with Ref (#527)
Browse files Browse the repository at this point in the history
* added missing getlogp impl for SimpleVarInfo with Ref

* included SimpleVarInfo with Ref in the TestUtils.setup_varinfos

* bump patch version

* moved impls of acclogp!! and setlogp!! for SimpleVarInfo next to each other
  • Loading branch information
torfjelde committed Sep 1, 2023
1 parent 549d9b1 commit e2178c6
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion Project.toml
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.23.14"
version = "0.23.15"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
18 changes: 10 additions & 8 deletions src/simple_varinfo.jl
Expand Up @@ -259,17 +259,11 @@ end
Base.isempty(vi::SimpleVarInfo) = isempty(vi.values)

getlogp(vi::SimpleVarInfo) = vi.logp
getlogp(vi::SimpleVarInfo{<:Any,<:Ref}) = vi.logp[]

setlogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = logp
acclogp!!(vi::SimpleVarInfo, logp) = Setfield.@set vi.logp = getlogp(vi) + logp

"""
keys(vi::SimpleVarInfo)
Return an iterator of keys present in `vi`.
"""
Base.keys(vi::SimpleVarInfo) = keys(vi.values)
Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values))

function setlogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
vi.logp[] = logp
return vi
Expand All @@ -280,6 +274,14 @@ function acclogp!!(vi::SimpleVarInfo{<:Any,<:Ref}, logp)
return vi
end

"""
keys(vi::SimpleVarInfo)
Return an iterator of keys present in `vi`.
"""
Base.keys(vi::SimpleVarInfo) = keys(vi.values)
Base.keys(vi::SimpleVarInfo{<:NamedTuple}) = map(k -> VarName{k}(), keys(vi.values))

function Base.show(io::IO, ::MIME"text/plain", svi::SimpleVarInfo)
if !(svi.transformation isa NoTransformation)
print(io, "Transformed ")
Expand Down
12 changes: 9 additions & 3 deletions src/test_utils.jl
Expand Up @@ -43,16 +43,22 @@ Return a tuple of instances for different implementations of `AbstractVarInfo` w
each `vi`, supposedly, satisfying `vi[vn] == get(example_values, vn)` for `vn` in `varnames`.
"""
function setup_varinfos(model::Model, example_values::NamedTuple, varnames)
# <:VarInfo
# VarInfo
vi_untyped = VarInfo()
model(vi_untyped)
vi_typed = DynamicPPL.TypedVarInfo(vi_untyped)
# <:SimpleVarInfo
# SimpleVarInfo
svi_typed = SimpleVarInfo(example_values)
svi_untyped = SimpleVarInfo(OrderedDict())

# SimpleVarInfo{<:Any,<:Ref}
svi_typed_ref = SimpleVarInfo(example_values, Ref(getlogp(svi_typed)))
svi_untyped_ref = SimpleVarInfo(OrderedDict(), Ref(getlogp(svi_untyped)))

lp = getlogp(vi_typed)
return map((vi_untyped, vi_typed, svi_typed, svi_untyped)) do vi
return map((
vi_untyped, vi_typed, svi_typed, svi_untyped, svi_typed_ref, svi_untyped_ref
)) do vi
# Set them all to the same values.
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
end
Expand Down

0 comments on commit e2178c6

Please sign in to comment.