Skip to content

Commit

Permalink
Merge branch 'master' into levy-ssm
Browse files Browse the repository at this point in the history
  • Loading branch information
yebai committed Feb 23, 2024
2 parents a71a47f + 1e5dfdd commit 3a9a7ba
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 50 deletions.
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.5.1"
version = "0.5.4"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand All @@ -24,7 +24,8 @@ Libtask = "0.8"
Random123 = "1.3"
Requires = "1.0"
StatsFuns = "0.9, 1"
julia = "1.3"
Random = "1.6"
julia = "1.6"

[extras]
Libtask = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
Expand Down
60 changes: 31 additions & 29 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,34 @@ end
"""
LibtaskModel{F}
State wrapper to hold `Libtask.CTask` model initiated from `f`
State wrapper to hold `Libtask.CTask` model initiated from `f`.
"""
struct LibtaskModel{F1,F2}
f::F1
ctask::Libtask.TapedTask{F2}

LibtaskModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask)
end

function LibtaskModel(f, args...)
return LibtaskModel(
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(
f,
Libtask.TapedTask(f, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}),
Libtask.TapedTask(
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
),
)
end

Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask))
"""
copy(model::AdvancedPS.LibtaskModel)
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R}
The task is copied (forked) and the inner model is deepcopied.
"""
function Base.copy(model::AdvancedPS.LibtaskModel)
return AdvancedPS.LibtaskModel(deepcopy(model.f), copy(model.ctask))
end

const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
return AdvancedPS.Trace(LibtaskModel(model, args...), rng)
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
end

# step to the next observe statement and
Expand All @@ -56,7 +60,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
end

# create a backward reference in task_local_storage
function addreference!(task::Task, trace::LibtaskTrace)
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
Expand All @@ -65,9 +69,7 @@ function addreference!(task::Task, trace::LibtaskTrace)
return task
end

current_trace() = current_task().storage[:__trace]

function update_rng!(trace::LibtaskTrace)
function AdvancedPS.update_rng!(trace::LibtaskTrace)
rng, = trace.model.ctask.args
trace.rng = rng
return trace
Expand All @@ -76,12 +78,12 @@ end
# Task copying version of fork for Trace.
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
newtrace = copy(trace)
update_rng!(newtrace)
AdvancedPS.update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

# add backward reference
addreference!(newtrace.model.ctask.task, newtrace)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end

Expand All @@ -94,11 +96,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
ctask = Libtask.TapedTask(
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
)
new_tapedmodel = LibtaskModel(newf, ctask)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
addreference!(ctask.task, newtrace)
AdvancedPS.addreference!(ctask.task, newtrace)
AdvancedPS.gen_refseed!(newtrace)
return newtrace
end
Expand Down Expand Up @@ -135,9 +137,8 @@ function AbstractMCMC.step(
AdvancedPS.forkr(copy(state.trajectory))
else
trng = AdvancedPS.TracedRNG()
gen_model = LibtaskModel(deepcopy(model), trng)
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
trace
end
end
Expand Down Expand Up @@ -174,9 +175,8 @@ function AbstractMCMC.sample(

traces = map(1:(sampler.nparticles)) do i
trng = AdvancedPS.TracedRNG()
gen_model = LibtaskModel(deepcopy(model), trng)
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
trace
end

Expand All @@ -202,7 +202,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle)
trng = deepcopy(particle.rng)
Random123.set_counter!(trng.rng, 0)
trng.count = 1
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(particle.model.f), trng), trng)
trace = AdvancedPS.Trace(
AdvancedPS.LibtaskModel(deepcopy(particle.model.f), trng), trng
)
score = AdvancedPS.advance!(trace, true)
while !isnothing(score)
score = AdvancedPS.advance!(trace, true)
Expand Down
3 changes: 0 additions & 3 deletions ext/SSMProblemsExt.jl

This file was deleted.

12 changes: 6 additions & 6 deletions src/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ end
Update reference trajectory. Defaults to `nothing`
"""
function update_ref!(
particle::Trace, pc::ParticleContainer, sampler::AbstractParticleSampler
)
particle::Trace, pc::ParticleContainer, sampler::T
) where {T<:AbstractMCMC.AbstractSampler}
return nothing
end

Expand Down Expand Up @@ -171,11 +171,11 @@ of the particle `weights`. For Particle Gibbs sampling, one can provide a refere
function resample_propagate!(
::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
sampler::T,
randcat=DEFAULT_RESAMPLER,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
)
) where {T<:AbstractMCMC.AbstractSampler}
# sample ancestor indices
n = length(pc)
nresamples = ref === nothing ? n : n - 1
Expand Down Expand Up @@ -233,11 +233,11 @@ end
function resample_propagate!(
rng::Random.AbstractRNG,
pc::ParticleContainer,
sampler::AbstractParticleSampler,
sampler::T,
resampler::ResampleWithESSThreshold,
ref::Union{Particle,Nothing}=nothing;
weights=getweights(pc),
)
) where {T<:AbstractMCMC.AbstractSampler}
# Compute the effective sample size ``1 / ∑ wᵢ²`` with normalized weights ``wᵢ``
ess = inv(sum(abs2, weights))

Expand Down
31 changes: 22 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Trace{F,R}
Trace{F,R}
"""
mutable struct Trace{F,R}
model::F
Expand All @@ -10,24 +10,37 @@ const Particle = Trace
const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R}
const GenericTrace{R} = Trace{<:AbstractGenericModel,R}

# reset log probability
reset_logprob!(::AdvancedPS.Particle) = nothing

reset_model(f) = deepcopy(f)
delete_retained!(f) = nothing

Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))
"""
copy(trace::Trace)
# This is required to make it visible from outside extensions
function observe end
function replay end
Copy a trace. The `TracedRNG` is deep-copied. The inner model is shallow-copied.
"""
Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))

"""
gen_refseed!(particle::Particle)
gen_refseed!(particle::Particle)
Generate a new seed for the reference particle
Generate a new seed for the reference particle.
"""
function gen_refseed!(particle::Particle)
seed = split(state(particle.rng.rng), 1)
return safe_set_refseed!(particle.rng, seed[1])
end

# A few internal functions used in the Libtask extension. Since it is not possible to access objects defined
# in an extension, we just define dummy in the main module and implement them in the extension.
function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
f::F
ctask::T
end
2 changes: 2 additions & 0 deletions src/rng.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,5 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
Increase the model step counter by `n`
"""
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

function update_rng! end
17 changes: 16 additions & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@

# Test task copy version of trace
trng = AdvancedPS.TracedRNG()
tr = AdvancedPS.Trace(Model(Ref(0)), trng, trng)
tr = AdvancedPS.Trace(Model(Ref(0)), trng)

consume(tr.model.ctask)
consume(tr.model.ctask)
Expand All @@ -143,6 +143,21 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)(rng::Random.AbstractRNG)
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end

@testset "seed container" begin
seed = 1
n = 3
Expand Down

0 comments on commit 3a9a7ba

Please sign in to comment.