Skip to content

Commit

Permalink
Add acclogp_assume!! and acclogp_observe!! (#565)
Browse files Browse the repository at this point in the history
* add hooks for acclogp!! depending on whether it's from an `assume` or
`observe` statement

* bump patch version

* Update src/context_implementations.jl

* Update Project.toml

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
torfjelde and yebai committed Nov 22, 2023
1 parent b52e4c2 commit 58bc18e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.24.3"

version = "0.24.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
20 changes: 15 additions & 5 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,15 @@ alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
require_gradient(spl::Sampler) = false
require_particles(spl::Sampler) = false

# Allows samplers, etc. to hook into the final logp accumulation in the tilde-pipeline.
function acclogp_assume!!(context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp!!(context, vi, logp)
end

function acclogp_observe!!(context::AbstractContext, vi::AbstractVarInfo, logp)
return acclogp!!(context, vi, logp)
end

# assume
"""
tilde_assume(context::SamplingContext, right, vn, vi)
Expand Down Expand Up @@ -115,7 +124,7 @@ probability of `vi` with the returned value.
"""
function tilde_assume!!(context, right, vn, vi)
value, logp, vi = tilde_assume(context, right, vn, vi)
return value, acclogp!!(context, vi, logp)
return value, acclogp_assume!!(context, vi, logp)
end

# observe
Expand Down Expand Up @@ -181,7 +190,7 @@ probability of `vi` with the returned value.
"""
function tilde_observe!!(context, right, left, vi)
logp, vi = tilde_observe(context, right, left, vi)
return left, acclogp!!(context, vi, logp)
return left, acclogp_observe!!(context, vi, logp)
end

function assume(rng, spl::Sampler, dist)
Expand Down Expand Up @@ -383,7 +392,7 @@ Falls back to `dot_tilde_assume(context, right, left, vn, vi)`.
"""
function dot_tilde_assume!!(context, right, left, vn, vi)
value, logp, vi = dot_tilde_assume(context, right, left, vn, vi)
return value, acclogp!!(context, vi, logp), vi
return value, acclogp_assume!!(context, vi, logp), vi
end

# `dot_assume`
Expand Down Expand Up @@ -539,7 +548,8 @@ function get_and_set_val!(
if istrans(vi)
push!!.((vi,), vns, reconstruct_and_link.((vi,), vns, dists, r), dists, (spl,))
# NOTE: Need to add the correction.
acclogp!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# FIXME: This is not great.
acclogp_assume!!(vi, sum(logabsdetjac.(bijector.(dists), r)))
# `push!!` sets the trans-flag to `false` by default.
settrans!!.((vi,), true, vns)
else
Expand Down Expand Up @@ -634,7 +644,7 @@ Falls back to `dot_tilde_observe(context, right, left, vi)`.
"""
function dot_tilde_observe!!(context, right, left, vi)
logp, vi = dot_tilde_observe(context, right, left, vi)
return left, acclogp!!(context, vi, logp)
return left, acclogp_observe!!(context, vi, logp)
end

# Falls back to non-sampler definition.
Expand Down

0 comments on commit 58bc18e

Please sign in to comment.