In [57]:
import numpy as np
from scipy.stats import geom, multivariate_normal

In [58]:
N_sample = 1024
dim_x = 20
dim_z = 20
r = 0.6

In [59]:
theta_0 = multivariate_normal.rvs(mean=np.zeros(dim_x), cov=np.eye(dim_x))
X = p_theta_x(theta_0).rvs(size=N_sample)
print(X[:2])

true_l = np.mean(np.log(p_theta_x(theta_0).pdf(X)))
print(true_l)

[[-1.22164858 -2.83952442  0.66721368  0.15193476  1.38345634 -0.5949272
   3.33365593  0.84768791 -0.71865092 -1.62451219 -0.23458977  0.11126794
  -2.83874075  1.02852013  0.82231303 -1.37545318 -0.34342638 -0.8384008
  -0.27486895 -1.24532479]
 [-0.70715156  2.96298072  0.06447447  1.35593835 -1.43705895  1.1619806
   1.78434866 -2.92224567 -1.59076792 -0.25269577  0.72233426 -1.25100265
  -0.47656747  2.67328624  1.43685174 -0.91295929 -1.68032872 -0.2067967
  -1.54216026 -1.11102375]]
-35.27464895661948


In [60]:
A = 1 / 2 * np.eye(dim_z, dim_x) #size = dim_z, dim_x
b = np.mean(X, axis=0) #size = dim_z, 1

In [61]:
def q_phi_z_given_x(x):
    mu = A @ x + b
    covariance = 2 / 3 * np.eye(dim_z)
    return multivariate_normal(mean=mu, cov=covariance)

def p_theta_x(theta):
    mu = theta
    covariance = 2 * np.eye(dim_x)
    return multivariate_normal(mean=mu, cov=covariance)

def p_theta_x_given_z(z):
    mu = z
    covariance = np.eye(dim_x)
    return multivariate_normal(mean=mu, cov=covariance)

def p_theta_z(theta):
    mu = theta
    covariance = np.eye(dim_z)
    return multivariate_normal(mean=mu, cov=covariance)

p = geom

def logmeanexp(data, axis=None):
    max_val = np.max(data, axis=axis)
    return max_val + np.log(np.mean(np.exp(data - max_val), axis=axis))

In [62]:
l_hat_ml_ss_list = []
l_hat_ml_rr_list = []

In [63]:
for x in X:
    K = p.rvs(r)
    size = 2 ** (K + 1)
    Z = q_phi_z_given_x(x).rvs(size=size)
    Z_E, Z_O = Z[::2], Z[1::2]
    log_weights = [np.log(p_theta_z(theta_0).pdf(z) * p_theta_x_given_z(z).pdf(x) / q_phi_z_given_x(x).pdf(z)) for z in Z]
    log_weights_E, log_weights_O = log_weights[::2], log_weights[1::2]
    I_0 = np.mean(log_weights)
    l_hat_E, l_hat_O = logmeanexp(log_weights_E), logmeanexp(log_weights_O)
    l_hat_O_E = logmeanexp(log_weights)
    delta_K = l_hat_O_E - 0.5 * (l_hat_O + l_hat_E)
    l_hat_ml_ss_x = I_0 + delta_K / p(r).pmf(K)
    l_hat_ml_ss_list.append(l_hat_ml_ss_x)
    delta_k_list = [logmeanexp(log_weights[:2**(k+1)]) - 0.5 * (logmeanexp(log_weights_O[:2**k]) + logmeanexp(log_weights_E[:2**k])) for k in range(K+1)]
    l_hat_ml_rr_x = I_0 + np.sum([delta_k_list[k] / (1 - p(r).cdf(k-1)) for k in range(K+1)])
    l_hat_ml_rr_list.append(l_hat_ml_rr_x)


In [64]:
l_hat_ml_ss = np.mean(l_hat_ml_ss_list)
print("estimateur ss =", l_hat_ml_ss)

l_hat_ml_rr = np.mean(l_hat_ml_rr_list)
print("estimateur rr =", l_hat_ml_rr)

estimateur ss = -36.46771081971367
estimateur rr = -35.23149816529057
