In [1]:
import time

# Make JAX 64bit
# from jax.config import config
# 
# config.update("jax_enable_x64", True)
import numpy as np
import ray
ray.init('auto')


2023-12-12 11:05:53,023	INFO worker.py:1489 -- Connecting to existing Ray cluster at address: 192.168.178.154:6379...
2023-12-12 11:05:53,032	INFO worker.py:1664 -- Connected to Ray cluster. View the dashboard at [1m[32mhttp://127.0.0.1:8265 [39m[22m


0,1
Python version:,3.11.5
Ray version:,2.8.1
Dashboard:,http://127.0.0.1:8265


In [2]:

ndims = 8

k_array = np.asarray([0, 1, 2, 3, 4, 5, 7])
s_array = np.asarray([1, 2, 3, 4, 5, 6, 7])
c_array = np.asarray([16, 32, 64, 128, 256]) * ndims


In [3]:
try:
    import ray
    from ray.util.queue import Queue
except ImportError:
    print("Install ray first with `pip install ray`")
    raise


@ray.remote(num_cpus=1, num_gpus=0)
def run(ndims, ensemble_size, input_queue: Queue, output_queue: Queue):
    # from jax.config import config
    # config.update("jax_enable_x64", True)
    from jaxns import Prior, Model
    import jax
    from jax import random, numpy as jnp
    import numpy as np
    import tensorflow_probability.substrates.jax as tfp
    tfpd = tfp.distributions
        
    prior_mu = jnp.zeros(ndims)
    prior_cov = jnp.eye(ndims)

    data_mu = 15 * jnp.ones(ndims)
    data_cov = jnp.eye(ndims)
    data_cov = jnp.where(data_cov == 0., 0.99, data_cov)

    def prior_model():
        x = yield Prior(
            tfpd.MultivariateNormalTriL(
                loc=prior_mu,
                scale_tril=jnp.linalg.cholesky(prior_cov)
            )
        )
        return x

    def log_likelihood(x):
        return tfpd.MultivariateNormalTriL(
            loc=data_mu,
            scale_tril=jnp.linalg.cholesky(data_cov)
        ).log_prob(x)

    model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

    true_logZ = tfpd.MultivariateNormalTriL(
        loc=prior_mu,
        scale_tril=jnp.linalg.cholesky(prior_cov + data_cov)
    ).log_prob(data_mu)

    # not super happy with this being 1.58 and being off by like 0.1. Probably related to the ESS.
    post_mu = prior_cov @ jnp.linalg.inv(prior_cov + data_cov) @ data_mu + data_cov @ jnp.linalg.inv(
        prior_cov + data_cov) @ prior_mu

    print(f"True post mu:{post_mu}")
    print(f"True log Z: {true_logZ}")

    while True:
        input_data = input_queue.get()
        if input_data is None:  # poison pill
            break
        (s, k, c, store_indices) = input_data
        nested_sampler = StandardStaticNestedSampler(
            model=model,
            num_live_points=c,
            max_samples=50000,
            sampler=UniDimSliceSampler(
                model=model,
                num_slices=model.U_ndims * s,
                num_phantom_save=k,
                midpoint_shrink=True,
                perfect=True
            ),
            init_efficiency_threshold=0.1,
            num_parallel_workers=1
        )

        @jax.jit
        def ns_run(key):
            termination_reason, state = nested_sampler._run(key=key, term_cond=TerminationCondition())
            results = nested_sampler._to_results(termination_reason=termination_reason, state=state, trim=False)
            return results.log_Z_mean, results.log_Z_uncert, results.total_num_likelihood_evaluations, results.total_num_samples, results.total_phantom_samples

        run_compiled = ns_run.lower(random.PRNGKey(0)).compile()

        dt = []
        results = []
        for _ in range(ensemble_size):
            t0 = time.time()
            results.append(run_compiled(random.PRNGKey(i)))
            results[-1][0].block_until_ready()
            dt.append(time.time() - t0)
        dt = np.asarray(dt)  # [m]
        print(f"Time taken s={s} k={k} c={c}: {sum(dt)}")
        log_Z_mean, log_Z_uncert, num_likelihood_evals, total_num_samples, total_phantom_samples = np.asarray(
            results).T  # [:, m]
        output_data = (
            (dt, log_Z_mean, log_Z_uncert, num_likelihood_evals, total_num_samples, total_phantom_samples, true_logZ),
            store_indices)
        output_queue.put(output_data)
    # Poison pill
    output_queue.put(None)


num_workers = 10
input_queue = Queue()
output_queue = Queue()

In [None]:
from jaxns import TerminationCondition
from jaxns.samplers import UniDimSliceSampler
from jaxns.nested_sampler import StandardStaticNestedSampler

m = 100
Ns = len(s_array)
Nk = len(k_array)
Nc = len(c_array)

log_Z_mean_array = np.zeros((Ns, Nk, Nc, m))
log_Z_uncert_array = np.zeros((Ns, Nk, Nc, m))
num_likelihood_evals_array = np.zeros((Ns, Nk, Nc, m))
run_time_array = np.zeros((Ns, Nk, Nc, m))
total_num_samples_array = np.zeros((Ns, Nk, Nc, m))
total_num_phantom_samples_array = np.zeros((Ns, Nk, Nc, m))

for i, s in enumerate(s_array):
    for j, k in enumerate(k_array):
        for l, c in enumerate(c_array):
            store_indices = (i, j, l)
            input_data = (s, k, c, store_indices)
            input_queue.put(input_data)

# Poison pills
for _ in range(num_workers):
    input_queue.put(None)

workers = []
for _ in range(num_workers):
    workers.append(run.remote(ndims, m, input_queue, output_queue))

num_poison_pills = 0
true_logZ = None
while num_poison_pills < num_workers:
    output_data = output_queue.get()
    if output_data is None:
        num_poison_pills += 1
        continue
    (
        (dt, log_Z_mean, log_Z_uncert, num_likelihood_evals, total_num_samples, total_phantom_samples, true_logZ),
        store_indices
    ) = output_data
    i, j, l = store_indices
    run_time_array[i, j, l, :] = dt
    log_Z_mean_array[i, j, l, :] = log_Z_mean
    log_Z_uncert_array[i, j, l, :] = log_Z_uncert
    num_likelihood_evals_array[i, j, l, :] = num_likelihood_evals
    total_num_samples_array[i, j, l, :] = total_num_samples
    total_num_phantom_samples_array[i, j, l, :] = total_phantom_samples

# Save the result arrays and axes into npz file
save_file = "bias_experiment_results.npz"
np.savez(
    save_file,
    run_time_array=np.asarray(run_time_array),
    log_Z_mean_array=np.asarray(log_Z_mean_array),
    log_Z_uncert_array=np.asarray(log_Z_uncert_array),
    num_likelihood_evals_array=np.asarray(num_likelihood_evals_array),
    total_num_samples_array=np.asarray(total_num_samples_array),
    total_num_phantom_samples_array=np.asarray(total_num_phantom_samples_array),
    s_array=np.asarray(s_array),
    k_array=np.asarray(k_array),
    c_array=np.asarray(c_array),
    true_logZ=true_logZ
)

INFO[2023-12-12 11:06:31,743]: Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2023-12-12 11:06:31,744]: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO[2023-12-12 11:06:31,745]: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
[36m(pid=543061)[0m INFO[2023-12-12 11:06:34,009]: Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(pid=543061)[0m INFO[2023-12-12 11:06:34,009]: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
[36m(pid=543061)[0m INFO[2023-12-12 11:06:34,010]: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


[36m(run pid=543055)[0m True post mu:[1.6797328 1.6797304 1.6797318 1.6797323 1.6797304 1.6797314 1.6797304
[36m(run pid=543055)[0m  1.6797304]
[36m(run pid=543055)[0m True log Z: -109.26490783691406




In [None]:
# Load results into arrays of same names
save_file = "bias_experiment_results.npz"

npzfile = np.load(save_file)
run_time_array = npzfile['run_time_array'].mean(-1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
log_Z_mean_array = npzfile['log_Z_mean_array'].mean(
    -1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
log_Z_uncert_array = np.sqrt(
    np.sqrt(npzfile['log_Z_uncert_array']).mean(-1))  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
num_likelihood_evals_array = npzfile['num_likelihood_evals_array'].mean(
    -1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
total_num_samples_array = npzfile['total_num_samples_array'].mean(
    -1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
total_num_phantom_samples_array = npzfile['total_num_phantom_samples_array'].mean(
    -1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
s_array = npzfile['s_array']  # (len(s_array),)
k_array = npzfile['k_array']  # (len(k_array),)
true_logZ = npzfile['true_logZ']  # ()

# bias_array = np.mean(npzfile['log_Z_mean_array'] - true_logZ, axis=-1) # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
# rms_array = np.sqrt(np.mean(np.square(npzfile['log_Z_mean_array'] - true_logZ), axis=-1)) # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
# bias_uncert_array = np.sqrt(rms_array**2 + np.mean(npzfile['log_Z_uncert_array']**2, axis=-1)) # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))


In [None]:
sample_efficiency_array = (npzfile['total_num_samples_array'] / npzfile['num_likelihood_evals_array']).mean(
    -1)  # (len(s_array), len(k_array), m) -> (len(s_array), len(k_array))
run_time_speed_up_array = run_time_array[:, 0:1] / run_time_array  # (len(s_array), len(k_array))
efficiency_improvement_array = sample_efficiency_array / sample_efficiency_array[:, 0:1]  # (len(s_array), len(k_array))

# After a threshold number of slices, the bias is independent of the number of phantom samples
A crucial component of nested sampling is generating i.i.d. uniform samples from the likelihood constrained prior distribution. When using Markov chain likelihood samplers, such as slice sampling, this is accomplished by sequentially drawing samples from an ergodic Markov chain. A well known problem is that if the number of proposals between acceptance is too low the samples with exhibit auto-correlation.

We directly observed this below by looking at bias in the resulting uncertainty estimate as a function of the number of proposal steps between acceptance. Crucially this property is independent of the number of phantom samples, which forms the crux of our discovery. Each point below corresponds to a particular fraction of phantom samples, and number of slices. In general, the bias decreases with increasing number of slices, and increasing phantom fraction. However, after a threshold number of slices, the bias is independent of the  phantom fraction. This is a crucial result, as it means that we can use a smaller number of likelihood evaluations to generate a larger number of i.i.d. samples from the likelihood constrained prior distribution, and thus achieve a high sample efficiency.

In [None]:
import pylab as plt

cm = plt.cm.get_cmap('PuOr')


def color(c, unique_c):
    return cm(plt.Normalize(np.min(unique_c), np.max(unique_c))(c))


In [None]:
# Plot log_Z (with error bars) vs num slices, color coded by phantom fraction
plt.figure()
unique_c = np.unique(k_array / (k_array + 1))
for i, k in enumerate(k_array):
    phantom_fraction = k / (k + 1)
    plt.errorbar(s_array, log_Z_mean_array[:, i], yerr=log_Z_uncert_array[:, i], fmt='o',
                 c=color(phantom_fraction, unique_c),
                 label=f"Phantom Fraction: {phantom_fraction * 100:.0f}%")
plt.xlabel("Num Slices")
plt.ylabel("Bias (nats)")
plt.gca().axhline(true_logZ, color='k', linestyle='--')

# 
# 
# # put a red box around the region of interest (All points with slice factor >= 3)
# # Make the box a little bigger than the y error bars
# 
num_slice_factor_threshold = 4

mask = (s_array >= num_slice_factor_threshold)
x = s_array[mask]
y = log_Z_mean_array[mask]
yerr = log_Z_uncert_array[mask]

lower_left = [0.97 * np.min(x), np.min(y - 1.2 * yerr)]
upper_right = [1.01 * np.max(x), np.max(y + 1.2 * yerr)]
plt.gca().add_patch(plt.Rectangle(lower_left, upper_right[0] - lower_left[0], upper_right[1] - lower_left[1],
                                  fill=False, edgecolor='r', lw=2))

plt.legend(loc='lower right')
# plt.savefig("bias_vs_num_slices.png", dpi=300)
# plt.savefig("bias_vs_num_slices.pdf", dpi=300)
plt.show()

In [None]:
import pylab as plt

num_slice_factor_threshold = 4

# For all runs in consistent region (red box), plot bias vs num likelihood evals, color coded by number of slices
plt.figure()
colors = s_array
unique_c = np.unique(colors)
for j, s in enumerate(s_array):
    plt.errorbar(num_likelihood_evals_array[j, :], log_Z_mean_array[j, :], yerr=log_Z_uncert_array[j, :], fmt='o',
                 c=color(s, unique_c),
                 label=f"Num Slices: {s}")
plt.xlabel("Num Likelihood Evaluations")
plt.ylabel("Bias (nats)")
plt.gca().axhline(true_logZ, color='k', linestyle='--')
plt.legend(loc='lower right')
plt.show()

# Using phantom samples to improve sample efficiency
Using the above result, the sample efficiency can be significantly boosted by using a large enough number of slices, and larger phantom fraction. We can easily see this looking at bias vs run time speed up. The run time speed is defined as the ratio of the run time with no phantom samples to the run time with phantom samples. We see a speed up of almost 4x with a phantom fraction of 0.8, and slice factor >= 3. This is a significant improvement in sample efficiency, and is the key to achieving high sample efficiency with nested sampling in high dimensions.

In [None]:
# So imshow of bias over k and s.

plt.figure()
plt.imshow(log_Z_mean_array - true_logZ, origin='lower',
           extent=[np.min(k_array), np.max(k_array), np.min(s_array), np.max(s_array)],
           aspect='auto', cmap='PuOr')
plt.xlabel("Num Phantom Samples")
plt.ylabel("Num Slices")
plt.colorbar(label="Bias (nats)")
plt.show()


In [None]:
# Plot log_Z with y error bars vs phantom fraction
plt.figure()
unique_c = np.unique(s_array)
phantom_fraction = k_array / (k_array + 1)
for i, s in enumerate(s_array):
    plt.errorbar(phantom_fraction, log_Z_mean_array[i, :], yerr=log_Z_uncert_array[i, :], fmt='o', c=color(s, unique_c),
                 label=f"Num Slices: {s}")
plt.xlabel("Phantom Fraction")
plt.ylabel(r" $\log Z$ (nats)")
plt.gca().axhline(true_logZ, color='k', linestyle='--')
plt.legend(loc='lower right')
plt.show()


In [None]:
# Plot log Z vs sample efficiency, color coded by number of slices
plt.figure()
colors = s_array
unique_c = np.unique(colors)
for j, s in enumerate(s_array):
    plt.errorbar(sample_efficiency_array[j, :], log_Z_mean_array[j, :], yerr=log_Z_uncert_array[j, :], fmt='o',
                 c=color(s, unique_c),
                 label=f"Num Slices: {s}")
plt.xlabel("Sample Efficiency")
plt.ylabel(r" $\log Z$ (nats)")
plt.gca().axhline(true_logZ, color='k', linestyle='--')
plt.legend(loc='lower right')
plt.show()

In [None]:
# Plot log_Z vs run time, color coded by number of phantom samples
plt.figure()
colors = k_array
unique_c = np.unique(colors)
for i, k in enumerate(k_array):
    plt.errorbar(run_time_array[:, i], log_Z_mean_array[:, i], yerr=log_Z_uncert_array[:, i], fmt='o',
                 c=color(k, unique_c),
                 label=f"Num Phantom Samples: {k}")
plt.xlabel("Run Time (s)")
plt.ylabel(r" $\log Z$ (nats)")
plt.gca().axhline(true_logZ, color='k', linestyle='--')
plt.legend(loc='lower right')
plt.show()

In [None]:
# Plot speed up vs phantom fraction, plotting only those with slice factor >= 3

plt.figure()
colors = s_array
unique_c = np.unique(colors)
phantom_fraction = k_array / (k_array + 1)
for j, s in enumerate(s_array):
    plt.errorbar(phantom_fraction, run_time_speed_up_array[j, :], fmt='o', c=color(s, unique_c),
                 label=f"Num Slices: {s}")
plt.xlabel("Phantom Fraction")
plt.ylabel("Run Time Speed Up")
plt.gca().axhline(1, color='k', linestyle='--')
plt.legend(loc='lower right')
plt.show()

# Ablation study, large values of phantom fraction should introduce autocorrelation

We explore the impact of large values of phantom fraction on the resulting bias. We see that for large values of phantom fraction, the bias increases significantly. This is due to the fact that the phantom samples are no longer i.i.d. and thus the resulting log-evidence estimate is biased.

We restrict ourselves to `slice_factor=6` as this is the largest value of slice factor we consider, and thus the bias is independent of the number of slices. We explore `num_phantom` from `{s, 2*s, ..., s * (D-1) - 1, s * D - 1}`.


In [None]:
bias_array = []
rms_array = []
num_likelihood_evals_array = []
num_slice_factor_array = []
phantom_fraction_array = []
run_time_array = []
total_num_samples_array = []
total_num_phantom_samples_array = []
m = 100
num_slice_factor = 6
num_phantom_save_array = list(range(num_slice_factor)) + [num_slice_factor * i for i in range(1, ndims)] + [
    num_slice_factor * ndims - 1]
for num_phantom_save in num_phantom_save_array:
    samples_per_iter = (1 + num_phantom_save)
    num_live_points_effective = ndims * 64  # lcm is not valid anymore.
    num_live_points = int(num_live_points_effective / samples_per_iter) + 1

    nested_sampler = StandardStaticNestedSampler(
        model=model,
        num_live_points=num_live_points,
        max_samples=100000,
        sampler=UniDimSliceSampler(
            model=model,
            num_slices=model.U_ndims * num_slice_factor,
            num_phantom_save=num_phantom_save,
            midpoint_shrink=True,
            perfect=True
        ),
        init_efficiency_threshold=0.1,
        num_parallel_workers=1
    )


    @jax.jit
    def run(key):
        termination_reason, state = nested_sampler._run(key=key, term_cond=TerminationCondition())
        results = nested_sampler._to_results(termination_reason=termination_reason, state=state, trim=False)
        return results.log_Z_mean, results.total_num_likelihood_evaluations, results.total_num_samples, results.total_phantom_samples


    run_compiled = run.lower(random.PRNGKey(0)).compile()

    t0 = time.time()
    results = []
    for i in range(m):
        results.append(run(random.PRNGKey(i)))

    print(f"Time taken n={num_live_points} k={num_phantom_save}: {time.time() - t0}")
    log_Z, num_likelihood_evals, total_num_samples, total_phantom_samples = jnp.mean(jnp.asarray(results), axis=0)
    run_time_array.append((time.time() - t0) / m)
    bias_array.append(jnp.mean(jnp.asarray(log_Z) - true_logZ))
    rms_array.append(jnp.sqrt(jnp.mean(jnp.square(jnp.asarray(log_Z) - true_logZ))))
    num_likelihood_evals_array.append(jnp.mean(jnp.asarray(num_likelihood_evals)))
    num_slice_factor_array.append(num_slice_factor)
    phantom_fraction_array.append(num_phantom_save / samples_per_iter)
    total_num_samples_array.append(total_num_samples)
    total_num_phantom_samples_array.append(total_phantom_samples)

# Save the results into npz file
save_file = "bias_experiment_results_ablation.npz"
np.savez(
    save_file,
    bias_array=np.asarray(bias_array),
    rms_array=np.asarray(rms_array),
    num_likelihood_evals_array=np.asarray(num_likelihood_evals_array),
    num_slice_factor_array=np.asarray(num_slice_factor_array),
    phantom_fraction_array=np.asarray(phantom_fraction_array),
    run_time_array=np.asarray(run_time_array),
    total_num_samples_array=np.asarray(total_num_samples_array),
    total_num_phantom_samples_array=np.asarray(total_num_phantom_samples_array)
)

In [None]:
# Load results into arrays of same names
save_file = "bias_experiment_results_ablation.npz"

npzfile = np.load(save_file)
bias_array = npzfile['bias_array']
rms_array = npzfile['rms_array']
num_likelihood_evals_array = npzfile['num_likelihood_evals_array']
num_slice_factor_array = npzfile['num_slice_factor_array']
phantom_fraction_array = npzfile['phantom_fraction_array']
run_time_array = npzfile['run_time_array']
total_num_samples_array = npzfile['total_num_samples_array']
total_num_phantom_samples_array = npzfile['total_num_phantom_samples_array']

# log_Z_mean = bias_array + true_logZ
# new_log_Z_mean = -109.15
# bias_array = bias_array - (new_log_Z_mean - true_logZ)
# rms_array = jnp.sqrt(rms_array**2 - (true_logZ - log_Z_mean)**2 + (new_log_Z_mean - log_Z_mean)**2)



In [None]:
# Plot bias with rms y error bars vs phantom fraction, colored by number of likelihood evals (with color bar not labels)
plt.figure()
colors = num_likelihood_evals_array
k = 1 / (1 / phantom_fraction_array - 1)
for _k, _b, _rms, _c in zip(k, bias_array, rms_array, colors):
    plt.errorbar(_k, _b, yerr=_rms, fmt='o', c=color(_c, np.unique(colors)))
plt.xlabel("Num Phantom Samples Retained")
plt.ylabel("Bias (nats)")
plt.gca().axhline(0, color='k', linestyle='--')
plt.gca().axvline(model.U_ndims, color='r', linestyle='--')
# Create custom mappable for colorbar (errorbar() doesn't make a mappable like scatter)
sm = plt.cm.ScalarMappable(cmap=cm, norm=plt.Normalize(np.min(colors), np.max(colors)))
sm.set_array([])
plt.colorbar(sm, label="Num Likelihood Evaluations", ax=plt.gca())
plt.savefig("bias_vs_phantom_fraction_ablation.png", dpi=300)
plt.savefig("bias_vs_phantom_fraction_ablation.pdf", dpi=300)
plt.show()