# Deblending with *Scarlet*
<br>Owner(s): **Fred Moolekamp** ([@fred3m](https://github.com/LSSTScienceCollaborations/StackClub/issues/new?body=@fred3m))
<br>Last Verified to Run: **2018-08-17**
<br>Verified Stack Release: **w_2018_32**

The purpose of this tutorial is to familiarize you with the basics of using *scarlet* to model blended scenes, and how tweaking various objects and parameters affects the resulting model. A tutorial that is more specific to using scarlet in the context of the LSST DM Science Pipelines is also available.

### Learning Objectives:

After working through this tutorial you should be able to: 
1. Configure and run _scarlet_ on a test list of objects;
2. Understand its various model assumptions and applied constraints.

Before attempting this tutorial it will be useful to read the [introduction](http://scarlet.readthedocs.io/en/latest/user_docs.html) to the *scarlet* User Guide, and many of the exercises below may require referencing the *scarlet* [docs](http://scarlet.readthedocs.io/en/latest/index.html).

### Logistics
This notebook is intended to be runnable on `lsst-lspdev.ncsa.illinois.edu` from a local git clone of https://github.com/LSSTScienceCollaborations/StackClub.

## Set-up

In [None]:
# Import the necessary libraries
import os

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
# don't interpolate the pixels
matplotlib.rc('image', interpolation='none')

import numpy as np
import scarlet
import scarlet.display

# Display functions

Below are several usful functions used throughout this tutorial to visualize the data and models.

In [None]:
# Display the sources
def display_sources(sources, image, norm=None, subset=None, combine=False, show_sed=True, filter_indices=None):
    """Display the data and model for all sources in a blend

    This convenience function is used to display all (or a subset) of
    the sources and (optionally) their SED's.
    """
    if subset is None:
        # Show all sources in the blend
        subset = range(len(sources))
    if filter_indices is None:
        filter_indices = [3,2,1]
    for m in subset:
        # Load the model for the source
        src = sources[m]
        model = [comp.get_model() for comp in src]

        # Select the image patch the overlaps with the source and convert it to an RGB image
        img_rgb = scarlet.display.img_to_rgb(image[src[0].bb], filter_indices=filter_indices, norm=norm)

        # Build a model for each component in the model
        rgb = []
        for _model in model:
            # Convert the model to an RGB image
            _rgb = scarlet.display.img_to_rgb(_model, filter_indices=filter_indices, norm=norm)
            rgb.append(_rgb)

        # Display the image and model
        figsize = [6,3]
        columns = 2
        # Calculate the number of columns needed and shape of the figure
        if show_sed:
            figsize[0] += 3
            columns += 1
        if not combine:
            figsize[0] += 3*(len(model)-1)
            columns += len(model)-1
        # Build the figure
        fig = plt.figure(figsize=figsize)
        ax = [fig.add_subplot(1,columns,n+1) for n in range(columns)]
        ax[0].imshow(img_rgb)
        ax[0].set_title("Data: Source {0}".format(m))
        for n, _rgb in enumerate(rgb):
            ax[n+1].imshow(_rgb)
            if combine:
                ax[n+1].set_title("Initial Model")
            else:
                ax[n+1].set_title("Component {0}".format(n))
        if show_sed:
            for comp in src:
                ax[-1].plot(comp.sed)
            ax[-1].set_title("SED")
            ax[-1].set_xlabel("Band")
            ax[-1].set_ylabel("Intensity")
        # Mark the current source in the image
        y,x = src[0].center
        ax[0].plot(x-src[0].bb[2].start, y-src[0].bb[1].start, 'x', color="#5af916", mew=2)
        plt.tight_layout()
        plt.show()

def display_model_residual(images, blend, peaks, norm, filter_indices=None):
    """Display the data, model, and residual for a given result
    """
    if filter_indices is None:
        filter_indices = [3,2,1]
    model = blend.get_model()
    residual = images-model
    print("Data range: {0:.3f} to {1:.3f}\nresidual range: {2:.3f} to {3:.3f}\nrms: {4:.3f}".format(
        np.min(images),
        np.max(images),
        np.min(residual),
        np.max(residual),
        np.sqrt(np.std(residual)**2+np.mean(residual)**2)
    ))
    # Create RGB images
    img_rgb = scarlet.display.img_to_rgb(images, filter_indices=filter_indices, norm=norm)
    model_rgb = scarlet.display.img_to_rgb(model, filter_indices=filter_indices, norm=norm)
    residual_norm = scarlet.display.Linear(img=residual)
    residual_rgb = scarlet.display.img_to_rgb(residual, filter_indices=filter_indices, norm=residual_norm)

    # Show the data, model, and residual
    fig = plt.figure(figsize=(15,5))
    ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
    ax[0].imshow(img_rgb)
    ax[0].set_title("Data")
    ax[1].imshow(model_rgb)
    ax[1].set_title("Model")
    ax[2].imshow(residual_rgb)
    ax[2].set_title("Residual")
    for k,component in enumerate(blend.components):
        y,x = component.center
        #px, py = peaks[k]
        ax[0].plot(x, y, "gx")
        #ax[0].plot(px, py, "rx")
        ax[1].text(x, y, k, color="r")
    plt.show()

def show_psfs(psfs, filters, norm=None):
    rows = int(np.ceil(len(psfs)/3))
    columns = min(len(psfs), 3)
    figsize = (45/columns, rows*5)
    fig = plt.figure(figsize=figsize)
    ax = [fig.add_subplot(rows, columns, n+1) for n in range(len(psfs))]
    for n, psf in enumerate(psfs):
        im = ax[n].imshow(psf, norm=norm)
        ax[n].set_title("{0}-band PSF".format(filters[n]))
        plt.colorbar(im, ax=ax[n])
    plt.show()

def display_diff_kernels(psf_blend, diff_kernels):
    model = psf_blend.get_model()
    for b, component in enumerate(psf_blend.components):
        fig = plt.figure(figsize=(15,2.5))
        ax = [fig.add_subplot(1,4,n+1) for n in range(4)]
        # Display the psf
        ax[0].set_title("psf")
        _img = ax[0].imshow(psfs[b])
        fig.colorbar(_img, ax=ax[0])
        # Display the model
        ax[1].set_title("modeled psf")
        _model = np.ma.array(model[b], mask=model[b]==0)
        _img = ax[1].imshow(_model)
        fig.colorbar(_img, ax=ax[1])
        # Display the difference kernel
        ax[2].set_title("difference kernel")
        _img = ax[2].imshow(np.ma.array(diff_kernels[b], mask=diff_kernels[b]==0))
        fig.colorbar(_img, ax=ax[2])
        # Display the residual
        ax[3].set_title("residual")
        residual = psfs[b]-model[b]
        vabs = np.max(np.abs(residual))
        _img = ax[3].imshow(residual, vmin=-vabs, vmax=vabs, cmap='seismic')
        fig.colorbar(_img, ax=ax[3])
        plt.show()

# Load and Display the data

The `file_path` points to a directory with 147 HSC blends from the COSMOS field detected by the LSST pipeline. Changing `idx` below will select a different blend.

In [None]:
# Load the sample images
idx = 53
file_path = "/project/shared/data/testdata_deblender/real_data/hsc_cosmos/not_matched"
files = os.listdir(file_path)
data = np.load(os.path.join(file_path, files[idx]))
image = data["images"]
wmap = data["weights"]
peaks = data["peaks"]
psfs = data["psfs"]
filters = ["G", "R", "I", "Z", "Y"]
# Only a rough estimate of the background is needed
# to initialize and resize the sources
bg_rms = np.std(image, axis=(1,2))
print("Background RMS: {0}".format(bg_rms))

# Use Asinh scaling for the images
norm = scarlet.display.Asinh(img=image, Q=10)
# Map i,r,g -> RGB
filter_indices = [3,2,1]
# Convert the image to an RGB image
img_rgb = scarlet.display.img_to_rgb(image, filter_indices=filter_indices, norm=norm)
plt.imshow(img_rgb)
plt.title("Image: {0}".format(idx))
for src in peaks:
    plt.plot(src[0], src[1], "rx", mew=2)
plt.show()

# Initializing Sources

Astrophysical objects are modeled in scarlet as a collection of components, where each component has a single SED that is constant over it's morphology (band independent intensity). So a single source might have multiple components, like a bulge and disk, or a single component.

The different classes that inherit from `Source` mainly differ in how they are initialized, and otherwise behave similarly during the optimization routine. This section illustrates the differences between different source initialization classes.

The simplest source is a single component intialized with only a single pixel (at the center of the object) turned on.

### <span style="color:red"> *WARNING* </span>
Scarlet accepts source positions using the numpy/C++ convention of (y,x), which is different than the astropy and LSST stack convention of (x,y).

In [None]:
sources = [scarlet.PointSource((peak[1], peak[0]), image) for peak in peaks]

# Display the initial guess for each source
display_sources(sources, image, norm=norm)

## Exercise:

* Experiment with the above code by using `ExtendedSource`, which initializes each object as a single component with maximum flux at the peak that falls off monotonically and has 180 degree symmetry; and using `MultiComponentSource`, which models a source as two components (a bulge and a disk) that are each symmetric and montonically decreasing from the peak.

# Deblending a scene

The `Blend` class contains the list of sources, the image, and any other configuration parameters necessary to fit the data, including routines to fit the center positions and resize the bounding box containing the sources (if necessary). Once a blend has been initialized with a list of sources, the image and background RMS values must be set (the background RMS is used to determine when to truncate the bounding box around a source).

In [None]:
blend = scarlet.Blend(sources)
blend.set_data(image, bg_rms=bg_rms)

Next we can fit a model, given a maximum number of iterations and the relative error required for convergence.

In [None]:
blend.fit(100, 1e-2)
print("Deblending completed in {0} iterations".format(blend.it))
display_model_residual(image, blend, peaks, norm)
display_sources(sources, image, norm)

## Exercises

* Experiment by running the above code using different source models (for example `ExtendedSource`) to see how initializtion affects the belnding results.

* The code above initialized the sources at their exact centers. Try offsetting the initial positions by `0.5` pixels in `x` and/or `y` and passing a `shift_center=0` argument when initializing the source. This prevents the source from updating its position, so notice how that affects the resulting model.

# Constraints

The above models used the default constraints: perfect symmetry and a weighted monotonicity that decreases from the peak. So each source is defined (internally during initialization) with the constraints

In [None]:
import scarlet.constraint as sc
constraints = (sc.SimpleConstraint(),
               sc.DirectMonotonicityConstraint(use_nearest=False),
               sc.DirectSymmetryConstraint())

where `SimpleConstraint` forces the SED and morphology to be non-negative, the SED to be normalized to unity, and the peak to have some (minimal) flux at the center.

## Exercises

* Go back to the source initialization cell and pass a custom set of constraints. For example, pass `DirectSymmetryConstraint` a number between 0 and 1 to set the level of symmetry required, or eliminate the symmetry constraint altogether and see how that affects deblending.

* Set `use_nearest=True` in the `DirectMonotonicityConstraint`.

* Add `L0Constraint` or `L1Constraint` to the list of constraints and observe the results.

# Configuration

There are additional configuration paramters that can be used to initialize a source, as described in http://scarlet.readthedocs.io/en/latest/config.html#Configuration-(scarlet.config).

## Exercises

* Initialize the sources with a custom configuration where `refine_skip=2`, which updates the positions and box sizes on every other step, and see how the results compare

# PSF Deconvolution

When analyzing real images the PSF will be different in each band unless they have been PSF matched. In general deblending should not be performed on PSF matched coadds, as matching will increase the blending in bands with better seeing. Instead scarlet can be used to build a deconvolved model which is a more sparse (and less blended) representation of the data, and convolve the model in each band to compare to the input data.

To initialize a source with a PSF, pass the PSF as an input to the new source:

In [None]:
scarlet.ExtendedSource(peaks[0], image, bg_rms, psf=psfs)

# Partial PSF Deconvolution

As discussed in the tutorial http://scarlet.readthedocs.io/en/latest/psf_matching.html, the data is noisy and the fully deconvolved scene is undersampled, making the application of the constraints and and full convolution kernel unstable and prone to biases. Instead we can create a target PSF and model the sources in the partially deconvolved target PSF scene.

First we need to specify the target PSF. *scarlet* includes a `fit_target_psf` function to fit the PSF in each band to either a `moffat`, `gaussian`, or `double_gaussian` function. 

In [None]:
import scarlet.psf_match

show_psfs(psfs, filters)

# Find the target PSF
target_psf = scarlet.psf_match.fit_target_psf(psfs, scarlet.psf_match.moffat)
plt.imshow(target_psf)
plt.title("target PSF")
plt.colorbar()
plt.show()

Once we have the target PSF we can find the difference kernel in each band using *scarlet*. The `build_diff_kernels` function basically treats the PSF image as a blend, where the PSF in each band is a monochromatic source, and fits the difference kernels using the minimum number of pixels necessary.

In [None]:
diff_kernels, psf_blend = scarlet.psf_match.build_diff_kernels(psfs, target_psf)
display_diff_kernels(psf_blend, diff_kernels)

## Exercises

* Try building the difference kernels while varying the parameters in `build_diff_kernels`, for example using larger and smaller values for `l0_thresh`.

* Go back up to source initialization and use `psf=psfs` to fully deconvolve the scene and fit the blend

* Try the same thing but set `psf=diff_kernels` for each source to partially deconvolve the scene.