# DESC Lamp demo 
## Catalog extraction
This notebook shows how to pull postage stamps to use for strong gravitaional lens searches.

# DC2: Generate Postage Stamps (Cutouts) for objects in the Object Catalog

Owner: **Rémy Joseph** ([@herjy](https://github.com/herjy/DESC-Lamp))
<br>Last Verified to Run: **2021-11-22** (by @herjy)

This notebook is partly based on the `dm_butler_postage_stamps_for_object_catalogs` notebook by Yao-Yuan Mao and the previous notebooks `dm_butler_postage_stamps` notebook by Michael Wood-Vasey and the Stack Club `ButlerTutorial` by Daniel Perrefort.

Light curve extraction follows the `dia_sn_vs_truth` notebook by Michael Wood-Vasey.

Here we simply copy what was in Yao-Yuan's notebook and trim it to the usecase of selecting galaxies from cosmoDC2 catalogs with magnitude cuts. The notebook will in fine evolve to incorporate the functions we design to streamline this preselection process.

### Logistics
This is intended to be runnable at NERSC through the https://jupyter.nersc.gov interface from a local git clone of https://github.com/herjy/DESC-Lamp in your NERSC directory.  But you can also run it wherever, with appropriate adjustment of the 'repo' location to point to a place where you have a Butler repo will all of the images. 

This notebook uses the `desc-stack-weekly-latest` kernel. Instructions for setting up the proper DESC python kernel can be found here: https://confluence.slac.stanford.edu/x/o5DVE

## Set up

First we will load the needed modules and DC2 DR6 data sets: object catalogs (with `GCRCatalogs`) and DRP products (with `desc_dc2_dm_data`).

In [10]:
import desclamp
# A few common packages
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
%matplotlib inline

# We will use astropy's WCS and ZScaleInterval for plotting
from astropy.wcs import WCS
from astropy.visualization import ZScaleInterval
# Also to convert sky coordinates
from astropy.coordinates import SkyCoord
import astropy.units as u

# We will use several stack functions
import lsst.geom
import lsst.afw.display as afwDisplay
import lsst.afw.display.rgb as rgb

# And also DESC packages to get the data path
import GCRCatalogs
from GCRCatalogs import GCRQuery
import desc_dc2_dm_data

ModuleNotFoundError: No module named 'desclamp'

We will be using the DC2 Run 2.2i DR6 v2 data. The catalogs and there validation are described here: https://arxiv.org/pdf/2110.03769.pdf


In [3]:
dc2_data_version = "2.2i_dr6"
GCRCatalogs.get_available_catalogs(names_only=True, name_contains=dc2_data_version)
cat = GCRCatalogs.load_catalog("dc2_object_run"+dc2_data_version)
diaSrc = GCRCatalogs.load_catalog('dc2_dia_source_run1.2p_test')
diaObject = GCRCatalogs.load_catalog('dc2_dia_object_run1.2p_test')
truth_lc = GCRCatalogs.load_catalog('dc2_truth_run2.2i_sn_truth_summary')


All the catalog names (including old catalogs) can be found by uncommenting the line below.
Information obout these catalogs can be found in the [DC2 data product overview](https://confluence.slac.stanford.edu/display/LSSTDESC/DC2+Data+Product+Overview). The Rubin project's [Data product definition document](https://docushare.lsstcorp.org/docushare/dsweb/Get/LSE-163/LSE-163_DataProductsDefinitionDocumentDPDD.pdf)(DPDD) provides further insight into the content of these catalogs. 

These are [GCR catalogs](https://github.com/LSSTDESC/gcr-catalogs) that may use slightly different definitions and namings than the one used in Rubin's DPDD. The details for the entries of these catalogs can be found in the [GCR Catalogs SChema description](https://github.com/LSSTDESC/gcr-catalogs/blob/master/GCRCatalogs/SCHEMA.md#schema-for-dc2-object-catalogs)

In [4]:
#print('\n'.join(sorted(GCRCatalogs.get_available_catalogs(include_default_only=False))))

Summoning Alfred:

In [5]:
butler = desc_dc2_dm_data.get_butler(dc2_data_version)

## Select a sample of galaxies based on selection criteria

Here we will use arbitrary (actually from [Rojas et al. 2021](https://arxiv.org/pdf/2109.00014.pdf)) selection criteria to extract a few moock patches.

To learn what columns are in the object catalogs, refer to [this schema table](https://github.com/LSSTDESC/gcr-catalogs/blob/master/GCRCatalogs/SCHEMA.md#schema-for-dc2-object-catalogs). And sometimes it'd be helpful to look at the [source code](https://github.com/LSSTDESC/gcr-catalogs/blob/master/GCRCatalogs/dc2_object.py#L341).

In [6]:
bright_galaxy_query = GCRQuery(
    "clean",
    "extendedness == 1",
    "mag_g_cModel- mag_i_cModel < 5",
    "mag_g_cModel- mag_i_cModel > 1.8",
    "mag_g_cModel- mag_r_cModel < 3",
    "mag_g_cModel- mag_r_cModel > 0.6",
    "mag_r_cModel < 22.5",
    "mag_r_cModel > 18",
    "mag_g_cModel > 20",
    "mag_i_cModel > 18.2",
    "snr_g_cModel > 10",
    "snr_r_cModel > 10",
    "snr_i_cModel > 10",
)

columns_to_get = ["objectId", "ra", "dec", "tract", "patch"]
assert cat.has_quantities(columns_to_get)
# Here we use native_filters to limit to tract == 4639 to save some load time
#4430, 4431, 4432, 4638, 4639, 4640

tracts = [4639]
filters = f"(tract == {tracts[0]})"
for t in tracts[1:]:
    filters +=  f" | (tract == {t})"

objects = cat.get_quantities(columns_to_get, filters=bright_galaxy_query, native_filters=filters)


  return -2.5 * np.log10(flux) + AB_mag_zp_wrt_nanoJansky
  return -2.5 * np.log10(flux) + AB_mag_zp_wrt_nanoJansky


In [7]:
print(len(objects['tract']))

9617


In [8]:
# make it a pandas data frame for the ease of manipulation
objects = pd.DataFrame(objects)

## Extrtacting postage stamps

Now we need to extract postage stamps of the coadded images. For that we need for each object there coordinates, but also there tact and patch number.

In [None]:
skymap = butler.get('deepCoadd_skyMap')

Now we build a function that will get tratc and patch inforrmation for a given (Ra, Dec) position:

In [None]:
fig = plt.figure(figsize=(36, 36), dpi=100)
gs = plt.GridSpec(6, 6, fig)

#Limiting to 16 objects
n0 = 72
objects_sel = objects.loc[:16]

cutout_size = 100
cutout_extent = lsst.geom.ExtentI(cutout_size, cutout_size)
id=0
for (_, object_this), gs_this in zip(objects_sel.iterrows(), gs):
    radec = lsst.geom.SpherePoint(object_this["ra"], object_this["dec"], lsst.geom.degrees)
    center = skymap.findTract(radec).getWcs().skyToPixel(radec)
    bbox = lsst.geom.BoxI(lsst.geom.Point2I((center.x - cutout_size*0.5, center.y - cutout_size*0.5)), cutout_extent)

    cutouts = [butler.get("deepCoadd_sub", bbox=bbox, tract=object_this["tract"], patch=object_this["patch"], filter=band) for band in "irg"]
    wcs_fits_meta = cutouts[0].getWcs().getFitsMetadata()
    image_rgb = rgb.makeRGB(*cutouts)
    del cutouts  # let gc save some memory for us

    ax = plt.subplot(gs_this, projection=WCS(wcs_fits_meta), label=str(object_this["objectId"]))
    ax.imshow(image_rgb, origin='lower')
    del image_rgb  # let gc save some memory for us
    
    for c in ax.coords:
        c.set_ticklabel(exclude_overlapping=True, size=10)
        c.set_axislabel('', size=0)


## Extracting light curves

To search for lensed transients, it is useful to have access to light curves. These are extracted from Difference Imaging Analysis (DIA). We used the scripts provided in the `dia_sn_vs_truth` notebook by Michael Wood-Vasey to produce the following light curves.

First, we need to match the position of objects to sources in the DIA catalogs:

In [None]:
# Match on RA, Dec

for i in range(len(objects_sel)):
    sn_position = SkyCoord(objects.loc[i]["ra"], objects.loc[i]["dec"], unit='deg')

    diaObjects_cat = diaObject.get_quantities(['ra', 'dec', 'diaObjectId'])
    diaObject_positions = SkyCoord(diaObjects_cat['ra'], diaObjects_cat['dec'], unit='deg')

    idx, sep2d, _ = sn_position.match_to_catalog_sky(diaObject_positions)
    
    print(f'Index: {idx} is {sep2d.to(u.arcsec)[0]:0.6f} away')

In [None]:
diaObjects_cat = pd.DataFrame(diaObjects_cat)
objects_sel = diaObjects_cat.loc[:36]

cutout_size = 100
cutout_extent = lsst.geom.ExtentI(cutout_size, cutout_size)
id=0
for (_, object_this), gs_this in zip(objects_sel.iterrows(), gs):
    radec = lsst.geom.SpherePoint(object_this["ra"], object_this["dec"], lsst.geom.degrees)
    center = skymap.findTract(radec).getWcs().skyToPixel(radec)
    bbox = lsst.geom.BoxI(lsst.geom.Point2I((center.x - cutout_size*0.5, center.y - cutout_size*0.5)), cutout_extent)

    cutouts = [butler.get("deepCoadd_sub", bbox=bbox, tract=4839, patch="0,0", filter=band) for band in "irg"]
    wcs_fits_meta = cutouts[0].getWcs().getFitsMetadata()
    image_rgb = rgb.makeRGB(*cutouts)
    del cutouts  # let gc save some memory for us

    ax = plt.subplot(gs_this, projection=WCS(wcs_fits_meta), label=str(object_this["objectId"]))
    ax.imshow(image_rgb, origin='lower')
    del image_rgb  # let gc save some memory for us
    
    for c in ax.coords:
        c.set_ticklabel(exclude_overlapping=True, size=10)
        c.set_axislabel('', size=0)


In [None]:
def plot_lightcurve(df, plot='mag', flux_col_names=None,
                    title=None, marker='o', linestyle='none',
                    colors=None, label_prefix='',
                    **kwargs):
    """Plot a lightcurve from a DataFrame.
    """
    # At lexigraphical order, if not wavelength order.
    # Assume fixed list of filters.
    filter_order = ['u', 'g', 'r', 'i', 'z', 'y']

    if colors is None:
        colors = {'u': 'violet', 'g': 'indigo', 'r': 'blue', 'i': 'green', 'z': 'orange', 'y': 'red'}
    
    if flux_col_names is not None:
        flux_col, flux_err_col = flux_col_names
    else:
        if plot == 'flux':
            flux_col = 'psFlux'
            flux_err_col = 'psFluxErr'
        else:
            flux_col = 'mag'
            flux_err_col = 'mag_err'
        
    for filt in filter_order:
        this_filter = df.query(f'filter == "{filt}"')
        if this_filter.empty:
            continue
        # This if sequence is a little silly.
        plot_kwargs = {'linestyle': linestyle, 'marker': marker, 'color': colors[filt],
                       'label': f'{label_prefix} {filt}'}
        plot_kwargs.update(kwargs)

        if flux_err_col in this_filter.columns:
            plt.errorbar(this_filter['mjd'], this_filter[flux_col], this_filter[flux_err_col],
                         **plot_kwargs)
                        
        else:
            if marker is None:
                plt.plot(this_filter['mjd'], this_filter[flux_col], **plot_kwargs)

            else:
                plot_kwargs.pop('linestyle')
                plt.scatter(this_filter['mjd'], this_filter[flux_col], **plot_kwargs)



    plt.xlabel('MJD')

    if plot == 'flux':
        plt.ylabel('psFlux [nJy]')
    else:
        # Ensure that y-axis decreases as one goes up
        # Because plot_lightcurve could be called several times on the same axis,
        # simply inverting is not correct.  We have to reverse a sorted list.
        plt.ylim(sorted(plt.ylim(), reverse=True))
        plt.ylabel('mag [AB]')

    if title is not None:
        plt.title(title)
    plt.legend()