# Let's try source injection in Gen 3

In [None]:
# Generic python packages
import warnings
import numpy as np
import pylab as plt

# Set a standard figure size to use
plt.rcParams['figure.figsize'] = (8.0, 8.0)

# LSST Science Pipelines (Stack) packages
import lsst.daf.butler as dafButler
import lsst.afw.display as afwDisplay
from astropy.visualization import make_lupton_rgb
from lsst.afw.image import MultibandExposure
import gc                            # imports python's garbage collector

import lsst.geom as geom
import lsst.afw.coord as afwCoord
import lsst.geom as afwGeom
afwDisplay.setDefaultBackend('matplotlib')

from desclamp import postage

from tqdm.notebook import tqdm
import desclamp as lamp


In [None]:
def createRGB(image, bgr="gri", stretch=1, Q=10, scale=None):
    """
    Create an RGB color composite image.

    Parameters
    ----------
    image : `MultibandExposure`
        `MultibandExposure` to display.
    bgr : sequence
        A 3-element sequence of filter names (i.e., keys of the exps dict)
        indicating what band to use for each channel. If `image` only has
        three filters then this parameter is ignored and the filters
        in the image are used.
    stretch: int
        The linear stretch of the image.
    Q: int
        The Asinh softening parameter.
    scale: list of 3 floats, each less than 1. (default: None)
        Re-scales the RGB channels.

    Returns
    -------
    rgb: ndarray
        RGB (integer, 8-bits per channel) colour image as an NxNx3 numpy array.
    """

    # If the image only has 3 bands, reverse the order of the bands
    #   to produce the RGB image
    if len(image) == 3:
        bgr = image.filters

    # Extract the primary image component of each Exposure with the
    #   .image property, and use .array to get a NumPy array view.

    if scale is None:
        r_im = image[bgr[2]].array  # numpy array for the r channel
        g_im = image[bgr[1]].array  # numpy array for the g channel
        b_im = image[bgr[0]].array  # numpy array for the b channel
    else:
        # manually re-scaling the images here
        r_im = image[bgr[2]].array * scale[0]
        g_im = image[bgr[1]].array * scale[1]
        b_im = image[bgr[0]].array * scale[2]

    rgb = make_lupton_rgb(image_r=r_im,
                          image_g=g_im,
                          image_b=b_im,
                          stretch=stretch, Q=Q)
    # "stretch" and "Q" are parameters to stretch and scale the pixel values

    return rgb



def cutout_coadd(butler, ra, dec, band='r', datasetType='deepCoadd',
                 skymap=None, cutoutSideLength=51, **kwargs):
    """
    Produce a cutout from a coadd at the given ra, dec position.

    Adapted from DC2 tutorial notebook by Michael Wood-Vasey.

    Parameters
    ----------
    butler: lsst.daf.persistence.Butler
        Servant providing access to a data repository
    ra: float
        Right ascension of the center of the cutout, in degrees
    dec: float
        Declination of the center of the cutout, in degrees
    band: string
        Filter of the image to load
    datasetType: string ['deepCoadd']
        Which type of coadd to load.  Doesn't support 'calexp'
    skymap: lsst.afw.skyMap.SkyMap [optional]
        Pass in to avoid the Butler read.  Useful if you have lots of them.
    cutoutSideLength: float [optional]
        Size of the cutout region in pixels.

    Returns
    -------
    MaskedImage
    """
    radec = geom.SpherePoint(ra, dec, geom.degrees)
    cutoutSize = geom.ExtentI(cutoutSideLength, cutoutSideLength)

    if skymap is None:
        skymap = butler.get("skyMap")

    # Look up the tract, patch for the RA, Dec
    tractInfo = skymap.findTract(radec)
    patchInfo = tractInfo.findPatch(radec)
    
    wcs = tractInfo.getWcs()
    xy = geom.PointI(wcs.skyToPixel(radec))
    bbox = geom.BoxI(xy - cutoutSize // 2, cutoutSize)
    patch = tractInfo.getSequentialPatchIndex(patchInfo)

    coaddId = {'tract': tractInfo.getId(), 'patch': patch, 'band': band}
    
    parameters = {'bbox': bbox}
    try:
        cutout_image = butler.get(datasetType, parameters=parameters, dataId=coaddId)
        return cutout_image
    except:
        return None
    


def remove_figure(fig):
    """Remove a figure to reduce memory footprint. """
    # get the axes and clear their images
    for ax in fig.get_axes():
        for im in ax.get_images():
            im.remove()
    fig.clf()      # clear the figure
    plt.close(fig) # close the figure
    gc.collect()   # call the garbage collector




In [None]:

def display_lenses(ra, dec, transient_pos, skymap=None):
    """ Display an RGB image of object at position ra, dec.
    """
    
    if skymap == None:
        skymap = butler.get("skymap")
    
    cutout_g = cutout_coadd(butler, 
                            ra, 
                            dec, 
                            band='g',
                            datasetType='deepCoadd', 
                            cutoutSideLength=51, 
                            skymap=skymap)
    
    if cutout_g != None:
        wcs = cutout_g.getWcs()
        bbox = cutout_g.getBBox()
        x, y = [], []
        for pos in transient_pos:
            pos = geom.SpherePoint(pos[0], pos[1], geom.degrees)
        
            xp, yp = wcs.skyToPixel(pos)
            x.append(xp-bbox.getMinX())
            y.append(yp-bbox.getMinY())
            

        cutout_r = cutout_coadd(butler, 
                                ra, 
                                dec, 
                                band='r',
                                datasetType='deepCoadd', 
                                cutoutSideLength=51, 
                                skymap=skymap)
        cutout_i = cutout_coadd(butler, 
                                ra, 
                                dec, 
                                band='i',
                                datasetType='deepCoadd', 
                                cutoutSideLength=51, 
                                skymap=skymap)
        
        # Multiband exposures need a list of images and filters
        coadds = [cutout_g, cutout_r, cutout_i]
        coadds = MultibandExposure.fromExposures(['g', 'r', 'i'], coadds)
       
        fig, ax = plt.subplots(figsize=(20, 20), nrows=1, ncols=2)
        
        # original make_lupton_rgb without any scaling
        rgb_original = createRGB(coadds.image, 
                                 bgr=['g', 'r', 'i'], 
                                 scale=None)
        ax[0].imshow(rgb_original, origin='lower')
        ax[0].set_title('original', fontsize=30)
        ax[0].plot(x, y, 'ok')
        
        # make_lupton_rgb with scaled rgb channels
        ax[1].set_title('re-scaled', fontsize=30)
        rgb_scaled = createRGB(coadds.image, 
                               bgr=['g', 'r', 'i'],
                               scale=[0.6, 0.7, 1.0])
        ax[1].imshow(rgb_scaled, origin='lower')
        ax[1].plot(x, y, 'ok', markersize=5)
        
        ax[0].set_axis_off()
        ax[1].set_axis_off()
        plt.show()
        
        # clean up memory
        remove_figure(fig)

    
    pass