# Fitting Methods

Here we will explore the various fitting methods in AutoProf. You have already encountered some of the methods, but here we will take a more systematic approach and discuss their strengths/weaknesses. Each method will be applied to the same problem with the same initial conditions so you can see how they operate.

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde as kde

%matplotlib inline
import autoprof as ap

In [None]:
# Setup a fitting problem. You can ignore this cell to start, it just makes some test data to fit

def true_params():

    sky_param = np.array([1.5])
    sersic_params = np.array([
        [ 68.44035491,  65.58516735,   0.54945988, 127.19794926*np.pi/180,   2.14513004,   22.05219055,   2.45583024],
        [ 54.00353786,  41.54430634,   0.40203928,  82.03862521*np.pi/180,   2.88613347,   12.095631,     2.76711163],
        [ 43.13601431,  58.3422508,    0.71894728, 167.07973506*np.pi/180,   3.964371,     5.3767236,     2.41520244],
    ])

    return sersic_params, sky_param

def init_params():

    sky_param = np.array([1.4])
    sersic_params = np.array([
        [ 67.,  66.,   0.6, 130.*np.pi/180,   1.5,   25.,   2.],
        [ 55.,  40.,   0.5,  80.*np.pi/180,   2.,   10.,     3.],
        [ 41.,  60.,    0.8, 170.*np.pi/180,   3.,     4.,     2.],
    ])

    return sersic_params, sky_param

def initialize_model(target, use_true_params = True):

    # Pick parameters to start the model with
    if use_true_params:
        sersic_params, sky_param = true_params()
    else:
        sersic_params, sky_param = init_params()

    # List of models, starting with the sky
    model_list = [ap.models.AutoProf_Model(
        name = "sky",
        model_type = "flat sky model",
        target = target,
        parameters = {"sky": sky_param[0]},
    )]
    # Add models to the list
    for i, params in enumerate(sersic_params):
        model_list.append([
            ap.models.AutoProf_Model(
                name = f"sersic {i}",
                model_type = "sersic galaxy model",
                target = target,
                parameters = {
                    "center": [params[0],params[1]],
                    "q": params[2],
                    "PA": params[3],
                    "n": params[4],
                    "Re": params[5],
                    "Ie": params[6],
                },
                #psf_mode = "full", # uncomment to try everything with PSF blurring (takes longer)
            )
        ])

    MODEL = ap.models.Group_Model(
        name = "group",
        model_list = model_list,
        target = target,
    )
    # Make sure every model is ready to go
    MODEL.initialize()
    
    return MODEL

def generate_target():

    N = 100
    pixelscale = 1.
    rng = np.random.default_rng(42)

    # PSF has sigma of 2x pixelscale
    PSF = ap.utils.initialize.gaussian_psf(2, 21, pixelscale)
    PSF /= np.sum(PSF)

    target = ap.image.Target_Image(
        data = np.zeros((N,N)),
        pixelscale = pixelscale,
        psf = PSF,
    )

    MODEL = initialize_model(target, True)
    
    # Sample the model with the true values to make a mock image
    img = MODEL().data.detach().cpu().numpy()
    # Add poisson noise
    target.data = torch.Tensor(img + rng.normal(scale = np.sqrt(img)/2))  
    target.variance = torch.Tensor(img/4)
    
    fig, ax = plt.subplots(figsize = (8,8))
    ap.plots.target_image(fig, ax, target)
    ax.axis("off")
    plt.show()

    return target


def corner_plot(chain, labels=None, bins=50, true_values=None, plot_density=True, plot_contours=True, figsize=(10, 10)):
    ndim = chain.shape[1]
    
    fig, axes = plt.subplots(ndim, ndim, figsize=figsize)
    
    for i in range(ndim):
        for j in range(ndim):
            ax = axes[i, j]

            if i == j:
                # Plot the histogram of parameter i
                ax.hist(chain[:, i], bins=bins, histtype="step", density=True, color="k", lw=1)
                
                if plot_density:
                    # Plot the kernel density estimate
                    kde_x = np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 100)
                    kde_y = kde(chain[:, i])(kde_x)
                    ax.plot(kde_x, kde_y, color="green", lw=1)
                    
                if true_values is not None:
                    ax.axvline(true_values[i], color='red', linestyle='--', lw=1)

            elif i > j:
                # Plot the 2D histogram of parameters i and j
                ax.hist2d(chain[:, j], chain[:, i], bins=bins, cmap="Greys")

                if plot_contours:
                    # Plot the kernel density estimate contours
                    kde_ij = kde([chain[:, j], chain[:, i]])
                    x, y = np.mgrid[ax.get_xlim()[0]:ax.get_xlim()[1]:100j, ax.get_ylim()[0]:ax.get_ylim()[1]:100j]
                    positions = np.vstack([x.ravel(), y.ravel()])
                    kde_pos = np.reshape(kde_ij(positions).T, x.shape)
                    ax.contour(x, y, kde_pos, colors="green", linewidths=1, levels=5)
                    
                if true_values is not None:
                    ax.axvline(true_values[j], color='red', linestyle='--', lw=1)
                    ax.axhline(true_values[i], color='red', linestyle='--', lw=1)

            else:
                ax.axis("off")

            if j == 0 and labels is not None:
                ax.set_ylabel(labels[i])
            else:
                ax.yaxis.set_major_locator(plt.NullLocator())


            if i == ndim - 1 and labels is not None:
                ax.set_xlabel(labels[j])
            else:
                ax.xaxis.set_major_locator(plt.NullLocator())
    
    plt.show()

target = generate_target()

## Levenberg-Marquardt

This fitter is identitied as `ap.fit.LM` and it employs a variant of the second order Newton's method to converge very quickly to the local minimum. This is the generally accepted best algorithm for most use cases in Chi^2 minimization. If you don't know what to pick, start with this minimizer. The LM optimizer bridges the gap between first-order gradient descent and second order Newton's method. When far from the minimum, Newton's method is unstable and can give wildly wrong results, however, near the minimum it has "quadratic convergence." This means that once near the minimum it takes only a few iterations to converge to several decimal places. The "L" scale parameter goes from L >> 1 which represents gradient descent to L << 1 which is Newton's Method. 

LM can handle a lot of scenarios and converge to the minimum. Keep in mind, however, that it is seeking a local minimum, so it is best to start off the algorithm as close as possible to the best fit parameters. AutoProf can automatically initialize, as discussed in other notebooks, but even that needs help sometimes (often in the form of a segmentation map).

In [None]:
MODEL = initialize_model(target, False)
fig, axarr = plt.subplots(1,4, figsize = (24,5))
plt.subplots_adjust(wspace= 0.1)
ap.plots.model_image(fig, axarr[0], MODEL)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL)
axarr[1].set_title("Residuals before optimization")

res = ap.fit.LM(MODEL, verbose = 1).fit()

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Iterative Fit (models)

An iterative fitter is identified as ap.fit.Iter and this makes use of the other fitters under certain circumstances. This method is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. The iterative fitter will cycle through the models in a `Group_Model` object and fit them one at a time to the image, using the residuals from the previous cycle. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. It is however more dependent on good initialization than other methods like the Levenberg-Marquardt. Also, it is possible for the Iter method to get stuck under certaint circumstances.

Note that while the Iterative fitter needs a `Group_Model` object to iterate over, it is not necessarily true that the sub models are `Component_Model` objects, they could be `Group_Model` objects as well. In this way it is possible to cycle through and fit "clusters" of objects that are nearby, so long as it doesn't consume too much memory.

In [None]:
MODEL = initialize_model(target, False)
fig, axarr = plt.subplots(1,4, figsize = (24,5))
plt.subplots_adjust(wspace= 0.1)
ap.plots.model_image(fig, axarr[0], MODEL)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL)
axarr[1].set_title("Residuals before optimization")

res = ap.fit.Iter(MODEL, verbose = 1).fit()

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Iterative Fit (parameters)

This is an iterative fitter identified as ap.fit.Iter_LM and method is generally employed for large models where it is not feasible to hold all the relevant data in memory at once. The iterative fitter will cycle through chunks of parameters (you can choose how many parameters at a time) and fit them one at a time to the image, using the residuals from the previous cycle. This can be a very robust way to deal with some fits, especially if the overlap between models is not too strong. 

Note that this is iterating over the parameters, not the models. This allows it to handle parameter covariances even for very large models (if they happen to land in the same chunk). However, for this to work it must evaluate the whole model at each iteration making it somewhat slower in principle than the regular Iter fitter, though it can make up for it by fitting larger chunks at a time which makes the whole optimization faster.

In [None]:
MODEL = initialize_model(target, False)
fig, axarr = plt.subplots(1,4, figsize = (24,5))
plt.subplots_adjust(wspace= 0.1)
ap.plots.model_image(fig, axarr[0], MODEL)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL)
axarr[1].set_title("Residuals before optimization")

res = ap.fit.Iter_LM(MODEL, chunks = 11, verbose = 1).fit()

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL)
axarr[3].set_title("Residuals after optimization")
plt.show()

## Gradient Descent

A gradient descent fitter is identified as ap.fit.Grad and uses standard first order derivative methods as provided by pytorch. These gradient descent methods include Adam, classic plus momentum, and LBFGS to name a few. The first order gradient is faster to evaluate and uses less memory, however it is considerably slower to converge than Levenberg-Marquardt. The gradient descent method with a small learning rate will reliably converge towards a local minimum, it will just do so slowly.

In [None]:
MODEL = initialize_model(target, False)
fig, axarr = plt.subplots(1,4, figsize = (24,5))
plt.subplots_adjust(wspace= 0.1)
ap.plots.model_image(fig, axarr[0], MODEL)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL)
axarr[1].set_title("Residuals before optimization")

res = ap.fit.Grad(MODEL, max_iter = 500, verbose = 1, optim_kwargs = {"lr": 5e-3}).fit()

ap.plots.model_image(fig, axarr[2], MODEL)
axarr[2].set_title("Model after optimization")
ap.plots.residual_image(fig, axarr[3], MODEL)
axarr[3].set_title("Residuals after optimization")
plt.show()

## No U-Turn Sampler (NUTS)

Unlike the above methods, NUTS does not stricktly seek a minimum chi^2, instead it is an MCMC method which seeks to explore the likelihood space and provide a full posterior in the form of random samples. The NUTS method in AutoProf is actually just a wrapper for the pyro implementation (__[link here](https://docs.pyro.ai/en/stable/index.html)__). Most of the functionality can be accessed this way, though for very advanced applications it may be necessary to manually interface with pyro (this is not very challenging as AutoProf is fully differentiable).

The first iteration of NUTS is always very slow since it compiles the forward method on the fly, after that each sample is drawn very quickly. The warmup iterations take longer as the method is exploring the space and determining the ideal step size for fast integration with minimal numerical error. Once the algorithm begins sampling it is able to move quickly (for an MCMC) throught the parameter space. For many models, the NUTS sampler is able to collect nearly completely uncorrelated samples, meaning that even 100 is enough to get a good estimate of the posterior sometimes.

NUTS is far faster than other MCMC implementations such as a standard Metropolis Hastings MCMC. However, it is still a lot slower than the other optimizers (LM) since it is doing more than seeking a single high likelihood point, it is fully exploring the likelihood space.

In [None]:
MODEL = initialize_model(target, False)
fig, axarr = plt.subplots(1,2, figsize = (12,5))
plt.subplots_adjust(wspace= 0.1)
ap.plots.model_image(fig, axarr[0], MODEL)
axarr[0].set_title("Model before optimization")
ap.plots.residual_image(fig, axarr[1], MODEL)
axarr[1].set_title("Residuals before optimization")
plt.show()

# Use LM to start the sampler at a high likelihood location, no burn-in needed!
res1 = ap.fit.LM(MODEL).fit()

# Run the NUTS sampler
res = ap.fit.NUTS(MODEL, warmup = 20, max_iter = 100).fit()

In [None]:
# corner plot of the posterior
param_names = list(MODEL.parameter_order())
i = 0
while i < len(param_names):
    param_names[i] = param_names[i].replace(" ", "")
    if "center" in param_names[i]:
        center_name = param_names.pop(i)
        param_names.insert(i, center_name.replace("center", "y"))
        param_names.insert(i, center_name.replace("center", "x"))
    i += 1
    
ser, sky = true_params()
corner_plot(
    res.chain.detach().cpu().numpy(), 
    labels = param_names, 
    figsize = (20,20), 
    true_values = np.concatenate((sky,ser.ravel()))
)