-
Notifications
You must be signed in to change notification settings - Fork 12
Add Trace without Turing #17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,6 +1,10 @@ | ||
| module AdvancedPS | ||
|
|
||
| import Distributions | ||
| import Libtask | ||
| import StatsFuns | ||
|
|
||
| include("resampling.jl") | ||
| include("container.jl") | ||
|
|
||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,81 +1,79 @@ | ||
| mutable struct Trace{Tspl<:AbstractSampler, Tvi<:AbstractVarInfo, Tmodel<:Model} | ||
| model::Tmodel | ||
| spl::Tspl | ||
| vi::Tvi | ||
| ctask::CTask | ||
|
|
||
| function Trace{SampleFromPrior}(model::Model, spl::AbstractSampler, vi::AbstractVarInfo) | ||
| return new{SampleFromPrior,typeof(vi),typeof(model)}(model, SampleFromPrior(), vi) | ||
| end | ||
| function Trace{S}(model::Model, spl::S, vi::AbstractVarInfo) where S<:Sampler | ||
| return new{S,typeof(vi),typeof(model)}(model, spl, vi) | ||
| end | ||
| struct Trace{F} | ||
| f::F | ||
| ctask::Libtask.CTask | ||
| end | ||
|
|
||
| function Base.copy(trace::Trace) | ||
| vi = deepcopy(trace.vi) | ||
| res = Trace{typeof(trace.spl)}(trace.model, trace.spl, vi) | ||
| res.ctask = copy(trace.ctask) | ||
| return res | ||
| end | ||
| const Particle = Trace | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The purpose of differentiating
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good 👍 |
||
|
|
||
| # NOTE: this function is called by `forkr` | ||
| function Trace(f, m::Model, spl::AbstractSampler, vi::AbstractVarInfo) | ||
| res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) | ||
| ctask = CTask() do | ||
| res = f() | ||
| produce(nothing) | ||
| return res | ||
| end | ||
| task = ctask.task | ||
| if task.storage === nothing | ||
| task.storage = IdDict() | ||
| function Trace(f) | ||
| ctask = let f=f | ||
| Libtask.CTask() do | ||
| res = f() | ||
| Libtask.produce(nothing) | ||
| return res | ||
| end | ||
| end | ||
| task.storage[:turing_trace] = res # create a backward reference in task_local_storage | ||
| res.ctask = ctask | ||
| return res | ||
| end | ||
|
|
||
| function Trace(m::Model, spl::AbstractSampler, vi::AbstractVarInfo) | ||
| res = Trace{typeof(spl)}(m, spl, deepcopy(vi)) | ||
| reset_num_produce!(res.vi) | ||
| ctask = CTask() do | ||
| res = m(vi, spl) | ||
| produce(nothing) | ||
| return res | ||
| end | ||
| task = ctask.task | ||
| if task.storage === nothing | ||
| task.storage = IdDict() | ||
| end | ||
| task.storage[:turing_trace] = res # create a backward reference in task_local_storage | ||
| res.ctask = ctask | ||
| return res | ||
| # add backward reference | ||
| newtrace = Trace(f, ctask) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is slightly unclear how EDIT: I found the answer in the companion PR for Turing.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, the |
||
| addreference!(ctask.task, newtrace) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good idea to wrap these into a new function - it makes the code more readable! |
||
|
|
||
| return newtrace | ||
| end | ||
|
|
||
| # step to the next observe statement, return log likelihood | ||
| Libtask.consume(t::Trace) = (increment_num_produce!(t.vi); consume(t.ctask)) | ||
| Base.copy(trace::Trace) = Trace(trace.f, copy(trace.ctask)) | ||
|
|
||
| # step to the next observe statement and | ||
| # return the log probability of the transition (or nothing if done) | ||
| advance!(t::Trace) = Libtask.consume(t.ctask) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's sensible to use
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree. But I noticed that this would require additional changes since we have to propagate the state of the iterators in some way, if we do not assume that they are mutating. Hence I did not switch to the interface in this PR. |
||
|
|
||
| # reset log probability | ||
| reset_logprob!(t::Trace) = nothing | ||
|
|
||
| reset_model(f) = nothing | ||
| delete_retained!(f) = nothing | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As a note for future, I think these types and abstract functions for
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a nice goal but I think it needs more time. I am not completely convinced that the methods and abstractions in this PR are the best ones for future development, so I guess it might be useful to stabilize some API in AdvancedPS first 🙂
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, let's just keep an eye on what kinds of functions end up being useful to see where generalizations are possible. |
||
|
|
||
| # Task copying version of fork for Trace. | ||
| function fork(trace :: Trace, is_ref :: Bool = false) | ||
| function fork(trace::Trace, isref::Bool = false) | ||
| newtrace = copy(trace) | ||
| is_ref && set_retained_vns_del_by_spl!(newtrace.vi, newtrace.spl) | ||
| newtrace.ctask.task.storage[:turing_trace] = newtrace | ||
| isref && delete_retained!(newtrace.f) | ||
|
|
||
| # add backward reference | ||
| addreference!(newtrace.ctask.task, newtrace) | ||
|
|
||
| return newtrace | ||
| end | ||
|
|
||
| # PG requires keeping all randomness for the reference particle | ||
| # Create new task and copy randomness | ||
| function forkr(trace::Trace) | ||
| newtrace = Trace(trace.ctask.task.code, trace.model, trace.spl, deepcopy(trace.vi)) | ||
| newtrace.spl = trace.spl | ||
| reset_num_produce!(newtrace.vi) | ||
| newf = reset_model(trace.f) | ||
| ctask = let f=trace.ctask.task.code | ||
| Libtask.CTask() do | ||
| res = f() | ||
| Libtask.produce(nothing) | ||
| return res | ||
| end | ||
| end | ||
|
|
||
| # add backward reference | ||
| newtrace = Trace(newf, ctask) | ||
| addreference!(ctask.task, newtrace) | ||
|
|
||
| return newtrace | ||
| end | ||
|
|
||
| current_trace() = current_task().storage[:turing_trace] | ||
| # create a backward reference in task_local_storage | ||
| function addreference!(task::Task, trace::Trace) | ||
| if task.storage === nothing | ||
| task.storage = IdDict() | ||
| end | ||
| task.storage[:__trace] = trace | ||
|
|
||
| const Particle = Trace | ||
| return task | ||
| end | ||
|
|
||
| current_trace() = current_task().storage[:__trace] | ||
|
|
||
| """ | ||
| Data structure for particle filters | ||
|
|
@@ -141,7 +139,7 @@ end | |
|
|
||
| Compute the normalized weights of the particles. | ||
| """ | ||
| getweights(pc::ParticleContainer) = softmax(pc.logWs) | ||
| getweights(pc::ParticleContainer) = StatsFuns.softmax(pc.logWs) | ||
|
|
||
| """ | ||
| getweight(pc::ParticleContainer, i) | ||
|
|
@@ -155,7 +153,7 @@ getweight(pc::ParticleContainer, i) = exp(pc.logWs[i] - logZ(pc)) | |
|
|
||
| Return the logarithm of the normalizing constant of the unnormalized logarithmic weights. | ||
| """ | ||
| logZ(pc::ParticleContainer) = logsumexp(pc.logWs) | ||
| logZ(pc::ParticleContainer) = StatsFuns.logsumexp(pc.logWs) | ||
|
|
||
| """ | ||
| effectiveSampleSize(pc::ParticleContainer) | ||
|
|
@@ -168,7 +166,7 @@ function effectiveSampleSize(pc::ParticleContainer) | |
| end | ||
|
|
||
| """ | ||
| resample_propagate!(pc::ParticleContainer[, randcat = resample_systematic, ref = nothing; | ||
| resample_propagate!(pc::ParticleContainer[, randcat = resample, ref = nothing; | ||
| weights = getweights(pc)]) | ||
|
|
||
| Resample and propagate the particles in `pc`. | ||
|
|
@@ -179,7 +177,7 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere | |
| """ | ||
| function resample_propagate!( | ||
| pc::ParticleContainer, | ||
| randcat = Turing.Inference.resample_systematic, | ||
| randcat = resample, | ||
| ref::Union{Particle, Nothing} = nothing; | ||
| weights = getweights(pc) | ||
| ) | ||
|
|
@@ -231,6 +229,22 @@ function resample_propagate!( | |
| pc | ||
| end | ||
|
|
||
| function resample_propagate!( | ||
| pc::ParticleContainer, | ||
| resampler::ResampleWithESSThreshold, | ||
| ref::Union{Particle,Nothing} = nothing; | ||
| weights = getweights(pc) | ||
| ) | ||
| # Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ`` | ||
| ess = inv(sum(abs2, weights)) | ||
|
|
||
| if ess ≤ resampler.threshold * length(pc) | ||
| resample_propagate!(pc, resampler.resampler, ref; weights = weights) | ||
| end | ||
|
|
||
| pc | ||
| end | ||
|
|
||
| """ | ||
| reweight!(pc::ParticleContainer) | ||
|
|
||
|
|
@@ -249,19 +263,18 @@ function reweight!(pc::ParticleContainer) | |
| # the execution of the model is finished. | ||
| # Here ``yᵢ`` are observations, ``xᵢ`` variables of the particle filter, and | ||
| # ``θᵢ`` are variables of other samplers. | ||
| score = Libtask.consume(p) | ||
| score = advance!(p) | ||
|
|
||
| if score === nothing | ||
| numdone += 1 | ||
| else | ||
| # Increase the unnormalized logarithmic weights, accounting for the variables | ||
| # of other samplers. | ||
| increase_logweight!(pc, i, score + getlogp(p.vi)) | ||
| # Increase the unnormalized logarithmic weights. | ||
| increase_logweight!(pc, i, score) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIRC, EDIT: Found the answer in Turing's companion PR. Maybe consider adding a comment here to avoid future confusion.
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The main motivation for this change is that the setup felt too Turing-specific - it seemed like it should be possible to just implement a model where |
||
|
|
||
| # Reset the accumulator of the log probability in the model so that we can | ||
| # accumulate log probabilities of variables of other samplers until the next | ||
| # observation. | ||
| resetlogp!(p.vi) | ||
| reset_logprob!(p) | ||
| end | ||
| end | ||
|
|
||
|
|
@@ -333,19 +346,3 @@ function sweep!(pc::ParticleContainer, resampler) | |
|
|
||
| return logevidence | ||
| end | ||
|
|
||
| function resample_propagate!( | ||
| pc::ParticleContainer, | ||
| resampler::ResampleWithESSThreshold, | ||
| ref::Union{Particle,Nothing} = nothing; | ||
| weights = getweights(pc) | ||
| ) | ||
| # Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ`` | ||
| ess = inv(sum(abs2, weights)) | ||
|
|
||
| if ess ≤ resampler.threshold * length(pc) | ||
| resample_propagate!(pc, resampler.resampler, ref; weights = weights) | ||
| end | ||
|
|
||
| pc | ||
| end | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| [deps] | ||
| Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" | ||
| Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" | ||
|
|
||
| [compat] | ||
| julia = "1.3" | ||
| Libtask = "0.5" | ||
| julia = "1.3" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clever trick to unify
FunctionandTracedModel!Maybe consider making
TracedModelandTracedFunas subtypes ofAbstractMCMC.AbstractModel? It slightly improves clarity, and also prevent potential misuses.https://github.com/TuringLang/AbstractMCMC.jl/blob/master/src/AbstractMCMC.jl#L48
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, maybe that could be done. On the other hand, such type constraints can be quite restrictive and it didn't seem necessary to introduce them at this stage. Also maybe even
FixedModelwould be more appropriate sinceTracedModelwraps aModeland a specific sampler andVarInfoobject (I guess it should also include the RNG but this is not part of the current implementation and hence I did not want to change it in these PRs).