# Modelling a Lens image using a No U-Turn Sampler

In this hypothetical scenario we have an image of galaxy galaxy strong lensing and we would like to recover a model of this scene. Thus we will need to determine parameters for the background source light, the lensing galaxy light, and the lensing galaxy mass distribution. A common technique for analyzing strong lensing systems is a Markov Chain Monte-Carlo which can explore the parameter space and provide us with important metrics about the model and uncertainty on all parameters. Since caustics is differentiable we have access to especially efficient gradient based MCMC algorithms. A very convenient algorithm is the No U-Turn Sampler, or NUTS, which uses derivarives to efficiently explore the likelihood distribution by treating it like a potential that a point mass is exploring. The NUST version we use as implemented in the Pyro package has no tunable parameters, thus we can simply give it a start point and it will explore for as many iterations as we give it. What's more, NUTS is so efficient that very often the autocorrelation length for the samples is approximately 1, meaning that each sample is independent from all the others! This is especially handy in the complex non-linear space of strong lensing models.

In [None]:
import caustics
import numpy as np
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import norm

import pyro
import pyro.distributions as dist
from pyro.infer import MCMC as pyro_MCMC
from pyro.infer import NUTS as pyro_NUTS

## Specs for the data

These are some properties of the data that aren't very interesting for the demo, it includes the size of the image, pixelscale, noise level, etc.

In [None]:
# Data specs

background_rms = 0.005  #  background noise per pixel
exp_time = 500.0  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
numPix = 60  #  cutout pixel size per axis
pixelscale = 0.05  #  pixel size in arcsec (area per pixel = pixel_scale**2)
fwhm = 0.05  # full width at half maximum of PSF
psf_sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
psf_type = "GAUSSIAN"  # 'GAUSSIAN', 'PIXEL', 'NONE'

cosmology = caustics.FlatLambdaCDM(name="cosmo")
cosmology.to(dtype=torch.float32)

upsample_factor = 1
thx, thy = caustics.utils.get_meshgrid(
    pixelscale / upsample_factor,
    upsample_factor * numPix,
    upsample_factor * numPix,
    dtype=torch.float32,
)
z_l = torch.tensor(0.5, dtype=torch.float32)
z_s = torch.tensor(1.5, dtype=torch.float32)

## Build simulator forward model

Here we build the caustics simulator which will handle the lensing and generating our images for the sake of fitting. It includes a model for the lens mass distribution, lens light, and source light. We also include a simple gaussian PSF for extra realism, though for simplicity we will use the same PSF model for simulating the mock data and fitting.

In [None]:
# Set up the forward model

# Lens mass model (SIE + shear)
lens_sie = caustics.SIE(name="galaxylens", cosmology=cosmology, z_l=1.0)
lens_shear = caustics.ExternalShear(
    name="externalshear", cosmology=cosmology, x0=0.0, y0=0.0, z_l=1.0
)
lens_mass_model = caustics.SinglePlane(
    name="lensmass", cosmology=cosmology, lenses=[lens_sie, lens_shear], z_l=1.0
)

# Lens light model (sersic)
lens_light_model = caustics.Sersic(name="lenslight")

# Source light model (sersic)
source_light_model = caustics.Sersic(name="sourcelight")

# Gaussian PSF Model
psf_image = caustics.utils.gaussian(
    nx=upsample_factor * 6 + 1,
    ny=upsample_factor * 6 + 1,
    pixelscale=pixelscale / upsample_factor,
    sigma=psf_sigma,
    upsample=2,
)

# Image plane simulator
sim = caustics.Lens_Source(
    lens=lens_mass_model,
    lens_light=lens_light_model,
    source=source_light_model,
    psf=psf_image,
    pixels_x=numPix,
    pixelscale=pixelscale,
    upsample_factor=upsample_factor,
    z_s=2.0,
)

## Sample some mock data

Here we write out the true values for all the parameters in the model. In total there are 21 parameters, so this is quite a complex model already! We then plot the data so we can see what it is we re trying to fit.

Note that when we sample the simulator we call it with quad_level=7. This means the simulator will use gaussian quadrature sub-pixel integration to ensure the brightness of each pixel is very accurately computed.

In [None]:
# Generate the mock data
true_params = {
    "galaxylens": {
        "x0": 0.05,
        "y0": 0.0,
        "q": 0.86,
        "phi": -0.20,
        "b": 0.66,
    },
    "externalshear": {"gamma_1": 0.0, "gamma_2": -0.05},
    "sourcelight": {
        "x0": 0.1,
        "y0": 0.0,
        "q": 0.75,
        "phi": 1.18,
        "n": 1.0,
        "Re": 0.1 / np.sqrt(0.75),
        "Ie": 16 * pixelscale**2,
    },
    "lenslight": {
        "x0": 0.05,
        "y0": 0.0,
        "q": 0.75,
        "phi": 1.18,
        "n": 2.0,
        "Re": 0.6 / np.sqrt(0.75),
        "Ie": 16 * pixelscale**2,
    },
}
allparams = []
for model in true_params:
    for key in true_params[model]:
        allparams.append(true_params[model][key])
allparams = torch.tensor(allparams)
print(true_params)

# simulate lens, crop extra evaluation for PSF
true_system = sim(allparams, quad_level=7)  # simulate at high resolution

fig, axarr = plt.subplots(1, 2, figsize=(15, 8))
axarr[0].imshow(
    np.log10(true_system.detach().cpu().numpy()), cmap="inferno", origin="lower"
)
axarr[0].axis("off")
axarr[0].set_title("Mock Lens System")
torch.manual_seed(42)
shot_noise = torch.normal(
    mean=torch.zeros_like(true_system),
    std=torch.sqrt(torch.abs(true_system) / exp_time),
)
background = torch.normal(
    mean=torch.zeros_like(true_system), std=torch.tensor(background_rms)
)
variance = (torch.abs(true_system) / exp_time) + background_rms**2
obs_system = true_system + shot_noise + background
print(((obs_system - true_system) ** 2 / variance).sum().item() / 3600)
axarr[1].imshow(
    np.log10(obs_system.detach().cpu().numpy()), cmap="inferno", origin="lower"
)
axarr[1].axis("off")
axarr[1].set_title("Mock Observation")
# plt.colorbar()
plt.show()

## Fit using NUTS

We now model the data using NUTS. First we need to construct a log likelihood function, this is what NUTS will be exploring. In our case this is just the squared residuals, divided by the variance in each pixel. As a prior, we just set some extremely wide values so that we will explore just the likelihood; in general one would want to pick more informative priors. The rest is specific to the Pyro NUTS implementation, though there are other codes which implement MCMC samples (for example emcee), it is possible to use any of them with caustics!

Note, we use 50 warmup steps for Pyro, this is so it can automatically determine an appropriate step size and compute a "mass matrix" which helps the sampler explore much more efficiently!

In [None]:
def step(model, prior):
    x = pyro.sample("x", prior)
    # Log-likelihood function
    res = model(x)
    log_likelihood_value = -0.5 * torch.sum(((res - obs_system) ** 2) / variance)
    # Observe the log-likelihood
    pyro.factor("obs", log_likelihood_value)


prior = dist.Normal(
    allparams,
    torch.ones_like(allparams) * 1e2 + torch.abs(allparams) * 1e2,
)
nuts_kwargs = {
    "jit_compile": True,
    "ignore_jit_warnings": True,
    "step_size": 1e-3,
    "full_mass": True,
    "adapt_step_size": True,
    "adapt_mass_matrix": True,
}

nuts_kernel = pyro_NUTS(step, **nuts_kwargs)
init_params = {"x": allparams}

# Run MCMC with the NUTS sampler and the initial guess
mcmc_kwargs = {
    "num_samples": 100,
    "warmup_steps": 50,
    "initial_params": init_params,
    "disable_progbar": False,
}

mcmc = pyro_MCMC(nuts_kernel, **mcmc_kwargs)

mcmc.run(sim, prior)

We have only taken 100 samples in this demo, in general you would want many more. However its always a good idea to plot the chains and check that they look uncorrelated, everything seems fine here!

In [None]:
chain = mcmc.get_samples()["x"]
chain = chain.numpy()

plt.plot(
    range(len(chain)),
    (chain - np.mean(chain, axis=0)) / np.std(chain, axis=0)
    + 5 * np.arange(len(allparams)),
)
plt.title("Chain for each parameter")
plt.show()

## Examine uncertainties

Just like in the LM example we can produce a corner plot of our parameters and parameter pair uncertainties. To keep with the format of the LM example, we use the 100 samples from NUTS to compute a covariance matrix for the parameters, as you can see it is nearly identical to what we recovered using LM. This makes sense since we are analyzing the same problem with the same likelihood surface, they should give the same results up to the approximation in LM and sampling statistics in NUTS.

In [None]:
def corner_plot(
    chain,
    labels=None,
    figsize=(10, 10),
    true_values=None,
    ellipse_colors="g",
):
    num_params = chain.shape[1]
    cov_matrix = np.cov(chain, rowvar=False)
    mean = np.mean(chain, axis=0)
    fig, axes = plt.subplots(num_params, num_params, figsize=figsize)
    plt.subplots_adjust(wspace=0.0, hspace=0.0)

    for i in range(num_params):
        for j in range(num_params):
            ax = axes[i, j]

            if i == j:
                x = np.linspace(
                    mean[i] - 3 * np.sqrt(cov_matrix[i, i]),
                    mean[i] + 3 * np.sqrt(cov_matrix[i, i]),
                    100,
                )
                y = norm.pdf(x, mean[i], np.sqrt(cov_matrix[i, i]))
                ax.plot(x, y, color="g")
                ax.set_xlim(
                    mean[i] - 3 * np.sqrt(cov_matrix[i, i]),
                    mean[i] + 3 * np.sqrt(cov_matrix[i, i]),
                )
                if true_values is not None:
                    ax.axvline(true_values[i], color="red", linestyle="-", lw=1)
            elif j < i:
                ax.scatter(chain[:, j], chain[:, i], color="c", s=0.1, zorder=0)
                cov = cov_matrix[np.ix_([j, i], [j, i])]
                lambda_, v = np.linalg.eig(cov)
                lambda_ = np.sqrt(lambda_)
                angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))
                for k in [1, 2]:
                    ellipse = Ellipse(
                        xy=(mean[j], mean[i]),
                        width=lambda_[0] * k * 2,
                        height=lambda_[1] * k * 2,
                        angle=angle,
                        edgecolor=ellipse_colors,
                        facecolor="none",
                    )
                    ax.add_artist(ellipse)

                # Set axis limits
                margin = 3
                ax.set_xlim(
                    mean[j] - margin * np.sqrt(cov_matrix[j, j]),
                    mean[j] + margin * np.sqrt(cov_matrix[j, j]),
                )
                ax.set_ylim(
                    mean[i] - margin * np.sqrt(cov_matrix[i, i]),
                    mean[i] + margin * np.sqrt(cov_matrix[i, i]),
                )

                if true_values is not None:
                    ax.axvline(true_values[j], color="red", linestyle="-", lw=1)
                    ax.axhline(true_values[i], color="red", linestyle="-", lw=1)

            if j > i:
                ax.axis("off")

            if i < num_params - 1:
                ax.set_xticklabels([])
            else:
                if labels is not None:
                    ax.set_xlabel(labels[j])
            ax.yaxis.set_major_locator(plt.NullLocator())

            if j > 0:
                ax.set_yticklabels([])
            else:
                if labels is not None:
                    ax.set_ylabel(labels[i])
            ax.xaxis.set_major_locator(plt.NullLocator())

    plt.show()

In this figure the green contours show the covariance matrix computed from the samples, the cyan points are the samples themselves, and the red lines are ground truth.

In [None]:
fig = corner_plot(chain, true_values=allparams.numpy())
plt.show()