<img align="left" src = https://lsstdesc.org/assets/img/logo.png width=250 style="padding: 10px"> 
<b>Testing New DIA Kernel Bases</b> <br>
Contact author: Michael Wood-Vasey <br>
Last verified to run: 2023-06-05 <br>
LSST Science Pipelines version: Weekly 2023_21 <br>
Container Size: large <br>
Targeted learning level: intermediate <br>

Sets up a interactive stepping through of the tasks to do image subtraction to allow for easier modifications to StarSelector, Kernel bases, and Detection.

Note: This Notebook is written below the PipelineTask level.  Rather is uses individual Tasks directly and reads/writes output products to the butler.  This is pedagogically useful to understand how that works, and pratically helpful in working with the evolving `source_injection` package.  However, this structure is not scalable to larger runs (100+ images).  Such large-scale runs should be done as part of an integrated Task that can be connected and run through the large-scale cluster jobs submission.

1. [x] Find set of images that overlap
2. [x] Pick one as template, one as science
3. [ ] Also deepCoadd.  Be able to use either.
4. [ ] Run subtractions through Tasks
5. [ ] Run detection and measurement through Task

In [None]:
import os

from astropy.coordinates import SkyCoord
from astropy.table import Table
from astropy.wcs import WCS
import astropy.units as u

import gc

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import lsst.afw.display as afwDisplay
from lsst.afw.image import MultibandExposure
from lsst.afw.math import Warper, WarperConfig
from lsst.daf.butler import Butler, DeferredDatasetHandle
import lsst.geom as geom
from lsst.ip.diffim import AlardLuptonSubtractConfig, AlardLuptonSubtractTask
from lsst.ip.diffim import GetTemplateConfig, GetTemplateTask
from lsst.ip.diffim import DetectAndMeasureConfig, DetectAndMeasureTask
from lsst.pipe.tasks.makeWarp import MakeWarpConfig, MakeWarpTask
import lsst.sphgeom

In [None]:
afwDisplay.setDefaultBackend('matplotlib')
plt.style.use('tableau-colorblind10')
%matplotlib inline

## Some helper utilities

In [None]:
def show_image_on_wcs(calexp, figsize=(8, 8), ax=None, x=None, y=None,
                           pixel_extent=None, stamp_size=None,
                           vmin=-200, vmax=400,
                           marker="o", color="red", size=20):
    """
    Show an image with an RA, Dec grid overlaid.  Optionally add markers.
    
    Notes
    -----
    Specifying both pixel_extent and size is undefined.
    """
    if ax is None:
        fig = plt.figure(figsize=figsize)
        plt.subplot(projection=WCS(calexp.getWcs().getFitsMetadata()))
        ax = plt.gca()

    if stamp_size is not None and x is not None and y is not None:    
        half_stamp = stamp_size / 2
        # If x and y are of different types, then user should clarify what they wanted
        if np.isscalar(x):
            first_x = x
            first_y = y
        else:
            first_x = x[0]
            first_y = y[0]
            
        pixel_extent = (int(first_x - half_stamp), int(first_x + half_stamp),
                        int(first_y - half_stamp), int(first_y + half_stamp))
    if pixel_extent is None:
        pixel_extent = (0, calexp.width, 0, calexp.height)

    print(pixel_extent)
    # Image array is y, x.  
    # So we select from the image array in [Y_Begin:Y_End, X_Begin:X_End]
    # But then `extent` is (X_Begin, X_End, Y_Begin, Y_End)
    im = ax.imshow(calexp.image.array[pixel_extent[2]:pixel_extent[3],
                                      pixel_extent[0]:pixel_extent[1]],
                   cmap="gray", vmin=vmin, vmax=vmax,
                   extent=pixel_extent, origin="lower")
    ax.grid(color="white", ls="solid")
    ax.set_xlabel("Right Ascension")
    ax.set_ylabel("Declination")
    if x is not None and y is not None:
        ax.scatter(x, y, s=size, marker=marker, edgecolor=color, facecolor="none")
        ax.set_xlim(pixel_extent[0:2])
        ax.set_ylim(pixel_extent[2:4])

In [None]:
def show_image_with_mask_plane(calexp, figsize=(8, 8)):
    fig, ax = plt.subplots(figsize=figsize)
    display = afwDisplay.Display(frame=fig)
    display.scale('asinh', 'zscale')
    display.setMaskTransparency(80)
    display.setMaskPlaneColor('DETECTED', 'blue')
    display.mtv(calexp)
    plt.show()
    
    return display

In [None]:
def htm_from_ra_dec_level(ra, dec, level):
    pixelization = lsst.sphgeom.HtmPixelization(level)
    htm_id = pixelization.index(
        lsst.sphgeom.UnitVector3d(
            lsst.sphgeom.LonLat.fromDegrees(ra, dec)
        )
    )
    return htm_id

In [None]:
def get_dataset_refs_from_htm_list(dataset_type, htm_ids, level, aggregate="intersection"):
    hi = htm_ids[0]

    # dataset_refs is an iterator, but each query is only a few hundred results,
    #   so convert to a list for future convenience
    htm_kwargs = {}
    htm_kwargs[f"htm{level}"] = hi
    dataset_refs = list(butler.registry.queryDatasets(dataset_type, dataId={"band": band}, **htm_kwargs))
    dataset_refs = set(dataset_refs)
    
    for hi in htm_ids[1:]:
        htm_kwargs = {}
        htm_kwargs[f"htm{level}"] = hi
        dr = list(butler.registry.queryDatasets(dataset_type, dataId={"band": band}, **htm_kwargs))
        if aggregate == "intersection":
            dataset_refs = dataset_refs.intersection(set(dr))
        elif aggregate == "union":
            dataset_refs = dataset_refs.union(set(dr))
        else:
            print("Aggregation method '{aggregate}' not supported.")
            return
        
    return list(dataset_refs)

## Defining Dataset based on Site

We can run this on either DC2 or HSC by choosing appropriate RA, Dec

Currently (2023-05-26) DC2 is only available at the IDF and HSC is only available at the USDF,
so we split by site.

In [None]:
SITE = "IDF"

survey_site = {"USDF": "HSC", "IDF": "DC2"}
repo_site = {"USDF": "/repo/main", "IDF": "dp02"}
collection_site = {"USDF": "HSC/runs/RC2/w_2023_15/DM-38691", "IDF": "2.2i/runs/DP0.2"}

ra_dec_survey = {"HSC": (150, 2.5), "DC2": (55, -30)}

In [None]:
collection = collection_site[SITE]
repo_config = repo_site[SITE]

user = os.getenv("USER")
output_collection = f"u/{user}/test_dia"

In [None]:
butler = Butler(repo_config, run=output_collection, collections=[output_collection, collection])

In [None]:
# Do a spatial query for calexps using HTM levels following example in 04b_Intermediate_Butler_Queries.ipynb
ra, dec = ra_dec_survey[survey_site[SITE]]
band = "i"

In [None]:
level = 20  # the resolution of the HTM grid
htm_id = htm_from_ra_dec_level(ra, dec, level)

In [None]:
parent_level = htm_id // 10
htm_ids = [parent_level * 10 + i for i in [0, 1, 2, 3]]

In [None]:
htm_ids

In [None]:
dataset_refs = get_dataset_refs_from_htm_list("calexp", htm_ids, level)

# Sort by visitId to get a loose time order
ids_visit = [dr.dataId["visit"] for dr in dataset_refs]
dataset_refs = [dataset_refs[idx] for idx in np.argsort(ids_visit)]

print(dataset_refs)

In [None]:
print(f"Found {len(list(dataset_refs))} calexps")

In [None]:
visit_table = butler.get("visitTable")

We should find 17 calexps for DC2.  (RA, Dec) = (55, -30)  
[2023-06-02] Wait, now we find 140 under w_2023_21?  Maybe I just have run this on DC2 for a while and changed the HTM level or something.

We should find 44 calexps for HSC COSMOS.  (RA, Dec) = (150, +2.5)

# Run subtraction between calexps 1 and 0.

In [None]:
single_image_template = butler.get("calexp", dataset_refs[0].dataId)

In [None]:
science_dr = dataset_refs[1]
science = butler.get("calexp", science_dr.dataId)

### Get a template from the deepCoadd
Here we get a template from the (tract, patch) deepCoadd reassembled to be continous across calexp.

In [None]:
sky_map = butler.get("skyMap")

In [None]:
get_template_task_config = GetTemplateConfig()
get_template_task = GetTemplateTask(config=get_template_task_config)

In [None]:
coadd_exposure_refs = get_dataset_refs_from_htm_list("deepCoadd", htm_ids, level, aggregate="union")
print(len(coadd_exposure_refs))

In [None]:
coadd_exposure_refs

In [None]:
help(sky_map.findClosestTractPatchList)

In [None]:
tract_info = sky_map.findClosestTractPatchList([geom.SpherePoint(ra, dec, geom.degrees)])

In [None]:
t, p = tract_info[0]

In [None]:
p[0].sequential_index

In [None]:
# coadd_exposure_deferred_dataset_handles = [DeferredDatasetHandle(butler, dr, parameters=None) for dr in coadd_exposure_refs]

In [None]:
coadd_exposure_deferred_dataset_handles = [butler.getDeferred(dr) for dr in coadd_exposure_refs]

Check that we're close to original RA, Dec

In [None]:
bbox = science.getBBox()
wcs = science.getWcs()

In [None]:
science.getWcs().pixelToSky(bbox.getCenter())

In [None]:
inputs = {"coaddExposures" : coadd_exposure_deferred_dataset_handles,
          "bbox": science.getBBox(),
          "skyMap": sky_map,
          "wcs": science.getWcs(),
          "visitInfo": science.visitInfo,
         }

In [None]:
data_ids

In [None]:
results = get_template_task.getOverlappingExposures(inputs)
coadd_exposures = results.coaddExposures
data_ids = results.dataIds

deep_coadd_template = get_template_task.run(coadd_exposures, inputs["bbox"], inputs["wcs"], data_ids)

In [None]:
expanded_data_ids = data_ids
# patch is 36
# Square of patches around is
patches = [28, 29, 30,
           35, 36, 37,
           42, 43, 44]
expanded_data_ids = [{"band": "i", "skymap": "DC2", "tract": 4638, "patch": p} for p in patches]
coadd_exposures = [butler.get("deepCoadd", dataId=did) for did in expanded_data_ids]

expanded_deep_coadd_template = get_template_task.run(coadd_exposures, inputs["bbox"], inputs["wcs"], expanded_data_ids)


In [None]:
show_image_on_wcs(expanded_deep_coadd_template.template, vmin=-1, vmax=+2)

## Subtraction

In [None]:
def warp(science, template):
    warper_config = WarperConfig()
    warper = Warper.fromConfig(warper_config)

    science_wcs = science.getWcs()
    science_bbox = science.getBBox()
    
    warped_template = warper.warpExposure(science_wcs, template, destBBox=science_bbox)
    # Add PSF.  I think doing this directly without warping is wrong.
    # At least the x,y mapping should be updated
    warped_template.setPsf(template.getPsf())
    
    return warped_template


def subtract(science, template, source_catalog):
    # https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L196
    config = AlardLuptonSubtractConfig()
    task = AlardLuptonSubtractTask(config=config)
    # Star Selection is done here:
    #   https://github.com/lsst/ip_diffim/blob/main/python/lsst/ip/diffim/subtractImages.py#L603

    warped_template = warp(science, template)
    
    subtraction = task.run(warped_template, science, source_catalog)
    
    return subtraction


def detect(science, subtraction):
    # Run detection on subtraction
    detect_and_measure_config = DetectAndMeasureConfig()
    detect_and_measure_task = DetectAndMeasureTask(config=detect_and_measure_config)

    detect_and_measure = detect_and_measure_task.run(science,
                                                     subtraction.matchedTemplate,
                                                     subtraction.difference)

    return detect_and_measure

In [None]:
# template = single_image_template
template = expanded_deep_coadd_template.template

In [None]:
"""
Subtract template image from image referred to by data_id and run detection.

Butler needs to be writeable to store output of subtraction and detection.
"""
science = butler.get("calexp", science_dr.dataId)
source_catalog = butler.get("src", dataId=science_dr.dataId)

subtraction = subtract(science, template, source_catalog)

detection_catalog = detect(science, subtraction)

In [None]:
show_image_on_wcs(template, vmin=-1, vmax=+2)

In [None]:
show_image_on_wcs(science)

In [None]:
show_image_on_wcs(subtraction.difference)

The negative regions above are saturated stars, as indicated by the masked-image view below where "green" is saturated.he negative regions above are saturated stars, as indicated by the masked-image view below where "green" is saturated.

Interpreting the above image plane correctly requires marking the saturated regions.  Stars brighter than ~17th mag will saturate in LSST images.  This means that the recording counts are not propotional to the flux, so the subtraction between two images of that field will not yield clean subtractions of the stars.  In general in one of the images the stars will be a little more saturated than the other and so have fewer proporational counts.  In this case for DC2, , it's the template image that has slightly more saturated stars (due to a higher sky brightness or a sharper PSF FWHM).T

In [None]:
show_image_with_mask_plane(subtraction.difference)

In [None]:
print("Mask plane bit definitions:\n", display.getMaskPlaneColor())
print("\nMask plane methods:\n")
help(display.setMaskPlaneColor)

## DIA Source Catalog

We can getting a better sense of the true performance of the image subtraction by looking at the catalog of detected and measured sources, the DIA Source Cstalog.

In [None]:
dia_src = detection_catalog.diaSources.asAstropy()

In [None]:
# Specific list.  But "base_PixelFlags_flag" should be set for any of these
full_list_pixelflags_indicating_bad_source = ["base_PixelFlags_flag_saturated",
"base_PixelFlags_flag_saturatedCenter",
"base_PixelFlags_flag_suspect",
"base_PixelFlags_flag_suspectCenter",
"base_PixelFlags_flag_offimage",
"base_PixelFlags_flag_edge",
"base_PixelFlags_flag_bad",]

There seem to be objects with some of the above flags set, but where "base_PixelFlags_flag" is not set.  Investigate.  This is a bug.

Apply flags that marker things that pipeline is indicating might be real transients.

In [None]:
flags_indicating_bad_source = ["base_PixelFlags_flag_saturated",
                               "base_PixelFlags_flag_saturatedCenter",
                               "base_PixelFlags_flag_suspect",
                               "base_PixelFlags_flag_suspectCenter",
                               "base_PixelFlags_flag_offimage",
                               "base_PixelFlags_flag_edge",
                               "base_PixelFlags_flag_bad",
                               "base_SdssShape_flag",
                               "ip_diffim_DipoleFit_flag_classification",
                               "ip_diffim_DipoleFit_flag_classificationAttempted",
                               "base_GaussianFlux_flag",
                               "slot_Shape_flag",]

In [None]:
bad = [dia_src[flag] for flag in flags_indicating_bad_source]
bad = np.any(np.vstack(bad), axis=0)

In [None]:
good_dia_src = dia_src[~bad]

In [None]:
print(f"Found {len(good_dia_src)} good DIA sources.")

In [None]:
import re
shape_flags = [c for c in good_dia_src.columns if re.search("base.*Shape.*_.*flag", c)]
sdss_flags = [c for c in good_dia_src.columns if re.search("base.*Sdss.*_.*flag", c)]
slot_flags = [c for c in good_dia_src.columns if re.search("slot_.*flag", c)]

In [None]:
x = good_dia_src["slot_Shape_x"]
y = good_dia_src["slot_Shape_y"]

In [None]:
print(3810/171751)
print(5876/549468)

In [None]:
good_dia_src[["slot_PsfFlux_instFlux", "ip_diffim_forced_PsfFlux_instFlux", "ip_diffim_forced_PsfFlux_instFluxErr"]]

In [None]:
i = 1

In [None]:
show_image_on_wcs(subtraction.matchedTemplate, x=x[i], y=y[i], stamp_size=100)

In [None]:
show_image_on_wcs(subtraction.matchedScience, x=x[i], y=y[i], stamp_size=100)

In [None]:
show_image_on_wcs(subtraction.difference, x=x[i], y=y[i], stamp_size=100)

In [None]:
geom.Extent2I(100, 100)

In [None]:
center = geom.SpherePoint(good_dia_src["coord_ra"][i], good_dia_src["coord_dec"][i], geom.radians)
extent = geom.Extent2I(100, 100)
cutout = subtraction.difference.getCutout(center, extent)

In [None]:
show_image_with_mask_plane(cutout)

In [None]:
good_dia_src[i][slot_flags]

In [None]:
flags = [c for c in good_dia_src.columns if re.search("_flag", c)]

In [None]:
good_dia_src[flags]

In [None]:
dia_src.columns

In [None]:
len(dia_src.columns)