Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@ version = "0.1.0"

[deps]
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
Distributions = "0.23, 0.24"
Libtask = "0.5"
StatsFuns = "0.9"
julia = "1.3"
4 changes: 4 additions & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
module AdvancedPS

import Distributions
import Libtask
import StatsFuns

include("resampling.jl")
include("container.jl")

end
163 changes: 80 additions & 83 deletions src/container.jl
Original file line number Diff line number Diff line change
@@ -1,81 +1,79 @@
mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model}
model::Tmodel
spl::Tspl
vi::Tvi
ctask::CTask

function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo)
return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi)
end
function Trace{S}(model::Model, spl::S, vi::AbstractVarInfo) where S<:Sampler
return new{S,typeof(vi),typeof(model)}(model, spl, vi)
end
struct Trace{F}
f::F
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever trick to unify Function and TracedModel!

Maybe consider making TracedModel and TracedFun as subtypes of AbstractMCMC.AbstractModel? It slightly improves clarity, and also prevent potential misuses.

https://github.com/TuringLang/AbstractMCMC.jl/blob/master/src/AbstractMCMC.jl#L48

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider making TracedModel and TracedFun as subtypes of AbstractMCMC.AbstractModel? It slightly improves clarity, and also prevent potential misuses.

Yes, maybe that could be done. On the other hand, such type constraints can be quite restrictive and it didn't seem necessary to introduce them at this stage. Also maybe even FixedModel would be more appropriate since TracedModel wraps a Model and a specific sampler and VarInfo object (I guess it should also include the RNG but this is not part of the current implementation and hence I did not want to change it in these PRs).

ctask::Libtask.CTask
end

function Base.copy(trace::Trace)
vi = deepcopy(trace.vi)
res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi)
res.ctask = copy(trace.ctask)
return res
end
const Particle = Trace
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The purpose of differentiating Particle and Trace no longer holds since AdvancedPS is independent of Turing. We can probably clean up the terminology, e.g. by getting rid of Trace and use Particle everwhere. But this can be done in a separate PR as you suggested.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good 👍


# NOTE: this function is called by `forkr`
function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
ctask = CTask() do
res = f()
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
function Trace(f)
ctask = let f=f
Libtask.CTask() do
res = f()
Libtask.produce(nothing)
return res
end
end
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res.ctask = ctask
return res
end

function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo)
res = Trace{typeof(spl)}(m, spl, deepcopy(vi))
reset_num_produce!(res.vi)
ctask = CTask() do
res = m(vi, spl)
produce(nothing)
return res
end
task = ctask.task
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:turing_trace] = res # create a backward reference in task_local_storage
res.ctask = ctask
return res
# add backward reference
newtrace = Trace(f, ctask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is slightly unclear how deepcopy(vi) is performed for f::TracedModel here.

EDIT: I found the answer in the companion PR for Turing.
Ref: https://github.com/TuringLang/Turing.jl/pull/1482/files#diff-f2a243e03d83fd80c6af74a78557a37eb7bcf265c509d29b9faa7ec19873b0c5R20

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the deepcopy(vi) parts are a bit annoying and prone to bugs. They should all be handled in the Turing PR and it should be exactly the same as in the current implementation. However, I noticed some places where some operations seem redundant (e.g., deepcopy of a deepcopy or multiple calls of reset_num_produce!) but it was a bit difficult to figure out during the split where things could be dropped so I kept everything for now.

addreference!(ctask.task, newtrace)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to wrap these into a new function - it makes the code more readable!


return newtrace
end

# step to the next observe statement, return log likelihood
Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask))
Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask))

# step to the next observe statement and
# return the log probability of the transition (or nothing if done)
advance!(t::Trace) = Libtask.consume(t.ctask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's sensible to use advance! here. But if we are going to support the Iterator interface, shouldn't it be iterate? See https://docs.julialang.org/en/v1/manual/interfaces/

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree. But I noticed that this would require additional changes since we have to propagate the state of the iterators in some way, if we do not assume that they are mutating. Hence I did not switch to the interface in this PR.


# reset log probability
reset_logprob!(t::Trace) = nothing

reset_model(f) = nothing
delete_retained!(f) = nothing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a note for future, I think these types and abstract functions for Trace and VarInfo should go into a lightweight package AbstractPPL. Then we should gradually switch to AbstractPPL in DynamicPPL and talk with other PPL library developers (e.g. Birch, Soss, Gen) to create a minimal PPL base package in Julia (and beyond!).

@cpfiffer @phipsgabler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a nice goal but I think it needs more time. I am not completely convinced that the methods and abstractions in this PR are the best ones for future development, so I guess it might be useful to stabilize some API in AdvancedPS first 🙂

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's just keep an eye on what kinds of functions end up being useful to see where generalizations are possible.


# Task copying version of fork for Trace.
function fork(trace :: Trace, is_ref :: Bool = false)
function fork(trace::Trace, isref::Bool = false)
newtrace = copy(trace)
is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl)
newtrace.ctask.task.storage[:turing_trace] = newtrace
isref && delete_retained!(newtrace.f)

# add backward reference
addreference!(newtrace.ctask.task, newtrace)

return newtrace
end

# PG requires keeping all randomness for the reference particle
# Create new task and copy randomness
function forkr(trace::Trace)
newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi))
newtrace.spl = trace.spl
reset_num_produce!(newtrace.vi)
newf = reset_model(trace.f)
ctask = let f=trace.ctask.task.code
Libtask.CTask() do
res = f()
Libtask.produce(nothing)
return res
end
end

# add backward reference
newtrace = Trace(newf, ctask)
addreference!(ctask.task, newtrace)

return newtrace
end

current_trace() = current_task().storage[:turing_trace]
# create a backward reference in task_local_storage
function addreference!(task::Task, trace::Trace)
if task.storage === nothing
task.storage = IdDict()
end
task.storage[:__trace] = trace

const Particle = Trace
return task
end

current_trace() = current_task().storage[:__trace]

"""
Data structure for particle filters
Expand Down Expand Up @@ -141,7 +139,7 @@ end

Compute the normalized weights of the particles.
"""
getweights(pc::ParticleContainer) = softmax(pc.logWs)
getweights(pc::ParticleContainer) = StatsFuns.softmax(pc.logWs)

"""
getweight(pc::ParticleContainer, i)
Expand All @@ -155,7 +153,7 @@ getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc))

Return the logarithm of the normalizing constant of the unnormalized logarithmic weights.
"""
logZ(pc::ParticleContainer) = logsumexp(pc.logWs)
logZ(pc::ParticleContainer) = StatsFuns.logsumexp(pc.logWs)

"""
effectiveSampleSize(pc::ParticleContainer)
Expand All @@ -168,7 +166,7 @@ function effectiveSampleSize(pc::ParticleContainer)
end

"""
resample_propagate!(pc::ParticleContainer[, randcat = resample_systematic, ref = nothing;
resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing;
weights = getweights(pc)])

Resample and propagate the particles in `pc`.
Expand All @@ -179,7 +177,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
"""
function resample_propagate!(
pc::ParticleContainer,
randcat = Turing.Inference.resample_systematic,
randcat = resample,
ref::Union{Particle, Nothing} = nothing;
weights = getweights(pc)
)
Expand Down Expand Up @@ -231,6 +229,22 @@ function resample_propagate!(
pc
end

function resample_propagate!(
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
weights = getweights(pc)
)
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
ess = inv(sum(abs2, weights))

if ess ≤ resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
end

pc
end

"""
reweight!(pc::ParticleContainer)

Expand All @@ -249,19 +263,18 @@ function reweight!(pc::ParticleContainer)
# the execution of the model is finished.
# Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and
# ``θᵢ`` are variables of other samplers.
score = Libtask.consume(p)
score = advance!(p)

if score === nothing
numdone += 1
else
# Increase the unnormalized logarithmic weights, accounting for the variables
# of other samplers.
increase_logweight!(pc, i, score + getlogp(p.vi))
# Increase the unnormalized logarithmic weights.
increase_logweight!(pc, i, score)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, score + getlogp(p.vi) is necessary for Gibbs sampling to work since some variables might be updated outside particle Gibbs. I am slightly confused by how the new mechanism handles such cases?

EDIT: Found the answer in Turing's companion PR. Maybe consider adding a comment here to avoid future confusion.
Ref: https://github.com/TuringLang/Turing.jl/pull/1482/files#diff-f2a243e03d83fd80c6af74a78557a37eb7bcf265c509d29b9faa7ec19873b0c5R24

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main motivation for this change is that the setup felt too Turing-specific - it seemed like it should be possible to just implement a model where produce yields the correct log probabilities right away. Also getlogp would introduce a dependency on DynamicPPL.


# Reset the accumulator of the log probability in the model so that we can
# accumulate log probabilities of variables of other samplers until the next
# observation.
resetlogp!(p.vi)
reset_logprob!(p)
end
end

Expand Down Expand Up @@ -333,19 +346,3 @@ function sweep!(pc::ParticleContainer, resampler)

return logevidence
end

function resample_propagate!(
pc::ParticleContainer,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing} = nothing;
weights = getweights(pc)
)
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
ess = inv(sum(abs2, weights))

if ess ≤ resampler.threshold * length(pc)
resample_propagate!(pc, resampler.resampler, ref; weights = weights)
end

pc
end
22 changes: 13 additions & 9 deletions src/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,6 @@ function ResampleWithESSThreshold(resampler = resample)
ResampleWithESSThreshold(resampler, 0.5)
end

# Default resampling scheme
function resample(w::AbstractVector{<:Real}, num_particles::Integer=length(w))
return resample_systematic(w, num_particles)
end

# More stable, faster version of rand(Categorical)
function randcat(p::AbstractVector{<:Real})
T = eltype(p)
Expand All @@ -36,11 +31,17 @@ function randcat(p::AbstractVector{<:Real})
return s
end

function resample_multinomial(w::AbstractVector{<:Real}, num_particles::Integer)
function resample_multinomial(
w::AbstractVector{<:Real},
num_particles::Integer = length(w),
)
return rand(Distributions.sampler(Distributions.Categorical(w)), num_particles)
end

function resample_residual(w::AbstractVector{<:Real}, num_particles::Integer)
function resample_residual(
w::AbstractVector{<:Real},
num_particles::Integer = length(weights),
)
# Pre-allocate array for resampled particles
indices = Vector{Int}(undef, num_particles)

Expand Down Expand Up @@ -79,7 +80,7 @@ are selected according to the multinomial distribution defined by the normalized
i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer)
function resample_stratified(weights::AbstractVector{<:Real}, n::Integer = length(weights))
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand Down Expand Up @@ -124,7 +125,7 @@ numbers `u₁`, ..., `uₙ` where ``uₖ = (u + k − 1) / n``. Based on these n
normalized `weights`, i.e., `xᵢ = j` if and only if
``uᵢ \\in [\\sum_{s=1}^{j-1} weights_{s}, \\sum_{s=1}^{j} weights_{s})``.
"""
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer)
function resample_systematic(weights::AbstractVector{<:Real}, n::Integer = length(weights))
# check input
m = length(weights)
m > 0 || error("weight vector is empty")
Expand Down Expand Up @@ -157,3 +158,6 @@ function resample_systematic(weights::AbstractVector{<:Real}, n::Integer)

return samples
end

# Default resampling scheme
const resample = resample_systematic
4 changes: 3 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
[deps]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
julia = "1.3"
Libtask = "0.5"
julia = "1.3"
Loading