In [None]:
%matplotlib inline
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno', interpolation='none', origin='lower')

# our code for source separation
import scarlet
import scarlet.display
# code to detect sources
import sep

# top open fits files
from astropy.io import fits

In [None]:
# set data path on local machine
data_dir = '/Users/pmelchior/data/eelr2019/hsc_images/'

In [None]:
# find directory for source at RA/DEC
ra, dec=37.77, -3.75

coords = []
from astropy.wcs import WCS
import glob
for d in range(445):
    file = glob.glob("{:05d}/cutout_HSC-G_*_src_*.fits".format(d))[0]
    hdulist = fits.open(file)
    wcs = WCS(hdulist[1].header)
    lon, lat = wcs.wcs_pix2world(156//2, 156//2, 0)
    coords.append((lon, lat))
coords = np.array(coords)
np.argmin(np.sum((coords - np.array([ra, dec]))**2, axis=1))

In [None]:
# open files of the source
d = 52
bands = ['G', 'R', 'I', 'Z', 'Y']
images = []
psfs = []
for b in bands:
    file = glob.glob("{:05d}/cutout_HSC-{}_*_src_*.fits".format(d, b))[0]
    hdulist = fits.open(file)
    images.append(hdulist[1].data)
    hdulist.close()
    
    file = glob.glob("{:05d}/psf_HSC-{}_*_src_*.fits".format(d, b))[0]
    hdulist = fits.open(file)
    psfs.append(hdulist[0].data)
    hdulist.close()
    
images = (np.array(images)[:,40:-40,40:-40]).copy()
psfs = np.array(psfs)

In [None]:
# Display the sources
def display_sources(sources, observation, norm=None, subset=None, combine=False, show_sed=True):
    """Display the data and model for all sources in a blend
    
    This convenience function is used to display all (or a subset) of
    the sources and (optionally) their SED's.
    """
    if subset is None:
        # Show all sources in the blend
        subset = range(len(sources))
    for m in subset:
        # Load the model for the source
        src = sources[m]
        if hasattr(src, "components"):
            components = len(src.components)
        else:
            components = 1
        # Convolve the model with the psfs in the observation
        model = observation.render(src.get_model())
                      
        if norm is None:
            # Adjust the stretch based on the maximum flux in the model for the current source
            if model.max() > 10 * bg_rms.max():
                norm = AsinhMapping(minimum=model.min(), stretch=model.max()*.05, Q=10)
            else:
                norm = LinearMapping(minimum=model.min(), maximum=model.max())

        # Select the image patch the overlaps with the source and convert it to an RGB image
        img_rgb = scarlet.display.img_to_rgb(images, norm=norm)

        # Build a model for each component in the model
        if hasattr(src, "components"):
            rgb = []
            for component in src.components:
                # Convert the model to an RGB image
                _model = observation.render(component.get_model())
                _rgb = scarlet.display.img_to_rgb(_model, norm=norm)
                rgb.append(_rgb)
        else:
            # There is only a single component
            rgb = [scarlet.display.img_to_rgb(model, norm=norm)]

        # Display the image and model
        figsize = [10,4]
        columns = 2
        # Calculate the number of columns needed and shape of the figure
        if show_sed:
            figsize[0] += 3
            columns += 1
        if not combine:
            figsize[0] += 3*(components-1)
            columns += components-1
        # Build the figure
        fig = plt.figure(figsize=figsize)
        ax = [fig.add_subplot(1,columns,n+1) for n in range(columns)]
        ax[0].imshow(img_rgb)
        ax[0].set_title("Data: Source {0}".format(m))
        for n, _rgb in enumerate(rgb):
            ax[n+1].imshow(_rgb)
            if combine:
                ax[n+1].set_title("Initial Model")
            else:
                ax[n+1].set_title("Component {0}".format(n))
        if show_sed:
            frame = src.frame
            if components > 1:
                for comp in src:
                    ax[-1].plot(comp.sed)
            else:
                ax[-1].plot(src.sed)
            ax[-1].set_xticks(range(frame.C))
            ax[-1].set_xticklabels(frame.channels)
            ax[-1].set_title("SED")
            ax[-1].set_xlabel("Band")
            ax[-1].set_ylabel("Intensity")
        # Mark the current source in the image
        if components > 1:
            y,x = src.components[0].pixel_center
        else:
            y,x = src.pixel_center
        ax[0].plot(x, y, 'wx', mew=2)
        plt.tight_layout()
        plt.show()

In [None]:
def makeCatalog(img):
    detect = img.mean(axis=0)
    bkg = sep.Background(detect)
    #catalog = sep.extract(detect, 1.5, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4)
    catalog, segmap = sep.extract(detect, 1.2, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4, segmentation_map=True)
    bg_rms = np.array([sep.Background(band).globalrms for band in img])
    return catalog, segmap, bg_rms

catalog, segmap, bg_rms = makeCatalog(images)

from astropy.visualization.lupton_rgb import AsinhMapping, LinearMapping

stretch = 1
Q = 5
norm = AsinhMapping(minimum=0, stretch=stretch, Q=Q)
img_rgb = scarlet.display.img_to_rgb(images, norm=norm)
plt.figure(figsize=(8,8))
plt.imshow(img_rgb)
# Mark all of the sources from the detection cataog
for k, src in enumerate(catalog):
    plt.text(src["x"], src["y"], str(k), color="w", size=24)

In [None]:
# display psfs
pnorm = AsinhMapping(minimum=psfs.min(), stretch=1e-2, Q=1)
prgb = scarlet.display.img_to_rgb(psfs, norm=pnorm)
plt.figure()
plt.imshow(prgb)
plt.show()

# Get the target PSF to partially deconvolve the image psfs
model_psf, _, _ = scarlet.psf.fit_target_psf(psfs, scarlet.psf.moffat)
model_psf = model_psf[None,:,:]

model_rgb = scarlet.display.img_to_rgb(model_psf, norm=pnorm)
# Display the target PSF using the same scaling as the observed PSFs
plt.figure()
plt.imshow(model_rgb)
plt.show()

In [None]:
class EELRHostSource(scarlet.ExtendedSource):
    def __init__(self, frame, sky_coord, observation, bg_rms, sed):
        self._s = sed.copy()
        super().__init__(frame, sky_coord, observation, bg_rms, symmetric=False)
        
    def update(self):
        self.update_sed()
        return super().update()
    
    def update_sed(self):
        x = self._sed
        s = self._s
        x[:] = np.maximum(np.dot(x,s) / np.dot(s,s) * s, 0)
        
        return self

class FreeSource(scarlet.RandomSource):
    def __init__(self, frame, center, sed=None, R=None):
        super().__init__(frame)
        
        if sed is not None:
            self._sed[:] = sed
            self._s = sed.copy()
        else:
            self._s = None
        
        self.pixel_center = center
        
        if R is not None:
            c, ny, nx = frame.shape
            dy = np.arange(ny) - self.pixel_center[0]
            dx = np.arange(nx) - self.pixel_center[1]
            dist2 = dy[:,None]**2 + dx[None,:]**2
            self.mask = dist2 > R**2
        else:
            self.mask = None

        self.update()
        
        
    def update(self):
        self.update_sed()
        self.update_morph()
        return self
    
    def update_sed(self):
        if self._s is not None:
            x = self._sed
            s = self._s
            x[:] = np.maximum(np.dot(x,s) / np.dot(s,s) * s, 0)       
        return self
    
    def update_morph(self):
        if self.mask is not None:
            x = self._morph
            x[self.mask] = 0
            x[:,:] = np.maximum(x, 0)
        return self
    
def mag2amplitude(mags):
    dlambda = np.array([.14, .14, .16, .13, .11])
    fnu_Jy = 10**((48.6+mags)/-2.5)
    photons_1Jy = 1.51e7 / dlambda
    return photons_1Jy * fnu_Jy

In [None]:
bands = ['g', 'r', 'i', 'z', 'y']
frame = scarlet.Frame(images.shape, channels=bands, psfs=model_psf)
weights = np.ones_like(images) / (bg_rms[:,None,None]**2)
observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

# from Ai-Lei spectral decomposition
mags = np.array([26.3652646789765, 23.73486734207137, 21.898052328748093, 28, 28])#21.05645664757267])
eelr_sed = mag2amplitude(mags)
mags = np.array([21.981511938811067, 20.824537185641706, 20.148344573001445, 19.814933194560677, 19.870960800812682])
host_sed = mag2amplitude(mags)

# refine source list
sources = []
for k, src in enumerate(catalog):
    if k == 1:
        sources.append(EELRHostSource(frame, (src['y'],src['x']), observation, bg_rms, host_sed))
        sources.append(FreeSource(frame, (src['y'],src['x']), sed=eelr_sed, R=32))
    elif k != 5:
         sources.append(scarlet.ExtendedSource(frame, (src['y'],src['x']), observation, bg_rms))

In [None]:
blend = scarlet.Blend(sources, observation)
%time blend.fit(200, e_rel=1e-5)
print("scarlet ran for {} iterations to MSE = {}".format(len(blend.mse), blend.mse[-1]))
plt.semilogy(blend.mse)

In [None]:
# Load the model and calculate the residual
model = blend.get_model()
model_ = observation.render(model)
residual = images-model_
# Create RGB images
model_rgb = scarlet.display.img_to_rgb(model_, norm=norm)
residual_rgb = scarlet.display.img_to_rgb(residual)

# Show the data, model, and residual
fig = plt.figure(figsize=(15,5))
ax = [fig.add_subplot(1,3,n+1) for n in range(3)]
ax[0].imshow(img_rgb)
ax[0].set_title("Data")
ax[1].imshow(model_rgb)
ax[1].set_title("Model")
ax[2].imshow(residual_rgb)
ax[2].set_title("Residual")

for k,src in enumerate(blend.sources):
    try:
        y,x = src.pixel_center
    except AttributeError:
        y,x = src[0].pixel_center
    ax[0].text(x, y, k, color="w")
    ax[1].text(x, y, k, color="w")
plt.show()

In [None]:
display_sources(sources[1:5], observation, norm=norm)