# Roman, Rubin SN Simulation modeling with AstroPhot

Author: Michael Wood-Vasey <wmwv@pitt.edu>  
Last Verified to run: 2023-11-09

Use the [AstroPhot](https://autostronomy.github.io/AstroPhot/) package to model lightcurve of SN in Roman HLTDS [Supernova Survey](https://ui.adsabs.harvard.edu/abs/2023MNRAS.523.3874W)
or Rubin LSST [LSST DESC DC2](https://data.lsstdesc.org/doc/dc2_sim_sky_survey) simulations.

Notable Requirements:  
astrophot  
astropy  
torch  

Major TODOs:
  * [~] Get reasonable PSF model for image
    - Done for DC2
  * [ ] Get calibrated zeropoints and make sure they're correct for the given PSF (i.e., aperture corrections).
  * [~] Implement SIP WCS in AstroPhot to deal with slight variation in object positions
    - Instead implemented a per-image (but not per object) astrometric shift.
  * [x] Update to v0.13 AstroPhot constraints.
  * [ ] Fit in linear flux
  * [ ] Actually use a GPU (start with NERSC, hope for Apple Silicon MPS development)
  * [ ] Compare to truth catalog
  * [ ] SED Modeling
  * [ ] Mock up transmission function
  * [ ] Integrate with SN lightcurve model (e.g., SALT3 or friends)
  * [ ] Come up with galaxy SED model

## Data

This tutorial presents two options for datasets: "DC2" and "RomanSN".  The datasets were chosen because they were of interest to the primary author (MWV).  Unfortunately, there's just a bit of work to download the data.  For the Roman SN simulations, the extra work is that you need to download the full image focal planes, even when we're just using one detector from each (out of 18 total).  For the Rubin LSST DESC DC2 simulations, the extra work is using either Globus to download the data from the [DESC Data Archive](https://data.lsstdesc.org/doc/download), or using the [Rubin Science Pipelines DP0.2](https://data.lsst.cloud) to download the data through the Portal or API interfaces.

In [None]:
# Select `DATASET` as either "DC2" or "RomanSN"
DATASET = "DC2"
# DATASET = "RomanSN"

### LSST DESC Data Challenge 2 Data

The first data set use data from the LSST DESC DC2 simulated data set as processed by the LSST Science Pipelines for Data Preview 0.2.

https://arxiv.org/abs/2101.04855  
https://ui.adsabs.harvard.edu/abs/2021ApJS..253...31L.


Option 1: Access to the DESC DC2 through the [DESC Data Archive](https://data.lsstdesc.org/doc/download) requires creating a Globus account and having a Globus end point whereever you want to put the data.  Then GUI selection of the datasets you want to download.

Option 2: Access to the Rubin-processed DP0.2 data requires [registering to be a DP0 Delegate](https://dp0-2.lsst.io/dp0-delegate-resources/index.html#) and being an LSST Data Rights Holder.

This tutorial was written using data that were downloaded through the https://data.lsst.cloud Portal as the g-, r-, and i-band images overlapping the position of a supernova simulated in the DC2 data: ICRS (RA, Dec): (60.2901401, -44.142051) degrees during the 2025-08-01 through 2025-12-31.

This tutorials assumes that the DC2 image files will be placed into:
DATADIR = "data/DC2"

The DC2 tables of simulated SNe, can be downloaded via Globus at:

https://data.lsstdesc.org/browse/dataset/4dab60f0-1b22-4304-b01d-519311783c4c

### Roman HLTDS Data

The other data set uses data from Roman simulations of SN in 
"A synthetic Roman Space Telescope High-Latitude Time-Domain Survey: supernovae in the deep field"
Wang et al. 2023, MNRAS, 523, 3, 3874.
https://ui.adsabs.harvard.edu/abs/2023MNRAS.523.3874W

Data page
https://roman.ipac.caltech.edu/sims/SN_Survey_Image_sim.html

We're here trying out 3 images selected by Lauren Aldoroty that have a SN.  You could download them here with the following bash script: 

Note that this scripts work, but gets you much more data than we needed. The three images are 1.3 GB each, covering the 18 detectors for the Roman WFI instrument.  We only need one, but the tarballs are for the full focal plane.  The rotate_Y_truth catalog is 288 MB.  The remaining catalog files are much smaller.

```
DATADIR=data/RomanSN
mkdir -p ${DATADIR}
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_132.tar.gz" --output ${DATADIR}/rotate_update_Y106_132.tar.gz
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_174.tar.gz" --output ${DATADIR}/rotate_update_Y106_174.tar.gz
curl "https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_update_Y106_175.tar.gz" --output ${DATADIR}/rotate_update_Y106_175.tar.gz

# SN input catalog / SN
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/WFIRST_AKARI_FIXED_HEAD.FITS --output ${DATADIR}/WFIRST_AKARI_FIXED_HEAD.FITS

# SN input lightcurves
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/WFIRST_AKARI_FIXED_PHOT.FITS.gz --output ${DATADIR}/WFIRST_AKARI_FIXED_PHOT.FITS.gz

# SN truth
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/rotate_Y_truth.tar.gz --output ${DATADIR}/rotate_Y_truth.tar.gz

# Image Metadata
curl https://roman.ipac.caltech.edu/data/sims/sn_image_sims/paper_rotate.fits --output ${DATADIR}/paper_rotate.fits

cd data; (for f in rotate_update_Y106_*.tar.gz; do tar xvzf $f "*_1.fits.gz"; done); gunzip *.fits.gz; cd -
```


If you like you can view the FITS files and the SN region file (see below) with
```
ds9 data/rotate_update_Y106_132_1.fits data/rotate_update_Y106_174_1.fits data/rotate_update_Y106_175_1.fits -region load all sn.reg -scale mode zscale -scale match -frame lock wcs -zoom to fit
```

In [None]:
import os
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import iqr
import re

import torch

import astropy.units as u
from astropy.coordinates import SkyCoord
from astropy.io import fits
from astropy.table import Table
from astropy.time import Time
from astropy.wcs import WCS

import astrophot as ap
from astrophot.image import PSF_Image, Window

### SN and Host position

In [None]:
if DATASET == "DC2":
    # DC2
    # fitting window will be a npix x npix box
    npix_dict = {41021613806: 100,
                 11392192729110: 100}

    sn_dict = {
        41021613806: {"ra": 60.2901401, "dec": -44.142051},
        11392192729110: {"ra": 60.1996112, "dec": -44.2708990},
    }
    host_dict = {
        41021613806: {"ra": 60.288242, "dec": -44.139890},
        11392192729110: {"ra": 60.1996379, "dec": -44.2705822},
    }

    sn_id = 11392192729110
    sn = sn_dict[sn_id]
    host = host_dict[sn_id]
    npix = npix_dict[sn_id]

elif DATASET == "RomanSN":
    # RomanSN
    sn = {"ra": 71.30192566051916, "dec": -53.60051728973533}
    # We don't have a separate galaxy position here
    host = sn
    npix = 50  # window will be a npix x npix box

In [None]:
sn_coord = SkyCoord(sn["ra"], sn["dec"], unit=u.degree)

### Data files

In [None]:
if DATASET == "DC2":
    data_dir = "data/DC2"

    image_file_basenames_dict = {
        41021613806: [
            "image_calexp-g-944022-R43_S11-2025-10-14T04.fits",
            "image_calexp-g-944052-R31_S12-2025-10-14T04.fits",
            "image_calexp-g-944236-R31_S11-2025-10-14T06.fits",
            "image_calexp-g-960109-R03_S21-2025-11-14T01.fits",
            "image_calexp-g-975987-R43_S01-2025-12-08T02.fits",
            "image_calexp-r-909835-R30_S01-2025-08-18T07.fits",
            "image_calexp-r-909869-R31_S10-2025-08-18T08.fits",
            "image_calexp-r-909956-R01_S02-2025-08-18T09.fits",
            "image_calexp-r-910001-R12_S10-2025-08-18T09.fits",
            "image_calexp-r-942690-R02_S11-2025-10-12T08.fits",
            "image_calexp-r-942722-R23_S10-2025-10-12T08.fits",
            "image_calexp-r-943370-R24_S12-2025-10-13T06.fits",
            "image_calexp-r-943372-R22_S00-2025-10-13T06.fits",
            "image_calexp-r-943428-R01_S22-2025-10-13T07.fits",
            "image_calexp-r-963953-R21_S02-2025-11-19T02.fits",
            "image_calexp-r-963987-R32_S20-2025-11-19T02.fits",
            "image_calexp-r-964110-R03_S22-2025-11-19T04.fits",
            "image_calexp-r-969979-R34_S10-2025-11-28T04.fits",
            "image_calexp-r-970022-R24_S10-2025-11-28T04.fits",
            "image_calexp-i-915668-R24_S22-2025-08-25T07.fits",
            "image_calexp-i-915698-R01_S11-2025-08-25T07.fits",
            "image_calexp-i-924121-R20_S20-2025-09-19T06.fits",
            "image_calexp-i-934006-R31_S02-2025-10-01T06.fits",
            "image_calexp-i-934015-R20_S12-2025-10-01T06.fits",
            "image_calexp-i-941091-R11_S02-2025-10-10T08.fits",
            "image_calexp-i-945601-R10_S12-2025-10-16T03.fits",
            "image_calexp-i-966789-R01_S02-2025-11-24T01.fits",
            "image_calexp-i-966821-R23_S10-2025-11-24T01.fits",
            "image_calexp-i-966822-R24_S00-2025-11-24T01.fits",
            "image_calexp-i-976120-R10_S11-2025-12-08T03.fits",
            "image_calexp-i-976272-R33_S22-2025-12-08T05.fits",
            "image_calexp-i-976304-R41_S00-2025-12-08T05.fits",
            "image_calexp-i-976336-R23_S20-2025-12-08T06.fits",
        ],
        11392192729110: [
            "image_calexp-r-896748-R14_S20-2025-08-03T08.fits",
            "image_calexp-r-909835-R30_S01-2025-08-18T07.fits",
            "image_calexp-r-963953-R21_S02-2025-11-19T02.fits",
            "image_calexp-r-964110-R03_S22-2025-11-19T04.fits",
            "image_calexp-r-969979-R34_S10-2025-11-28T04.fits",
            "image_calexp-r-970022-R24_S10-2025-11-28T04.fits",
            "image_calexp-i-941091-R11_S02-2025-10-10T08.fits",
            "image_calexp-i-976272-R33_S22-2025-12-08T05.fits",
        ],
    }
    image_file_basenames = image_file_basenames_dict[sn_id]
    # SN live from MJD 60930 to 61020

    select_filter = re.compile("image_calexp-([a-zA-Z0-9]+)-")
    band = [select_filter.match(f).groups()[0] for f in image_file_basenames]

    # Limit to only the "r" band for testing
#    r_band, = np.where(np.array(band) == "g")
#    image_file_basenames = np.array(image_file_basenames)[r_band]
#    band = np.array(band)[r_band]

elif DATASET == "RomanSN":
    data_dir = "data/RomanSN"
    band = "Y106"

    # The metadata for the files is stored by row idx that is encoded in the filename
    # We'll use that later to look up the informatino for the file.
    image_info_row = [132, 174, 175]
    detectors = [1, 1, 1]
    image_file_basenames = [
        f"rotate_update_{band}_{idx}_{det}.fits"
        for idx, det in zip(image_info_row, detectors)
    ]

    band = len(image_file_basenames) * band

else:
    print(f"Unsupported DATASET: {DATASET}")

In [None]:
image_files = [os.path.join(data_dir, bn) for bn in image_file_basenames]

Get the PSF from the calexp PSFEx model

In [None]:
def resample_psf_2x2(psf):
    nc, nx, ny = psf.shape
    
    # If odd, then we want to preserve the center row nx//2, nx//2
    # And we'll average the other rows together, 2 at a time

    # Rows
    # nx = 40
    # 0:20, 40-20:40
    # 
    a = psf[:, :nx//2, :].reshape(nc, nx//4, 2, ny).sum(axis=2) / 2
    b = psf[:, -(nx-1)//2:, :].reshape(nc, nx//4, 2, ny).sum(axis=2) / 2
    if nx % 2 == 1:
        x_pieces = [a, psf[:, nx//2:nx//2+1, :], b]
    else:
        x_pieces = [a, b]
    c = np.concatenate(x_pieces, axis=1)
    
    # Columns
    d = c[:, :, :ny//2].reshape(nc, (nx+1)//2, ny//4, 2).sum(axis=3) / 2
    e = c[:, :, -(ny-1)//2:].reshape(nc, (nx+1)//2, ny//4, 2).sum(axis=3) / 2
    if ny % 2 == 1:
        y_pieces = [d, c[:, :, ny//2:ny//2+1], e]
    else:
        y_pieces = [d, e]
    new_psf = np.concatenate(y_pieces, axis=2)
    
    return new_psf

In [None]:
def test_resample_psf_2x2_odd():
    test_array_odd = np.zeros((6, 81, 81))
    smaller_test_array_odd = resample_psf_2x2(test_array_odd)
    assert np.shape(smaller_test_array_odd) == (6, 41, 41)

def test_resample_psf_2x2_even():
    test_array_even = np.zeros((6, 40, 40))
    smaller_test_array_even = resample_psf_2x2(test_array_even)
    assert np.shape(smaller_test_array_even) == (6, 20, 20)

In [None]:
test_resample_psf_2x2_odd()

In [None]:
test_resample_psf_2x2_even()

In [None]:
def read_psfex_image(hdu_info, hdu_data, resample=False):
#     group = hdu.data["group"]
#     degree = hdu.data["degree"]
#     basis = hdu.data["basis"]
#     coeff = hdu.data["coeff"]
    size = hdu_data.data["_size"]
    comp = hdu_data.data["_comp"] 
#    print(hdu_data.data["_context_first"])
#    print(hdu_data.data["_context_second"])

    image = comp.reshape(*size[0][::-1])
    
    # The PSF is oversampled by pixstep
    pixstep = hdu_info.data._pixstep[0]
    
    if resample:
        image = resample_psf_2x2(image)
    
    return pixstep, image

In [None]:
def calc_fwhm_from_psf_image(image):
    _, nx, ny = np.shape(image)

    x, y = np.meshgrid(np.arange(nx) - nx//2, np.arange(ny) - ny//2)

    norm = np.sum(image)

    # Calculate first moment
    x_center = np.sum(image * x**1) / norm
    y_center = np.sum(image * y**1) / norm

    # Calculate second moments
    I_xx = (np.sum(image * (x-x_center)**2 * (y-y_center)**0) / norm)
    I_xy = (np.sum(image * (x-x_center)**1 * (y-y_center)**1) / norm)
    I_yy = (np.sum(image * (x-x_center)**0 * (y-y_center)**2) / norm)

    fwhm = np.sqrt(I_xx + I_yy)
 
    return fwhm

In [None]:
f = image_files[0]
hdu = fits.open(f)

In [None]:
hdu[11].data._pixstep[0]

In [None]:
pixstep, image = read_psfex_image(hdu[11], hdu[12])
fwhm = calc_fwhm_from_psf_image(image)
print(fwhm)

In [None]:
plt.imshow(image[0, :, :], cmap="rainbow")
plt.colorbar()

### Detector, Image, and FITS file order

In [None]:
# These are 4k x 4k images
pixel_scale = {"DC2": 0.2, "RomanSN": 0.11}  # "/pixel
fwhm = {"DC2": 0.6, "RomanSN": 0.2}  # "

# The HDU order is different between the two datasets
HDU_IDX = {
    "DC2": {"image": 1, "mask": 2, "variance": 3, "psfex_info": 11, "psfex_data": 12},
    "RomanSN": {"image": 1, "mask": 3, "variance": 2},
}
# as are the FITS extension names
HDU_NAMES = {
    "DC2": {"image": "image", "mask": "mask", "variance": "variance"},
    "RomanSN": {"image": "SCI", "mask": "DQ", "variance": "ERR"},
}
# so we have to use a translation regardless.

In [None]:
### Bad pixel mask values
bad_pixel_bitmask = {}

## DC2
# Pixel mask values are defined in
# https://github.com/lsst/afw/blob/29afe694f19d80cba34b10ffd361dc6ca8d49dd1/src/image/detail/MaskDict.cc#L206
# The "right" way is to use `getMaskPlaneDict` in the LSST Science Pipelines
# but we don't want to introduce that dependency here.

# "EDGE": 4 is not necessarily bad, although it could be a cause for concern we'll accept it for now
# "DETECTED": 5 and "DETECTED_NEGATIVE": 6 are both bits indicating detection of objects
# and are not "bad"
basic_mask_plane_bits = {
    "BAD": 0,
    "SAT": 1,
    "INTRP": 2,
    "CR": 3,
    "SUSPECT": 7,
    "NO_DATA": 8,
}
bad_pixel_bitmask["DC2"] = sum(
    2 ** np.array([v for v in basic_mask_plane_bits.values()])
)

## Roman
bad_pixel_bitmask["RomanSN"] = 0b1

### Image metadata table

If you have the information, create a image metadata table here called `image_metadata`.  It will be used below to add key information to the AstroPhot target header metadata and which will in turn be used to build the lightcurve table for the photometry.  The table should have rows in the same order as the `image_file_basenames` (and `image_files`) arrays.  It is expected to have "mjd" and "band".

#### DC2
For DC2, the DATE-START, DATE-END, and DATA-AVG are stored in the header.  So we get the MJD by reading the headers and translate DATE-AVG to MJD

#### RomanSN
For Roman SN simulation, the dates are stored in a separate file: `paper_rotate.fits`.  We look up the MJD from appropriate row using the index in the simulated filename.

In [None]:
lightcurve_truth = None

if DATASET == "DC2":
    mjd = []
    for f in image_files:
        header = fits.getheader(f)
        dt = Time(header["DATE-AVG"], scale=header["TIMESYS"].lower())
        mjd.append(dt.mjd)
        
    summary_file = os.path.join(data_dir, "truth_sn_summary_v1-0-0.parquet")
    variability_file = os.path.join(data_dir, "truth_sn_variability_v1-0-0.parquet")

    df = pd.read_parquet(variability_file)
    lightcurve_truth = df[df["id"] == sn_id]
    
    lightcurve_truth = Table.from_pandas(lightcurve_truth)

    zp = 8.90 + 2.5 * 9
    lightcurve_truth["mag"] = -2.5 * np.log10(lightcurve_truth["delta_flux"]) + zp
    lightcurve_truth["band"] = lightcurve_truth["bandpass"]

elif DATASET == "RomanSN":
    image_metadata_basename = "paper_rotate.fits"
    image_metadata_filename = os.path.join(data_dir, image_metadata_basename)
    image_metadata = Table.read(image_metadata_filename)

    mjd = image_metadata[image_info_row]["date"]

In [None]:
import pandas as pd
summary_file = os.path.join(data_dir, "truth_sn_summary_v1-0-0.parquet")
df = pd.read_parquet(summary_file)

In [None]:
df[df["id"] == sn_id]

In [None]:
image_metadata = Table({"image_filename": image_file_basenames, "mjd": mjd})
image_metadata["band"] = band

#### Make a convenient region file of the SN coordinates

In [None]:
# Write out RA, Dec to prepare for making a ds9 region file.
def write_ds9_region_file(
    coordinate_table, region_filename="ds9.reg", ra_colname="RA", dec_colname="DEC", id_colname=None,
):
    with open(region_filename, "w") as f:
        f.write("wcs; icrs;\n")
        if id_colname is None:
            for r, d in coordinate_table[[ra_colname, dec_colname]]:
                f.write(f"point({r},{d})\n")
        else:
            for r, d, i in coordinate_table[[ra_colname, dec_colname, id_colname]]:
                f.write(f"point({r},{d})  # text={{{i}}}\n")            

In [None]:
if DATASET == "DC2":
    sn_metadata_basename = "truth_sn_summary_v1-0-0.parquet"
    sn_metadata_filename = os.path.join(data_dir, sn_metadata_basename)
    sn_metadata = Table.read(sn_metadata_filename)

    min_mjd, max_mjd = 60888, 61040
    in_date_range = (min_mjd < sn_metadata["t0"]) & (sn_metadata["t0"] < max_mjd)
    mag_threshold = 24
    in_mag_range = sn_metadata["mB"] < mag_threshold 
    in_range = in_date_range & in_mag_range
    
    region_basename = "dc2_sn.reg"
    region_filename = os.path.join(data_dir, region_basename)
    
    overwrite = False
    if overwrite:
        write_ds9_region_file(
            coordinate_table=sn_metadata[in_range],
            ra_colname="ra",
            dec_colname="dec",
            id_colname="id",
            region_filename=region_filename,
        )

In [None]:
if DATASET == "RomanSN":
    sn_metadata_basename = "WFIRST_AKARI_FIXED_HEAD.FITS"
    sn_metadata_filename = os.path.join(data_dir, sn_metadata_basename)
    sn_metadata = Table.read(sn_metadata_filename)

    overwrite = False
    if overwrite:
        write_ds9_region_file(coordinate_table=sn_metadata, region_filename="roman_sn.reg")

### General SN+host fitting

The rest of this Notebook should work in general for any data set (`image_files`), SN coordinates (`sn`), host coordinates (`host`), and lightcurve seeded with a column for MJD (`lightcurve`) set up above.

In [None]:
DEFAULT_ZP = 22.5  # Appropriate if the image was calibrated and scaled to nanomaggies


def make_target(
    image_filepath,
    coord: Optional[SkyCoord] = None,
    fwhm: float = fwhm[DATASET],
    psf_size: int = 51,
    pixel_scale: float = pixel_scale[DATASET],
    zeropoint: Optional[float] = None,
    hdu_idx: dict = HDU_IDX[DATASET],
    bad_pixel_bitmask: Optional[int] = bad_pixel_bitmask[DATASET],
    do_mask=False,
):
    """Make an AstroPhot target.

    image_filepath: str, Filepath to image file.
        Image file assumed to have [image, mask, variance].
        WCS assumed to be present in image HDU header

    coord: SkyCoord object with center of window
    fwhm: float, Full-Width at Half-Maximum in arcsec
    psf_size: float, width of the PSF
    pixel_scale: float, "/pix
       This is used along with fwhm, psf_size to set a Gaussian PSF model
       Would be better to have an actual PSF model from the image
    pixel_shape: (int, int), pix
    zeropoint: float, calibration of counts in image.
    """
    hdu = fits.open(image_filepath)
    header = hdu[0].header  # Primary header
    img = hdu[hdu_idx["image"]].data  # Image HDU
    var = hdu[hdu_idx["variance"]].data  # Variance HDU

    sigma_to_fwhm = 2.355

    if do_mask:
        # But need to translate the informative mask with a bad-pixel mask.
        # E.g., for an LSST Science Pipelines mask, one of the mask values
        # is that that pixel is part of a footprint of a valid object
        # We don't want to mask those!
        informative_mask = hdu[hdu_idx["mask"]].data  # Mask
        bad_pixel_mask = informative_mask & bad_pixel_bitmask

    # LSST Science Pipelines processed data will store a zeropoint in MAGZERO
    if zeropoint is None:
        try:
            zeropoint = header["MAGZERO"] + 2.5 * np.log10(header["EXPTIME"])
        except:
            zeropoint = DEFAULT_ZP

    wcs = WCS(hdu[hdu_idx["image"]].header)

    # If a PSF image is available, use it to calculate FWHM
    if "psfex_info" in hdu_idx.keys():
        pixstep, image = read_psfex_image(
            hdu[hdu_idx["psfex_info"]], hdu[hdu_idx["psfex_data"]], resample=True,
        )
        pixel_scale = 3600 * wcs.pixel_scale_matrix

        psf_upscale = round(1 / pixstep)
        # Tensor expects float64
        psf = image[0, :, :].astype("float64")  # just take main component
        psf /= np.sum(psf)
#         print(psf_upscale)
#         psf = PSF_Image(
#             data=psf,
#             psf_upscale=psf_upscale,
#             pixelscale=pixel_scale,
#         )
        fwhm = calc_fwhm_from_psf_image(image)
        print(fwhm * pixel_scale * pixstep)
    else:
        # we construct a basic gaussian psf for each image
        # by giving the simga (arcsec), image width (pixels), and pixelscale (arcsec/pixel)
        psf = ap.utils.initialize.gaussian_psf(
            fwhm / sigma_to_fwhm, psf_size, pixel_scale
        )

    target_kwargs = {
        "data": np.array(img, dtype=np.float64),
        "variance": var,
        "zeropoint": zeropoint,
        "psf": psf,
        "wcs": wcs,
    }

    if do_mask:
        target_kwargs["mask"] = bad_pixel_mask
    if coord is not None:
        target_kwargs["reference_radec"] = (coord.ra.degree, coord.dec.degree)

    target = ap.image.Target_Image(**target_kwargs)

    target.header.filename = image_filepath

    return target

In [None]:
targets = ap.image.Target_Image_List(make_target(f, coord=sn_coord) for f in image_files)

Add MJD and band from lightcurve table to the target metadata

In [None]:
# Assume targets and lightcurve are in same order.
# Could do a more robust lookup by lightcurve filename, but not implemented for now
for i, r in enumerate(image_metadata):
    targets[i].header.mjd = r["mjd"]
    targets[i].header.band = r["band"]

Plot just the area of interest

In [None]:
def make_window_for_target(target, ra, dec, npix=npix):
    window = target.window.copy()
    center_xy = window.world_to_pixel(ra, dec)

    xmin = center_xy[0] - npix // 2
    xmax = center_xy[0] + npix // 2
    ymin = center_xy[1] - npix // 2
    ymax = center_xy[1] + npix // 2

    window.crop_to_pixel([[xmin, xmax], [ymin, ymax]])
    return window

def make_windows_for_targets(targets, ra, dec, npix=npix):
    windows = [make_window_for_target(t, ra, dec) for t in targets]
    return windows

In [None]:
windows = make_windows_for_targets(targets, sn["ra"], sn["dec"])

In [None]:
n = len(targets.image_list)
side = int(np.sqrt(n)) + 1
fig, ax = plt.subplots(side, side, figsize=(3 * side, 3 * side))

for i in range(n):
    ap.plots.target_image(fig, ax.ravel()[i], targets[i], window=windows[i], flipx=True)

plt.show()

The coordinate axes are in arcseconds, but in the local relative coordinate system for each image.  AstroPhot used the pixel scale to translate pixels -> arcsec.

Translate SN and host positions to projection plane positions for target.  By construction of our targets, this is in the same projection plane position.

In [None]:
sn_xy = targets[0].world_to_plane(sn["ra"], sn["dec"])
host_xy = targets[0].world_to_plane(host["ra"], host["dec"])

### Plotting Convenience Function

In [None]:
# We divide up because "model_image" expects a single axis object if single image
# while it wants an array of axis objects if there are multiple images in the image list
# model_image will not accept a one-element array if there is no image_list
def plot_target_model(model, **kwargs):
    if hasattr(model.target, "image_list"):
        _plot_target_model_multiple(model, **kwargs)
    else:
        _plot_target_model_single(model, **kwargs)


def _plot_target_model_multiple(model, window=None, titles=None, base_figsize=(12, 4), figsize=None):
    n = len(model.target.image_list)
    if figsize is None:
        figsize = (base_figsize[0], n*base_figsize[1])
    fig, ax = plt.subplots(n, 3, figsize=figsize)
    # Would like to just call this, but window isn't parsed as a list
    # https://github.com/Autostronomy/AstroPhot/issues/142
    #    ap.plots.target_image(fig, ax[:, 0], model.target, window=window, flipx=True)
    for axt, mod, win in zip(ax[:, 0], model.target.image_list, window):
        ap.plots.target_image(fig, axt, mod, win, flipx=True)

    if titles is not None:
        for i, title in enumerate(titles):
            ax[i, 0].set_title(title)
    ap.plots.model_image(fig, ax[:, 1], model, window=window, flipx=True)
    ax[0, 1].set_title("Model")
    ap.plots.residual_image(fig, ax[:, 2], model, window=window, flipx=True)
    ax[0, 2].set_title("Residual")
    plt.show()


def _plot_target_model_single(model, window=None, title=None, figsize=(16, 4)):
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    ap.plots.target_image(fig, ax[0], model.target, window=window, flipx=True)
    ax[0].set_title(title)
    ap.plots.model_image(fig, ax[1], model, window=window, flipx=True)
    ax[1].set_title("Model")
    ap.plots.residual_image(fig, ax[2], model, window=window, flipx=True)
    ax[2].set_title("Residual")
    plt.show()

### Jointly fit model across images

In [None]:
model_sky = []
model_host = []
model_sn = []

# The DC2 images are calibrated exposures that have the sky subtracted
# The RomanSN images are "raw" science images with sky.
FIT_SKY = {"DC2": False, "RomanSN": True}
FIT_HOST = True
FIT_SN = True
CORRECT_SIP = True

if FIT_SKY[DATASET]:
    for i, (target, window) in enumerate(zip(targets, windows)):
        model_sky.append(
            ap.models.AstroPhot_Model(
                name=f"sky model {i}",
                model_type="flat sky model",
                target=target,
                window=window,
            )
        )
    
if FIT_HOST:
    model_host_band = {}
    for i, (b, target, window) in enumerate(zip(band, targets, windows)):
        model_host.append(
            ap.models.AstroPhot_Model(
                name=f"host model {i}",
                model_type="sersic galaxy model",
                target=target,
                psf_mode="full",
                parameters={"center": host_xy},
                window=window,
            )
        )
        # I think this assignment copies reference that points to same underlying object
        # in 'model_host' and 'model_host_band'
        # The initialization step assumes that the reference model gets initialized first.
        # So we just mark use the first model in the list of each band.
        if b not in model_host_band.keys():
            model_host_band[b] = i  
        
if FIT_SN:
    for i, (target, window) in enumerate(zip(targets, windows)):
        model_sn.append(
            ap.models.AstroPhot_Model(
                name=f"SN model {i}",
                model_type="psf star model",
                psf=target.psf,
                target=target,
                psf_mode="none",
                parameters={"center": sn_xy},
                window=window,
            )
        )

In [None]:
# AstroPhot doesn't handle SIP WCS yet.
# We'll roughly work around this by allowing a small shift in position
# for all (both) objects on the image.
CORRECT_SIP = True
if CORRECT_SIP:
    def calc_center(params):
        return params["nominal_center"].value + params["astrometric"].value

    if FIT_HOST and FIT_SN:
        host_center = ap.param.Parameter_Node(
            name = "nominal_center",
            value = host_xy    
        )

        sn_center = ap.param.Parameter_Node(
            name = "nominal_center",
            value = sn_xy
        )
            
        for i in range(len(model_host)):
            # The x, y delta is the same for both the SN and host
            # but can be different for each image.
            P_astrometric = ap.param.Parameter_Node(
                name = "astrometric",
                value = [0, 0],
            )

            model_host[i]["center"].value = calc_center
            model_host[i]["center"].link(host_center, P_astrometric)
            
            model_sn[i]["center"].value = calc_center
            model_sn[i]["center"].link(sn_center, P_astrometric)

Constrain host model to be the same per band

In [None]:
for b, model in zip(band, model_host):
    if model.name == model_host[model_host_band[b]].name:
        continue
    for parameter in ["q", "PA", "n", "Re", "Ie"]:
        model[parameter].value = model_host[model_host_band[b]][parameter]

In [None]:
# Create a two-tier hierarchy of group models
# following recommendation from Connor Stone.

# Group model for each class: sky, host, sn
all_model_list = []
if len(model_sky) > 0:
    sky_group_model = ap.models.AstroPhot_Model(
        name="Sky",
        model_type="group model",
        models=[*model_sky],
        target=targets,
    )
    all_model_list.extend(sky_group_model)

if len(model_host) > 0:
    host_group_model = ap.models.AstroPhot_Model(
        name="Host",
        model_type="group model",
        models=[*model_host],
        target=targets,
    )
    all_model_list.extend(host_group_model)

if len(model_sn) > 0:
    sn_group_model = ap.models.AstroPhot_Model(
        name="SN",
        model_type="group model",
        models=[*model_sn],
        target=targets,
    )
    all_model_list.extend(sn_group_model)

# Group model holds all the classes
model_host_sn = ap.models.AstroPhot_Model(
    name="Host+SN",
    model_type="group model",
    models=all_model_list,
    target=targets,
)

We have to initialize the model so that there is a value for `parameters["center"]`

In [None]:
model_host_sn.initialize()

In [None]:
print(model_host_sn.parameters)

In [None]:
result = ap.fit.LM(model_host_sn, verbose=True).fit()
print(result.message)

In [None]:
print(result.model.parameters)

The uncertainties above aren't actually particularly real uncertainties.  They're from the initial uncertainty set when initializing the object.  To update the uncertainty, we can use the coraviance matrix.  We'll just take the diagonal for now.

But there could be noticeable off-diagonal elements. E.g., the flux estimates are [anti]correlated with the galaxy estimate.

In [None]:
result.update_uncertainty()

The uncertainties for the center positions and astrometric uncertainties aren't calculated correctly right now.

But the the flux uncertainties are reasonable.

In [None]:
print(result.model.parameters)

In [None]:
covar = result.covariance_matrix.detach().cpu().numpy()
plt.imshow(
    covar,
    origin="lower",
    vmin=1e-8, vmax=1e-1, norm="log",
)
plt.colorbar()

Let's focus on the SN flux uncertainties:

This is a little clunky because I don't have a better way of looking up the names of the parameters in the covariance matrix.

In [None]:
sn_flux_starts_at_parameter_idx = -len(targets.image_list)
covar = result.covariance_matrix.detach().cpu().numpy()
plt.imshow(
    covar[sn_flux_starts_at_parameter_idx:, sn_flux_starts_at_parameter_idx:],
    origin="lower",
    vmin=1e-6, vmax=1, norm="log",
)
plt.colorbar()

In [None]:
sn_model_names = [f"SN model {i}" for i in range(len(targets.image_list))]

In [None]:
filenames = [model_host_sn.models[m].target.header.filename for m in sn_model_names]
bands = [model_host_sn.models[m].target.header.band for m in sn_model_names]
mag = [
    -2.5 * model_host_sn.models[m].parameters["flux"].value.detach().cpu().numpy()
    + model_host_sn.models[m].target.zeropoint.detach().cpu().numpy()
    for m in sn_model_names
]
# mag_err = 2.5 * diagonal_parameter_std[sn_flux_starts_at_parameter_idx:]
mag_err = [
    2.5 * model_host_sn.models[m].parameters["flux"].uncertainty.detach().cpu().numpy() for m in sn_model_names
]

In [None]:
[model_host_sn.models[m].target.zeropoint.detach().cpu().numpy() for m in sn_model_names]

In [None]:
lightcurve = Table(
    {"filename": filenames, "band": bands, "mjd": mjd, "mag": mag, "mag_err": mag_err}
)

In [None]:
lightcurve["mjd"].info.format = ">10.3f"
lightcurve["mag"].info.format = ">7.4f"
lightcurve["mag_err"].info.format = ">7.4f"

In [None]:
lightcurve

The current model is in log10flux instead of linear flux, so non-detections appear with very large mag_err.

In [None]:
color_for_band = {
    "u": "purple",
    "g": "blue",
    "r": "green",
    "i": "red",
    "z": "black",
    "y": "yellow",
}

# Don't print non-detections until we switch to flux fitting
mag_err_threshold = 0.5
for b in np.unique(lightcurve["band"]):
    (idx,) = np.where(
        (lightcurve["band"] == b) & (lightcurve["mag_err"] < mag_err_threshold)
    )
    plt.errorbar(
        lightcurve[idx]["mjd"],
        lightcurve[idx]["mag"] + 3.7,
        lightcurve[idx]["mag_err"],
        marker="o",
        markerfacecolor=color_for_band[b],
        markeredgecolor=color_for_band[b],
        ecolor=color_for_band[b],
        linestyle="none",
        label=f"{b}",
    )
plt.ylabel("mag")
plt.xlabel("MJD")
plt.title(f"Proof of Concept: {DATASET}")
# plt.ylim(23.5, 17)


if lightcurve_truth is not None:
    for b in np.unique(lightcurve["band"]):
        (idx,) = np.where(lightcurve_truth["band"] == b)
        plt.scatter(
            lightcurve_truth[idx]["MJD"],
            lightcurve_truth[idx]["mag"],
            color=color_for_band[b],
            marker="*",
            label=f"model {b}",
        )
        
plt.legend()

plt.ylim(plt.ylim()[::-1]);


In [None]:
lightcurve_truth[lightcurve_truth["band"] == "r"]

The current spread in values during a night is a very loose guide to the amount we need to improve in extracting the lightcurve in this Notebook.

In [None]:
plot_target_model(
    model_host_sn,
    window=windows,
    titles=image_file_basenames,
)