# Roman SN Simulation modeling with AstroPhot

Use AstroPhot package to model ightcurve of SN in Roman simulations.

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import iqr

import torch

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.table import Table
from astropy.wcs import WCS

import astrophot as ap
from astrophot.image.window_object import Window

Data.  We'll be using data from Roman simulations of SN in 
"A synthetic Roman Space Telescope High-Latitude Time-Domain Survey: supernovae in the deep field"
Wang et al. 2023, MNRAS, 523, 3, 3874.
https://ui.adsabs.harvard.edu/abs/2023MNRAS.523.3874W

Data page
https://roman.ipac.caltech.edu/sims/SN_Survey_Image_sim.html

We're here trying out 3 images selected by Lauren Aldoroty that have a SN.  You could download them here with the following bash script: 

Note that this scripts work, but gets you much more data than we needed. The three images are 1.3 GB each, covering the 18 detectors for the Roman WFI instrument.  We only need one, but the tarballs are for the full focal plane.  The rotate_Y_truth catalog is 288 MB.  The remaining catalog files are much smaller.

```
DATADIR=data
mkdir -p ${DATADIR}
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_132.tar.gz" --output ${DATADIR}/rotate_update_Y106_132.tar.gz
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_174.tar.gz" --output ${DATADIR}/rotate_update_Y106_174.tar.gz
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_175.tar.gz" --output ${DATADIR}/rotate_update_Y106_175.tar.gz

# SN input catalog / SN
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/WFIRST_AKARI_FIXED_HEAD.FITS --output ${DATADIR}/WFIRST_AKARI_FIXED_HEAD.FITS

# SN input lightcurves
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/WFIRST_AKARI_FIXED_PHOT.FITS.gz --output ${DATADIR}/WFIRST_AKARI_FIXED_PHOT.FITS.gz

# SN truth
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_Y_truth.tar.gz --output ${DATADIR}/rotate_Y_truth.tar.gz

# Image Metadata
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/paper_rotate.fits --output ${DATADIR}/paper_rotate.fits

cd data; (for f in rotate_update_Y106_*.tar.gz; do tar xvzf $f "*_1.fits.gz"; done); gunzip *.fits.gz; cd -
```


In [None]:
sn = {"ra": 71.30192566051916, "dec": -53.60051728973533}
sn_coord = SkyCoord(sn["ra"], sn["dec"], unit=u.deg)

band = "Y106"
detector = 1

data_dir = "data"
# The metadata for the files is stored by row idx that is encoded in the filename
# We'll use that later to look up the informatino for the file.
image_info_row = [132, 174, 175]
image_file_basenames = [f"rotate_update_{band}_{idx}_{detector}.fits" for idx in image_info_row]

image_files = [os.path.join(data_dir, bn) for bn in image_file_basenames]

In [None]:
sn_metadata_basename = "WFIRST_AKARI_FIXED_HEAD.FITS"
sn_metadata_filename = os.path.join(data_dir, sn_metadata_basename)
sn_metadata = Table.read(sn_metadata_filename)

In [None]:
# Write out RA, DEc to prepare for making a ds9 region file.
def write_ds9_region_file(
    coordinate_table, region_filename="ds9.reg", ra_colname="RA", dec_colname="DEC"
):
    region_filename = "sn.reg"
    with open(region_filename, "w") as f:
        f.write("wcs; icrs;\n")
        for r, d in coordinate_table[[ra_colname, dec_colname]]:
            f.write(f"point({r},{d});\n")

In [None]:
overwrite = False
if overwrite:
    write_ds9_region_file(coordinate_table=sn_metadata, region_filename="sn.reg")

If you like you can open up the FITS files and the region file with
```
ds9 data/rotate_update_Y106_132_1.fits data/rotate_update_Y106_174_1.fits data/rotate_update_Y106_175_1.fits -region load all sn.reg -scale mode zscale -scale match -frame lock wcs -zoom to fit
```

The images don't have dates in the header, so we get them from `paper_rotate.fits`

In [None]:
image_metadata_basename = "paper_rotate.fits"
image_metadata_filename = os.path.join(data_dir, image_metadata_basename)
image_metadata = Table.read(image_metadata_filename)

In [None]:
image_metadata[image_info_row]

In [None]:
img0 = fits.open(image_files[0])
wcs0 = WCS(img0[1].header)
print(wcs0)
print(wcs0.pixel_scale_matrix)
print(np.linalg.det(wcs0.pixel_scale_matrix))

In [None]:
# These are 4k x 4k images
pixel_scale = 0.11  # "/pixel
fwhm = 0.2  # "


def make_target(
    image_filepath,
    coord: SkyCoord = None,
    window_size: float = 5,
    fwhm: float = fwhm,
    psf_size: int = 51,
    pixel_scale: float = pixel_scale,
    pixel_shape: int = (100, 100),
    zeropoint: float = 22.5,
    image_hdu_idx: int = 1,
    variance_hdu_idx: int = 2,
    mask_hdu_idx: int = 3,
):
    """Make an AstroPhot target.

    image_filepath: str, Filepath to image file.
        Image file assumed to have [image, mask, variance].
        WCS assumed to be present in image HDU header

    coord: SkyCoord object with center of window
    fwhm: float, Full-Width at Half-Maximum in arcsec
    psf_size: float, width of the PSF
    pixel_scale: float, "/pix
       This is used along with fwhm, psf_size to set a Gaussian PSF model
       Would be better to have an actual PSF model from the image
    pixel_shape: (int, int), pix
    zeropoint: float, calibration of counts in image.
    """
    hdu = fits.open(image_filepath)
    img = hdu[image_hdu_idx].data  # Image HDU
    # But need to translate the informative mask with a bad-pixel mask.
    # E.g., one of the mask values is that that pixel is part of a footprint of a valid object
    # We don't want to mask those!
    var = hdu[variance_hdu_idx].data  # Variance HDU
    mask = hdu[mask_hdu_idx].data  # Mask
    nx, ny = np.shape(img)

    # we construct a basic gaussian psf for each image
    # by giving the simga (arcsec), image width (pixels), and pixelscale (arcsec/pixel)
    psf = ap.utils.initialize.gaussian_psf(fwhm / 2.355, psf_size, pixel_scale)
    wcs = WCS(hdu[image_hdu_idx].header)

    if coord is not None:
        target = ap.image.Target_Image(
            data=np.array(img, dtype=np.float64),
            variance=var,
            zeropoint=22.5,
            psf=psf,
            wcs=wcs,
            reference_radec=(coord.ra.degree, coord.dec.degree),
        )
        
    else:
        target = ap.image.Target_Image(
            data=np.array(img, dtype=np.float64),
            variance=var,
            zeropoint=22.5,
            psf=psf,
            wcs=wcs,
        )

    return target

In [None]:
target_0 = make_target(image_files[0], coord=sn_coord)

In [None]:
target_0.window

In [None]:
target_1 = make_target(image_files[1], coord=sn_coord)

In [None]:
target_1.window

Plot just the area of interest

In [None]:
coord = (sn_coord.ra.degree, sn_coord.dec.degree)
npix = 100
pixel_shape = (npix, npix)

In [None]:
center_xy = target_0.window.world_to_pixel(sn["ra"], sn["dec"])
print(center_xy)

In [None]:
def make_window_for_target(target, ra, dec, npix=100):
    window = target.window.copy()
    center_xy = window.world_to_pixel(ra, dec)

    xmin = center_xy[0] - npix // 2
    xmax = center_xy[0] + npix // 2
    ymin = center_xy[1] - npix // 2
    ymax = center_xy[1] + npix // 2

    window.crop_to_pixel([[xmin, xmax], [ymin, ymax]])
    return window

In [None]:
print(repr(target_0.window))

In [None]:
make_window_for_target(target_0, sn["ra"], sn["dec"])

In [None]:
window0 = make_window_for_target(target_0, sn["ra"], sn["dec"])
window1 = make_window_for_target(target_1, sn["ra"], sn["dec"])

In [None]:
fig1, ax1 = plt.subplots(1, 2, figsize=(12, 6))
ap.plots.target_image(fig1, ax1[0], target_0, window=window0, flipx=True)
ax1[0].set_title(image_file_basenames[0])
ap.plots.target_image(fig1, ax1[1], target_1, window=window1, flipx=True)
ax1[1].set_title(image_file_basenames[1])

plt.show()

The coordinate axes are in arcseconds, but in the local relative coordinate system for each image.  AstroPhot used the pixel scale to translate pixels -> arcsec.

In [None]:
model_0 = ap.models.AstroPhot_Model(
    name="host model",
    model_type="sersic galaxy model",
    target=target_0,
    psf_mode="full",
    window=window0,
)

In [None]:
sn_xy_1 = target_1.world_to_plane(sn["ra"], sn["dec"])
host_xy_1 = sn_xy_1

In [None]:
target_0[window0].window

In [None]:
target_1[window1].window

In [None]:
print(model_0.parameters["center"])

We have to initialize the model so that there is a value for `parameters["center"]`

In [None]:
model_0.initialize()

In [None]:
print(model_0.parameters["center"])

In [None]:
print(model_0.parameters)

In [None]:
result = ap.fit.LM(model_0, verbose=True).fit()
print(result.message)

In [None]:
print(model_0.parameters)

In [None]:
# We divide up because "model_image" expects a single axis object if single image
# while it wants an array of axis objects if there are multiple images in the image list
# model_image will not accept a one-element array if there is no image_list
def plot_target_model(model, **kwargs):
    if hasattr(model.target, "image_list"):
        _plot_target_model_multiple(model, **kwargs)
    else:
        _plot_target_model_single(model, **kwargs)


def _plot_target_model_multiple(model, window=None, title=None, figsize=(12, 12)):
    n = len(model.target.image_list)
    fig, ax = plt.subplots(n, 3, figsize=figsize)
    ap.plots.target_image(fig, ax[:, 0], model.target, window=window, flipx=True)
    #    ax[0].set_title(title)
    ap.plots.model_image(fig, ax[:, 1], model, window=window, flipx=True)
    ax[0, 1].set_title("Model")
    ap.plots.residual_image(fig, ax[:, 2], model, window=window, flipx=True)
    ax[0, 2].set_title("Residual")
    plt.show()


def _plot_target_model_single(target, model, window=None, title=None, figsize=(16, 4)):
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ap.plots.model_image(fig, ax[0], model, window=window, flipx=True)
    ax[0].set_title("Model")
    ap.plots.target_image(fig, ax[1], target, window=window, flipx=True)
    ax[1].set_title(title)
    ap.plots.residual_image(fig, ax[2], model, window=window, flipx=True)
    ax[2].set_title("Residual")
    plt.show()

In [None]:
plot_target_model(target_0, model_0, window=window0)

In [None]:
model_host_0 = ap.models.AstroPhot_Model(
    name="host model 0",
    model_type="sersic galaxy model",
    target=target_0,
    psf_mode="full",
    parameters = {"center": host_xy_1},
    window=window0,
)
model_sn_0 = ap.models.AstroPhot_Model(
    name="SN model 0",
    model_type="psf star model",
    target=target_0,
    psf_mode="full",
    parameters = {"center": sn_xy_1},
    window=window0,
)

model_host_sn_0 = ap.models.AstroPhot_Model(
    name="Host+SN",
    model_type="group model",
    models=[model_host_0, model_sn_0],
    target=target_0,
)

In [None]:
model_host_sn_0.initialize()

In [None]:
result = ap.fit.LM(model_host_sn_0, verbose=True).fit()
print(result.message)

In [None]:
print(model_host_sn_0.parameters)

In [None]:
plot_target_model(target_0, model_host_sn_0, window=window0, title=image_file_basenames[0])

Now jointly fit model across images

In [None]:
model_host_0 = ap.models.AstroPhot_Model(
    name="host model 0",
    model_type="sersic galaxy model",
    target=target_0,
    psf_mode="full",
    parameters = {"center": host_xy_1},
    window=window0,
)
model_sn_0 = ap.models.AstroPhot_Model(
    name="SN model 0",
    model_type="psf star model",
    target=target_0,
    psf_mode="full",
    parameters = {"center": sn_xy_1},
    window=window0,
)
model_host_1 = ap.models.AstroPhot_Model(
    name="host model 1",
    model_type="sersic galaxy model",
    target=target_1,
    psf_mode="full",
    parameters = {"center": host_xy_1},
    window=window1,
)
model_sn_1 = ap.models.AstroPhot_Model(
    name="SN model 1",
    model_type="psf star model",
    target=target_1,
    psf_mode="full",
    parameters = {"center": sn_xy_1},
    window=window1,
)

In [None]:
model_host_1.add_equality_constraint(model_host_0, ["center", "q", "PA", "n", "Re", "Ie"])
model_sn_1.add_equality_constraint(model_sn_1, ["center"])

In [None]:
model_host_sn_0_1 = ap.models.AstroPhot_Model(
    name="Host+SN",
    model_type="group model",
    models=[model_host_0, model_host_1, model_sn_0, model_sn_1],
    target=ap.image.Target_Image_List((target_0, target_1))
)

In [None]:
model_host_sn_0_1.initialize()

In [None]:
result = ap.fit.LM(model_host_sn_0_1, verbose=True).fit()
print(result.message)

In [None]:
print(model_host_sn_0_1.parameters)

In [None]:
plot_target_model(
    model_host_sn_0_1,
    window=[window0, window1],
    title=image_file_basenames[0],
    figsize=(12, 8),
)