In [None]:
%matplotlib inline
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno', interpolation='none', origin='lower')

# our code for source separation
import scarlet
import scarlet.display
# code to detect sources
import sep

# to open fits files
from astropy.io import fits
from astropy.table import Table

In [None]:
# set data path on local machine
data_dir = '/home/czhao/Synced/Documents/PrincetonStuff/2019-20/spring/IW/hsc_images'

In [None]:
# find directory for source at RA/DEC
# source_list = Table.read('parent_sample/source_list.fits')
source_list = Table.read('good_ir_merged.fits')

def get_source_i(ra, dec):
    d = (source_list['RA'] - ra)**2 + (source_list['DEC'] - dec)**2
    i = np.argmin(d)
    if d[i] * 3600 > 1:
        print("Closest match more than 1 arcsec away: Proceed with caution!")
    return i

# most prominent candidate
# ra, dec=37.77, -3.75
# source_i = get_source_i(ra, dec)

source_i = 1

source_id = int(source_list[source_i]["OBJID"])

In [None]:
# open files of the source
import glob
bands = ['G', 'R', 'I', 'Z', 'Y']
images = []
psfs = []
for b in bands:
    file = glob.glob("{}/{:05d}/cutout_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
    hdulist = fits.open(file)
    images.append(hdulist[1].data)
    hdulist.close()
    
    file = glob.glob("{}/{:05d}/psf_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
    hdulist = fits.open(file)
    psfs.append(hdulist[0].data)
    hdulist.close()

images = (np.array(images)[:,40:-40,40:-40]).copy()

# pad PSFs to the same shape
psf_height = max(psf.shape[0] for psf in psfs)
psf_width = max(psf.shape[1] for psf in psfs)
psfs = np.stack([np.pad(psf, (((psf_height - psf.shape[0]) // 2,), ((psf_width - psf.shape[1]) // 2,)))
                 for psf in psfs])
print(psfs.shape)

In [None]:
# get magnitude of line emitters
def mag2amplitude(mags):
    dlambda = np.array([.14, .14, .16, .13, .11])
    fnu_Jy = 10**((48.6+mags)/-2.5)
    photons_1Jy = 1.51e7 / dlambda
    return photons_1Jy * fnu_Jy

# from Ai-Lei spectral decomposition tables
# this will require a read-out method to work with other sources
# mags = np.array([26.3652646789765, 23.73486734207137, 21.898052328748093, 28, 21.05645664757267])

mags = np.array([source_list[source_i][f"speclineMag_{band.lower()}"] for band in bands])
print(mags)

# last element (Y band) appears untrustworthy
mask = [0,0,0,0,1]

In [None]:
# detect all sources in the image
def makeCatalog(img):
    detect = img.mean(axis=0)
    bkg = sep.Background(detect)
    #catalog = sep.extract(detect, 1.5, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4)
    catalog, segmap = sep.extract(detect, 1.2, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4, segmentation_map=True)
    bg_rms = np.array([sep.Background(band).globalrms for band in img])
    return catalog, segmap, bg_rms

catalog, segmap, bg_rms = makeCatalog(images)

# make and image and label all sources
# first define color stretch and convert 5 bands to RGB channels
stretch = 1
Q = 5
norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
img_rgb = scarlet.display.img_to_rgb(images, norm=norm)

plt.figure(figsize=(6,6))
plt.imshow(img_rgb)
# Mark all of the sources from the detection cataog
for k, src in enumerate(catalog):
    plt.text(src["x"], src["y"], str(k), color="w")

In [None]:
# display psfs
psfs_min = min(psf.min() for psf in psfs)
pnorm = scarlet.display.AsinhMapping(minimum=psfs_min, stretch=1e-2, Q=1)
prgb = scarlet.display.img_to_rgb(psfs, norm=pnorm)
plt.figure()
plt.imshow(prgb)
plt.show()

In [None]:
# define Frame and Observation:
from functools import partial
model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))
bands = ['g', 'r', 'i', 'z', 'y']

frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)

# no weight maps, use flat background noise variance instead
weights = np.ones_like(images) / (bg_rms[:,None,None]**2)
observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

In [None]:
class SEDConstraint(scarlet.Constraint):
    def __init__(self, sed):
        self._sed = sed

    def __call__(self, X, step):
        S = self._sed
        # closest X that is in the direction of S
        # allows for flux rescaling: only direction is constrained
        if not np.ma.is_masked(S):
            X[:] = np.maximum(np.dot(X, S) / np.dot(S, S) * S, 0)
        else:
            X_ = X[~S.mask]
            S_ = S[~S.mask]
            X[:][~S.mask] =  np.maximum(np.dot(X_, S_) / np.dot(S_, S_) * S_, 0)
        return X
    
class RadialMaskConstraint(scarlet.Constraint):
    def __init__(self, shape, pixel_center, R):
        c, ny, nx = shape
        dy = np.arange(ny) - pixel_center[0]
        dx = np.arange(nx) - pixel_center[1]
        dist2 = dy[:,None]**2 + dx[None,:]**2
        self.mask = dist2 > R**2
        
    def __call__(self, X, step):
        X[self.mask] = 0
        X[:,:] = np.maximum(X, 0)
        return X
        
    
class EELRSource(scarlet.RandomSource):
    """Source to describe EELR
    
    It has a free-form morphology, possible constrained to be within R of the center
    but its SED can be determined up to a constant.
    """
    def __init__(self, frame, sky_coord, sed=None, R=None):
        super().__init__(frame)
        
        center = np.array(frame.get_pixel(sky_coord), dtype="float")
        self.pixel_center = tuple(np.round(center).astype("int"))
        
        if sed is not None:
            self._parameters[0].constraint = SEDConstraint(sed)
            self._parameters[0][:] = self._parameters[0].constraint(self._parameters[0], 0)
        if R is not None:
            self._parameters[1].constraint = RadialMaskConstraint(frame.shape, self.pixel_center, R)
            self._parameters[1][:,:] = self._parameters[1].constraint(self._parameters[1], 0)

In [None]:
def get_center_source(catalog):
    eelr_host_ind = -1
    closest_distsq = images.shape[1]**2 + images.shape[2]**2
    center = (images.shape[1] / 2, images.shape[2] / 2)
    for k, src in enumerate(catalog):
        distsq = (src['y'] - center[0])**2 + (src['x'] - center[1])**2
        if distsq < closest_distsq:
            eelr_host_ind = k
            closest_distsq = distsq
    return eelr_host_ind

In [None]:
eelr_host_ind = get_center_source(catalog)
print(eelr_host_ind)
print(catalog[eelr_host_ind]['x'], catalog[eelr_host_ind]['y'])

sources = []
for k, src in enumerate(catalog):
    if k == eelr_host_ind:
        sources.append(scarlet.MultiComponentSource(frame, (src['y'], src['x']), observation, thresh=0.2, shifting=True))
        
        # set mag for EELR source
        mags = np.ma.masked_array(mags, mask=mask)
        eelr_sed = mag2amplitude(mags)
        sources.append(EELRSource(frame, (src['y'],src['x']), sed=eelr_sed, R=None))
    else:
        sources.append(scarlet.ExtendedSource(frame, (src['y'],src['x']), observation, shifting=True, thresh=0.5))
            
blend = scarlet.Blend(sources, observation)

In [None]:
# run the fitter
%time blend.fit(200, e_rel=1e-5)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
scarlet.display.show_scene(sources, observation=observation, norm=norm, show_observed=True, show_rendered=True, show_residual=True)

In [None]:
scarlet.display.show_sources(sources, observation, show_observed=True, show_rendered=True, norm=norm)

# Resampling

In [None]:
path_base = f"{source_id:05d}_"

print(path_base)

In [None]:
def do_one_eelr_sample(catalog, host_ind, mags, mask):
#     frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)
#     observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)
    sources = []

    for k, src in enumerate(catalog):
        if k == host_ind:
            sources.append(scarlet.MultiComponentSource(frame, (src['y'], src['x']), observation, thresh=0.2, shifting=True))

            # set mag for EELR source
            mags = np.ma.masked_array(mags, mask=mask)
            eelr_sed = mag2amplitude(mags)
            sources.append(EELRSource(frame, (src['y'],src['x']), sed=eelr_sed, R=None))
        else:
             sources.append(scarlet.ExtendedSource(frame, (src['y'],src['x']), observation, shifting=True, thresh=0.5))  
    
    blend = scarlet.Blend(sources, observation)
    blend.fit(200, e_rel=1e-5)
    print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
    
    # EELR source and logL
    return sources[host_ind+1], -blend.loss[-1]

In [None]:
eelr_samples = []
for _ in range(5):
    eelr_samples.append(do_one_eelr_sample(catalog, eelr_host_ind, mags, mask)[0])

In [None]:
scarlet.display.show_sources(eelr_samples, observation, show_observed=True, show_rendered=True, norm=norm)

In [None]:
def compute_snr(X):
    sig = (X**2).mean(axis=0)
    noise = X.var(axis=0)
    snr = sig / noise
    snr[sig==0] = 0
    snr = np.nan_to_num(snr, nan=np.nanmax(snr))
    return snr

In [None]:
# Create heatmap of average (since multiple bands) SNR of each EELR pixel

eelr_models = np.stack([sample.get_model() for sample in eelr_samples])

snr_per_band = compute_snr(eelr_models)
snr = snr_per_band.mean(axis=0)

plt.imshow(snr)
plt.colorbar()
plt.savefig(f"{path_base}model_snr.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of average morphology

eelr_morphs = np.stack([sample.morph for sample in eelr_samples])
mean_morph = eelr_morphs.mean(axis=0)

plt.imshow(mean_morph)
plt.colorbar()
plt.savefig(f"{path_base}morph_avg.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of morphology SNR

morph_snr = compute_snr(eelr_morphs)

plt.imshow(morph_snr)
plt.colorbar()
plt.savefig(f"{path_base}morph_snr.png", dpi=200, bbox_inches="tight")
plt.show()

# Unmasked Y band

In [None]:
unmasked_y_samples = []
unmasked_y_logL = []
samples_per_y = 10

for y in range(22, 27+1):
    print(f"y = {y}:")
    this_mags = mags
    this_mags[4] = y
    for i in range(samples_per_y):
        sample, logL = do_one_eelr_sample(catalog, eelr_host_ind, this_mags, [0, 0, 0, 0, 0])
        unmasked_y_samples.append(sample)
        unmasked_y_logL.append(logL)

In [None]:
subsample = [unmasked_y_samples[i] for i in range(0, len(unmasked_y_samples), samples_per_y)]
scarlet.display.show_sources(subsample, observation, show_observed=True, show_rendered=True, norm=norm)

## Bayesian Inference

In [None]:
def weighted_mean_and_var(samples, weights):
    samples = np.stack(samples)
    mean = np.average(samples, axis=0, weights=weights)
    var = (weights * np.moveaxis((samples - mean)**2, 0, 2)).sum(axis=-1) / weights.sum()
    return mean, var

In [None]:
morphs = [sample.morph for sample in unmasked_y_samples]
scaled_likelihoods = np.exp(unmasked_y_logL - min(unmasked_y_logL))
bayes_mean, bayes_var = weighted_mean_and_var(morphs, scaled_likelihoods)

plt.imshow(bayes_mean)
plt.colorbar()
plt.savefig(f"{path_base}morph_bayes_mean.png", dpi=200, bbox_inches="tight")
plt.show()

plt.imshow(bayes_var)
plt.colorbar()
plt.savefig(f"{path_base}morph_bayes_var.png", dpi=200, bbox_inches="tight")
plt.show()