/
pg_lg_approx_check.m
124 lines (79 loc) · 3.7 KB
/
pg_lg_approx_check.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
% Sample a sequence from linear Gaussian model with replica cSMC
% and approximate predictive.
function x_new = pg_lg_approx_check(y, x_cur, x_tilde_cur, M, ...
L_post_init, L_post, L_init, L, phi_vec, sigma_2, ...
mean_1_mat, mean_i_mat)
x_d = size(x_cur, 1);
n = size(x_cur, 2);
K_c = size(x_tilde_cur, 3);
x_pool = zeros(x_d, M, n);
lw = zeros(M, n);
lw_tmp = zeros(M, K_c);
bwd_hat = zeros(M, n);
mean_mat = zeros(x_d, M, n);
C = L*L';
C_init = L_init*L_init';
Phi = diag(phi_vec);
L_wgt = chol(C + Phi^2\C, 'lower');
L_wgt_init = chol(C_init + Phi^2\C, 'lower');
l = randi(K_c, M-1, 1);
l_all = [randi(K_c); l];
x_pool(:, 1, 1) = x_cur(:, 1);
x_mean = mean_1_mat*reshape(x_tilde_cur(:, 2, l), x_d, M-1);
x_pool(:, 2:end, 1) = x_mean + L_post_init*randn(x_d, M-1);
mean_mat(:, :, 1) = bsxfun(@times, phi_vec', x_pool(:, :, 1));
for j = 1 : K_c
lw_tmp(:, j) = mvn_lpdf_L(x_tilde_cur(:, 2, j), mean_mat(:, :, 1), L);
end
bwd_hat(:, 2) = add_logs_mat(lw_tmp)';
lw(:, 1) = mvn_lpdf_L(y(:, 1), x_pool(:, :, 1), sqrt(sigma_2)*eye(x_d))' + ...
mvn_lpdf_L(Phi\reshape(x_tilde_cur(:, 2, l_all), x_d, M), zeros(x_d, 1), ...
L_wgt_init)';
w = exp(lw(:, 1) - max(lw(:, 1)));
for i = 2 : n - 1
l = randi(K_c, M-1, 1);
l_all = [randi(K_c); l];
mix_ind_vec = randsample(M, M-1, 'true', w);
x_pool(:, 1, i) = x_cur(:, i);
x_mean = mean_i_mat*(x_pool(:, mix_ind_vec, i-1) + ...
reshape(x_tilde_cur(:, i+1, l), x_d, M-1));
x_pool(:, 2:end, i) = x_mean + L_post*randn(x_d, M-1);
mean_mat(:, :, i) = bsxfun(@times, phi_vec', x_pool(:, :, i));
all_past_ind = [1; mix_ind_vec];
mu_wgt = mean_mat(:, all_past_ind, i-1);
for j = 1 : K_c
lw_tmp(:, j) = mvn_lpdf_L(x_tilde_cur(:, i+1, j), mean_mat(:, :, i), L);
end
bwd_hat(:, i+1) = add_logs_mat(lw_tmp);
lw(:, i) = mvn_lpdf_L(y(:, i), x_pool(:, :, i), ...
sqrt(sigma_2)*eye(x_d))' - bwd_hat(all_past_ind, i)' + ...
mvn_lpdf_L(Phi\reshape(x_tilde_cur(:, i+1, l_all), x_d, M), ...
mu_wgt, L_wgt)';
w = exp(lw(:, i) - max(lw(:, i)));
end
x_pool(:, 1, n) = x_cur(:, n);
mix_ind_vec = randsample(M, M-1, 'true', w);
x_mean = bsxfun(@times, phi_vec', x_pool(:, mix_ind_vec, n-1));
x_prop_all = x_mean + L*randn(x_d, M-1);
x_pool(:, 2:end, n) = x_prop_all;
all_past_ind = [1; mix_ind_vec];
lw(:, n) = mvn_lpdf_L(y(:, n), x_pool(:, :, n), ...
sqrt(sigma_2)*eye(x_d))' - bwd_hat(all_past_ind, n)';
% Sample a new sequence
ind_set = 1:M;
x_new = zeros(x_d, n);
% Sample the index of the last hidden state
hid_lpr = lw(:, n);
hid_pr = exp(hid_lpr - max(hid_lpr));
x_ind = ind_set(sum((rand(1) >= cumsum(hid_pr./sum(hid_pr)))) + 1);
x_new(:, n) = x_pool(:, x_ind, n);
% Sample the indices of the remaining states, going backward in time
for i = n - 1 : -1 : 1
x_phi = mean_mat(:, :, i);
hid_lpr = mvn_lpdf_L(x_new(:, i+1), x_phi, L) + lw(:, i) - ...
bwd_hat(:, i+1);
hid_pr = exp(hid_lpr - max(hid_lpr));
x_ind = ind_set(sum((rand(1) >= cumsum(hid_pr./sum(hid_pr)))) + 1);
x_new(:, i) = x_pool(:, x_ind, i);
end
end