# Quick Start Guide
This tutorial shows how to use the resampling capabilities in scarlet

In [None]:
# Import Packages and setup
import logging

import numpy as np

import scarlet
import scarlet.display
import astropy.io.fits as fits
from astropy.wcs import WCS
import sep

%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno')
matplotlib.rc('image', interpolation='none')

## Load and display the sample data

### Load the sample data and source catalog
Loading the source catalog requires astropy. If you don't have astropy installed you'll have to do the source detection yourself. For example, using `sep`.

In [None]:

def makeCatalog(img):
    if np.size(img.shape) == 3:
        detect = img.mean(axis=0) # simple average for detection
    else:
        detect = img.byteswap().newbyteorder()
        
    bkg = sep.Background(detect)
    catalog = sep.extract(detect, 4, err=bkg.globalrms)
    if np.size(img.shape) == 3:
        bg_rms = np.array([sep.Background(band).globalrms for band in img])
    else:
        bg_rms =  sep.Background(detect).globalrms
    return catalog, bg_rms




In [None]:
obs_hdu = fits.open('../data/test_resampling/Cut_HSC1.fits')
data_hsc = obs_hdu[0].data.byteswap().newbyteorder()
wcs_hsc = WCS(obs_hdu[0].header)
psf_hsc = fits.open('../data/test_resampling/PSF_HSC.fits')[0].data



x,y = np.where(psf_hsc[0]*0==0)
Np1, Np2 = psf_hsc[0].shape

hst_hdu = fits.open('../data/test_resampling/Cut_HST1.fits')
data_hst = hst_hdu[0].data
wcs_hst = WCS(hst_hdu[0].header)
psf_hst = fits.open('../data/test_resampling/PSF_HST.fits')[0].data[np.newaxis,:,:]
wcs_hst.wcs.crval += 2.4750118475607095e-05*np.array([-np.cos(0.4136047623181346 ), -np.sin(0.4136047623181346 )])

noise = np.concatenate((np.concatenate((data_hst[0,:], data_hst[:,0])),
                                            np.concatenate((data_hst[-1,:], data_hst[:,-1]))))

n1,n2 = np.shape(data_hst)
data_hst = data_hst.reshape(1, n1, n2).byteswap().newbyteorder()*np.max(data_hsc)/np.max(data_hst)

r, N1, N2 = data_hsc.shape

catalog_hst, bg_rms_hst = makeCatalog(data_hst)
catalog_hsc, bg_rms_hsc = makeCatalog(data_hsc)

#plotting setup
norm = scarlet.display.Asinh(img=data_hsc, Q=20)
# Map i,r,g -> RGB
filter_indices = [2,0,1]

xo,yo = catalog_hst['x'], catalog_hst['y']

ra, dec = wcs_hst.wcs_pix2world(yo,xo,0)
Yo,Xo, l = wcs_hsc.wcs_world2pix(ra, dec, 0, 0)

img_rgb = scarlet.display.img_to_rgb(data_hsc, filter_indices=filter_indices, norm=norm)

plt.figure(figsize=(15,5))
plt.subplot(121)
plt.imshow(img_rgb)
plt.plot(Xo,Yo, 'o')
plt.subplot(122)
plt.imshow(data_hst[0]+0.5, cmap = 'gist_stern')
plt.plot(xo,yo, 'o')
plt.show()

## Initialize the sources
Each source is a list of fundamental `scarlet.Component` instances and must be based on `scarlet.Source` or a derived class, in this case `ExtendedSource`, which enforces that the source is monotonic and symmetric.

In [None]:
obs_hst = scarlet.Observation(data_hst, wcs = wcs_hst, psfs = None, structure = np.array([0,0,0,1]))
obs_hsc = scarlet.ObservationToResample(data_hsc,  wcs = wcs_hsc, psfs = psf_hsc, structure = np.array([1,1,1,0]))

scene = scarlet.Scene((r+1, n1,n2), wcs = wcs_hst, psfs = psf_hst)

obs = [obs_hsc, obs_hst]

bg_rms = np.concatenate((bg_rms_hsc, bg_rms_hst))

# In[5]:

sources = [scarlet.CombinedExtendedSource((ra[i], dec[i]), scene, obs, bg_rms, symmetric = 0, monotonic = 1) for i in range(ra.size)]

rgb = np.zeros((4,n1,n2))

for src in sources:
    for i in range(4):
        rgb[i,:,:] += src.morph*src.sed[i]
    

    
im_rgb = scarlet.display.img_to_rgb(rgb, filter_indices=filter_indices, norm=norm)

plt.figure(figsize=(15,5))
plt.subplot(121)
plt.imshow(im_rgb)
plt.subplot(122)
plt.imshow(img_rgb)
plt.show()

## Create and fit the model
The `scarlet.Blend` class represent the sources as a tree and has the machinery to fit all of the sources to the given images. In this example the code is set to run for a maximum of 200 iterations, but will end early if the likelihood and all of the constraints converge.

In [None]:
blend = scarlet.Blend(scene, sources, obs)

In [None]:
blend.fit(200, e_rel = 1e-3)
print("scarlet ran for {0} iterations".format(blend.it))

### View the full model
First we load the model for the entire blend and its residual. Then we display the model using the same $sinh^{-1}$ stretch as the full image and a linear stretch for the residual.

In [None]:
# Load the model and calculate the residual
im = np.zeros((N1,N2))
im[obs_hsc._over_lr[0].astype(int), obs_hsc._over_lr[1]] = np.dot(data_hst.flatten(),obs_hsc.resconv_op[1,:,:])

model = blend.get_model()

modelhr_rgb = scarlet.display.img_to_rgb(model, filter_indices=filter_indices, norm=norm)

plt.figure(figsize=(15,5))
plt.subplot(131)
plt.imshow(modelhr_rgb)
plt.subplot(132)
plt.imshow(data_hst[0], cmap = 'gist_stern')
plt.subplot(133)
plt.imshow(data_hst[0]-model[-1], cmap = 'gist_stern')
plt.show()

img = obs_hsc.get_model_image(model)

model_rgb = scarlet.display.img_to_rgb(img, filter_indices=filter_indices, norm=norm)
residual = data_hsc-img
residual_rgb = scarlet.display.img_to_rgb(residual, filter_indices=filter_indices, norm=norm)


plt.figure(figsize=(15,5))
plt.subplot(131)
plt.title('model')
plt.imshow(model_rgb)
plt.subplot(132)
plt.title('residuals')
plt.imshow(residual_rgb)
plt.subplot(133)
plt.title('image')
plt.imshow(img_rgb)
plt.show()

### View the source models
It can also be useful to view the model for each source. For each source we extract the portion of the image contained in the sources bounding box, the true simulated source flux, and the model of the source, scaled so that all of the images have roughly the same pixel scale.

In [None]:
has_truth = False
axes = 2

for k,src in enumerate(blend.components):
    
    # Get the model for a single source
    component = blend.get_model(src.morph)

    print(component.shape)
    # Get the patch from the original image
    
    _img = src.morph
    print(src.sed, src.morph.max())
    
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.imshow(_img[:,:], interpolate = None, cmap = 'gist_stern')
    plt.plot(xo,yo, 'o')
    plt.title('Model HR')
    
    
    im = np.zeros((N1,N2))
    im[obs_hsc._over_lr[0].astype(int), obs_hsc._over_lr[1]] = np.dot(_img.flatten(),obs_hsc.resconv_op[0,:,:])

    
    plt.subplot(132)
    plt.imshow(im, interpolate = None, cmap = 'gist_stern')
    plt.plot(Xo,Yo, 'o')
    plt.title('Model LR')
    plt.subplot(133)
    plt.plot(src.sed)
    plt.title('Model LR')
    plt.show()
    
    _img = src.get_model()

    #_img_rgb = scarlet.display.img_to_rgb(_img)

    _img_lr = obs_hsc.get_model_image(_img)
    img_lr_rgb = scarlet.display.img_to_rgb(_img_lr, filter_indices=filter_indices, norm=norm)
    res = data_hsc-_img_lr
    res_rgb = scarlet.display.img_to_rgb(res, filter_indices=filter_indices, norm=norm)
        # Set the figure size
    ratio = src.shape[2]/src.shape[1]
    fig_height = 3*src.shape[1]/20
    fig_width = max(3*fig_height*ratio,2)
    fig = plt.figure(figsize=(fig_width, fig_height))
    # Generate and show the figure
    
    plt.subplot(131)
    plt.imshow(img_lr_rgb)
    plt.plot(Xo,Yo, 'o', markersize = 35)
    plt.subplot(132)
    plt.imshow(res_rgb)
    plt.plot(Xo,Yo, 'o', markersize = 35)
    plt.subplot(133)
    plt.imshow(img_rgb)
    plt.plot(Xo,Yo, 'o', markersize = 35)
    plt.show()

    

    img_hr = obs_hst.get_model(_img)
    res = data_hst-img_hr[-1]
    
    print(res.shape)
    plt.figure(figsize=(15,5))
    plt.subplot(131)
    plt.imshow(img_hr[-1], cmap = 'gist_stern')
    plt.plot(xo,yo, 'o', markersize = 5)
    plt.subplot(132)
    plt.imshow(res[0], cmap = 'gist_stern')
    plt.plot(xo,yo, 'o', markersize = 5)
    plt.subplot(133)
    plt.imshow(data_hst[0], cmap = 'gist_stern')
    plt.plot(xo,yo, 'o', markersize = 5)
    plt.show()
