In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import numpy as np
import torch
from sbi import analysis as analysis
from sbi import utils as utils
from sbi.inference import NPE, simulate_for_sbi
from sbi.utils.user_input_checks import (
    check_sbi_inputs,
    process_prior,
    process_simulator,
)
print("WARNING: CUSTOM VERSION OF SBI REQUIRED")
print("This can be found at: https://github.com/james-alvey-42/sbi-acp")
import matplotlib.pyplot as plt
import corner

In [39]:
num_dim = 3
prior = utils.BoxUniform(low=-2 * torch.ones(num_dim), high=2 * torch.ones(num_dim))

def simulator(theta):
    # linear gaussian
    return theta + 1.0 + torch.randn_like(theta) * 0.1

# Check prior, simulator, consistency
prior, num_parameters, prior_returns_numpy = process_prior(prior)
simulator = process_simulator(simulator, prior, prior_returns_numpy)
check_sbi_inputs(simulator, prior)
theta_obs = torch.tensor([[1.0, 1.0, 1.0]])
x_obs = simulator(theta_obs)

In [6]:
def get_R(samples):
    """
    Computes the Gelman-Rubin (GR) statistic for convergence assessment. The
    GR statistic is a convergence diagnostic used to assess whether multiple
    Markov chains have converged to the same distribution. Values close to 1
    indicate convergence. For details see
    https://en.wikipedia.org/wiki/Gelman-Rubin_statistic

    Parameters:
    -----------
    samples : numpy.ndarray
        Array containing MCMC samples with dimensions
        (N_steps, N_chains, N_parameters).

    Returns:
    --------
    R : numpy.ndarray
        Array containing the Gelman-Rubin statistics indicating convergence for
        the different parameters. Values close to 1 indicate convergence.

    """

    # Get the shapes
    N_steps, N_chains, N_parameters = samples.shape

    # Chain means
    chain_mean = np.mean(samples, axis=0)

    # Global mean
    global_mean = np.mean(chain_mean, axis=0)

    # Variance between the chain means
    variance_of_means = (
        N_steps
        / (N_chains - 1)
        * np.sum((chain_mean - global_mean[None, :]) ** 2, axis=0)
    )

    # Variance of the individual chain across all chains
    intra_chain_variance = np.std(samples, axis=0, ddof=1) ** 2

    # And its averaged value over the chains
    mean_intra_chain_variance = np.mean(intra_chain_variance, axis=0)

    # First term
    term_1 = (N_steps - 1) / N_steps

    # Second term
    term_2 = variance_of_means / mean_intra_chain_variance / N_steps

    # This is the R (as a vector running on the paramters)
    return term_1 + term_2

In [None]:
num_models = 10
models = {}
for idx in range(num_models):
    inference = NPE(prior=prior)

    # generate simulations and pass to the inference object
    theta, x = simulate_for_sbi(simulator, proposal=prior, num_simulations=2000)
    inference = inference.append_simulations(theta, x)
    models["model_" + str(idx)] = {
        "inference": inference,
        "posterior_samples": [],
        "theta": theta,
        "x": x,
        "epochs_trained": 0,
        "train_losses": [],
        "validation_losses": [],
        "state_dict_loc": "state_dict/model_" + str(idx) + "_epoch_",
    }

In [None]:
max_epochs = 200
for epoch in range(max_epochs):
    for model in models:
        model_info = models[model]
        _density_estimator = model_info["inference"].train(
            max_num_epochs=0,
            # learning_rate=1e-4,
            model_info=model_info,
            force_first_round_loss=True,
        )
        _posterior = model_info["inference"].build_posterior(_density_estimator)
        _posterior_samples = _posterior.sample((1000,), x=x_obs)
        model_info["posterior_samples"].append(_posterior_samples)
        model_info["epochs_trained"] += 1

for model in models:
    model_info = models[model]
    model_info["train_losses"] = model_info["inference"]._summary["training_loss"]
    model_info["validation_losses"] = model_info["inference"]._summary["validation_loss"]


In [None]:
fig = plt.figure(figsize=(10, 5))
for model in models:
    model_info = models[model]
    ax = plt.subplot(1, 2, 1)
    plt.plot(model_info["train_losses"])
    ax = plt.subplot(1, 2, 2)
    plt.plot(model_info["validation_losses"])

In [None]:
colors = plt.cm.viridis(np.linspace(0, 1, num_models))
# make a gif over the epochs
for epoch in range(max_epochs):
    for m_idx, model in enumerate(models):
        model_info = models[model]
        if m_idx == 0:
            corner.corner(
                model_info["posterior_samples"][epoch].numpy(),
                labels=[r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"],
                truths=[1.0, 1.0, 1.0],
                color=colors[m_idx],
            )
        else:
            corner.corner(
                model_info["posterior_samples"][epoch].numpy(),
                fig=plt.gcf(),
                color=colors[m_idx],
            )
    plt.savefig(
        "../figures/epoch_" + str(epoch) + ".png"
    )
    plt.clf()

import imageio

images = []
for epoch in range(max_epochs):
    images.append(imageio.imread(f"figures/epoch_{epoch}.png"))
imageio.mimsave('../figures/posterior.gif', images, duration=1.5)

In [None]:
R_values = []
for epoch in range(max_epochs):
    post_samples = []
    for model in models:
        model_info = models[model]
        post_samples.append(model_info["posterior_samples"][epoch].numpy())
    R_val = get_R(np.transpose(post_samples, (1, 0, 2)))
    R_values.append(R_val)
plt.plot(np.array(R_values))


In [None]:
fig = plt.figure(figsize=(10, 5))
for model in models:
    model_info = models[model]
    ax = plt.subplot(1, 2, 1)
    plt.plot(model_info["train_losses"])
    ax = plt.subplot(1, 2, 2)
    plt.plot(model_info["validation_losses"])

In [None]:
R_values = []
for epoch in range(max_epochs):
    post_samples = []
    for _ in range(10):
        post_samples.append(np.random.randn(1000, 3))
    R_val = get_R(np.transpose(post_samples, (1, 0, 2)))
    R_values.append(R_val)
plt.plot(np.array(R_values))

In [None]:
R_values = []
for epoch in range(max_epochs):
    post_samples = []
    for _ in range(10):
        post_samples.append(np.random.randn(100, 3))
    R_val = get_R(np.transpose(post_samples, (1, 0, 2)))
    R_values.append(R_val)
plt.plot(np.array(R_values))

In [None]:
R_values = []
for epoch in range(max_epochs):
    post_samples = []
    for _ in range(10):
        post_samples.append(np.random.randn(10, 3))
    R_val = get_R(np.transpose(post_samples, (1, 0, 2)))
    R_values.append(R_val)
plt.plot(np.array(R_values))

In [None]:
post_samples = []
true_posterior = np.random.normal(x_obs - 1, 0.1, (10000, 3))
for model in models:
    model_info = models[model]
    post_samples.append(model_info["posterior_samples"][-1].numpy())
print(np.array(post_samples).shape)
corner.corner(np.reshape(np.array(post_samples), (10 * 1000, 3)));
#corner.corner(model_info["posterior_samples"][-1].numpy(), color='green', fig=plt.gcf())
corner.corner(true_posterior, fig = plt.gcf(), color="red");