In [10]:
import importlib

from matplotlib import pyplot as plt
import torch
from torch import Tensor

import deep_tensor as dt

from examples.double_banana.double_banana import DoubleBanana
from examples.plotting import corner_plot

In [11]:
plt.style.use("../plotstyle.mplstyle")
torch.manual_seed(1)

importlib.reload(dt)

<module 'deep_tensor' from '/Users/adeb0907/Documents/usyd/deep-tensor-py/deep_tensor/__init__.py'>

Build BAE-inspired preconditioner

In [12]:
poly = dt.Lagrange1(num_elems=30)

sigma = 0.3
data = torch.tensor([3.0, 5.0])
model = DoubleBanana(sigma, data)

# Build linear sample-based preconditioner
n_xs = 1000
xs = torch.normal(0.0, 1.0, size=(n_xs, 2))
ys = model.param_to_obs(xs)
samples = torch.hstack((ys, xs))

preconditioner = dt.SampleBasedPreconditioner(samples)

Build DIRT object for the joint density

In [13]:
bridge = dt.Tempering()

dirt = dt.DIRT(
    model.neglogpri_joint,
    model.negloglik_joint, 
    preconditioner, 
    poly, 
    bridge=bridge
)

[32m[DIRT][0m Iter:  1 | Cum. Fevals: 2.00e+03 | Cum. Time: 2.11e-02 s | Beta: 0.0001 | ESS: 0.9988
[94m[ALS][0m  Iter | Func Evals | Max Rank | Max Local Error | Mean Local Error | Max Debug Error | Mean Debug Error
[94m[ALS][0m     1 |      16492 |       22 |     1.00000e+00 |      1.00000e+00 |     1.61962e-02 |      6.28852e-03
[94m[ALS][0m  TT-cross complete. Final TT ranks: [22, 22, 1].
[32m[DIRT][0m Iter:  2 | Cum. Fevals: 3.70e+04 | Cum. Time: 9.54e-02 s | Beta: 0.0132 | ESS: 0.5014 | DHell: 0.0044
[94m[ALS][0m  Iter | Func Evals | Max Rank | Max Local Error | Mean Local Error | Max Debug Error | Mean Debug Error
[94m[ALS][0m     1 |      17143 |       24 |     8.50784e-01 |      7.48471e-01 |     1.26647e+00 |      3.89890e-01
[94m[ALS][0m  TT-cross complete. Final TT ranks: [24, 21, 1].
[32m[DIRT][0m Iter:  3 | Cum. Fevals: 7.33e+04 | Cum. Time: 3.74e-01 s | Beta: 0.0882 | ESS: 0.4766 | DHell: 0.0907
[94m[ALS][0m  Iter | Func Evals | Max Rank | Max Local E



[94m[ALS][0m     1 |      23467 |       28 |     5.24098e-01 |      5.13670e-01 |     7.57166e-01 |      3.05050e-01
[94m[ALS][0m  TT-cross complete. Final TT ranks: [28, 25, 1].
[32m[DIRT][0m Iter:  5 | Cum. Fevals: 1.65e+05 | Cum. Time: 2.17e+00 s | Beta: 1.0000 | ESS: 0.4546 | DHell: 0.1480
[94m[ALS][0m  Iter | Func Evals | Max Rank | Max Local Error | Mean Local Error | Max Debug Error | Mean Debug Error
[94m[ALS][0m     1 |      27001 |       30 |     8.28918e-01 |      7.03403e-01 |     7.88920e-01 |      2.65924e-01
[94m[ALS][0m  TT-cross complete. Final TT ranks: [30, 27, 1].
[32m[DIRT][0m DIRT construction complete.
[32m[DIRT][0m  • Layers: 5.
[32m[DIRT][0m  • Total function evaluations: 218568.
[32m[DIRT][0m  • Total time: 3.26 s.


In [14]:
xs_test = torch.randn((10, 2))
ys_test = model.param_to_obs(xs_test)
# ys_test += model.var_error * torch.randn_like(ys_test)  # add error

# For DHell
n_rs = 50_000
rs = dirt.reference.random(d=2, n=n_rs)

In [15]:
def plot_density_comparison(xs_grid, ys_grid, fxs_true, fxs_dirt, x_true, dhell, fname):

    fig, axes = plt.subplots(
        nrows=1, ncols=2, 
        figsize=(8, 4), 
        sharex=True, sharey=True
    )
    
    axes[0].pcolormesh(xs_grid, ys_grid, fxs_true, rasterized=True)
    axes[1].pcolormesh(xs_grid, ys_grid, fxs_dirt, rasterized=True)
    axes[0].set_title(r"$f(x)$ (True)")
    axes[1].set_title(r"$\hat{f}(x)$ (DIRT)")
    axes[0].set_ylabel(r"$x_{1}$")

    for ax in axes:
        ax.scatter(*x_true, c="k", marker="x", s=5)
        ax.set_xlabel(r"$x_{0}$")

    plt.suptitle(r"$\mathcal{D}_{\mathrm{H}}$"+f": {dhell:.4f}")
    plt.savefig(fname)
    
    return

Generate a grid to evaluate the target PDF on.

In [None]:
n_grid = 100

xs_grid = torch.linspace(-3.0, 3.0, n_grid)
ys_grid = torch.linspace(-3.0, 3.0, n_grid)

dx = xs_grid[1] - xs_grid[0]
grid = torch.tensor([[x, y] for y in ys_grid for x in xs_grid])

In [None]:
# for i, y_i in enumerate(ys_test):

#     y_is = y_i.repeat(n_grid**2, 1)
#     yx_is = torch.hstack((y_is, grid))

#     # Evaluate true conditional density on grid
#     neglogfxs_true = model.potential_joint(yx_is)
#     fxs_true = torch.exp(-neglogfxs_true)
#     fxs_true /= (fxs_true.sum() * dx**2)
#     fxs_true = fxs_true.reshape(n_grid, n_grid)

#     # Evaluate CIRT density on grid
#     neglogfxs_ys = dirt.eval_potential(y_is)
#     rs_grid, neglogfxs_grid = dirt.eval_rt(torch.hstack((y_is, grid)))
#     neglogfxs_dirt = neglogfxs_grid - neglogfxs_ys
#     fxs_dirt = torch.exp(-neglogfxs_dirt).reshape(n_grid, n_grid)

#     # Estimate Hellinger distance
#     xs, neglogfxs_dirt = dirt.eval_cirt(y_i, rs)
#     yxs = torch.hstack((y_i.repeat(n_rs, 1), xs))
#     neglogfxs_true = model.potential_joint(yxs)

#     def potential(xs: Tensor) -> Tensor:
#         y_is = y_i.repeat(xs.shape[0], 1)
#         yxs = torch.hstack((y_is, xs))
#         return model.potential_joint(yxs)
    
#     ms = dt.run_dirt_pcn(potential, dirt, n=1_000, dt=10.0, y_obs=y_i)

#     fig, ax = plt.subplots()
#     ax.pcolormesh(xs_grid, ys_grid, fxs_true, rasterized=True)
#     ax.plot(*ms[:10].T, c="white")
#     plt.show()

    # dhell2 = dt.compute_f_divergence(-neglogfxs_dirt, -neglogfxs_true, div="h2")
    # dhell = dhell2.sqrt()
    # print(f"Posterior {i}: DHell {dhell:.4f}")

    # plot_density_comparison(
    #     xs_grid, 
    #     ys_grid, 
    #     fxs_true, 
    #     fxs_dirt, 
    #     xs_test[i], 
    #     dhell=dhell, 
    #     fname=f"figures/02_posterior_{i}.pdf"
    # )



Run pCN-MCMC

In [None]:
for i, y_i in enumerate(ys_test):

    def potential(xs: Tensor) -> Tensor:
        y_is = y_i.repeat(xs.shape[0], 1)
        yxs = torch.hstack((y_is, xs))
        return model.potential_joint(yxs)
    
    res = dt.run_dirt_pcn(potential, dirt, n=5000, y_obs=y_i)
    print(res.acceptance_rate)

    rs = dirt.reference.random(d=2, n=5000)
    xs, neglogfxs_cirt = dirt.eval_cirt(y_i, rs)
    neglogfxs = potential(xs)

    res = dt.run_independence_sampler(xs, neglogfxs_cirt, neglogfxs)
    print(res.acceptance_rate)

    res = dt.run_importance_sampling(neglogfxs_cirt, neglogfxs)
    print(res.ess)