<img align="left" src = https://lsstdesc.org/assets/img/logo.png width=250 style="padding: 10px"> 
<b>Testing New DIA Kernel Bases</b> <br>
Contact author: Michael Wood-Vasey <br>
Last verified to run: 2023-06-01 <br>
LSST Science Pipelines version: Weekly 2023_21 <br>
Container Size: large <br>
Targeted learning level: intermediate <br>

Sets up a interactive stepping through of the tasks to do image subtraction to allow for easier modifications to StarSelector, Kernel bases, and Detection.

Note: This Notebook is written below the PipelineTask level.  Rather is uses individual Tasks directly and reads/writes output products to the butler.  This is pedagogically useful to understand how that works, and pratically helpful in working with the evolving `source_injection` package.  However, this structure is not scalable to larger runs (100+ images).  Such large-scale runs should be done as part of an integrated Task that can be connected and run through the large-scale cluster jobs submission.

1. [ ] Find set of images that overlap
2. [ ] Pick one as template, one as science
3. [ ] Also deepCoadd.  Be able to use either.
4. [ ] Run subtractions through Tasks
5. [ ] Run detection and measurement through Task

In [None]:
import os

from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy.wcs import WCS
import astropy.units as u

import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import lsst.afw.display as afwDisplay
from lsst.afw.image import MultibandExposure
from lsst.afw.math import Warper, WarperConfig
import lsst.geom as geom
from lsst.daf.butler import Butler, DimensionUniverse, DatasetType, CollectionType
from lsst.daf.butler.registry import MissingCollectionError
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
import lsst.sphgeom

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

We can run this on either DC2 or HSC by choosing appropriate RA, Dec

Currently (2023-05-26) DC2 is only available at the IDF and HSC is only available at the USDF,
so we split by site.

In [None]:
SITE = "IDF"

survey_site = {"USDF": "HSC", "IDF": "DC2"}
repo_site = {"USDF": "/repo/main", "IDF": "dp02"}
collection_site = {"USDF": "HSC/runs/RC2/w_2023_15/DM-38691", "IDF": "2.2i/runs/DP0.2"}

ra_dec_survey = {"HSC": (150, 2.5), "DC2": (55, -30)}

In [None]:
collection = collection_site[SITE]
repo_config = repo_site[SITE]

user = os.getenv("USER")
output_collection = f"u/{user}/test_dia"

In [None]:
butler = Butler(repo_config, run=output_collection, collections=[output_collection, collection])

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = ra_dec_survey[survey_site[SITE]]
band = "i"

In [None]:
def htm_from_ra_dec_level(ra, dec, level):
    pixelization = lsst.sphgeom.HtmPixelization(level)
    htm_id = pixelization.index(
        lsst.sphgeom.UnitVector3d(
            lsst.sphgeom.LonLat.fromDegrees(ra, dec)
        )
    )
    return htm_id

In [None]:
level = 20  # the resolution of the HTM grid
htm_id = htm_from_ra_dec_level(ra, dec, level)

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
hi = htm_ids[0]

# dataset_refs is an iterator, but each query is only a few hundred results,
#   so convert to a list for future convenience
htm_kwargs = {}
htm_kwargs[f"htm{level}"] = hi
dataset_refs = list(butler.registry.queryDatasets("calexp", dataId={"band": band}, **htm_kwargs))
dataset_refs = set(dataset_refs)

In [None]:
ra, dec

In [None]:
for hi in htm_ids[1:]:
    htm_kwargs = {}
    htm_kwargs[f"htm{level}"] = hi
    dr = list(butler.registry.queryDatasets("calexp", dataId={"band": band}, **htm_kwargs))
    dataset_refs = dataset_refs.intersection(set(dr))

In [None]:
dataset_refs = list(dataset_refs)
# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable") #, dataset_refs[0].dataId)

We should find 17 calexps for DC2.  (RA, Dec) = (55, -30)

We should find 44 calexps for HSC COSMOS.  (RA, Dec) = (150, +2.5)

# Run subtraction between calexps 1 and calexp 0.

In [None]:
single_image_template = butler.get("calexp", dataset_refs[0].dataId)

In [None]:
def warp(science, template):
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    science_wcs = science.getWcs()
    science_bbox = science.getBBox()
    
    warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
    # Add PSF.  I think doing this directly without warping is wrong.
    # At least the x,y mapping should be updated
    warped_template.setPsf(template.getPsf())
    
    return warped_template


def subtract(science, template, source_catalog):
    # https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L196
    config = AlardLuptonSubtractConfig()
    task = AlardLuptonSubtractTask(config=config)
    # Star Selection is done here:
    #   https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L603

    warped_template = warp(science, template)
    
    subtraction = task.run(warped_template, science, source_catalog)
    
    return subtraction


def detect(science, subtraction):
    # Run detection on subtraction
    detect_and_measure_config = DetectAndMeasureConfig()
    detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

    detect_and_measure = detect_and_measure_task.run(science,
                                                     subtraction.matchedTemplate,
                                                     subtraction.difference)

    return detect_and_measure

In [None]:
science_dr = dataset_refs[1]
template = single_image_template

In [None]:
"""
Subtract template image from image referred to by data_id and run detection.

Butler needs to be writeable to store output of subtraction and detection.
"""
science = butler.get("calexp", science_dr.dataId)
source_catalog = butler.get("src", dataId=science_dr.dataId)

subtraction = subtract(science, template, source_catalog)

detection_catalog = detect(science, subtraction)

## Some helper utilities for plotting

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), ax=None, x=None, y=None,
                      pixel_extent=None, stamp_size=None,
                      marker="o", color="red", size=20):
    """
    Show an image with an RA, Dec grid overlaid.  Optionally add markers.
    
    Notes
    -----
    Specifying both pixel_extent and size is undefined.
    """
    if ax is None:
        fig = plt.figure(figsize=figsize)
        plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
        ax = plt.gca()

    if stamp_size is not None and x is not None and y is not None:    
        half_stamp = stamp_size / 2
        # If x and y are of different types, then user should clarify what they wanted
        if np.isscalar(x):
            first_x = x
            first_y = y
        else:
            first_x = x[0]
            first_y = y[0]
            
        pixel_extent = (int(first_x - half_stamp), int(first_x + half_stamp),
                        int(first_y - half_stamp), int(first_y + half_stamp))
    if pixel_extent is None:
        pixel_extent = (int(calexp.getBBox().beginX), int(calexp.getBBox().endX),
                        int(calexp.getBBox().beginY), int(calexp.getBBox().endY))
    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)

    im = ax.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3], pixel_extent[0]:pixel_extent[1]],
                   cmap="gray", vmin=-200.0, vmax=400,
                   extent=pixel_extent, origin="lower")
    ax.grid(color="white", ls="solid")
    ax.set_xlabel("Right Ascension")
    ax.set_ylabel("Declination")
    if x is not None and y is not None:
        ax.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")

In [None]:
show_image_on_wcs(science)

In [None]:
show_image_on_wcs(subtraction.difference)

In [None]:
dir(subtraction)