# Inject variable sources into multiple images

Contact author: Jeff Carlin

Date last verified to run: Mon Apr 29 2024

RSP environment version: Weekly 2024_16

**Summary:**
A demo of how to inject a variable source into a set of `calexp` images, with the correct magnitude for each image, and then "warp" those images to a common WCS so that they are aligned and extract cutout images.

Import packages and then instantiate a butler for DP0.2.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import astropy.units as u
from astropy.table import Table
from astropy.coordinates import SkyCoord

from lsst.daf.butler import Butler
from lsst.daf.butler.registry import ConflictingDefinitionError
import lsst.afw.display as afwDisplay
import lsst.geom as geom
from lsst.source.injection import ingest_injection_catalog, generate_injection_catalog
from lsst.source.injection import VisitInjectConfig, VisitInjectTask
import lsst.sphgeom
from lsst.pipe.tasks.registerImage import RegisterConfig, RegisterTask

afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')

In [None]:
def cutout(im, xcen, ycen, size):
    '''Create a cutout of an input image array

    Parameters
    ----------
    im: `Image`
        Input image (extracted from an ExposureF) to cut down
    xcen, ycen: `int`
        Integer XY coordinates to center the cutout on
    size: `int`
        Width in pixels of the resulting image
    '''
    return im[xcen-size/2:xcen+size/2, ycen-size/2:ycen+size/2]

In [None]:
def sinusoidal_variability(period_days, amplitude_mags, mean_mag, exposure_midpts_mjd):
    '''Given input parameters and exposure midpoints,
        create a sinusoidally varying light curve

    Parameters:
    -----------
    period_days: `float`
        Sinusoidal period, in days
    amplitude_mags: `float`
        Amplitude of the variability, in magnitudes
    mean_mag: `float`
        Mean magnitude
    exposure_midpts_mjd: `array` of `floats`
        Midpoint times of exposures for which to calculate the magnitude
    '''
    # mag = mean_mag + A*sin(phase)
    # Set the magnitude in the first image to the mean:
    mag0 = mean_mag
    t0 = exposure_midpts_mjd[0]
    t_days = exposure_midpts_mjd - t0
    mjd_norm = t_days / period_days
    phase = np.mod(mjd_norm, 1.0)
    mags = mean_mag + amplitude_mags*np.sin(phase*2.0*np.pi)

    return mags
    

In [None]:
def warp_img(ref_img, img_to_warp, ref_wcs, wcs_to_warp):
    '''Warp an image to the same orientation as a reference image

    Parameters
    ----------
    ref_img: `ExposureF`
        Reference image to warp to
    img_to_warp: `ExposureF`
        Image to warp to the reference orientation
    ref_wcs: `WCS` object
        WCS of the reference image
    wcs_to_warp: `WCS` object
        WCS of the input image to be warped
    '''

    config = RegisterConfig()
    task = RegisterTask(name="register", config=config)
    warpedExp = task.warpExposure(img_to_warp, wcs_to_warp, ref_wcs,
                                  ref_img.getBBox())

    return warpedExp

In [None]:
butler_config = 'dp02'
collections = '2.2i/runs/DP0.2'
butler = Butler(butler_config, collections=collections)

### Find calexps overlapping a given position on the sky:

In [None]:
ra = 62.149
dec = -35.796

In [None]:
level = 20  # the resolution of the HTM grid
pixelization = lsst.sphgeom.HtmPixelization(level)

In [None]:
htm_id = pixelization.index(
    lsst.sphgeom.UnitVector3d(
        lsst.sphgeom.LonLat.fromDegrees(ra, dec)
    )
)

In [None]:
circle = pixelization.triangle(htm_id).getBoundingCircle()
scale = circle.getOpeningAngle().asDegrees()*3600.
level = pixelization.getLevel()
print(f'HTM ID={htm_id} at level={level} is bounded by a circle of radius ~{scale:0.2f} arcsec.')

In [None]:
datasetRefs = butler.registry.queryDatasets("calexp", htm20=htm_id,
                                            where="band = 'i'")

datasetRefs_list = []
for i, ref in enumerate(datasetRefs):
    datasetRefs_list.append(ref)

print(f"Found {len(list(datasetRefs))} calexps")

### Extract the time at midpoint of each exposure:

In [None]:
ccd_visit = butler.get('ccdVisitTable')
exp_midpoints = []
visits = []
detectors = []

for d in datasetRefs_list:
    did = d.dataId
    # Look up the info by visit and detector:
    ccdrow = (ccd_visit['visitId'] == did['visit']) & (ccd_visit['detector'] == did['detector'])
    exp_midpoints.append(ccd_visit[ccdrow]['expMidptMJD'].values[0])
    visits.append(did['visit'])
    detectors.append(did['detector'])

exp_midpoints = np.array(exp_midpoints)
visits = np.array(visits)
detectors = np.array(detectors)

In [None]:
exp_midpoints

### Assign variable magnitudes to inject

Use the "sinusoidal_variability" function defined above to create a variable star.

In [None]:
per = 100.0  # period in days
amp = 3.0  # amplitude in magnitudes
mag = 20.0  # mean magnitude
var_mags = sinusoidal_variability(per, amp, mag, exp_midpoints)

#### Plot the lightcurve we just created

In [None]:
tmp_midpts = np.arange(np.min(exp_midpoints), np.max(exp_midpoints), 1)
tmp_mags = sinusoidal_variability(per, amp, mag, tmp_midpts)
plt.plot(tmp_midpts, tmp_mags, color='Gray')
plt.plot(exp_midpoints, var_mags, 'k.')
start_ind = 0
finish_ind = 18
plt.plot(exp_midpoints[start_ind:finish_ind], var_mags[start_ind:finish_ind], 'r.')

plt.show()

### Combine all the information into a catalog of sources

The catalog consists of one star per visit.

In [None]:
ra_arr = np.full((len(var_mags)), ra)
dec_arr = np.full((len(var_mags)), dec)
id_arr = np.arange(0, len(var_mags), 1)
src_type_arr = np.full((len(var_mags)), 'Star')

In [None]:
inject_table = Table([id_arr, visits, detectors, ra_arr, dec_arr,
                      src_type_arr, exp_midpoints, var_mags],
                     names=['injection_id', 'visit', 'detector', 'ra', 'dec',
                            'source_type', 'exp_midpoint', 'mag'])

inject_table

### Inject a single star into each image

First, initialize the injection task. Then, extract info about the first visit, which we'll take as the "reference" visit to match others to.

In [None]:
inject_config = VisitInjectConfig()
inject_task = VisitInjectTask(config=inject_config)

In [None]:
ref_dataId = datasetRefs_list[start_ind].dataId
calexp_ref = butler.get('calexp', dataId=ref_dataId)
psf_ref = calexp_ref.getPsf()
photo_calib_ref = calexp_ref.getPhotoCalib()
wcs_ref = calexp_ref.getWcs()

xy_ref = wcs_ref.skyToPixel(geom.SpherePoint(ra*geom.degrees, dec*geom.degrees))
x_ref = int(np.round(xy_ref.x))
y_ref = int(np.round(xy_ref.y))

Loop over (a subset of) the visits, injecting a star of the appropriate magnitude based on the lightcurve, warping the resulting image to the orientation of the reference image, then extracting a cutout image around that star.

In [None]:
cutouts = []
dataids = []
mjd_mid_times = []
mags_injected = []

for i in range(start_ind, finish_ind):
    dataId_i = datasetRefs_list[i].dataId
    calexp_i = butler.get('calexp', dataId=dataId_i)
    psf_i = calexp_i.getPsf()
    photo_calib_i = calexp_i.getPhotoCalib()
    wcs_i = calexp_i.getWcs()

    try:
        injected_output_i = inject_task.run(
            injection_catalogs=[inject_table[i]],
            input_exposure=calexp_i.clone(),
            psf=psf_i,
            photo_calib=photo_calib_i,
            wcs=wcs_i,
        )
        injected_exposure_i = injected_output_i.output_exposure
        injected_catalog_i = injected_output_i.output_catalog

        img_warped = warp_img(calexp_ref, injected_exposure_i, wcs_ref, wcs_i)
        xy = img_warped.getWcs().skyToPixel(geom.SpherePoint(ra*geom.degrees, dec*geom.degrees))
        x = int(np.round(xy.x))
        y = int(np.round(xy.y))
        cutout_image = cutout(img_warped, x, y, 301)
        cutouts.append(cutout_image)
        mjd_mid_times.append(inject_table[i]['exp_midpoint'])
        mags_injected.append(inject_table[i]['mag'])
        dataids.append(dataId_i)
    except:
        # Some visits don't actually overlap the point where we're injecting a star
        print('No sources to inject for visit ', inject_table[i]['visit'])


In [None]:
fig, axs = plt.subplots(5, 3, figsize=(9, 10), dpi=150)

for i, ax in enumerate(fig.axes):
    plt.sca(ax)
    display0 = afwDisplay.Display(frame=fig)
    # display0.scale('linear', 'zscale')
    display0.scale('linear', min=-100, max=250)
    try:
        display0.mtv(cutouts[i].image)
        vis = dataids[i]['visit']
        mjd = mjd_mid_times[i]
        mag = mags_injected[i]
        plt.title(f'visit: {vis}, expMid: {mjd:0.5F}, mag={mag:0.2F}',
                  fontsize=8)
    except:
        print('skip index ', i)

plt.tight_layout()
plt.show()