We explore importance sampling on some trivial test cases.

The steps of importance sampling are described in the overleaf, Appendix D.

We will implement them first for the 1-dimensional z-test, then the 2- and 3-dimensional z-tests.

In [55]:
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import scipy.stats as stats
import pandas 
import sklearn
from sklearn.cluster import KMeans

In [4]:
mu = np.linspace(-2, 2, 11)
z = np.random.normal(0, 1, 1000)
data = mu[None,:] + z[:,None]

In [20]:
flat_data =  data.flatten()
#Now, we select the rejections in order to run k-means on them
selection = flat_data > 1.96
rejections = flat_data[selection]
standardized_rejections = (rejections - np.mean(rejections))/np.std(rejections)

In [22]:
n_clusters = 5
kmeans= KMeans(n_clusters=n_clusters,
               init = "random",
            n_init=10,
               max_iter=300,
               random_state = 42)

In [23]:
kmeans.fit(rejections.reshape(-1,1))

In [24]:
mu_cluster_centers = kmeans.cluster_centers_
mu_cluster_centers

array([[2.15110038],
       [3.52505375],
       [4.34462706],
       [2.5527284 ],
       [2.99012636]])

In [26]:
kmeans.labels_

array([0, 0, 0, ..., 4, 1, 1], dtype=int32)

In [27]:
n_orig = len(flat_data)
flat_labels = np.full(n_orig, -1,dtype = np.int32)
flat_labels[selection] = kmeans.labels_
np.unique(flat_labels, return_counts=True)

(array([-1,  0,  1,  2,  3,  4], dtype=int32),
 array([9649,  513,  151,   43,  383,  261]))

In [28]:
labels = flat_labels.reshape(data.shape)
n_sims_per_theta = data.shape[0] # 1000 for now
n_theta = data.shape[1] # 11 for now
target_fraction = np.full((n_clusters + 1, n_theta),-1)
labelset = np.unique(labels)
labelbins = np.append(labelset - 0.5,n_clusters - 0.5)
for i in range(n_theta):
    target_fraction[:,i] = np.histogram(labels[:,i], bins = labelbins)[0]

In [30]:
target_fraction
# looks good so far!

array([[1000,  999,  999,  996,  992,  973,  938,  874,  770,  635,  473],
       [   0,    1,    0,    3,    4,   18,   34,   62,  100,  130,  161],
       [   0,    0,    0,    0,    1,    1,    3,    6,   17,   41,   82],
       [   0,    0,    0,    0,    0,    0,    1,    1,    4,    9,   28],
       [   0,    0,    1,    0,    3,    5,   20,   37,   68,  110,  139],
       [   0,    0,    0,    1,    0,    3,    4,   20,   41,   75,  117]])

Now let's get into the business of doing the importance samples and re-weights

In [33]:
# Now we construct the weights matrix: how many sims are we planning for each value of theta?
n_per_thetaj = 1000
# We want this to net out to, let's say, 1% of the total weight...
sum_rejects = np.delete(target_fraction, 0, axis = 0)
any_successes = np.sum(sum_rejects, axis = 0) > 0
relevant_mu = mu[any_successes]
any_successes

array([False,  True,  True,  True,  True,  True,  True,  True,  True,
        True,  True])

In [36]:
temp = sum_rejects / np.maximum(np.sum(sum_rejects,axis = 0)[None,:], 1)
wjj = 0.01
important_weights = temp[:,any_successes]*(1 - wjj)
important_weights.shape

(5, 10)

In [37]:
[relevant_mu.shape , important_weights.shape]

[(10,), (5, 10)]

In [39]:
np.diag(np.full_like(relevant_mu,wjj)).shape

(10, 10)

In [47]:
full_weights = np.append(np.diag(np.full_like(relevant_mu,wjj)), important_weights, axis = 0)
np.sum(full_weights, axis = 0)

array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

In [48]:
full_weights.shape

(15, 10)

In [50]:
# Now we do the importance samples:
mu_importance = np.append(relevant_mu,mu_cluster_centers)
z_importance = np.random.normal(0, 1, size = 1000 * len(mu_importance)).reshape(1000, len(mu_importance))
data_importance = mu_importance[None,:] + z_importance

In [51]:
data_importance.shape, mu_importance.shape, relevant_mu.shape, np.transpose(full_weights).shape

((1000, 15), (15,), (10,), (10, 15))

In [71]:
def mixture_ratio(j, X_ik, wj, mus):
    muj = mus[j]
    ratio = jnp.exp((mus - muj) * X_ik - 0.5 * (mus ** 2 - muj ** 2))
    return wj @ ratio

def average_rej(j, rejs, X_k, wj, mus):
    mr_vm = jax.vmap(mixture_ratio, in_axes=(None, 0, None, None))
    return jnp.mean(rejs / mr_vm(j, X_k, wj, mus))

def imp_est_j(j, rejs, X, wj, mus):
    avg_rej_vm = jax.vmap(average_rej, in_axes=(None, 0, 0, None, None))
    return wj @ avg_rej_vm(j, rejs, X, wj, mus)

@jax.jit
def imp_est(rejs, X, ws, mus):
    idx = jnp.arange(0, ws.shape[0])
    return jax.vmap(imp_est_j, in_axes=(0, None, None, 0, None))(
        idx, rejs, X, ws, mus
    )

In [72]:
X = data_importance.T
rejs = X > 1.96
ws = full_weights.T
mus = mu_importance
imp_est(rejs, X, ws, mus)

Array([1.8215892e-04, 7.9667533e-04, 2.8277736e-03, 9.1251247e-03,
       2.4826376e-02, 5.9318118e-02, 1.2270814e-01, 2.2361267e-01,
       3.5971874e-01, 5.1689267e-01], dtype=float32)

In [75]:
#dimensions: i (simulations), j (initial theta), k (importance theta which generates samples), m (second dummy copy of importance theta)
inside_exponent = data_importance[:,None,:, None]*(mu_importance[None,None,None, :] - relevant_mu[None,:,None, None]) - mu_importance[None,None,None,:]**2/2 + relevant_mu[None,:,None, None]**2/2
likelihood_ratios = np.exp(inside_exponent) # I bet this is the problem! Look in this line and the above for a bug!
denoms = np.sum(likelihood_ratios * np.transpose(full_weights)[None, :, None, :], axis = 3)
rejects = data_importance > 1.96
inner_mean = np.mean(rejects[:,None,:]/denoms, axis = 0) #this is the inner sum divided by n_j
inner_mse_estimate = np.mean((rejects[:,None,:]/denoms)**2, axis = 0) - inner_mean**2 # trying to do an empirical calculation of the variance of each obs
final_result = np.sum(inner_mean * np.transpose(full_weights), axis = 1)
final_variance_estimate =np.sum((inner_mse_estimate/1000) * (np.transpose(full_weights)**2), axis = 1)

In [77]:
final_result

array([1.82158972e-04, 7.96675437e-04, 2.82777322e-03, 9.12512642e-03,
       2.48263719e-02, 5.93181096e-02, 1.22708153e-01, 2.23612629e-01,
       3.59718700e-01, 5.16892710e-01])

In [79]:
final_variance_estimate

array([1.33535664e-10, 2.55633023e-09, 1.81151078e-08, 1.07858147e-07,
       7.71020676e-07, 3.06608957e-06, 9.18504517e-06, 2.13519917e-05,
       3.64442843e-05, 4.88504788e-05])

In [82]:
#estimated sample size ratio
((final_result * (1-final_result)) / (1000)) /final_variance_estimate
# EXCELLENT!

array([1363.8737789 ,  311.39980869,  155.65885344,   83.83102044,
         31.39996616,   18.19890459,   11.72023219,    8.1308584 ,
          6.3198156 ,    5.11181553])

In [83]:
mu_importance

array([-1.6       , -1.2       , -0.8       , -0.4       ,  0.        ,
        0.4       ,  0.8       ,  1.2       ,  1.6       ,  2.        ,
        2.15110038,  3.52505375,  4.34462706,  2.5527284 ,  2.99012636])

The denominator formula:

denom = sum w_jk Pk/Pj (X).

The likelihood ratio is exp([x -\ mu_j]^2/2 - [x-\mu_k]^2/2) = exp(-mu_k^2/2 + mu_j^2/2 + x(mu_k - mu_j))

Now let's generalize this to two-dimensional mu!

In [None]:
mu = np.linspace(-2, 2, 11)
z = np.random.normal(0, 1, 1000)
data = mu[None,:] + z[:,None]
flat_data =  data.flatten()
#Now, we select the rejections in order to run k-means on them
selection = flat_data > 1.96
rejections = flat_data[selection]
standardized_rejections = (rejections - np.mean(rejections))/np.std(rejections)
n_clusters = 5
kmeans= KMeans(n_clusters=n_clusters,
               init = "random",
            n_init=10,
               max_iter=300,
               random_state = 42)
kmeans.fit(standardized_rejections.reshape(-1,1))
mu_cluster_centers = kmeans.cluster_centers_ * np.std(rejections) + np.mean(rejections)
kmeans.labels_
n_orig = len(flat_data)
flat_labels = np.full(n_orig, -1,dtype = np.int32)
flat_labels[selection] = kmeans.labels_
labels = flat_labels.reshape(data.shape)
n_sims_per_theta = data.shape[0] # 1000 for now
n_theta = data.shape[1] # 11 for now
target_fraction = np.full((n_clusters + 1, n_theta),-1)
labelset = np.unique(labels)
labelbins = np.append(labelset - 0.5,n_clusters - 0.5)
for i in range(n_theta):
    target_fraction[:,i] = np.histogram(labels[:,i], bins = labelbins)[0]

In [None]:
#Pilot sims done, now the real run:
n_per_thetaj = 1000
# We want this to net out to, let's say, 1% of the total weight...
sum_rejects = np.delete(target_fraction, 0, axis = 0)
any_successes = np.sum(sum_rejects, axis = 0) > 0
any_successes
relevant_mu = mu[any_successes]
temp = sum_rejects / np.sum(sum_rejects,axis = 0)[None,:]
wjj = 0.01
important_weights = temp[:,any_successes]*(1 - wjj)
full_weights = np.append(np.diag(np.full_like(relevant_mu,wjj)), important_weights, axis = 0)
# Now we do the importance samples:
mu_importance = np.append(relevant_mu,mu_cluster_centers)
z_importance = np.random.normal(0, 1, size = 1000 * len(mu_importance)).reshape(1000, len(mu_importance))
data_importance = mu_importance[None,:] + z_importance
inside_exponent = data_importance[:,None,:, None]*(mu_importance[None,None,None, :] - relevant_mu[None,:,None, None]) - mu_importance[None,None,None,:]**2/2 + relevant_mu[None,:,None, None]**2/2
likelihood_ratios = np.exp(inside_exponent) # I bet this is the problem! Look in this line and the above for a bug!
denoms = np.sum(likelihood_ratios * np.transpose(full_weights)[None, :, None, :], axis = 3)
rejects = data_importance > 1.96
inner_mean = np.mean(rejects[:,None,:]/denoms, axis = 0) #this is the inner sum divided by n_j
inner_mse_estimate = np.mean((rejects[:,None,:]/denoms)**2, axis = 0) - inner_mean**2 # trying to do an empirical calculation of the variance of each obs
final_result = np.sum(inner_mean * np.transpose(full_weights), axis = 1)
final_variance_estimate =np.sum((inner_mse_estimate/1000) * (np.transpose(full_weights)**2), axis = 1)


In [None]:
final_result