# Constrained Models

AstroPhot models can have very complex constraints which allow for specialized and powerful fitting problems to be solved. Here you will learn how to take advantage of these capabilities.

In [None]:
import astrophot as ap
import numpy as np
import torch
from astropy.io import fits
import matplotlib.pyplot as plt
from time import time

%matplotlib inline

## Valid Range

The simplest form of constraint on a parameter is to restrict its range to within some limit. This is done at creation of the variable and you simply indicate the endpoints (non-inclusive) of the limits.

In [None]:
target = ap.TargetImage(data=np.zeros((100, 100)), crpix=[49.5, 49.5], pixelscale=1)
gal1 = ap.Model(
    name="galaxy1",
    model_type="sersic galaxy model",
    # here we set the limits, note it can be different for each value of center.
    # The valid range is a tuple with two elements, the lower limit and the
    # upper limit, either can be None
    center={
        "value": [0, 0],
        "valid": ([-10, -20], [10, 20]),
    },
    # One sided limits can be used for example to ensure a value is positive
    Re={"valid": (0, None)},
    target=target,
)

# Now if we try to set a value outside the range we get a warning
gal1.center.value = [25, 25]
gal1.center.value = [0, 0]  # set back to good value

Internal functions in AstroPhot track these limits and so will not go outside those limits under normal circumstances (contact us if you find it happening!). 

## Equality constraints

Another form of constraint is an equality constraint. You can fix one parameter to track another's value so that they will always be equal.

In [None]:
gal1 = ap.Model(
    name="galaxy1",
    model_type="sersic galaxy model",
    center=[-25, -25],
    PA=0,
    q=0.9,
    n=2,
    Re=5,
    Ie=1.0,
    target=target,
)
gal2 = ap.Model(
    name="galaxy2",
    model_type="sersic galaxy model",
    center=[25, 25],
    PA=0,
    q=0.9,
    Ie=1.0,
    target=target,
)

# here we set the equality constraint, setting the values for gal2 equal to the parameters of gal1
gal2.n = gal1.n
gal2.Re = gal1.Re

# we make a group model to use both star models together
gals = ap.Model(
    name="gals",
    model_type="group model",
    models=[gal1, gal2],
    target=target,
)

fig, ax = plt.subplots()
ap.plots.model_image(fig, ax, gals)
plt.show()

gals.graphviz()

In [None]:
# We can now change a parameter value and both models will change
gal1.n.value = 1

fig, ax = plt.subplots()
ap.plots.model_image(fig, ax, gals)
plt.show()

Now that these two parameters are linked, optimization algorithms in AstroPhot will take this into account and fit the constrained set of parameters.

## Function constraints

In some situations one may be able to impose a constraint on a set of parameters using some function. Some concrete examples of this include:

- A spatially varying PSF can be forced to obey some smoothing function such as a plane or spline
- The SED of a multiband fit may be constrained to follow some pre-determined form
- A light curve model could be used to constrain the brightness in a multi-epoch analysis

The possibilities with this kind of constraint capability are quite extensive. If you do something creative with these functional constraints please let us know!

In [None]:
# Here we will demo a spatially varying PSF where the moffat "n" parameter changes across the image
target = ap.TargetImage(data=np.zeros((100, 100)), crpix=[49.5, 49.5], pixelscale=1)

psf_target = ap.PSFImage(data=np.zeros((55, 55)), pixelscale=1)

# We make parameters and a function to control the moffat n parameter
intercept = ap.Param("intercept", 3)
slope = ap.Param("slope", [1 / 50, -1 / 50])


def constrained_moffat_n(n_param):
    return n_param.intercept.value + torch.sum(n_param.slope.value * n_param.center.value)


# Next we make all the star and PSF objects
allstars = []
allpsfs = []
for x in [-30, 0, 30]:
    for y in [-30, 0, 30]:
        psf = ap.Model(
            name="psf",
            model_type="moffat psf model",
            Rd=2,
            n={"value": constrained_moffat_n},
            target=psf_target,
        )
        if len(allstars) > 0:
            psf.Rd = allstars[0].psf.Rd
        allstars.append(
            ap.Model(
                name=f"star {x} {y}",
                model_type="point model",
                center=[x, y],
                flux=1,
                target=target,
                psf=psf,
            )
        )

        # see we need to link the center as well so that it can be used in the function
        psf.n.link((intercept, slope, allstars[-1].center))


# A group model holds all the stars together
sky = ap.Model(name="sky", model_type="flat sky model", I=1e-5, target=target)
MODEL = ap.Model(
    name="spatial PSF",
    model_type="group model",
    models=[sky] + allstars,
    target=target,
)

fig, ax = plt.subplots()
ap.plots.model_image(fig, ax, MODEL)
plt.show()

See how the PSF parameters vary across the image, this model could now be optimized to fit some data and the parameters of the plane (`intercept` and `slope`) will be optimized alongside everything else to give the best possible optimized parameter values accounting for everything in the image!