-
Notifications
You must be signed in to change notification settings - Fork 9
/
container.jl
124 lines (108 loc) · 4.21 KB
/
container.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
@testset "container.jl" begin
@testset "copy particle container" begin
pc = AdvancedPS.ParticleContainer(AdvancedPS.Trace[])
newpc = copy(pc)
@test newpc.logWs == pc.logWs
@test typeof(pc) === typeof(newpc)
end
@testset "particle container" begin
# Create a resumable function that always returns the same log probability.
function fpc(logp)
f = let logp = logp
rng -> begin
for _ in 1:100
produce(logp)
end
end
end
return f
end
# Create particle container.
logps = [0.0, -1.0, -2.0]
particles = [AdvancedPS.Trace(fpc(logp), AdvancedPS.TracedRNG()) for logp in logps]
pc = AdvancedPS.ParticleContainer(particles)
# Initial state.
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)
@test AdvancedPS.logZ(pc) ≈ log(3)
@test AdvancedPS.effectiveSampleSize(pc) == 3
# Reweight particles.
AdvancedPS.reweight!(pc)
@test pc.logWs == logps
@test AdvancedPS.getweights(pc) ≈ exp.(logps) ./ sum(exp, logps)
@test all(
AdvancedPS.getweight(pc, i) ≈ exp(logps[i]) / sum(exp, logps) for i in 1:3
)
@test AdvancedPS.logZ(pc) ≈ log(sum(exp, logps))
# Reweight particles.
AdvancedPS.reweight!(pc)
@test pc.logWs == 2 .* logps
@test AdvancedPS.getweights(pc) == exp.(2 .* logps) ./ sum(exp, 2 .* logps)
@test all(
AdvancedPS.getweight(pc, i) ≈ exp(2 * logps[i]) / sum(exp, 2 .* logps) for
i in 1:3
)
@test AdvancedPS.logZ(pc) ≈ log(sum(exp, 2 .* logps))
# Resample and propagate particles with reference particle
particles_ref = [
AdvancedPS.Trace(fpc(logp), AdvancedPS.TracedRNG()) for logp in logps
]
pc_ref = AdvancedPS.ParticleContainer(particles_ref)
AdvancedPS.resample_propagate!(
Random.GLOBAL_RNG, pc_ref, AdvancedPS.resample_systematic, particles_ref[end]
)
@test pc_ref.logWs == zeros(3)
@test AdvancedPS.getweights(pc_ref) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc_ref, i) == 1 / 3 for i in 1:3)
@test AdvancedPS.logZ(pc_ref) ≈ log(3)
@test AdvancedPS.effectiveSampleSize(pc_ref) == 3
@test pc_ref.vals[end] === particles_ref[end]
# Resample and propagate particles.
AdvancedPS.resample_propagate!(Random.GLOBAL_RNG, pc)
@test pc.logWs == zeros(3)
@test AdvancedPS.getweights(pc) == fill(1 / 3, 3)
@test all(AdvancedPS.getweight(pc, i) == 1 / 3 for i in 1:3)
@test AdvancedPS.logZ(pc) ≈ log(3)
@test AdvancedPS.effectiveSampleSize(pc) == 3
# Reweight particles.
AdvancedPS.reweight!(pc)
@test pc.logWs ⊆ logps
@test AdvancedPS.getweights(pc) == exp.(pc.logWs) ./ sum(exp, pc.logWs)
@test all(
AdvancedPS.getweight(pc, i) ≈ exp(pc.logWs[i]) / sum(exp, pc.logWs) for i in 1:3
)
@test AdvancedPS.logZ(pc) ≈ log(sum(exp, pc.logWs))
# Increase unnormalized logarithmic weights.
logws = copy(pc.logWs)
AdvancedPS.increase_logweight!(pc, 2, 1.41)
@test pc.logWs == logws + [0, 1.41, 0]
# Reset unnormalized logarithmic weights.
logws = pc.logWs
AdvancedPS.reset_logweights!(pc)
@test pc.logWs === logws
@test all(iszero, pc.logWs)
end
@testset "trace" begin
n = Ref(0)
function f2(rng)
t = TArray(Int, 1)
t[1] = 0
for _ in 1:100
n[] += 1
produce(t[1])
n[] += 1
t[1] = 1 + t[1]
end
end
# Test task copy version of trace
tr = AdvancedPS.Trace(f2, AdvancedPS.TracedRNG())
consume(tr.ctask)
consume(tr.ctask)
a = AdvancedPS.fork(tr)
consume(a.ctask)
consume(a.ctask)
@test consume(tr.ctask) == 2
@test consume(a.ctask) == 4
end
end