Skip to content

Commit

Permalink
Merge 87b1d2f into 487a943
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Apr 17, 2021
2 parents 487a943 + 87b1d2f commit c997045
Show file tree
Hide file tree
Showing 8 changed files with 210 additions and 433 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ authors = ["TuringLang"]
version = "0.2.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"

[compat]
AbstractMCMC = "2"
Distributions = "0.23, 0.24"
Libtask = "0.5"
StatsFuns = "0.9"
Expand Down
3 changes: 3 additions & 0 deletions src/AdvancedPS.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
module AdvancedPS

import AbstractMCMC
import Distributions
import Libtask
import Random
import StatsFuns

include("resampling.jl")
include("container.jl")
include("smc.jl")
include("model.jl")

end
13 changes: 9 additions & 4 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ advance!(t::Trace) = Libtask.consume(t.ctask)
# reset log probability
reset_logprob!(t::Trace) = nothing

reset_model(f) = nothing
reset_model(f) = deepcopy(f)
delete_retained!(f) = nothing

# Task copying version of fork for Trace.
Expand Down Expand Up @@ -59,7 +59,7 @@ function forkr(trace::Trace)
# add backward reference
newtrace = Trace(newf, ctask)
addreference!(ctask.task, newtrace)

return newtrace
end

Expand Down Expand Up @@ -96,6 +96,11 @@ Base.collect(pc::ParticleContainer) = pc.vals
Base.length(pc::ParticleContainer) = length(pc.vals)
Base.@propagate_inbounds Base.getindex(pc::ParticleContainer, i::Int) = pc.vals[i]

function Base.rand(rng::Random.AbstractRNG, pc::ParticleContainer)
index = randcat(rng, getweights(pc))
return pc[index]
end

# registers a new x-particle in the container
function Base.push!(pc::ParticleContainer, p::Particle)
push!(pc.vals, p)
Expand Down Expand Up @@ -319,7 +324,7 @@ function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
logZ0 = logZ(pc)

# Reweight the particles by including the first observation ``y₁``.
isdone = reweight!(rng, pc)
isdone = reweight!(pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand All @@ -337,7 +342,7 @@ function sweep!(rng::Random.AbstractRNG, pc::ParticleContainer, resampler)
logZ0 = logZ(pc)

# Reweight the particles by including the next observation ``yₜ``.
isdone = reweight!(rng, pc)
isdone = reweight!(pc)

# Compute the normalizing constant ``Z₁`` after reweighting.
logZ1 = logZ(pc)
Expand Down
8 changes: 8 additions & 0 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
"""
observe(dist::Distribution, x)
Observe sample `x` from distribution `dist` and yield its log-likelihood value.
"""
function observe(dist::Distributions.Distribution, x)
return Libtask.produce(Distributions.loglikelihood(dist, x))
end

0 comments on commit c997045

Please sign in to comment.