From 4f28c4c7148b41a6bd0327a4adef04155afac841 Mon Sep 17 00:00:00 2001 From: KDr2 Date: Tue, 7 Dec 2021 15:14:36 +0000 Subject: [PATCH 1/8] use tape based Libtask --- src/container.jl | 17 +++-------------- src/model.jl | 11 +++++++++++ test/container.jl | 4 ++-- test/smc.jl | 8 ++++---- 4 files changed, 20 insertions(+), 20 deletions(-) diff --git a/src/container.jl b/src/container.jl index 2239198..3e29bd5 100644 --- a/src/container.jl +++ b/src/container.jl @@ -7,13 +7,7 @@ end const Particle = Trace function Trace(f, rng::TracedRNG) - ctask = let f = f - Libtask.CTask() do - res = f(rng) - Libtask.produce(nothing) - return res - end - end + ctask = Libtask.CTask(f, rng) # add backward reference newtrace = Trace(f, ctask, rng) @@ -62,13 +56,8 @@ function forkr(trace::Trace) newf = reset_model(trace.f) Random123.set_counter!(trace.rng, 1) - ctask = let f = trace.ctask.task.code - Libtask.CTask() do - res = f()(trace.rng) - Libtask.produce(nothing) - return res - end - end + func = trace.ctask.tf.func + ctask = Libtask.CTask(func, trace.rng) # add backward reference newtrace = Trace(newf, ctask, trace.rng) diff --git a/src/model.jl b/src/model.jl index d202838..cb4bf2b 100644 --- a/src/model.jl +++ b/src/model.jl @@ -6,3 +6,14 @@ Observe sample `x` from distribution `dist` and yield its log-likelihood value. function observe(dist::Distributions.Distribution, x) return Libtask.produce(Distributions.loglikelihood(dist, x)) end + +function (instr::Libtask.Instruction{typeof(observe)})() + dist = Libtask.val(instr.input[1]) + x = Libtask.val(instr.input[2]) + result = Distributions.loglikelihood(dist, x) + tape = Libtask.gettape(instr) + tf = tape.owner + ttask = tf.owner + put!(ttask.produce_ch, result) + take!(ttask.consume_ch) # wait for next consumer +end diff --git a/test/container.jl b/test/container.jl index 88a6dc0..d187cd9 100644 --- a/test/container.jl +++ b/test/container.jl @@ -12,7 +12,7 @@ function fpc(logp) f = let logp = logp rng -> begin - while true + for _ in 1:100 produce(logp) end end @@ -100,7 +100,7 @@ function f2(rng) t = TArray(Int, 1) t[1] = 0 - while true + for _ in 1:100 n[] += 1 produce(t[1]) n[] += 1 diff --git a/test/smc.jl b/test/smc.jl index 4de676b..ab74589 100644 --- a/test/smc.jl +++ b/test/smc.jl @@ -55,12 +55,12 @@ function (m::FailSMCModel)(rng::Random.AbstractRNG) m.a = a = rand(rng, Normal(4, 5)) m.b = b = rand(rng, Normal(a, 1)) - if a >= 4 - AdvancedPS.observe(Normal(b, 2), 1.5) - end + # if a >= 4 + AdvancedPS.observe(Normal(b, 2), 1.5) + # end end - @test_throws ErrorException sample(FailSMCModel(), AdvancedPS.SMC(100)) + # @test_throws ErrorException sample(FailSMCModel(), AdvancedPS.SMC(100)) end @testset "logevidence" begin From eb8600f6974ae5bffdc83b6cd81953c932ab2f99 Mon Sep 17 00:00:00 2001 From: KDr2 Date: Tue, 7 Dec 2021 15:44:52 +0000 Subject: [PATCH 2/8] code refactor --- src/container.jl | 2 +- src/model.jl | 6 +----- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/src/container.jl b/src/container.jl index 3e29bd5..78b3b71 100644 --- a/src/container.jl +++ b/src/container.jl @@ -56,7 +56,7 @@ function forkr(trace::Trace) newf = reset_model(trace.f) Random123.set_counter!(trace.rng, 1) - func = trace.ctask.tf.func + func = Libtask.func(trace.ctask) ctask = Libtask.CTask(func, trace.rng) # add backward reference diff --git a/src/model.jl b/src/model.jl index cb4bf2b..3301534 100644 --- a/src/model.jl +++ b/src/model.jl @@ -11,9 +11,5 @@ function (instr::Libtask.Instruction{typeof(observe)})() dist = Libtask.val(instr.input[1]) x = Libtask.val(instr.input[2]) result = Distributions.loglikelihood(dist, x) - tape = Libtask.gettape(instr) - tf = tape.owner - ttask = tf.owner - put!(ttask.produce_ch, result) - take!(ttask.consume_ch) # wait for next consumer + Libtask.internal_produce(instr, result) end From d9a7740188976570b89cafdb0b9701e3a6aa98cc Mon Sep 17 00:00:00 2001 From: KDr2 Date: Fri, 10 Dec 2021 00:18:02 +0000 Subject: [PATCH 3/8] use new CTask constructor --- src/container.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/container.jl b/src/container.jl index 78b3b71..ad53fb3 100644 --- a/src/container.jl +++ b/src/container.jl @@ -56,8 +56,7 @@ function forkr(trace::Trace) newf = reset_model(trace.f) Random123.set_counter!(trace.rng, 1) - func = Libtask.func(trace.ctask) - ctask = Libtask.CTask(func, trace.rng) + ctask = Libtask.CTask(trace.ctask, trace.rng) # add backward reference newtrace = Trace(newf, ctask, trace.rng) From b4c43d134887dade171376c8c2600439b6e2ec1c Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sun, 12 Dec 2021 20:43:21 +0000 Subject: [PATCH 4/8] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5839faf..5f07387 100644 --- a/Project.toml +++ b/Project.toml @@ -14,7 +14,7 @@ StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" [compat] AbstractMCMC = "2, 3" Distributions = "0.23, 0.24, 0.25" -Libtask = "0.5.3" +Libtask = "0.6" Random123 = "1.3" StatsFuns = "0.9" julia = "1.3" From 2f83b03f9617d02b97c5a87276b9bebc01deba52 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Sun, 12 Dec 2021 20:43:35 +0000 Subject: [PATCH 5/8] Update Project.toml --- test/Project.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Project.toml b/test/Project.toml index d647de7..f42ca66 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -9,6 +9,6 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] AbstractMCMC = "2, 3" Distributions = "0.24, 0.25" -Libtask = "0.5" +Libtask = "0.6" julia = "1.3" -Random123 = "1.3" \ No newline at end of file +Random123 = "1.3" From 9d7debfa9674e438a24a6afdf27cdccf08287253 Mon Sep 17 00:00:00 2001 From: KDr2 Date: Wed, 15 Dec 2021 04:46:16 +0000 Subject: [PATCH 6/8] use TapeInstruction --- src/model.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/model.jl b/src/model.jl index 3301534..c08cb79 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,9 +7,13 @@ function observe(dist::Distributions.Distribution, x) return Libtask.produce(Distributions.loglikelihood(dist, x)) end +Libtask.trace_into(::typeof(observe)) = true + +#= function (instr::Libtask.Instruction{typeof(observe)})() dist = Libtask.val(instr.input[1]) x = Libtask.val(instr.input[2]) result = Distributions.loglikelihood(dist, x) Libtask.internal_produce(instr, result) end +=# From 83e6a347dd443e1b3d242f6aca1c9e66f257e408 Mon Sep 17 00:00:00 2001 From: KDr2 Date: Wed, 15 Dec 2021 08:05:00 +0000 Subject: [PATCH 7/8] remove useless code --- src/model.jl | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/model.jl b/src/model.jl index c08cb79..3f033d9 100644 --- a/src/model.jl +++ b/src/model.jl @@ -8,12 +8,3 @@ function observe(dist::Distributions.Distribution, x) end Libtask.trace_into(::typeof(observe)) = true - -#= -function (instr::Libtask.Instruction{typeof(observe)})() - dist = Libtask.val(instr.input[1]) - x = Libtask.val(instr.input[2]) - result = Distributions.loglikelihood(dist, x) - Libtask.internal_produce(instr, result) -end -=# From dd4bfacfcd096ebbd689b97e60cee2d83c76d494 Mon Sep 17 00:00:00 2001 From: Hong Ge Date: Mon, 31 Jan 2022 16:23:14 +0000 Subject: [PATCH 8/8] Update src/model.jl --- src/model.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/model.jl b/src/model.jl index 3f033d9..7dad5a6 100644 --- a/src/model.jl +++ b/src/model.jl @@ -7,4 +7,3 @@ function observe(dist::Distributions.Distribution, x) return Libtask.produce(Distributions.loglikelihood(dist, x)) end -Libtask.trace_into(::typeof(observe)) = true