In [1]:
%env JAX_ENABLE_X64 True

env: JAX_ENABLE_X64=True


In [2]:
import time
start_time = time.time()

In [38]:
import logging
import multiprocessing

import numpy as np

from scipy.integrate import solve_ivp
from scipy.stats import gaussian_kde
from scipy.stats import multivariate_normal as mvn

import jax
import jax.numpy as jnp
from jax.scipy.stats import multivariate_normal as jmvn
from jax.scipy.stats import gaussian_kde as jgaussian_kde

import numdifftools

from stein_thinning.thinning import thin_gf

import lotka_volterra
import utils.caching
from utils.caching import make_cached_s3
from utils.parallel import apply_along_axis_parallel, get_map_parallel_joblib
from utils.paths import S3_BUCKET_NAME

In [4]:
logging.basicConfig()
logging.getLogger(utils.caching.__name__).setLevel(logging.DEBUG)

In [5]:
cached = make_cached_s3(S3_BUCKET_NAME)

In [6]:
map_parallel = get_map_parallel_joblib(multiprocessing.cpu_count())

Read-only input datasets:

In [7]:
@cached(batch_size=lotka_volterra.n_chains, read_only=True)
def rw_samples(i: int) -> np.ndarray:
    raise NotImplementedError

In [20]:
@cached(batch_size=lotka_volterra.n_chains, read_only=True)
def rw_log_p(i: int) -> np.ndarray:
    raise NotImplementedError

### Calculate log-pdf and gradient

In [9]:
def calculate_logpdf_and_grad(point, kernel_locations, kernel_cov):
    pdf = mvn.pdf(kernel_locations, mean=point, cov=kernel_cov)
    weighted_grad = np.einsum('ij,kj->ki', np.linalg.inv(kernel_cov), kernel_locations - point) * pdf.reshape(-1, 1)
    mixture_pdf = np.average(pdf)
    return np.concatenate([[np.log(mixture_pdf)], np.average(weighted_grad, axis=0) / mixture_pdf])

#### Verify KDE values against `scipy`

In [28]:
rng = np.random.default_rng(12345)

In [30]:
mean = np.array([0., 0.])
cov = np.array([
    [1., 0.8],
    [0.8, 1.],
])

In [31]:
n = 1000

In [32]:
sample = mvn.rvs(mean=mean, cov=cov, size=n, random_state=rng)

In [34]:
scipy_kde = gaussian_kde(sample.T, bw_method='silverman')

In [69]:
def scipy_logpdf(x):
    return scipy_kde.logpdf(x.T).squeeze()

In [76]:
logpdf_and_grad = np.apply_along_axis(lambda point: calculate_logpdf_and_grad(point, sample, scipy_kde.covariance), 1, sample)

In [77]:
baseline_logpdf = np.apply_along_axis(scipy_logpdf, 1, sample)
baseline_grad = np.apply_along_axis(numdifftools.Gradient(scipy_logpdf), 1, sample)

In [78]:
np.testing.assert_allclose(logpdf_and_grad[:, 0], baseline_logpdf)

In [79]:
np.testing.assert_allclose(logpdf_and_grad[:, 1:], baseline_grad)

### Check if a JAX-based version provide a performance improvement

In [10]:
@jax.jit
def calculate_logpdf_and_grad_jax(point, kernel_locations, kernel_cov):
    pdf = jmvn.pdf(kernel_locations, mean=point, cov=kernel_cov)
    kernel_cov_inv = jnp.linalg.inv(kernel_cov)
    weighted_grad = ((kernel_cov_inv @ (kernel_locations - point).T) * pdf).T
    mixture_pdf = jnp.average(pdf)
    return jnp.concatenate([jnp.log(mixture_pdf).reshape(1), jnp.average(weighted_grad, axis=0) / mixture_pdf])

In [11]:
i = 0
kernel_cov = gaussian_kde(rw_samples[i].T, bw_method='silverman').covariance
kernel_locations = rw_samples[i]
point = rw_samples[i][0]

DEBUG:utils.caching:Reading from disk cache: rw_samples_0
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_samples_0.npy


In [12]:
%%timeit
calculate_logpdf_and_grad(point, kernel_locations, kernel_cov)

118 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%%timeit
calculate_logpdf_and_grad_jax(jnp.array(point), jnp.array(kernel_locations), jnp.array(kernel_cov))

275 μs ± 151 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)


### KDE approximation for the MCMC output

In [None]:
chunk_size = 1000

In [81]:
kde = gaussian_kde(rw_samples[0].T, bw_method='silverman')

In [87]:
def scipy_logpdf(x):
    return kde.logpdf(x.T).squeeze()

In [90]:
%%time
grad100 = np.apply_along_axis(numdifftools.Gradient(scipy_logpdf), 1, rw_samples[0][:100])

CPU times: user 10min 29s, sys: 3min 16s, total: 13min 46s
Wall time: 7min 12s


In [92]:
%%time
logpdf100 = np.apply_along_axis(kde.logpdf, 1, rw_samples[0][:100])

CPU times: user 5.27 s, sys: 1.85 s, total: 7.11 s
Wall time: 3.75 s


In [96]:
np.testing.assert_allclose(logpdf100.squeeze(), rw_kde_log_grad_q[0][:100, 0])

In [98]:
np.argmax(np.abs(grad100 - rw_kde_log_grad_q[0][:100, 1:]))

np.int64(155)

In [99]:
grad100.flatten()[155]

np.float64(10858.240920035441)

In [100]:
rw_kde_log_grad_q[0][:100, 1:].flatten()[155]

np.float64(2.7800323043627637e-12)

In [97]:
np.testing.assert_allclose(grad100, rw_kde_log_grad_q[0][:100, 1:])

AssertionError: 
Not equal to tolerance rtol=1e-07, atol=0

Mismatched elements: 356 / 400 (89%)
Max absolute difference among violations: 10858.24092004
Max relative difference among violations: 3.4427717e+20
 ACTUAL: array([[ 1.584052e+02, -1.084620e+02,  7.956389e+01,  2.050563e+01],
       [ 1.584052e+02, -1.084620e+02,  7.956389e+01,  2.050563e+01],
       [-5.509490e+02,  3.772758e+02, -2.767567e+02, -7.132723e+01],...
 DESIRED: array([[ 1.583927e+02, -1.084618e+02,  7.956389e+01,  2.050563e+01],
       [ 1.583927e+02, -1.084618e+02,  7.956389e+01,  2.050563e+01],
       [-5.509567e+02,  3.772759e+02, -2.767567e+02, -7.132723e+01],...

In [14]:
@cached(batch_size=len(lotka_volterra.theta_inits))
def rw_kde_log_grad_q(i: int) -> np.ndarray:
    kernel_cov = gaussian_kde(rw_samples[i].T, bw_method='silverman').covariance
    def evaluate_for_row(row):
        logpdf, grad_logpdf = calculate_logpdf_and_grad(row, rw_samples[i], kernel_cov)
        return np.concatenate([[logpdf], grad_logpdf])
    return apply_along_axis_parallel(evaluate_for_row, 1, rw_samples[i], chunk_size, map_parallel)

In [15]:
@cached(batch_size=len(lotka_volterra.theta_inits))
def rw_kde_log_grad_q_jax(i: int) -> np.ndarray:
    points = jnp.array(rw_samples[i])
    kernel_cov = jgaussian_kde(points.T, bw_method='silverman').covariance
    def evaluate_for_row(point):
        return np.array(calculate_logpdf_and_grad_jax(jnp.array(point), points, kernel_cov))
    return apply_along_axis_parallel(evaluate_for_row, 1, rw_samples[i], chunk_size, map_parallel)

In [17]:
[np.max(np.abs(rw_kde_log_grad_q_jax(i) - rw_kde_log_grad_q(i)) / np.median(np.abs(rw_kde_log_grad_q(i)), axis=0)) for i in range(lotka_volterra.n_chains)]

DEBUG:utils.caching:Reading from disk cache: rw_kde_log_grad_q_0
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_kde_log_grad_q_0.npy
DEBUG:utils.caching:Reading from disk cache: rw_kde_log_grad_q_1
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_kde_log_grad_q_1.npy
DEBUG:utils.caching:Reading from disk cache: rw_kde_log_grad_q_2
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_kde_log_grad_q_2.npy
DEBUG:utils.caching:Reading from disk cache: rw_kde_log_grad_q_3
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_kde_log_grad_q_3.npy
DEBUG:utils.caching:Reading from disk cache: rw_kde_log_grad_q_4
DEBUG:utils.caching:Reading from S3 gradient-free-mcmc-postprocessing/rw_kde_log_grad_q_4.npy


[np.float64(1.3631493522886603e-12),
 np.float64(1.7654255704939414e-12),
 np.float64(4.5276080328998e-11),
 np.float64(8.261661579816863e-11),
 np.float64(1.0059566778282408e-11)]

In [27]:
fig, ax = plt.subplots()
ax.s

np.float64(10.121746406743608)

### Apply gradient-free thinning

In [23]:
n_points = 100

In [24]:
idx = thin_gf(
    rw_samples[0],
    rw_log_p[0],
    rw_kde_log_grad_q[0][:, 0],
    rw_kde_log_grad_q[0][:, 1:],
    n_points=n_points,
    preconditioner='med',
)
idx

  np.exp(log_q_m_p[ind1] + log_q_m_p[ind2]) * vfk0(sample[ind1], sample[ind2], gradient_q[ind1], gradient_q[ind2])
  k0 += 2 * integrand(slice(None), [idx[i - 1]])


array([189892,      2,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0,      0,      0,      0,      0,
            0,      0,      0,      0], dtype=uint32)

Notebook execution took:

In [18]:
time.time() - start_time

31.76407742500305