<img align="left" src = https://lsstdesc.org/assets/img/logo.png width=250 style="padding: 10px"> 
<b>Testing New DIA Kernel Bases</b> <br>
Contact author: Michael Wood-Vasey <br>
Last verified to run: 2023-06-05 <br>
LSST Science Pipelines version: Weekly 2023_21 <br>
Container Size: large <br>
Targeted learning level: intermediate <br>

Sets up a interactive stepping through of the tasks to do image subtraction to allow for easier modifications to StarSelector, Kernel bases, and Detection.

Note: This Notebook is written below the PipelineTask level.  Rather is uses individual Tasks directly and reads/writes output products to the butler.  This is pedagogically useful to understand how that works, and pratically helpful in working with the evolving `source_injection` package.  However, this structure is not scalable to larger runs (100+ images).  Such large-scale runs should be done as part of an integrated Task that can be connected and run through the large-scale cluster jobs submission.

1. [x] Find set of images that overlap
2. [x] Pick one as template, one as science
3. [x] Also deepCoadd.  Be able to use either.
4. [x] Run subtractions through Tasks
5. [x] Run detection and measurement through Task

In [None]:
from collections.abc import Sequence
import os
from typing import Union

import astropy.table
from astropy.wcs import WCS

import matplotlib.pyplot as plt
import numpy as np

import lsst.afw.display as afwDisplay
import lsst.afw.image
from lsst.afw.math import Warper, WarperConfig
import lsst.afw.table
from lsst.daf.butler import Butler
import lsst.geom as geom
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import GetTemplateConfig, GetTemplateTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
import lsst.sphgeom

Some things we may or may not want to import for overriding or inheriting:

In [None]:
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
from lsst.ip.diffim import MakeKernelConfig, MakeKernelTask, PsfMatchConfig, PsfMatchConfigAL, PsfMatchConfigDF
from lsst.ip.diffim.subtractImages import _subtractImages
from lsst.ip.diffim.utils import evaluateMeanPsfFwhm, getPsfFwhm
from lsst.meas.algorithms import SourceDetectionTask, SubtractBackgroundTask
from lsst.meas.base import SingleFrameMeasurementTask
from lsst.pex.exceptions import InvalidParameterError

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

## Some helper utilities

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), ax=None, x=None, y=None,
                           pixel_extent=None, stamp_size=None,
                           vmin=-200, vmax=400,
                           marker="o", color="red", size=20):
    """
    Show an image with an RA, Dec grid overlaid.  Optionally add markers.
    
    Notes
    -----
    Specifying both pixel_extent and size is undefined.
    """
    if ax is None:
        fig = plt.figure(figsize=figsize)
        plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
        ax = plt.gca()

    if stamp_size is not None and x is not None and y is not None:    
        half_stamp = stamp_size / 2
        # If x and y are of different types, then user should clarify what they wanted
        if np.isscalar(x):
            first_x = x
            first_y = y
        else:
            first_x = x[0]
            first_y = y[0]
            
        pixel_extent = (int(first_x - half_stamp), int(first_x + half_stamp),
                        int(first_y - half_stamp), int(first_y + half_stamp))
    if pixel_extent is None:
        pixel_extent = (0, calexp.width, 0, calexp.height)

    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)
    im = ax.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3],
                                      pixel_extent[0]:pixel_extent[1]],
                   cmap="gray", vmin=vmin, vmax=vmax,
                   extent=pixel_extent, origin="lower")
    ax.grid(color="white", ls="solid")
    ax.set_xlabel("Right Ascension")
    ax.set_ylabel("Declination")
    if x is not None and y is not None:
        ax.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")
        ax.set_xlim(pixel_extent[0:2])
        ax.set_ylim(pixel_extent[2:4])

In [None]:
def show_image_with_mask_plane(calexp, figsize=(8, 8)):
    fig, ax = plt.subplots(figsize=figsize)
    display = afwDisplay.Display(frame=fig)
    display.scale('asinh', 'zscale')
    display.setMaskTransparency(80)
    display.setMaskPlaneColor('DETECTED', 'blue')
    display.mtv(calexp)
    plt.show()
    
    return display

In [None]:
def htm_from_ra_dec_level(ra, dec, level):
    pixelization = lsst.sphgeom.HtmPixelization(level)
    htm_id = pixelization.index(
        lsst.sphgeom.UnitVector3d(
            lsst.sphgeom.LonLat.fromDegrees(ra, dec)
        )
    )
    return htm_id

In [None]:
def get_dataset_refs_from_htm_list(dataset_type, htm_ids, level, aggregate="intersection"):
    hi = htm_ids[0]

    # dataset_refs is an iterator, but each query is only a few hundred results,
    #   so convert to a list for future convenience
    htm_kwargs = {}
    htm_kwargs[f"htm{level}"] = hi
    dataset_refs = list(butler.registry.queryDatasets(dataset_type, dataId={"band": band}, **htm_kwargs))
    dataset_refs = set(dataset_refs)
    
    for hi in htm_ids[1:]:
        htm_kwargs = {}
        htm_kwargs[f"htm{level}"] = hi
        dr = list(butler.registry.queryDatasets(dataset_type, dataId={"band": band}, **htm_kwargs))
        if aggregate == "intersection":
            dataset_refs = dataset_refs.intersection(set(dr))
        elif aggregate == "union":
            dataset_refs = dataset_refs.union(set(dr))
        else:
            print("Aggregation method '{aggregate}' not supported.")
            return
        
    return list(dataset_refs)

## Defining Dataset based on Site

We can run this on either DC2 or HSC by choosing appropriate RA, Dec

Currently (2023-05-26) DC2 is only available at the IDF and HSC is only available at the USDF,
so we split by site.

In [None]:
SITE = "IDF"

survey_site = {"USDF": "HSC", "IDF": "DC2", "NERSC": "DC2"}
repo_site = {"USDF": "/repo/main", "IDF": "dp02", "NERSC": "/global/cfs/cdirs/lsst/production/gen3/DC2/Run2.2i/repo"}
collection_site = {"USDF": "HSC/runs/RC2/w_2023_15/DM-38691", "IDF": "2.2i/runs/DP0.2", "NERSC": "u/descdm/coadds_Y1_4638"}

ra_dec_survey = {"HSC": (150, 2.5), "DC2": (55, -30)}

In [None]:
collection = collection_site[SITE]
repo_config = repo_site[SITE]

user = os.getenv("USER")
output_collection = f"u/{user}/test_dia"

In [None]:
butler = Butler(repo_config, run=output_collection, collections=[output_collection, collection])

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = ra_dec_survey[survey_site[SITE]]
band = "i"

In [None]:
level = 20  # the resolution of the HTM grid
htm_id = htm_from_ra_dec_level(ra, dec, level)

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
dataset_refs = get_dataset_refs_from_htm_list("calexp", htm_ids, level)

# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable")

We should find 140 calexps for DC2.  (RA, Dec) = (55, -30)  

We should find 44 calexps for HSC COSMOS.  (RA, Dec) = (150, +2.5)

# Build template for subtraction


Also provide a single image template based on calexp[0] in the list.

In [None]:
single_image_template = butler.get("calexp", dataset_refs[0].dataId)

Run subtractions with calexp[1] as science.

In [None]:
science_dr = dataset_refs[1]
science = butler.get("calexp", science_dr.dataId)

### Get a template from the deepCoadd
Here we get a template from the (tract, patch) deepCoadd reassembled to be continous across calexp.

In [None]:
sky_map = butler.get("skyMap")

In [None]:
get_template_task_config = GetTemplateConfig()
get_template_task = GetTemplateTask(config=get_template_task_config)

In [None]:
bigger_level = 9
bigger_htm_id = htm_from_ra_dec_level(ra, dec, level=bigger_level)

coadd_exposure_refs = get_dataset_refs_from_htm_list("deepCoadd", [bigger_htm_id], level=bigger_level, aggregate="union")
coadd_exposure_deferred_dataset_handles = [butler.getDeferred(dr) for dr in coadd_exposure_refs]

In [None]:
coadd_exposure_deferred_dataset_handles

Check that we're close to original RA, Dec

In [None]:
science.getWcs().pixelToSky(science.getBBox().getCenter())

In [None]:
inputs = {"coaddExposures" : coadd_exposure_deferred_dataset_handles,
          "bbox": science.getBBox(),
          "skyMap": sky_map,
          "wcs": science.getWcs(),
          "visitInfo": science.visitInfo,
         }

In [None]:
results = get_template_task.getOverlappingExposures(inputs)
coadd_exposures = results.coaddExposures
data_ids = results.dataIds

In [None]:
deep_coadd_template = get_template_task.run(coadd_exposures, inputs["bbox"], inputs["wcs"], data_ids)

In [None]:
# Do I need to iterate or could I just del coadd_exposures with the same effect?
for ce in coadd_exposures:
    del ce

In [None]:
template = deep_coadd_template.template

figsize = (12, 6)
fig = plt.figure(figsize=figsize)
ax1 = fig.add_subplot(1, 2, 1, projection=WCS(template.getWcs().getFitsMetadata()))
show_image_on_wcs(template, vmin=-2, vmax=+4, ax=ax1)

ax2 = fig.add_subplot(1, 2, 2, projection=WCS(science.getWcs().getFitsMetadata()))
show_image_on_wcs(science, vmin=-200, vmax=+400, ax=ax2)

plt.tight_layout()

## Subtraction

In [None]:
def warp(science, template):
    "Warp input template image to WCS and Bounding Box of the science image."
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    science_wcs = science.getWcs()
    science_bbox = science.getBBox()
    
    warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
    # Add PSF.  I think doing this directly without warping is wrong.
    # At least the x,y mapping should be updated
    warped_template.setPsf(template.getPsf())
    
    return warped_template


def subtract(science, template, source_catalog, task=None, config=None):
    # https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L196
    if config is None and task is None:
        config = AlardLuptonSubtractConfig()
    if task is None:
        task = AlardLuptonSubtractTask(config=config)
    # Star Selection is done here:
    #   https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L603

    warped_template = warp(science, template)
    
    subtraction = task.run(warped_template, science, source_catalog)
    
    return subtraction


def detect(science, subtraction):
    # Run detection on subtraction
    detect_and_measure_config = DetectAndMeasureConfig()
    detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

    detect_and_measure = detect_and_measure_task.run(science,
                                                     subtraction.matchedTemplate,
                                                     subtraction.difference)

    return detect_and_measure

## Provide a modified makeKernel

We can inherit from and then modify methods of the MakeKernelTask to test ideas for improvements.

The kernel used by AlardLuptonSubtractTask is a configurable option.

In [None]:
# https://github.com/lsst/ip_diffim/blob/w.2023.07/python/lsst/ip/diffim/makeKernel.py#L45

class ModifiedMakeKernelConfig(MakeKernelConfig):
    """Stub inherited class to let room for future configuration passing"""
    # If you wanted to create a new config parameter to pass to the task:
    # foo = lsst.pex.config.ConfigChoiceField(doc="foo threshold", dtype=float, default=1.0)
    pass

In [None]:
class ModifiedMakeKernelTask(MakeKernelTask):
    """Construct a kernel for PSF matching two exposures

    This Modified class is an example for showing to to create your own kernel-solving class.
    """

    ConfigClass = ModifiedMakeKernelConfig
    _DefaultName = "makeModifiedKernel"

    # This is the routine we might want to replace wtih our own ideas
    # about finding a good convolution kernel
    # Needs to return an lsst.afw.math.LinearCombinationKernel
    # Original
    #  https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/makeKernel.py#L108
    def run(self, template, science, kernelSources, preconvolved=False):
        """Solve for the kernel and background model that best match two
        Exposures evaluated at the given source locations.

        Parameters
        ----------
        template : `lsst.afw.image.Exposure`
            Exposure that will be convolved.
        science : `lsst.afw.image.Exposure`
            The exposure that will be matched.
        kernelSources : `list` of `dict`
            A list of dicts having a "source" and "footprint"
            field for the Sources deemed to be appropriate for Psf
            matching. Can be the output from ``selectKernelSources``.
        preconvolved : `bool`, optional
            Was the science image convolved with its own PSF?

        Returns
        -------
        results : `lsst.pipe.base.Struct`

            ``psfMatchingKernel`` : `lsst.afw.math.LinearCombinationKernel`
                Spatially varying Psf-matching kernel.
            ``backgroundModel``  : `lsst.afw.math.Function2D`
                Spatially varying background-matching function.
        """
        # Just debugging that we're really running this modified task
        self.log.info("Running Modified Make Kernel Task")

        kernelCellSet = self._buildCellSet(
            template.maskedImage, science.maskedImage, kernelSources
        )
        # Calling getPsfFwhm on template.psf fails on some rare occasions when
        # the template has no input exposures at the average position of the
        # stars. So we try getPsfFwhm first on template, and if that fails we
        # evaluate the PSF on a grid specified by fwhmExposure* fields.
        # To keep consistent definitions for PSF size on the template and
        # science images, we use the same method for both.
        try:
            templateFwhmPix = getPsfFwhm(template.psf)
            scienceFwhmPix = getPsfFwhm(science.psf)
        except InvalidParameterError:
            self.log.debug(
                "Unable to evaluate PSF at the average position. "
                "Evaluting PSF on a grid of points."
            )
            templateFwhmPix = evaluateMeanPsfFwhm(
                template,
                fwhmExposureBuffer=self.config.fwhmExposureBuffer,
                fwhmExposureGrid=self.config.fwhmExposureGrid,
            )
            scienceFwhmPix = evaluateMeanPsfFwhm(
                science,
                fwhmExposureBuffer=self.config.fwhmExposureBuffer,
                fwhmExposureGrid=self.config.fwhmExposureGrid,
            )

        if preconvolved:
            scienceFwhmPix *= np.sqrt(2)

        ### THESE LINES ARE PROBABLY WHERE YOU WANT TO CHANGE: BEGIN ###
        basisList = self.makeKernelBasisList(
            templateFwhmPix, scienceFwhmPix, metadata=self.metadata
        )
        spatialSolution, psfMatchingKernel, backgroundModel = self._solve(
            kernelCellSet, basisList
        )
        ### END: THESE LINES ARE PROBABLY WHERE YOU WANT TO CHANGE

        return lsst.pipe.base.Struct(
            psfMatchingKernel=psfMatchingKernel,
            backgroundModel=backgroundModel,
        )

If we want to modify the run... method of the subtraction task itself, we would subclass AlardLuptonSubtractTask and modify the `run...` method.  Here we just pick "Modified" to be generic, but if one had a specific name, that would be good too.

Try importing calculation from al-algorithm.ipynb

In [None]:
from numpy.polynomial.polynomial import polyval2d
from numpy.polynomial.chebyshev import chebval2d
from scipy.stats import multivariate_normal
import scipy.signal
from sklearn import linear_model


def compute_xy_grids(x_len, y_len):
    x = np.arange(-x_len // 2 + 1, x_len // 2 + 1, 1, dtype=np.float32)
    y = np.arange(-y_len // 2 + 1, y_len // 2 + 1, 1, dtype=np.float32)
    xx, yy = np.meshgrid(x, y)
    return xx, yy


def gaussian2d(xx, yy, m=[0.0, 0.0], cov=[[1, 0], [0, 1]]):
    grid = np.dstack((xx, yy))
    var = multivariate_normal(mean=m, cov=cov)
    return var.pdf(grid)


def chebGauss2d(xx, yy, gauss_cov, poly_deg):
    # compute Gaussian
    gau = gaussian2d(xx, yy, cov=gauss_cov)
    # compute Chebyshev
    x_deg, y_deg = poly_deg[0], poly_deg[1]
    coef_x = np.zeros(x_deg + 1)
    coef_x[x_deg] = 1
    coef_y = np.zeros(y_deg + 1)
    coef_y[y_deg] = 1
    coefs = np.outer(coef_x, coef_y)
    cheb = chebval2d(xx, yy, c=coefs)
    return cheb * gau


def compute_kernel_bases(kernel_size, sig_ls, poly_deg_ls):
    xx, yy = compute_xy_grids(kernel_size, kernel_size)
    kernel_bases = []
    for id_x, x_sig in enumerate(sig_ls):
        for id_y, y_sig in enumerate(sig_ls):
            for x_deg in range(poly_deg_ls[id_x] + 1):
                for y_deg in range(poly_deg_ls[id_y] + 1):
                    gauss_cov = [[x_sig, 0.0], [0.0, y_sig]]
                    poly_deg = (x_deg, y_deg)
                    kernel_bases.append(chebGauss2d(xx, yy, gauss_cov, poly_deg))
    return kernel_bases


def compute_base_image_matrix(template, kernel_bases):
    base_im_ls = []
    for basis in kernel_bases:
        base_im = scipy.signal.fftconvolve(template, basis, mode="same")
        base_im_ls.append(base_im.flatten())
    base_image_matrix = np.vstack(base_im_ls).T
    return base_image_matrix


def compute_spatial_image_matrix(xx, yy, spatial_deg, verbose=False):
    spatial_image_ls = []
    for x_deg in range(spatial_deg + 1):
        for y_deg in range(spatial_deg - x_deg + 1):
            if verbose:
                print(x_deg, y_deg)
            coef_x = np.zeros(x_deg + 1)
            coef_x[x_deg] = 1
            coef_y = np.zeros(y_deg + 1)
            coef_y[y_deg] = 1
            coefs = np.outer(coef_x, coef_y)
            spatial_image = polyval2d(xx, yy, c=coefs)
            spatial_image_ls.append(spatial_image.flatten())
    spatial_image_matrix = np.vstack(spatial_image_ls).T
    return spatial_image_matrix


def compute_base_spatial_image_matrix(base_image_matrix, kernel_spatial_image_matrix):
    base_spatial_vec_ls = []
    for i in range(base_image_matrix.shape[1]):
        base_vec = base_image_matrix[:, i]
        for j in range(kernel_spatial_image_matrix.shape[1]):
            kernel_spatial_vec = kernel_spatial_image_matrix[:, j]
            base_spatial_vec = base_vec * kernel_spatial_vec
            base_spatial_vec_ls.append(base_spatial_vec)
    base_spatial_image_matrix = np.vstack(base_spatial_vec_ls).T
    return base_spatial_image_matrix


def compute_X(
    template,
    kernel_size,
    gauss_sig_ls,
    poly_deg_ls,
    kernel_spatial_deg: int = 0,
    background_spatial_deg: int = 0,
):
    # spatial expansion
    Nx, Ny = template.shape[1], template.shape[0]
    xx_norm, yy_norm = compute_xy_grids(Nx, Ny)
    xx_norm /= Nx
    yy_norm /= Ny

    # compute base image matrix
    kernel_bases = compute_kernel_bases(kernel_size, gauss_sig_ls, poly_deg_ls)
    base_image_matrix = compute_base_image_matrix(template, kernel_bases)
    # compute spatial image matrix
    kernel_spatial_image_matrix = compute_spatial_image_matrix(
        xx_norm, yy_norm, kernel_spatial_deg, verbose=False
    )
    background_spatial_image_matrix = compute_spatial_image_matrix(
        xx_norm, yy_norm, background_spatial_deg, verbose=False
    )
    # compute base spatial image matrix
    base_spatial_image_matrix = compute_base_spatial_image_matrix(
        base_image_matrix, kernel_spatial_image_matrix
    )
    # compute X
    X = np.concatenate(
        (base_spatial_image_matrix, background_spatial_image_matrix), axis=1
    )
    return X

In [None]:
def al_python(
    science_image_array, #: np.ndarray,
    template_image_array, #: np.ndarray,
    x_stars: Sequence[float],
    y_stars: Sequence[float],
    stamp_size: int = 51,
    gauss_sig_ls: Sequence[float] = (0.75, 1.5, 3.0),
    poly_deg_ls: Sequence[int] = (2, 2, 2),
    kernel_size: int = 31,
    kernel_spatial_deg: int = 0,
    kernel_background_deg: int = 0,
) -> np.ndarray:
    """
    Implement fitting and convolution in Python (numpy, scipy, sklearn)
    
    science and template image arrays must already be aligned.
    
    x_stars, y_stars: x, y positions of objects to use to fit kernel.
    stamp_size: size of the stamp around each object to extract.
        In the current Science Pipelines equivalent this is done as the footprint of the object.
    """

    print("Making X array")
    X = compute_X(
        template_image_array,
        kernel_size,
        gauss_sig_ls,
        poly_deg_ls,
        kernel_spatial_deg,
        background_spatial_deg,
    )

    print("Fitting")
    sci_vec = science_image_array.flatten()
    lin = linear_model.LinearRegression()
    lin.fit(X, sci_vec)

    sci_pred = lin.predict(X)

    return sci_pred

In [None]:
"""
Subtract template image from image referred to by data_id and run detection.
"""
science = butler.get("calexp", science_dr.dataId)
source_catalog = butler.get("src", dataId=science_dr.dataId)

In [None]:
# template = single_image_template
template = deep_coadd_template.template

In [None]:
nx = science.getWidth()
ny = science.getHeight()
# nx, ny = 125, 125

template_x0 = template.getX0()
template_y0 = template.getY0()
template_image_array = template.image.array[
    0-template_x0:nx-template_x0,
    0-template_y0:ny-template_y0,
]

science_x0 = science.getX0()
science_y0 = science.getY0()
science_image_array = science.image.array[
    0-science_x0:nx-science_x0,
    0-science_y0:ny-science_y0,
]

In [None]:
# Trying to make a block sparse matrix for the footprings
from scipy.sparse import bsr_matrix, coo_array, dok_array

In [None]:
def build_sparse_stamps(
    image: lsst.afw.image.ExposureF,
    source_catalog: [lsst.afw.table.SourceCatalog, astropy.table.Table],
    stamp_size: int = 50,
):
    "Returns stamps as a sparse matrix"

    M, N = science_image_array.shape

    template_stamps = dok_array((M, N), np.float32)

    # Reject objects too close to the edge
    x_stars = source_catalog["slot_Centroid_x"]
    y_stars = source_catalog["slot_Centroid_y"]
    (w,) = np.where(
        (stamp_size < x_stars)
        & (x_stars < M - stamp_size)
        & (stamp_size < y_stars)
        & (y_stars < N - stamp_size)
    )
    x_stars = x_stars[w]
    y_stars = y_stars[w]
    
    # Build the collection of stamps
    for x, y in zip(x_stars, y_stars):
        template_stamps[
            int(x - stamp_size // 2), int(y + stamp_size // 2)
        ] = template_image_array[int(x - stamp_size // 2), int(y + stamp_size // 2)]

    return template_stamps, x_stars, y_stars

### Sparse stamps

In [None]:
template_sparse_stamps, x_stars, y_stars = build_sparse_stamps(template, source_catalog)
science_sparse_stamps, x_stars, y_stars = build_sparse_stamps(science, source_catalog)

Fit for the difference

In [None]:
gauss_sig_ls = [0.75, 1.5, 3.0]
poly_deg_ls = [2, 2, 2]
kernel_size = 31
# spatial degree of freedom
kernel_spatial_deg = 2
background_spatial_deg = 0

print("Making X array")

X = compute_X(template_sparse_stamps, kernel_size, gauss_sig_ls, poly_deg_ls,
    kernel_spatial_deg=0, background_spatial_deg=0)

In [None]:
sci_vec = s_stamp.flatten()
lin = linear_model.LinearRegression()
lin.fit(X, sci_vec)
lin_fits.append(lin)

    
# Coefficiencts stored in lin.coef_
# Now fit for spatial variation of these coefficients

### Stack stamps

In [None]:
def build_stack_stamps(
    image: lsst.afw.image.ExposureF,
    source_catalog: Union[lsst.afw.table.SourceCatalog, astropy.table.Table],
    stamp_size: int = 50,
) -> [np.ndarray, Sequence[float], Sequence[float]]:
    "Returns stamps, x, y as a 3D array"
    stamp_size = 50

    image_x0 = image.getX0()
    image_y0 = image.getY0()
    N = image.getWidth() - image_x0
    M = image.getHeight() - image_y0

    x_stars = source_catalog["slot_Centroid_x"]
    y_stars = source_catalog["slot_Centroid_y"]
    (w,) = np.where(
        (stamp_size < x_stars)
        & (x_stars < M - stamp_size)
        & (stamp_size < y_stars)
        & (y_stars < N - stamp_size)
    )
    x_stars = x_stars[w]
    y_stars = y_stars[w]

    image_stamps = np.empty(
        shape=(stamp_size, stamp_size, len(x_stars)), dtype=np.float32
    )
    for i, (x, y) in enumerate(zip(x_stars, y_stars)):
        image_stamps[:, :, i] = image.image.array[
            int(image_x0 + x - stamp_size // 2) : int(image_x0 + x + stamp_size // 2),
            int(image_y0 + y - stamp_size // 2) : int(image_y0 + y + stamp_size // 2),
        ]

    return image_stamps, x_stars, y_stars

In [None]:
template_stamps, x_stars, y_stars = build_stack_stamps(template, source_catalog)
science_stamps, x_stars, y_stars = build_stack_stamps(science, source_catalog)

In [None]:
gauss_sig_ls = [0.75, 1.5, 3.0]
poly_deg_ls = [2, 2, 2]
kernel_size = 31
# spatial degree of freedom
kernel_spatial_deg = 2
background_spatial_deg = 0

In [None]:
for i, (x, y) in enumerate(zip(x_stars, y_stars)):
    t_stamp = template_stamps[:, :, i]
    s_stamp = science_stamps[:, :, i]
    X = compute_X(t_stamp, kernel_size, gauss_sig_ls, poly_deg_ls,
        kernel_spatial_deg=0, background_spatial_deg=0)
    sci_vec = s_stamp.flatten()
    lin = linear_model.LinearRegression()
    lin.fit(X, sci_vec)
    lin_fits.append(lin)

    

In [None]:
print("Making X array")
lin_fits = []
for i, (x, y) in enumerate(zip(x_stars, y_stars)):
    t_stamp = template_stamps[:, :, i]
    s_stamp = science_stamps[:, :, i]
    X = compute_X(t_stamp, kernel_size, gauss_sig_ls, poly_deg_ls,
        kernel_spatial_deg=0, background_spatial_deg=0)
    sci_vec = s_stamp.flatten()
    lin = linear_model.LinearRegression()
    lin.fit(X, sci_vec)
    lin_fits.append(lin)

    

In [None]:
# Coefficiencts stored in lin.coef_
# Now fit for spatial variation of these coefficients


In [None]:
dir(lin_fits[0])

In [None]:
lin_fits[0].coef_

In [None]:
lin.n_features_in_

In [None]:
sci_pred = lin.predict(X)

In [None]:
plt.imshow(template_image_array, vmin=-2, vmax=+4)

In [None]:
plt.imshow(sci_pred.reshape(nx, ny), vmin=-200, vmax=+400)

In [None]:
plt.imshow(science_image_array - sci_pred.reshape(nx, ny), vmin=-200, vmax=+400)

In [None]:
class ModifiedAlardLuptonSubtractConfig(AlardLuptonSubtractConfig):
    pass

class ModifiedAlardLuptonSubtractTask(AlardLuptonSubtractTask):
    ConfigClass = ModifiedAlardLuptonSubtractConfig
    _DefaultName = "modifiedAlardLuptonSubtract"
    
    # `run` calls `runConvolveTemplate` if the template PSF is better
    # Let's develop this example first before covering the `runConvolveScience` function.
    def runConvolveTemplate(self, template, science, selectSources):
        """Convolve the template image with a PSF-matching kernel and subtract
        from the science image.

        Parameters
        ----------
        template : `lsst.afw.image.ExposureF`
            Template exposure, warped to match the science exposure.
        science : `lsst.afw.image.ExposureF`
            Science exposure to subtract from the template.
        selectSources : `lsst.afw.table.SourceCatalog`
            Identified sources on the science exposure. This catalog is used to
            select sources in order to perform the AL PSF matching on stamp
            images around them.

        Returns
        -------
        results : `lsst.pipe.base.Struct`

            ``difference`` : `lsst.afw.image.ExposureF`
                Result of subtracting template and science.
            ``matchedTemplate`` : `lsst.afw.image.ExposureF`
                Warped and PSF-matched template exposure.
            ``backgroundModel`` : `lsst.afw.math.Function2D`
                Background model that was fit while solving for the PSF-matching kernel
            ``psfMatchingKernel`` : `lsst.afw.math.Kernel`
                Kernel used to PSF-match the template to the science image.
        """
        # Just some quick debugging to demonstrate that we are running the modifiedAlardLuptonSubtract
        self.log.info("Running Modified Subtraction Task")
        kernelSources = self.makeKernel.selectKernelSources(template, science,
                                                            candidateList=selectSources,
                                                            preconvolved=False)
        kernelResult = self.makeKernel.run(template, science, kernelSources,
                                           preconvolved=False)

        # https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L564
#         matchedTemplate = self._convolveExposure(template, kernelResult.psfMatchingKernel,
#                                                  self.convolutionControl,
#                                                  bbox=science.getBBox(),
#                                                  psf=science.psf,
#                                                  photoCalib=science.photoCalib)
        
        # Get the ndarray
        # Get the same BBox because our template might have intentional padding.
        # (because we're going to convolve it)
    
        nx = science.getWidth()
        ny = science.getHeight()
    
        template_x0 = template.getX0()
        template_y0 = template.getY0()
        template_array = template.image.array[0-template_x0:nx-template_x0,
                                              0-template_y0:ny-template_y0]

        science_x0 = science.getX0()
        science_y0 = science.getY0()
        science_array = science.image.array[0-science_x0:nx-science_x0,
                                            0-science_y0:ny-science_y0]

        sci_pred = new_algorithm(science_array, template_array)
        
        # Create an exposure object to hold the array
        matchedTemplate = template.clone()
        matchedTemplate.setPsf(science.psf)
        matchedTemplate.setPhotoCalib(science.photoCalib)
        convolvedImage = lsst.afw.image.MaskedImageF(template.getBBox())
        convolvedImage.array = sci_pred.reshape(template.image.array.shape)
        matchedTemplate.setMaskedImage(convolvedImage)

        # The actual subtraction is very simple once we have the matchedTemplate and backgroundModel
        # It's so simple that it is a free function (that we imported above)
        difference = _subtractImages(science, matchedTemplate,
                                     backgroundModel=(kernelResult.backgroundModel
                                                      if self.config.doSubtractBackground else None))

        correctedExposure = self.finalize(template, science, difference,
                                          kernelResult.psfMatchingKernel,
                                          templateMatched=True)

        return lsst.pipe.base.Struct(difference=correctedExposure,
                                     matchedTemplate=matchedTemplate,
                                     matchedScience=science,
                                     backgroundModel=kernelResult.backgroundModel,
                                     psfMatchingKernel=kernelResult.psfMatchingKernel)
    


In [None]:
subtract_config = AlardLuptonSubtractConfig()
subtract_config.makeKernel.retarget(ModifiedMakeKernelTask, ConfigClass=ModifiedMakeKernelConfig)
task = AlardLuptonSubtractTask(config=subtract_config)

In [None]:
modified_subtract_config = ModifiedAlardLuptonSubtractConfig()
modified_task = ModifiedAlardLuptonSubtractTask(config=modified_subtract_config)

In [None]:
modified_subtract_config

In [None]:
# subtraction = subtract(science, template, source_catalog, task=task)

In [None]:
modified_subtraction = subtract(science, template, source_catalog, task=modified_task)

Memory usage grows to ~15 GB and then kernel gets killed.

I don't know why we don't get log messages from the modified task.  There's more to learn about logging.

In [None]:
show_image_on_wcs(template, vmin=-1, vmax=+2)

In [None]:
show_image_on_wcs(science)

In [None]:
show_image_on_wcs(modified_subtraction.difference)

The negative regions above are saturated stars, as indicated by the masked-image view below where "green" is saturated.he negative regions above are saturated stars, as indicated by the masked-image view below where "green" is saturated.

Interpreting the above image plane correctly requires marking the saturated regions.  Stars brighter than ~17th mag will saturate in LSST images.  This means that the recording counts are not propotional to the flux, so the subtraction between two images of that field will not yield clean subtractions of the stars.  In general in one of the images the stars will be a little more saturated than the other and so have fewer proporational counts.  In this case for DC2, , it's the template image that has slightly more saturated stars (due to a higher sky brightness or a sharper PSF FWHM).T

In [None]:
display = show_image_with_mask_plane(subtraction.difference)

In [None]:
print("Mask plane bit definitions:\n", display.getMaskPlaneColor())
print("\nMask plane methods:\n")
help(display.setMaskPlaneColor)

In [None]:
detection_catalog = detect(science, modified_subtraction)

## DIA Source Catalog

We can getting a better sense of the true performance of the image subtraction by looking at the catalog of detected and measured sources, the DIA Source Cstalog.

In [None]:
dia_src = detection_catalog.diaSources.asAstropy()

In [None]:
# Specific list.  But "base_PixelFlags_flag" should be set for any of these
full_list_pixelflags_indicating_bad_source = ["base_PixelFlags_flag_saturated",
"base_PixelFlags_flag_saturatedCenter",
"base_PixelFlags_flag_suspect",
"base_PixelFlags_flag_suspectCenter",
"base_PixelFlags_flag_offimage",
"base_PixelFlags_flag_edge",
"base_PixelFlags_flag_bad",]

There seem to be objects with some of the above flags set, but where "base_PixelFlags_flag" is not set.  Investigate.  This is a bug.

Apply flags that marker things that pipeline is indicating might be real transients.

In [None]:
flags_indicating_bad_source = ["base_PixelFlags_flag_saturated",
                               "base_PixelFlags_flag_saturatedCenter",
                               "base_PixelFlags_flag_suspect",
                               "base_PixelFlags_flag_suspectCenter",
                               "base_PixelFlags_flag_offimage",
                               "base_PixelFlags_flag_edge",
                               "base_PixelFlags_flag_bad",
                               "base_SdssShape_flag",
                               "ip_diffim_DipoleFit_flag_classification",
                               "ip_diffim_DipoleFit_flag_classificationAttempted",
                               "base_GaussianFlux_flag",
                               "slot_Shape_flag",]

In [None]:
bad = [dia_src[flag] for flag in flags_indicating_bad_source]
bad = np.any(np.vstack(bad), axis=0)

In [None]:
good_dia_src = dia_src[~bad]

In [None]:
print(f"Found {len(good_dia_src)} good DIA sources out of {len(dia_src)} DIA sources.")

In [None]:
import re
shape_flags = [c for c in good_dia_src.columns if re.search("base.*Shape.*_.*flag", c)]
sdss_flags = [c for c in good_dia_src.columns if re.search("base.*Sdss.*_.*flag", c)]
slot_flags = [c for c in good_dia_src.columns if re.search("slot_.*flag", c)]

In [None]:
x = good_dia_src["slot_Shape_x"]
y = good_dia_src["slot_Shape_y"]

In [None]:
good_dia_src[["slot_PsfFlux_instFlux", "ip_diffim_forced_PsfFlux_instFlux", "ip_diffim_forced_PsfFlux_instFluxErr"]]

In [None]:
if len(good_dia_src) >= 1:
    i = 0

    show_image_on_wcs(subtraction.matchedTemplate, x=x[i], y=y[i], stamp_size=100)

    show_image_on_wcs(subtraction.matchedScience, x=x[i], y=y[i], stamp_size=100)

    show_image_on_wcs(subtraction.difference, x=x[i], y=y[i], stamp_size=100)

    geom.Extent2I(100, 100)

    center = geom.SpherePoint(good_dia_src["coord_ra"][i], good_dia_src["coord_dec"][i], geom.radians)
    extent = geom.Extent2I(100, 100)
    cutout = subtraction.difference.getCutout(center, extent)

    show_image_with_mask_plane(cutout)

    good_dia_src[i][slot_flags]

    flags = [c for c in good_dia_src.columns if re.search("_flag", c)]

    good_dia_src[flags]