# Getting Started with AstroPhot

In this notebook you will walk through the very basics of AstroPhot functionality. Here you will learn how to make models; how to set them up for fitting; and how to view the results. These core elements will come up every time you use AstroPhot, though in future notebooks you will learn how to take advantage of the advanced features in AstroPhot.

In [None]:
%load_ext autoreload
%autoreload 2

import astrophot as ap
import numpy as np
import torch
from astropy.io import fits
from astropy.wcs import WCS
import matplotlib.pyplot as plt

%matplotlib inline

## Your first model

The basic format for making an AstroPhot model is given below. Once a model object is constructed, it can be manipulated and updated in various ways.

In [None]:
model1 = ap.models.Model(
    name="model1",  # every model must have a unique name
    model_type="sersic galaxy model",  # this specifies the kind of model
    center=[50, 50],  # here we set initial values for each parameter
    q=0.6,
    PA=60 * np.pi / 180,
    n=2,
    Re=10,
    logIe=1,
    target=ap.image.TargetImage(
        data=np.zeros((100, 100)), zeropoint=22.5, pixelscale=1.0
    ),  # every model needs a target, more on this later
)
model1.initialize()  # before using the model it is good practice to call initialize so the model can get itself ready

# We can print the model's current state
print(model1)

In [None]:
# AstroPhot has built in methods to plot relevant information. We didn't specify the region on the sky for
# this model to focus on, so we just made a 100x100 window. Unless you are very lucky this won't
# line up with what you're trying to fit, so next we'll see how to give the model a target.
fig, ax = plt.subplots(figsize=(8, 7))
ap.plots.model_image(fig, ax, model1)
plt.show()

## Giving the model a Target

Typically, the main goal when constructing an AstroPhot model is to fit to an image. We need to give the model access to the image and some information about it to get started.

In [None]:
# first let's download an image to play with
hdu = fits.open(
    "https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r"
)
target_data = np.array(hdu[0].data, dtype=np.float64)
plt.imshow(
    target_data,
    origin="lower",
    cmap="gray_r",
    vmin=np.percentile(target_data, 1),
    vmax=np.percentile(target_data, 99),
)
plt.colorbar()
plt.title("Target Image")


# Create a target object with specified pixelscale and zeropoint
target = ap.image.TargetImage(
    data=target_data,
    pixelscale=0.262,  # Every target image needs to know it's pixelscale in arcsec/pixel
    zeropoint=22.5,  # optionally, you can give a zeropoint to tell AstroPhot what the pixel flux units are
    variance="auto",  # Automatic variance estimate for testing and demo purposes, in real analysis use weight maps, counts, gain, etc to compute variance!
)
i, j = target.pixel_center_meshgrid()
print(torch.all(torch.tensor(target_data) == target_data[i.int(), j.int()]))

# The default AstroPhot target plotting method uses log scaling in bright areas and histogram scaling in faint areas
fig3, ax3 = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig3, ax3, target)
plt.show()

In [None]:
# This model now has a target that it will attempt to match
model2 = ap.models.Model(
    name="model with target",
    model_type="sersic galaxy model",  # feel free to swap out sersic with other profile types
    target=target,  # now the model knows what its trying to match
)

# Instead of giving initial values for all the parameters, it is possible to simply call "initialize" and AstroPhot
# will try to guess initial values for every parameter assuming the galaxy is roughly centered. It is also possible
# to set just a few parameters and let AstroPhot try to figure out the rest. For example you could give it an initial
# Guess for the center and it will work from there.
model2.initialize()

# Plotting the initial parameters and residuals, we see it gets the rough shape of the galaxy right, but still has some fitting to do
fig4, ax4 = plt.subplots(1, 2, figsize=(16, 6))
ap.plots.model_image(fig4, ax4[0], model2)
ap.plots.residual_image(fig4, ax4[1], model2)
plt.show()

In [None]:
# Now that the model has been set up with a target and initialized with parameter values, it is time to fit the image
result = ap.fit.LM(model2, verbose=1).fit()

# See that we use ap.fit.LM, this is the Levenberg-Marquardt Chi^2 minimization method, it is the recommended technique
# for most least-squares problems. However, there are situations in which different optimizers may be more desirable
# so the ap.fit package includes a few options to pick from. The various fitting methods will be described in a
# different tutorial.
print("Fit message:", result.message)  # the fitter will return a message about its convergence

In [None]:
print(model2)
# we now plot the fitted model and the image residuals
fig5, ax5 = plt.subplots(1, 2, figsize=(16, 6))
ap.plots.model_image(fig5, ax5[0], model2)
ap.plots.residual_image(fig5, ax5[1], model2, normalize_residuals=True)
plt.show()

In [None]:
# Plot surface brightness profile

# we now plot the model profile and a data profile. The model profile is determined from the model parameters
# the data profile is determined by taking the median of pixel values at a given radius. Notice that the model
# profile is slightly higher than the data profile? This is because there are other objects in the image which
# are not being modelled, the data profile uses a median so they are ignored, but for the model we fit all pixels.
fig10, ax10 = plt.subplots(figsize=(8, 8))
ap.plots.radial_light_profile(fig10, ax10, model2)
ap.plots.radial_median_profile(fig10, ax10, model2)
plt.show()

## Update uncertainty estimates

After running a fit, the `ap.fit.LM` optimizer can update the uncertainty for each parameter. In fact it can return the full covariance matrix if needed. For a demo of what can be done with the covariance matrix see the `FittingMethods` tutorial. One important note is that the variance image needs to be correct for the uncertainties to be meaningful!

In [None]:
result.update_uncertainty()
print(model2)

Note that these uncertainties are pure statistical uncertainties that come from evaluating the structure of the $\chi^2$ minimum. Systematic uncertainties are not included and these often significantly outweigh the standard errors. As can be seen in the residual plot above, there is certainly plenty of unmodelled structure there. Use caution when interpreting the errors from these fits.

In [None]:
# Plot the uncertainty matrix

# While the scale of the uncertainty may not be meaningful if the image variance is not accurate, we
# can still see how the covariance of the parameters plays out in a given fit.
fig, ax = ap.plots.covariance_matrix(
    result.covariance_matrix.detach().cpu().numpy(),
    model2.build_params_array().detach().cpu().numpy(),
    model2.build_params_array_names(),
)
plt.show()

## Giving the model a specific target window

Sometimes an object isn't nicely centered in the image, and may not even be the dominant object in the image. It is therefore nice to be able to specify what part of the image we should analyze.

In [None]:
# note, we don't provide a name here. A unique name will automatically be generated using the model type
model3 = ap.models.Model(
    model_type="sersic galaxy model",
    target=target,
    window=[555, 665, 480, 595],  # this is a region in pixel coordinates (imin,imax,jmin,jmax)
)

print(f"automatically generated name: '{model3.name}'")

# We can plot the "model window" to show us what part of the image will be analyzed by that model
fig6, ax6 = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig6, ax6, model3.target)
ap.plots.model_window(fig6, ax6, model3)
plt.show()

In [None]:
model3.initialize()
result = ap.fit.LM(model3, verbose=1).fit()

In [None]:
# Note that when only a window is fit, the default plotting methods will only show that window
print(model3)
fig7, ax7 = plt.subplots(1, 2, figsize=(16, 6))
ap.plots.model_image(fig7, ax7[0], model3)
ap.plots.residual_image(fig7, ax7[1], model3, normalize_residuals=True)
plt.show()

## Setting parameter constraints

A common feature of fitting parameters is that they have some constraint on their behaviour and cannot be sampled at any value from (-inf, inf). AstroPhot circumvents this by remapping any constrained parameter to a space where it can take any real value, at least for the sake of fitting. For most parameters these constraints are applied by default; for example the axis ratio q is required to be in the range (0,1). Other parameters, such as the position angle (PA) are cyclic, they can be in the range (0,pi) but also can wrap around. It is possible to manually set these constraints while constructing a model.

In general adding constraints makes fitting more difficult. There is a chance that the fitting process runs up against a constraint boundary and gets stuck. However, sometimes adding constraints is necessary and so the capability is included.

In [None]:
# here we make a sersic model that can only have q and n in a narrow range
# Also, we give PA and initial value and lock that so it does not change during fitting
constrained_param_model = ap.models.Model(
    name="constrained parameters",
    model_type="sersic galaxy model",
    q={"valid": (0.4, 0.6)},
    n={"valid": (2, 3)},
    PA={"value": 60 * np.pi / 180},
    target=target,
)

Aside from constraints on an individual parameter, it is sometimes desirable to have different models share parameter values. For example you may wish to combine multiple simple models into a more complex model (more on that in a different tutorial), and you may wish for them all to have the same center. This can be accomplished with "equality constraints" as shown below.

In [None]:
# model 1 is a sersic model
model_1 = ap.models.Model(
    model_type="sersic galaxy model", center=[50, 50], PA=np.pi / 4, target=target
)
# model 2 is an exponential model
model_2 = ap.models.Model(model_type="exponential galaxy model", target=target)

# Here we add the constraint for "PA" to be the same for each model.
# In doing so we provide the model and parameter name which should
# be connected.
model_2.PA = model_1.PA

# Here we can see how the two models now both can modify this parameter
print(
    "initial values: model_1 PA",
    model_1.PA.value.item(),
    "model_2 PA",
    model_2.PA.value.item(),
)
# Now we modify the PA for model_1
model_1.PA.value = np.pi / 3
print(
    "change model_1: model_1 PA",
    model_1.PA.value.item(),
    "model_2 PA",
    model_2.PA.value.item(),
)

## Basic things to do with a model

Now that we know how to create a model and fit it to an image, lets get to know the model a bit better.

In [None]:
# Save the model state to a file

model2.save_state("current_spot.hdf5", appendable=True)  # save as it is
model2.q = 0.1  # do some updates to the model
model2.PA = 0.1
model2.n = 0.9
model2.Re = 0.1
model2.append_state("current_spot.hdf5")  # save the updated model state as often as you like

In [None]:
# load a model state from a file

model2.load_state("current_spot.hdf5", index=0)  # load the first state from the file
print(model2)  # see that the values are back to where they started

In [None]:
# Save the model image to a file

model_image_sample = model2()
model_image_sample.save("model2.fits")

saved_image_hdu = fits.open("model2.fits")
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(
    np.log10(saved_image_hdu[0].data),
    origin="lower",
    cmap="viridis",
)
plt.show()

In [None]:
# Plot model image with discrete levels

# this is very useful for visualizing subtle features and for eyeballing the brightness at a given location.
# just add the "cmap_levels" keyword to the model_image call and tell it how many levels you want
fig11, ax11 = plt.subplots(figsize=(8, 8))
ap.plots.model_image(fig11, ax11, model2, cmap_levels=15)
plt.show()

In [None]:
# Save and load a target image

target.save("target.fits")

# Note that it is often also possible to load from regular FITS files
new_target = ap.image.TargetImage(filename="target.fits")

fig, ax = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig, ax, new_target)
plt.show()

In [None]:
# Access the model image pixels directly

fig2, ax2 = plt.subplots(figsize=(8, 8))

pixels = model2().data.detach().cpu().numpy()

im = plt.imshow(
    np.log10(pixels),  # take log10 for better dynamic range
    origin="lower",
    cmap=ap.plots.visuals.cmap_grad,  # gradient colourmap default for AstroPhot
)
plt.colorbar(im)
plt.show()

## Load target with WCS information

In [None]:
# first let's download an image to play with
filename = "https://www.legacysurvey.org/viewer/fits-cutout?ra=36.3684&dec=-25.6389&size=700&layer=ls-dr9&pixscale=0.262&bands=r"
hdu = fits.open(filename)
target_data = np.array(hdu[0].data, dtype=np.float64)

wcs = WCS(hdu[0].header)

# Create a target object with WCS which will specify the pixelscale and origin for us!
target = ap.image.TargetImage(
    data=target_data,
    zeropoint=22.5,
    wcs=wcs,
)

fig3, ax3 = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig3, ax3, target)
plt.show()

## Even better, just load directly from a FITS file

AstroPhot recognizes standard FITS keywords to extract a target image. Note that this wont work for all FITS files, just ones that define the following keywords: `CTYPE1`, `CTYPE2`, `CRVAL1`, `CRVAL2`, `CRPIX1`, `CRPIX2`, `CD1_1`, `CD1_2`, `CD2_1`, `CD2_2`, and `MAGZP` with the usual meanings. AstroPhot can also handle SIP, see the SIP tutorial for details there.

Further keywords specific to AstroPhot that it uses for some advanced features like multi-band fitting are: `CRTAN1`, `CRTAN2` used for aligning images, and `IDNTY` used for identifying when two images are actually cutouts of the same image. And AstroPhot also will store the `PSF`, `WEIGHT`, and `MASK` in extra extensions of the FITS file when it makes one.

In [None]:
target = ap.image.TargetImage(filename=filename)

fig3, ax3 = plt.subplots(figsize=(8, 8))
ap.plots.target_image(fig3, ax3, target)
plt.show()

In [None]:
# List all the available model names

# AstroPhot keeps track of all the subclasses of the AstroPhot Model object, this list will
# include all models even ones added by the user
print(ap.models.Model.List_Models(usable=True, types=True))
print("---------------------------")
# It is also possible to get all sub models of a specific Type
print("only galaxy models: ", ap.models.GalaxyModel.List_Models(types=True))

## Using GPU acceleration

This one is easy! If you have a cuda enabled GPU available, AstroPhot will just automatically detect it and use that device. 

In [None]:
# check if AstroPhot has detected your GPU
print(ap.AP_config.ap_device)  # most likely this will say "cpu" unless you already have a cuda GPU,
# in which case it should say "cuda:0"

In [None]:
# If you have a GPU but want to use the cpu for some reason, just set:
ap.AP_config.ap_device = "cpu"
# BEFORE creating anything else (models, images, etc.)

## Boost GPU acceleration with single precision float32

If you are using a GPU you can get significant performance increases in both memory and speed by switching from double precision (the AstroPhot default) to single precision floating point numbers. The trade off is reduced precision, this can cause some unexpected behaviors. For example an optimizer may keep iterating forever if it is trying to optimize down to a precision below what the float32 will track. Typically, numbers with float32 are good down to 6 places and AstroPhot by default only attempts to minimize the Chi^2 to 3 places. However, to ensure the fit is secure to 3 places it often checks what is happenening down at 4 or 5 places. Hence, issues can arise. For the most part you can go ahead with float32 and if you run into a weird bug, try on float64 before looking further.

In [None]:
# Again do this BEFORE creating anything else
ap.AP_config.ap_dtype = torch.float32

# Now new AstroPhot objects will be made with single bit precision
T1 = ap.image.TargetImage(data=np.zeros((100, 100)), pixelscale=1.0)
T1.to()
print("now a single:", T1.data.dtype)

# Here we switch back to double precision
ap.AP_config.ap_dtype = torch.float64
T2 = ap.image.TargetImage(data=np.zeros((100, 100)), pixelscale=1.0)
T2.to()
print("back to double:", T2.data.dtype)
print("old image is still single!:", T1.data.dtype)

See how the window created as a float32 stays that way? That's really bad to have lying around! Make sure to change the data type before creating anything! 

## Tracking output

The AstroPhot optimizers, and occasionally the other AstroPhot objects, will provide status updates about themselves which can be very useful for debugging problems or just keeping tabs on progress. There are a number of use cases for AstroPhot, each having different desired output behaviors. To accommodate all users, AstroPhot implements a general logging system. The object `ap.AP_config.ap_logger` is a logging object which by default writes to AstroPhot.log in the local directory. As the user, you can set that logger to be any logging object you like for arbitrary complexity. Most users will, however, simply want to control the filename, or have it output to screen instead of a file. Below you can see examples of how to do that.

In [None]:
# note that the log file will be where these tutorial notebooks are in your filesystem

# Here we change the settings so AstroPhot only prints to a log file
ap.AP_config.set_logging_output(stdout=False, filename="AstroPhot.log")
ap.AP_config.ap_logger.info("message 1: this should only appear in the AstroPhot log file")

# Here we change the settings so AstroPhot only prints to console
ap.AP_config.set_logging_output(stdout=True, filename=None)
ap.AP_config.ap_logger.info("message 2: this should only print to the console")

# Here we change the settings so AstroPhot prints to both, which is the default
ap.AP_config.set_logging_output(stdout=True, filename="AstroPhot.log")
ap.AP_config.ap_logger.info("message 3: this should appear in both the console and the log file")

You can also change the logging level and/or formatter for the stdout and filename options (see `help(ap.AP_config.set_logging_output)` for details). However, at that point you may want to simply make your own logger object and assign it to the `ap.AP_config.ap_logger` variable.