In [1]:
import numpy as np
from scipy.stats import geom, multivariate_normal

In [15]:
from scipy.optimize import minimize

In [2]:
N_sample = 1024
dim_x = 20
dim_z = 20
r = 0.6


In [3]:
def p_theta_x(theta):
    mu = theta
    covariance = 2 * np.eye(dim_x)
    return multivariate_normal(mean=mu, cov=covariance)

def p_theta_x_given_z(z):
    mu = z
    covariance = np.eye(dim_x)
    return multivariate_normal(mean=mu, cov=covariance)

def p_theta_z(theta):
    mu = theta
    covariance = np.eye(dim_z)
    return multivariate_normal(mean=mu, cov=covariance)

p = geom

def logmeanexp(data, axis=None):
    max_val = np.max(data, axis=axis)
    return max_val + np.log(np.mean(np.exp(data - max_val), axis=axis))

Simulation des observations x : on fixe theta_0 qui sera estimé par la suite. 

In [5]:
theta_0 = (1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20)
X = p_theta_x(theta_0).rvs(size=N_sample)
print(X[:2])

l_true = np.mean(np.log(p_theta_x(theta_0).pdf(X)))
print(l_true)

[[-0.36198458  1.83326535  3.44773523  3.20245931  6.95326216  5.62672769
   3.86313666  9.35129607  8.03692123  9.49520592 10.76627178  8.19647969
  14.31245154 13.84770641 16.98895568 17.24321066 17.49234977 19.65081174
  18.91208332 19.30810068]
 [ 3.24508534  2.69624654  3.09693583  4.53538671  4.94141811  6.60038776
   4.63585538  6.4480126   7.98035145  8.6297854   8.81499425 13.93746605
  12.52637103 12.71976865 14.6728023  16.0704305  18.05753615 15.72401424
  20.72180722 21.09564372]]
-35.42179933119693


In [7]:
A = 1 / 2 * np.eye(dim_z, dim_x) #size = dim_z, dim_x
b = np.mean(X, axis=0) #size = dim_z, 1

In [8]:
def q_phi_z_given_x(x):
    mu = A @ x + b
    covariance = 2 / 3 * np.eye(dim_z)
    return multivariate_normal(mean=mu, cov=covariance)

In [10]:
l_hat_ml_ss_list = []
l_hat_ml_rr_list = []
l_hat_iwae_list = []
l_hat_sumo_list = []


In [12]:
for x in X:
    
    K = p.rvs(r)
    size = 2 ** (K + 1)
    
    Z = q_phi_z_given_x(x).rvs(size=size)
    Z_E, Z_O = Z[::2], Z[1::2]
    
    log_weights = [np.log(p_theta_z(theta_0).pdf(z) * p_theta_x_given_z(z).pdf(x) / q_phi_z_given_x(x).pdf(z)) for z in Z]
    log_weights_E, log_weights_O = log_weights[::2], log_weights[1::2]
    
    I_0 = np.mean(log_weights)
    l_hat_E, l_hat_O = logmeanexp(log_weights_E), logmeanexp(log_weights_O)
    l_hat_O_E = logmeanexp(log_weights)
    
    delta_K = l_hat_O_E - 0.5 * (l_hat_O + l_hat_E)
    l_hat_ml_ss_x = I_0 + delta_K / p(r).pmf(K)
    l_hat_ml_ss_list.append(l_hat_ml_ss_x)
    
    delta_k_list_rr = [logmeanexp(log_weights[:2**(k+1)]) - 0.5 * (logmeanexp(log_weights_O[:2**k]) + logmeanexp(log_weights_E[:2**k])) for k in range(K+1)]
    l_hat_ml_rr_x = I_0 + np.sum([delta_k_list_rr[k] / (1 - p(r).cdf(k-1)) for k in range(K+1)])
    l_hat_ml_rr_list.append(l_hat_ml_rr_x)

    l_hat_iwae_list.append(logmeanexp(log_weights))

    delta_k_list_sumo = [logmeanexp(log_weights[:2**(k+1)]) - logmeanexp(log_weights[:2**k]) for k in range(K+1)]
    l_hat_sumo_x = I_0 + np.sum([delta_k_list_sumo[k] / (1 - p(r).cdf(k-1)) for k in range(K+1)])
    l_hat_sumo_list.append(l_hat_sumo_x)

  log_weights = [np.log(p_theta_z(theta_0).pdf(z) * p_theta_x_given_z(z).pdf(x) / q_phi_z_given_x(x).pdf(z)) for z in Z]
  return max_val + np.log(np.mean(np.exp(data - max_val), axis=axis))


In [14]:
l_hat_ml_ss = np.mean(l_hat_ml_ss_list)
print("estimateur ss =", l_hat_ml_ss)

l_hat_ml_rr = np.mean(l_hat_ml_rr_list)
print("estimateur rr =", l_hat_ml_rr)

l_hat_iwae = np.mean(l_hat_iwae_list)
print("estimateur iwae =", l_hat_iwae)

l_hat_sumo = np.mean(l_hat_sumo_list)
print("estimateur sumo =", l_hat_sumo)

empirical_bias_squared_ss = (l_hat_ml_ss - l_true) ** 2
print("carré du biais empirique de l'estimateur ss :", empirical_bias_squared_ss)

empirical_bias_squared_rr = (l_hat_ml_rr - l_true) ** 2
print("carré du biais empirique de l'estimateur rr :", empirical_bias_squared_rr)

empirical_bias_squared_iwae = (l_hat_iwae - l_true) ** 2
print("carré du biais empirique de l'estimateur iwae :", empirical_bias_squared_iwae)

empirical_bias_squared_sumo = (l_hat_sumo - l_true) ** 2
print("carré du biais empirique de l'estimateur sumo :", empirical_bias_squared_sumo)

estimateur ss = nan
estimateur rr = nan
estimateur iwae = nan
estimateur sumo = nan
carré du biais empirique de l'estimateur ss : nan
carré du biais empirique de l'estimateur rr : nan
carré du biais empirique de l'estimateur iwae : nan
carré du biais empirique de l'estimateur sumo : nan


In [17]:
# Définition de l'estimateur 
def estimateur_ml_ss(theta, x):
    Z = q_phi_z_given_x(x).rvs(size=size)
    Z_E, Z_O = Z[::2], Z[1::2]
    log_weights = [np.log(p_theta_z(theta).pdf(z) * p_theta_x_given_z(z).pdf(x) / q_phi_z_given_x(x).pdf(z)) for z in Z]
    log_weights_E, log_weights_O = log_weights[::2], log_weights[1::2]
    I_0 = np.mean(log_weights)
    l_hat_E, l_hat_O = logmeanexp(log_weights_E), logmeanexp(log_weights_O)
    l_hat_O_E = logmeanexp(log_weights)
    delta_K = l_hat_O_E - 0.5 * (l_hat_O + l_hat_E)
    l_hat_ml_ss_x = I_0 + delta_K / p(r).pmf(K)
    
    
    
    
    # Retourner l'estimation
    return l_hat_ml_ss_x

def aggregate_estimator_ml_ss(theta):
    l_hat_ml_ss_list = []

    for x in X:
        l_hat_ml_ss_list.append(estimateur_ml_ss(theta, x))
    return np.mean(l_hat_ml_ss_list)
    


In [20]:

# Définir une fonction objectif pour l'optimisation
def objectif_ml_ss(theta):
    return -aggregate_estimator_ml_ss(theta)  # On minimise l'opposé de l'estimateur

# Supposer une valeur initiale pour theta
theta_init = np.random.randn(20)

# Minimiser la fonction objectif pour trouver theta optimal
resultat_optimisation = minimize(objectif_ml_ss, theta_init)

# Récupérer le theta optimal trouvé par l'optimisation
theta_optimal = resultat_optimisation.x

# Imprimer le résultat
print("Theta optimal trouvé par l'optimisation:", theta_optimal)

  log_weights = [np.log(p_theta_z(theta).pdf(z) * p_theta_x_given_z(z).pdf(x) / q_phi_z_given_x(x).pdf(z)) for z in Z]
  return max_val + np.log(np.mean(np.exp(data - max_val), axis=axis))


Theta optimal trouvé par l'optimisation: [-0.77817354  1.28672765 -0.61929193 -3.05927229 -1.2251555   1.0795783
  1.66435668  0.52353452  0.79062414 -0.97544221  1.9209496   1.20537603
  3.17550023  2.49634022 -0.87951735  2.14183195 -0.25013713  2.75519566
  1.36497097  1.01526797]
