-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose API for Turing integration #90
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -19,30 +19,29 @@ end | |||||
""" | ||||||
LibtaskModel{F} | ||||||
|
||||||
State wrapper to hold `Libtask.CTask` model initiated from `f` | ||||||
State wrapper to hold `Libtask.CTask` model initiated from `f`. | ||||||
""" | ||||||
struct LibtaskModel{F1,F2} | ||||||
f::F1 | ||||||
ctask::Libtask.TapedTask{F2} | ||||||
|
||||||
LibtaskModel(f::F1, ctask::Libtask.TapedTask{F2}) where {F1,F2} = new{F1,F2}(f, ctask) | ||||||
end | ||||||
|
||||||
function LibtaskModel(f, args...) | ||||||
return LibtaskModel( | ||||||
function AdvancedPS.LibtaskModel( | ||||||
f::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... | ||||||
) # Changed the API, need to take care of the RNG properly | ||||||
return AdvancedPS.LibtaskModel( | ||||||
f, | ||||||
Libtask.TapedTask(f, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)}), | ||||||
Libtask.TapedTask( | ||||||
f, rng, args...; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(f)} | ||||||
), | ||||||
) | ||||||
end | ||||||
|
||||||
Base.copy(model::LibtaskModel) = LibtaskModel(model.f, copy(model.ctask)) | ||||||
function Base.copy(model::AdvancedPS.LibtaskModel) | ||||||
return AdvancedPS.LibtaskModel(model.f, copy(model.ctask)) | ||||||
end | ||||||
|
||||||
const LibtaskTrace{R} = AdvancedPS.Trace{<:LibtaskModel,R} | ||||||
const LibtaskTrace{R} = AdvancedPS.Trace{<:AdvancedPS.LibtaskModel,R} | ||||||
|
||||||
function AdvancedPS.Trace( | ||||||
model::AdvancedPS.AbstractGenericModel, rng::Random.AbstractRNG, args... | ||||||
) | ||||||
return AdvancedPS.Trace(LibtaskModel(model, args...), rng) | ||||||
return AdvancedPS.Trace(AdvancedPS.LibtaskModel(model, rng, args...), rng) | ||||||
end | ||||||
|
||||||
# step to the next observe statement and | ||||||
|
@@ -56,7 +55,7 @@ function AdvancedPS.advance!(t::LibtaskTrace, isref::Bool=false) | |||||
end | ||||||
|
||||||
# create a backward reference in task_local_storage | ||||||
function addreference!(task::Task, trace::LibtaskTrace) | ||||||
function AdvancedPS.addreference!(task::Task, trace::LibtaskTrace) | ||||||
if task.storage === nothing | ||||||
task.storage = IdDict() | ||||||
end | ||||||
|
@@ -65,9 +64,7 @@ function addreference!(task::Task, trace::LibtaskTrace) | |||||
return task | ||||||
end | ||||||
|
||||||
current_trace() = current_task().storage[:__trace] | ||||||
|
||||||
function update_rng!(trace::LibtaskTrace) | ||||||
function AdvancedPS.update_rng!(trace::LibtaskTrace) | ||||||
rng, = trace.model.ctask.args | ||||||
trace.rng = rng | ||||||
return trace | ||||||
|
@@ -76,12 +73,12 @@ end | |||||
# Task copying version of fork for Trace. | ||||||
function AdvancedPS.fork(trace::LibtaskTrace, isref::Bool=false) | ||||||
newtrace = copy(trace) | ||||||
update_rng!(newtrace) | ||||||
AdvancedPS.update_rng!(newtrace) | ||||||
isref && AdvancedPS.delete_retained!(newtrace.model.f) | ||||||
isref && delete_seeds!(newtrace) | ||||||
|
||||||
# add backward reference | ||||||
addreference!(newtrace.model.ctask.task, newtrace) | ||||||
AdvancedPS.addreference!(newtrace.model.ctask.task, newtrace) | ||||||
return newtrace | ||||||
end | ||||||
|
||||||
|
@@ -94,11 +91,11 @@ function AdvancedPS.forkr(trace::LibtaskTrace) | |||||
ctask = Libtask.TapedTask( | ||||||
newf, trace.rng; deepcopy_types=Union{AdvancedPS.TracedRNG,typeof(trace.model.f)} | ||||||
) | ||||||
new_tapedmodel = LibtaskModel(newf, ctask) | ||||||
new_tapedmodel = AdvancedPS.LibtaskModel(newf, ctask) | ||||||
|
||||||
# add backward reference | ||||||
newtrace = AdvancedPS.Trace(new_tapedmodel, trace.rng) | ||||||
addreference!(ctask.task, newtrace) | ||||||
AdvancedPS.addreference!(ctask.task, newtrace) | ||||||
AdvancedPS.gen_refseed!(newtrace) | ||||||
return newtrace | ||||||
end | ||||||
|
@@ -134,10 +131,10 @@ function AbstractMCMC.step( | |||||
# Create reference trajectory. | ||||||
AdvancedPS.forkr(copy(state.trajectory)) | ||||||
else | ||||||
println(model) | ||||||
trng = AdvancedPS.TracedRNG() | ||||||
gen_model = LibtaskModel(deepcopy(model), trng) | ||||||
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng) | ||||||
addreference!(gen_model.ctask.task, trace) # Do we need it here ? | ||||||
trace = AdvancedPS.Trace(deepcopy(model), trng) | ||||||
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
trace | ||||||
end | ||||||
end | ||||||
|
@@ -174,9 +171,8 @@ function AbstractMCMC.sample( | |||||
|
||||||
traces = map(1:(sampler.nparticles)) do i | ||||||
trng = AdvancedPS.TracedRNG() | ||||||
gen_model = LibtaskModel(deepcopy(model), trng) | ||||||
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(model), trng), trng) | ||||||
addreference!(gen_model.ctask.task, trace) # Do we need it here ? | ||||||
trace = AdvancedPS.Trace(deepcopy(model), trng) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did we define There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think |
||||||
AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? | ||||||
trace | ||||||
end | ||||||
|
||||||
|
@@ -202,7 +198,9 @@ function AdvancedPS.replay(particle::AdvancedPS.Particle) | |||||
trng = deepcopy(particle.rng) | ||||||
Random123.set_counter!(trng.rng, 0) | ||||||
trng.count = 1 | ||||||
trace = AdvancedPS.Trace(LibtaskModel(deepcopy(particle.model.f), trng), trng) | ||||||
trace = AdvancedPS.Trace( | ||||||
AdvancedPS.LibtaskModel(deepcopy(particle.model.f), trng), trng | ||||||
) | ||||||
score = AdvancedPS.advance!(trace, true) | ||||||
while !isnothing(score) | ||||||
score = AdvancedPS.advance!(trace, true) | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.