# Fitting Methods

Here we will explore the various fitting methods in AstroPhot. You have already encountered some of the methods, but here we will take a more systematic approach and discuss their strengths/weaknesses. Each method will be applied to the same problem with the same initial conditions so you can see how they operate.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import torch
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse
from scipy.stats import gaussian_kde as kde
from scipy.stats import norm
from tqdm import tqdm

import astrophot as ap

In [None]:
# Setup a fitting problem. You can ignore this cell to start, it just makes some test data to fit


def true_params():

    # just some random parameters to use for fitting. Feel free to play around with these to see what happens!
    sky_param = np.array([10**1.5])
    sersic_params = np.array(
        [
            [
                58.44035491,
                55.58516735,
                0.54945988,
                37.19794926 * np.pi / 180,
                2.14513004,
                22.05219055,
                10**2.45583024,
            ],
            [
                44.00353786,
                31.54430634,
                0.40203928,
                172.03862521 * np.pi / 180,
                2.88613347,
                12.095631,
                10**2.76711163,
            ],
        ]
    )

    return sersic_params, sky_param


def init_params():

    sky_param = np.array([10**1.4])
    sersic_params = np.array(
        [
            [57.0, 56.0, 0.6, 40.0 * np.pi / 180, 1.5, 25.0, 10**2.0],
            [45.0, 30.0, 0.5, 170.0 * np.pi / 180, 2.0, 10.0, 10**3.0],
        ]
    )

    return sersic_params, sky_param


def initialize_model(target, use_true_params=True):

    # Pick parameters to start the model with
    if use_true_params:
        sersic_params, sky_param = true_params()
    else:
        sersic_params, sky_param = init_params()

    # List of models, starting with the sky
    model_list = [
        ap.Model(
            name="sky",
            model_type="flat sky model",
            target=target,
            I=sky_param[0],
        )
    ]
    # Add models to the list
    for i, params in enumerate(sersic_params):
        model_list.append(
            ap.Model(
                name=f"sersic {i}",
                model_type="sersic galaxy model",
                target=target,
                center=[params[0], params[1]],
                q=params[2],
                PA=params[3],
                n=params[4],
                Re=params[5],
                Ie=params[6],
                # psf_convolve = True, # uncomment to try everything with PSF blurring (takes longer)
            )
        )

    MODEL = ap.Model(
        name="group",
        model_type="group model",
        models=model_list,
        target=target,
    )
    # Make sure every model is ready to go
    MODEL.initialize()

    return MODEL


def generate_target():

    N = 99
    pixelscale = 1.0
    rng = np.random.default_rng(42)

    # PSF has sigma of 2x pixelscale
    PSF = ap.utils.initialize.gaussian_psf(2, 21, pixelscale)
    PSF /= np.sum(PSF)

    target = ap.TargetImage(
        data=np.zeros((N, N)),
        pixelscale=pixelscale,
        psf=PSF,
    )

    MODEL = initialize_model(target, True)

    # Sample the model with the true values to make a mock image
    img = MODEL().data.T.detach().cpu().numpy()
    # Add poisson noise
    target.data = torch.Tensor(img + rng.normal(scale=np.sqrt(img) / 2))
    target.variance = torch.Tensor(img / 4)

    fig, ax = plt.subplots(figsize=(8, 8))
    ap.plots.target_image(fig, ax, target)
    ax.axis("off")
    plt.show()

    return target


def corner_plot(
    chain,
    labels=None,
    bins=None,
    true_values=None,
    plot_density=True,
    plot_contours=True,
    figsize=(10, 10),
):
    ndim = chain.shape[1]

    fig, axes = plt.subplots(ndim, ndim, figsize=figsize)
    plt.subplots_adjust(wspace=0.0, hspace=0.0)
    if bins is None:
        bins = int(np.sqrt(chain.shape[0]))

    for i in range(ndim):
        for j in range(ndim):
            ax = axes[i, j]

            i_range = (np.min(chain[:, i]), np.max(chain[:, i]))
            j_range = (np.min(chain[:, j]), np.max(chain[:, j]))
            if i == j:
                # Plot the histogram of parameter i
                # ax.hist(chain[:, i], bins=bins, histtype="step", range = i_range, density=True, color="k", lw=1)

                if plot_density:
                    # Plot the kernel density estimate
                    kde_x = np.linspace(i_range[0], i_range[1], 100)
                    kde_y = kde(chain[:, i])(kde_x)
                    ax.plot(kde_x, kde_y, color="green", lw=1)

                if true_values is not None:
                    ax.axvline(true_values[i], color="red", linestyle="-", lw=1)
                ax.set_xlim(i_range)

            elif i > j:
                # Plot the 2D histogram of parameters i and j
                # ax.hist2d(chain[:, j], chain[:, i], bins=bins, cmap="Greys")

                if plot_contours:
                    # Plot the kernel density estimate contours
                    kde_ij = kde([chain[:, j], chain[:, i]])
                    x, y = np.mgrid[j_range[0] : j_range[1] : 100j, i_range[0] : i_range[1] : 100j]
                    positions = np.vstack([x.ravel(), y.ravel()])
                    kde_pos = np.reshape(kde_ij(positions).T, x.shape)
                    ax.contour(x, y, kde_pos, colors="green", linewidths=1, levels=3)

                if true_values is not None:
                    ax.axvline(true_values[j], color="red", linestyle="-", lw=1)
                    ax.axhline(true_values[i], color="red", linestyle="-", lw=1)
                ax.set_xlim(j_range)
                ax.set_ylim(i_range)

            else:
                ax.axis("off")

            if j == 0 and labels is not None:
                ax.set_ylabel(labels[i])
            ax.yaxis.set_major_locator(plt.NullLocator())

            if i == ndim - 1 and labels is not None:
                ax.set_xlabel(labels[j])
            ax.xaxis.set_major_locator(plt.NullLocator())

    plt.show()


def corner_plot_covariance(
    cov_matrix, mean, labels=None, figsize=(10, 10), true_values=None, ellipse_colors="g"
):
    num_params = cov_matrix.shape[0]
    fig, axes = plt.subplots(num_params, num_params, figsize=figsize)
    plt.subplots_adjust(wspace=0.0, hspace=0.0)

    for i in range(num_params):
        for j in range(num_params):
            ax = axes[i, j]

            if i == j:
                x = np.linspace(
                    mean[i] - 3 * np.sqrt(cov_matrix[i, i]),
                    mean[i] + 3 * np.sqrt(cov_matrix[i, i]),
                    100,
                )
                y = norm.pdf(x, mean[i], np.sqrt(cov_matrix[i, i]))
                ax.plot(x, y, color="g")
                ax.set_xlim(
                    mean[i] - 3 * np.sqrt(cov_matrix[i, i]), mean[i] + 3 * np.sqrt(cov_matrix[i, i])
                )
                if true_values is not None:
                    ax.axvline(true_values[i], color="red", linestyle="-", lw=1)
            elif j < i:
                cov = cov_matrix[np.ix_([j, i], [j, i])]
                lambda_, v = np.linalg.eig(cov)
                lambda_ = np.sqrt(lambda_)
                angle = np.rad2deg(np.arctan2(v[1, 0], v[0, 0]))
                for k in [1, 2]:
                    ellipse = Ellipse(
                        xy=(mean[j], mean[i]),
                        width=lambda_[0] * k * 2,
                        height=lambda_[1] * k * 2,
                        angle=angle,
                        edgecolor=ellipse_colors,
                        facecolor="none",
                    )
                    ax.add_artist(ellipse)

                # Set axis limits
                margin = 3
                ax.set_xlim(
                    mean[j] - margin * np.sqrt(cov_matrix[j, j]),
                    mean[j] + margin * np.sqrt(cov_matrix[j, j]),
                )
                ax.set_ylim(
                    mean[i] - margin * np.sqrt(cov_matrix[i, i]),
                    mean[i] + margin * np.sqrt(cov_matrix[i, i]),
                )

                if true_values is not None:
                    ax.axvline(true_values[j], color="red", linestyle="-", lw=1)
                    ax.axhline(true_values[i], color="red", linestyle="-", lw=1)

            if j > i:
                ax.axis("off")

            if i < num_params - 1:
                ax.set_xticklabels([])
            else:
                if labels is not None:
                    ax.set_xlabel(labels[j])
            ax.yaxis.set_major_locator(plt.NullLocator())

            if j > 0:
                ax.set_yticklabels([])
            else:
                if labels is not None:
                    ax.set_ylabel(labels[i])
            ax.xaxis.set_major_locator(plt.NullLocator())

    plt.show()


target = generate_target()

## Levenberg-Marquardt

This fitter is identitied as `ap.fit.LM` and it employs a variant of the second order Newton's method to converge very quickly to the local minimum. This is the generally accepted best algorithm for most use cases in $\chi^2$ minimization. If you don't know what to pick, start with this minimizer. The LM optimizer bridges the gap between first-order gradient descent and second order Newton's method. When far from the minimum, Newton's method is unstable and can give wildly wrong results, so LM takes gradient descent steps. However, near the minimum it switches to the Newton's method which has "quadratic convergence" this means that it takes only a few iterations to converge to several decimal places. This can be represented as:

$(H + LI)h = g$

Where H is the Hessian matrix of second derivatives, L is the damping parameter, I is the identity matrix, h is the step we will take in parameter space, and g is the gradient. We solve this linear system for h to get the next update step. The "L" scale parameter goes from L >> 1 which represents gradient descent to L << 1 which is Newton's Method. When L >> 1 the hessian is effectively zero and we get $h = g/L$ which is just gradient descent with $1/L$ as the learning rate. In AstroPhot the damping parameter is treated somewhat differently, but the concept is the same.

LM can handle a lot of scenarios and converge to the minimum. Keep in mind, however, that it is seeking a local minimum, so it is best to start off the algorithm as close as possible to the best fit parameters. AstroPhot can automatically initialize, as discussed in other notebooks, but even that needs help sometimes (often in the form of a segmentation map).

The main drawback of LM is its memory consumption which goes as $\mathcal{O}(PN)$ where P is the number of pixels and N is the number of parameters.

In [None]:
MODEL = initialize_model(target, False)

res_lm = ap.fit.LM(MODEL, verbose=1).fit()
print(res_lm.message)

In [None]:
MODEL_init = initialize_model(target, False)
fig, axarr = plt.subplots(1, 4, figsize=(24, 5))
plt.subplots_adjust(wspace=0.1)
ap.plots.model_image(fig, axarr[0], MODEL_init)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)
axarr[1].set_title("Residuals before optimization")

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)
axarr[3].set_title("Residuals after optimization")
plt.show()

Now that LM has found the $\chi^2$ minimum, we can do a really neat trick. Since LM needs the hessian matrix, we have access to the hessian matrix at the minimum. This is in fact equal to the negative Fisher information matrix. If we take the matrix inverse of this matrix then we get the covariance matrix for a multivariate gaussian approximation of the $\chi^2$ surface near the minimum. With the covariance matrix we can create a corner plot just like we would with an MCMC. We will see later that the MCMC methods (at least the ones which converge) produce very similar results! 

In [None]:
param_names = list(MODEL.build_params_array_names())
set, sky = true_params()
corner_plot_covariance(
    res_lm.covariance_matrix.detach().cpu().numpy(),
    MODEL.build_params_array().detach().cpu().numpy(),
    labels=param_names,
    figsize=(20, 20),
    true_values=np.concatenate((sky, set.ravel())),
)

## Iterative Fit (models)

An iterative fitter is identified as `ap.fit.Iter`, this method is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. The iterative fitter will cycle through the models in a `GroupModel` object and fit them one at a time to the image, using the residuals from the previous cycle. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. It is however more dependent on good initialization than other methods like the Levenberg-Marquardt. Also, it is possible for the Iter method to get stuck in a local minimum under certain circumstances.

Note that while the Iterative fitter needs a `GroupModel` object to iterate over, it is not necessarily true that the sub models are `ComponentModel` objects, they could be `GroupModel` objects as well. In this way it is possible to cycle through and fit "clusters" of objects that are nearby, so long as it doesn't consume too much memory.

By only fitting one model at a time it is possible to get caught in a local minimum, or to get out of a local minimum that a different fitter was stuck in. For this reason it can be good to mix-and-match the iterative optimizers so they can help each other get unstuck if a fit is very challenging. 

In [None]:
MODEL = initialize_model(target, False)

res_iter = ap.fit.Iter(MODEL, verbose=1).fit()

In [None]:
MODEL_init = initialize_model(target, False)
fig, axarr = plt.subplots(1, 4, figsize=(24, 5))
plt.subplots_adjust(wspace=0.1)
ap.plots.model_image(fig, axarr[0], MODEL_init)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)
axarr[1].set_title("Residuals before optimization")

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Iterative Fit (parameters)

This is an iterative fitter identified as `ap.fit.IterParam` and is generally employed for complicated models where it is not feasible to hold all the relevant data in memory at once. This iterative fitter will cycle through chunks of parameters and fit them one at a time to the image. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. This is very similar to the other iterative fitter, however it is necessary for certain fitting circumstances when the problem can't be broken down into individual component models. This occurs, for example, when the models have many shared (constrained) parameters and there is no obvious way to break down sub-groups of models.

Note that this is iterating over the parameters, not the models. This allows it to handle parameter covariances even for very large models (if they happen to land in the same chunk). However, for this to work it must evaluate the whole model at each iteration making it somewhat slower than the regular `Iter` fitter, though it can make up for it by fitting larger chunks at a time which makes the whole optimization faster.

By only fitting a subset of parameters at a time it is possible to get caught in a local minimum, or to get out of a local minimum that a different fitter was stuck in. For this reason it can be good to mix-and-match the iterative optimizers so they can help each other get unstuck. Since this iterative fitter chooses parameters randomly, it can sometimes get itself unstuck if it gets a lucky combination of parameters. Generally giving it more parameters to work with at a time is better.

In [None]:
# MODEL = initialize_model(target, False)
# fig, axarr = plt.subplots(1, 4, figsize=(24, 5))
# plt.subplots_adjust(wspace=0.1)
# ap.plots.model_image(fig, axarr[0], MODEL)
# axarr[0].set_title("Model before optimization")
# ap.plots.residual_image(fig, axarr[1], MODEL, normalize_residuals=True)
# axarr[1].set_title("Residuals before optimization")

# res_iterlm = ap.fit.Iter_LM(MODEL, chunks=11, verbose=1).fit()

# ap.plots.model_image(fig, axarr[2], MODEL)
# axarr[2].set_title("Model after optimization")
# ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)
# axarr[3].set_title("Residuals after optimization")
# plt.show()

# Scipy Minimize

Any AstroPhot model becomes a function `model(x)` where `x` is a 1D tensor of
all the current dynamic parameters. This functional format is common for
external packages to use. AstroPhot includes a wrapper to access the
`scipy.optimize.minimize` minimizer list. AstroPhot will ensure the minimizers
respect the valid ranges set for each parameter.

Typically, the AstroPhot LM optimizer is faster and more accurate than the Scipy
ones. The exact reason is unclear, but the Scipy minimizers are intended for
very general use, while the LM optimizer is specifically optimized for gaussian
log likelihoods.

In the case below, the minimizer thinks it has terminated successfully, although
in fact it is quite far from the minimum. Consider this a lesson in trusting the
"success" message from an optimizer. It turns out to be very challenging to
identify if an optimizer is at a minimum, let alone the global minimum.

In [None]:
MODEL = initialize_model(target, False)

res_scipy = ap.fit.ScipyFit(MODEL, method="SLSQP", verbose=1).fit()
print(res_scipy.scipy_res)

In [None]:
MODEL_init = initialize_model(target, False)
fig, axarr = plt.subplots(1, 4, figsize=(24, 5))
plt.subplots_adjust(wspace=0.1)
ap.plots.model_image(fig, axarr[0], MODEL_init)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)
axarr[1].set_title("Residuals before optimization")

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Gradient Descent

A gradient descent fitter is identified as `ap.fit.Grad` and uses standard first order derivative methods as provided by PyTorch. These gradient descent methods include Adam, SGD, and LBFGS to name a few. The first order gradient is faster to evaluate and uses less memory, however it is considerably slower to converge than Levenberg-Marquardt. The gradient descent method with a small learning rate will reliably converge towards a local minimum, it will just do so slowly. 

In the example below we let it run for 1000 steps and even still it has not converged. In general you should not use gradient descent to optimize a model. However, in a challenging fitting scenario the small step size of gradient descent can actually be an advantage as it will not take any unedpectedly large steps which could mix up some models, or hop over the $\chi^2$ minimum into impossible parameter space. Just make sure to finish with LM after using Grad so that it fully converges to a reliable minimum.

In [None]:
MODEL = initialize_model(target, False)

res_grad = ap.fit.Grad(MODEL, verbose=1, max_iter=1000, optim_kwargs={"lr": 5e-2}).fit()

In [None]:
MODEL_init = initialize_model(target, False)
fig, axarr = plt.subplots(1, 4, figsize=(24, 5))
plt.subplots_adjust(wspace=0.1)
ap.plots.model_image(fig, axarr[0], MODEL_init)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL_init, normalize_residuals=True)
axarr[1].set_title("Residuals before optimization")

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL, normalize_residuals=True)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Metropolis Adjusted Langevin Algorithm (MALA)

This is one of the simplest gradient based samplers, and is very powerful. The standard Metropolis Hastings algorithm will use a gaussian proposal distribution then use the Metropolis Hastings accept/reject stage. MALA uses gradient information to determine a better proposal distribution locally (while maintaining detailed balance) and then uses the Metropolis Hastings accept/reject stage. We have not integrated this algorithm directly into AstroPhot, instead we write it all out below to show the simplicity and power of the method. Expand the cell below if you are interested!

In [None]:
def mala_sampler(initial_state, log_prob, log_prob_grad, num_samples, epsilon, mass_matrix):
    """Metropolis Adjusted Langevin Algorithm (MALA) sampler with batch dimension.

    Args:
    - initial_state (numpy array): Initial states of the chains, shape (num_chains, dim).
    - log_prob (function): Function to compute the log probabilities of the current states.
    - log_prob_grad (function): Function to compute the gradients of the log probabilities.
    - num_samples (int): Number of samples to generate.
    - epsilon (float): Step size for the Langevin dynamics.
    - mass_matrix (numpy array): Mass matrix, shape (dim, dim), used to scale the dynamics.


    Returns:
    - samples (numpy array): Array of sampled values, shape (num_samples, num_chains, dim).
    """
    num_chains, dim = initial_state.shape
    samples = np.zeros((num_samples, num_chains, dim))
    x_current = np.array(initial_state)
    current_log_prob = log_prob(x_current)
    inv_mass_matrix = np.linalg.inv(mass_matrix)
    chol_inv_mass_matrix = np.linalg.cholesky(inv_mass_matrix)

    pbar = tqdm(range(num_samples))
    acceptance_rate = np.zeros([0])
    for i in pbar:
        gradients = log_prob_grad(x_current)
        noise = np.dot(np.random.randn(num_chains, dim), chol_inv_mass_matrix.T)
        proposal = (
            x_current + 0.5 * epsilon**2 * np.dot(gradients, inv_mass_matrix) + epsilon * noise
        )

        # proposal = x_current + 0.5 * epsilon**2 * gradients + epsilon * np.random.randn(num_chains, *dim)
        proposal_log_prob = log_prob(proposal)
        # Metropolis-Hastings acceptance criterion, computed for each chain
        acceptance_log_prob = proposal_log_prob - current_log_prob
        accept = np.log(np.random.rand(num_chains)) < acceptance_log_prob
        acceptance_rate = np.concatenate([acceptance_rate, accept])
        pbar.set_description(f"Acceptance rate: {acceptance_rate.mean():.2f}")

        # Update states where accepted
        x_current[accept] = proposal[accept]
        current_log_prob[accept] = proposal_log_prob[accept]

        samples[i] = x_current

    return samples

In [None]:
MODEL = initialize_model(target, False)

# Use LM to start the sampler at a high likelihood location, no burn-in needed!
res1 = ap.fit.LM(MODEL).fit()


def density(x):
    x = torch.as_tensor(x, dtype=ap.config.DTYPE)
    return torch.vmap(MODEL.gaussian_log_likelihood)(x).detach().cpu().numpy()


sim_grad = torch.vmap(torch.func.grad(MODEL.gaussian_log_likelihood))


def density_grad(x):
    x = torch.as_tensor(x, dtype=ap.config.DTYPE)
    return sim_grad(x).numpy()


x0 = MODEL.build_params_array().detach().cpu().numpy()
x0 = x0 + np.random.normal(scale=0.001, size=(8, x0.shape[0]))
chain_mala = mala_sampler(
    initial_state=x0,
    log_prob=density,
    log_prob_grad=density_grad,
    num_samples=300,
    epsilon=2e-1,
    mass_matrix=torch.linalg.inv(res1.covariance_matrix).detach().cpu().numpy(),
)
chain_mala = chain_mala.reshape(-1, chain_mala.shape[-1])

In [None]:
# # corner plot of the posterior
param_names = list(MODEL.build_params_array_names())

set, sky = true_params()
corner_plot(
    chain_mala,
    labels=param_names,
    figsize=(20, 20),
    true_values=np.concatenate((sky, set.ravel())),
)

## Hamiltonian Monte-Carlo (HMC)

The `ap.fit.HMC` takes a fixed number of steps at a fixed step size following Hamiltonian dynamics. This is in contrast to NUTS which attempts to optimally choose these parameters. The simplest way to think of HMC is as performing a number of MALA steps all in one go, so if `leapfrog_steps = 10` then HMC is very similar to running MALA then taking every tenth step and adding it to the chain. HMC results will still have autocorrelation which will depend on the problem and choice of step parameters.

In [None]:
MODEL = initialize_model(target, False)

# Use LM to start the sampler at a high likelihood location, no burn-in needed!
res1 = ap.fit.LM(MODEL).fit()

# Run the HMC sampler
res_hmc = ap.fit.HMC(
    MODEL,
    warmup=1,
    max_iter=150,
    epsilon=1e-1,
    leapfrog_steps=10,
    inv_mass=res1.covariance_matrix,
).fit()

In [None]:
# corner plot of the posterior
param_names = list(MODEL.build_params_array_names())

set, sky = true_params()
corner_plot(
    res_hmc.chain.detach().cpu().numpy(),
    labels=param_names,
    figsize=(20, 20),
    true_values=np.concatenate((sky, set.ravel())),
)

## Metropolis Hastings

This is the more standard MCMC algorithm using the Metropolis Hastngs accept step identified with `ap.fit.MHMCMC`. Under the hood, this is just a wrapper for the excellent `emcee` package, if you want to take advantage of more `emcee` features you can very easily use `ap.fit.MHMCMC` as a starting point. However, one should keep in mind that for large models it can take exceedingly long to actually converge to the posterior. Instead of waiting that long, we demonstrate the functionality with 100 steps (and 30 chains), but suggest using MALA for any real world problem. Still, if there is something NUTS can't handle (a function that isn't differentiable) then MHMCMC can save the day (even if it takes all day to do it).

In [None]:
MODEL = initialize_model(target, False)

# Use LM to start the sampler at a high likelihood location, no burn-in needed!
print("running LM fit")
res1 = ap.fit.LM(MODEL).fit()

# Run the HMC sampler
print("running MHMCMC sampling")
res_mh = ap.fit.MHMCMC(MODEL, verbose=1, max_iter=100).fit()

In [None]:
# corner plot of the posterior
# note that, even 3000 samples is not enough to overcome the autocorrelation so the posterior has not converged.
param_names = list(MODEL.build_params_array_names())

set, sky = true_params()
corner_plot(
    res_mh.chain[::10],  # thin by a factor 10 so the plot works in reasonable time
    labels=param_names,
    figsize=(20, 20),
    true_values=np.concatenate((sky, set.ravel())),
)