diff --git a/ext/AdvancedPSLibtaskExt.jl b/ext/AdvancedPSLibtaskExt.jl index 70b0d60..d6062fd 100644 --- a/ext/AdvancedPSLibtaskExt.jl +++ b/ext/AdvancedPSLibtaskExt.jl @@ -131,10 +131,9 @@ function AbstractMCMC.step( # Create reference trajectory. AdvancedPS.forkr(copy(state.trajectory)) else - println(model) trng = AdvancedPS.TracedRNG() trace = AdvancedPS.Trace(deepcopy(model), trng) - AdvancedPS.addreference!(trace.model.ctask.task, trace) # Do we need it here ? + AdvancedPS.addreference!(trace.model.ctask.task, trace) # TODO: Do we need it here ? trace end end diff --git a/test/container.jl b/test/container.jl index de3ab47..7a31229 100644 --- a/test/container.jl +++ b/test/container.jl @@ -143,6 +143,21 @@ @test consume(a.model.ctask) == 4 end + @testset "current trace" begin + struct TaskIdModel <: AdvancedPS.AbstractGenericModel end + + function (model::TaskIdModel)(rng::Random.AbstractRNG) + # Just print the task it's running in + id = objectid(AdvancedPS.current_trace()) + return Libtask.produce(id) + end + + trace = AdvancedPS.Trace(TaskIdModel(), AdvancedPS.TracedRNG()) + AdvancedPS.addreference!(trace.model.ctask.task, trace) + + @test AdvancedPS.advance!(trace, false) === objectid(trace) + end + @testset "seed container" begin seed = 1 n = 3