Skip to content

Commit

Permalink
Merge 20aab8b into b157064
Browse files Browse the repository at this point in the history
  • Loading branch information
FredericWantiez committed Nov 3, 2023
2 parents b157064 + 20aab8b commit 6449fbd
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 40 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AdvancedPS"
uuid = "576499cb-2369-40b2-a588-c64705576edc"
authors = ["TuringLang"]
version = "0.5.1"
version = "0.5.2"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Expand Down
55 changes: 26 additions & 29 deletions ext/AdvancedPSLibtaskExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,30 +19,29 @@ end
"""
LibtaskModel{F}
State wrapper to hold `Libtask.CTask` model initiated from `f`
State wrapper to hold `Libtask.CTask` model initiated from `f`.
"""
struct LibtaskModel{F1,F2}
f::F1
ctask::Libtask.TapedTask{F2}

LibtaskModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask)
end

function LibtaskModel(f, args...)
return LibtaskModel(
function AdvancedPS.LibtaskModel(
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
) # Changed the API, need to take care of the RNG properly
return AdvancedPS.LibtaskModel(
f,
Libtask.TapedTask(f, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}),
Libtask.TapedTask(
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}
),
)
end

Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask))
function Base.copy(model::AdvancedPS.LibtaskModel)
return AdvancedPS.LibtaskModel(model.f, copy(model.ctask))
end

const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R}
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R}

function AdvancedPS.Trace(
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args...
)
return AdvancedPS.Trace(LibtaskModel(model, args...), rng)
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng)
end

# step to the next observe statement and
Expand All @@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false)
end

# create a backward reference in task_local_storage
function addreference!(task::Task, trace::LibtaskTrace)
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace)
if task.storage === nothing
task.storage = IdDict()
end
Expand All @@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace)
return task
end

current_trace() = current_task().storage[:__trace]

function update_rng!(trace::LibtaskTrace)
function AdvancedPS.update_rng!(trace::LibtaskTrace)
rng, = trace.model.ctask.args
trace.rng = rng
return trace
Expand All @@ -76,12 +73,12 @@ end
# Task copying version of fork for Trace.
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false)
newtrace = copy(trace)
update_rng!(newtrace)
AdvancedPS.update_rng!(newtrace)
isref && AdvancedPS.delete_retained!(newtrace.model.f)
isref && delete_seeds!(newtrace)

# add backward reference
addreference!(newtrace.model.ctask.task, newtrace)
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace)
return newtrace
end

Expand All @@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace)
ctask = Libtask.TapedTask(
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)}
)
new_tapedmodel = LibtaskModel(newf, ctask)
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask)

# add backward reference
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng)
addreference!(ctask.task, newtrace)
AdvancedPS.addreference!(ctask.task, newtrace)
AdvancedPS.gen_refseed!(newtrace)
return newtrace
end
Expand Down Expand Up @@ -135,9 +132,8 @@ function AbstractMCMC.step(
AdvancedPS.forkr(copy(state.trajectory))
else
trng = AdvancedPS.TracedRNG()
gen_model = LibtaskModel(deepcopy(model), trng)
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ?
trace
end
end
Expand Down Expand Up @@ -174,9 +170,8 @@ function AbstractMCMC.sample(

traces = map(1:(sampler.nparticles)) do i
trng = AdvancedPS.TracedRNG()
gen_model = LibtaskModel(deepcopy(model), trng)
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng)
addreference!(gen_model.ctask.task, trace) # Do we need it here ?
trace = AdvancedPS.Trace(deepcopy(model), trng)
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ?
trace
end

Expand All @@ -202,7 +197,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle)
trng = deepcopy(particle.rng)
Random123.set_counter!(trng.rng, 0)
trng.count = 1
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(particle.model.f), trng), trng)
trace = AdvancedPS.Trace(
AdvancedPS.LibtaskModel(deepcopy(particle.model.f), trng), trng
)
score = AdvancedPS.advance!(trace, true)
while !isnothing(score)
score = AdvancedPS.advance!(trace, true)
Expand Down
31 changes: 22 additions & 9 deletions src/model.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Trace{F,R}
Trace{F,R}
"""
mutable struct Trace{F,R}
model::F
Expand All @@ -10,24 +10,37 @@ const Particle = Trace
const SSMTrace{R} = Trace{<:AbstractStateSpaceModel,R}
const GenericTrace{R} = Trace{<:AbstractGenericModel,R}

# reset log probability
reset_logprob!(::AdvancedPS.Particle) = nothing

reset_model(f) = deepcopy(f)
delete_retained!(f) = nothing

Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))
"""
copy(trace::Trace)
# This is required to make it visible from outside extensions
function observe end
function replay end
Copy a trace. The `TracedRNG` is deep-copied. The inner model is shallow-copied.
"""
Base.copy(trace::Trace) = Trace(copy(trace.model), deepcopy(trace.rng))

"""
gen_refseed!(particle::Particle)
gen_refseed!(particle::Particle)
Generate a new seed for the reference particle
Generate a new seed for the reference particle.
"""
function gen_refseed!(particle::Particle)
seed = split(state(particle.rng.rng), 1)
return safe_set_refseed!(particle.rng, seed[1])
end

# A few internal functions used in the Libtask extension. Since it is not possible to access objects defined
# in an extension, we just define dummy in the main module and implement them in the extension.
function observe end
function replay end
function addreference! end

current_trace() = current_task().storage[:__trace]

# We need this one to be visible outside of the extension for dispatching (Turing.jl).
struct LibtaskModel{F,T}
f::F
ctask::T
end
2 changes: 2 additions & 0 deletions src/rng.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,5 @@ Random123.set_counter!(r::TracedRNG, n::Integer) = r.count = n
Increase the model step counter by `n`
"""
inc_counter!(r::TracedRNG, n::Integer=1) = r.count += n

function update_rng! end
17 changes: 16 additions & 1 deletion test/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@

# Test task copy version of trace
trng = AdvancedPS.TracedRNG()
tr = AdvancedPS.Trace(Model(Ref(0)), trng, trng)
tr = AdvancedPS.Trace(Model(Ref(0)), trng)

consume(tr.model.ctask)
consume(tr.model.ctask)
Expand All @@ -143,6 +143,21 @@
@test consume(a.model.ctask) == 4
end

@testset "current trace" begin
struct TaskIdModel <: AdvancedPS.AbstractGenericModel end

function (model::TaskIdModel)(rng::Random.AbstractRNG)
# Just print the task it's running in
id = objectid(AdvancedPS.current_trace())
return Libtask.produce(id)
end

trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG())
AdvancedPS.addreference!(trace.model.ctask.task, trace)

@test AdvancedPS.advance!(trace, false) === objectid(trace)
end

@testset "seed container" begin
seed = 1
n = 3
Expand Down

0 comments on commit 6449fbd

Please sign in to comment.