In [None]:
%matplotlib inline

from functools import partial

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno', interpolation='none', origin='lower')

# our code for source separation
import scarlet
import scarlet.display
# code to detect sources
import sep

# to open fits files
from astropy.io import fits
from astropy.table import Table

In [None]:
# set data path on local machine
data_dir = '/home/czhao/Synced/Documents/PrincetonStuff/2019-20/spring/IW/hsc_images'

In [None]:
# source_list = Table.read('parent_sample/source_list.fits')
# source_list = Table.read('good_ir_merged.fits')
source_list = Table.read('parent_sample/EELR_HSCmag_from_SDSSspec.fits')

In [None]:
# find directory for source at RA/DEC

def get_source_i(ra, dec):
    d = (source_list['RA'] - ra)**2 + (source_list['DEC'] - dec)**2
    i = np.argmin(d)
    if d[i] * 3600 > 1:
        print("Closest match more than 1 arcsec away: Proceed with caution!")
    return i

# most prominent candidate
# ra, dec=37.77, -3.75
# source_i = get_source_i(ra, dec)

source_i = 2

source_id = int(source_list[source_i]["OBJID"])

In [None]:
import glob

def get_images_and_psfs(source_id):
    # open files of the source
    bands = ['G', 'R', 'I', 'Z', 'Y']
    images = []
    psfs = []
    for b in bands:
        file = glob.glob("{}/{:05d}/cutout_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
        hdulist = fits.open(file)
        images.append(hdulist[1].data)
        hdulist.close()

        file = glob.glob("{}/{:05d}/psf_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
        hdulist = fits.open(file)
        psfs.append(hdulist[0].data)
        hdulist.close()
    images = (np.array(images)[:,40:-40,40:-40]).copy()

    # pad PSFs to the same shape
    psf_height = max(psf.shape[0] for psf in psfs)
    psf_width = max(psf.shape[1] for psf in psfs)
    psfs = np.stack([np.pad(psf, (((psf_height - psf.shape[0]) // 2,), ((psf_width - psf.shape[1]) // 2,)))
                     for psf in psfs])
    
    return images, psfs

In [None]:
images, psfs = get_images_and_psfs(source_id)

In [None]:
# get magnitude of line emitters
def mag2amplitude(mags):
    dlambda = np.array([.14, .14, .16, .13, .11])
    fnu_Jy = 10**((48.6+mags)/-2.5)
    photons_1Jy = 1.51e7 / dlambda
    return photons_1Jy * fnu_Jy

In [None]:
# from Ai-Lei spectral decomposition tables
# this will require a read-out method to work with other sources
# mags = np.array([26.3652646789765, 23.73486734207137, 21.898052328748093, 28, 21.05645664757267])

# mags = np.array([source_list[source_i][f"speclineMag_{band.lower()}"] for band in bands])

mags = np.nan_to_num(source_list[source_i]['MAG_AB_LINEONLY'], nan=30.0)

print(mags)

# last element (Y band) appears untrustworthy
band_mask = [0,0,0,0,1]

In [None]:
# detect all sources in the image
def makeCatalog(img):
    detect = img.mean(axis=0)
    bkg = sep.Background(detect)
    #catalog = sep.extract(detect, 1.5, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4)
    catalog, segmap = sep.extract(detect, 1.2, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4, segmentation_map=True)
    bg_rms = np.array([sep.Background(band).globalrms for band in img])
    return catalog, segmap, bg_rms

def display_img(images, norm, catalog):
    # make and image and label all sources
    img_rgb = scarlet.display.img_to_rgb(images, norm=norm)

    plt.figure(figsize=(6,6))
    plt.imshow(img_rgb)
    # Mark all of the sources from the detection cataog
    for k, src in enumerate(catalog):
        plt.text(src["x"], src["y"], str(k), color="w")
    plt.show()

In [None]:
catalog, segmap, bg_rms = makeCatalog(images)
# first define color stretch and convert 5 bands to RGB channels
stretch = 1
Q = 5
norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
display_img(images, norm, catalog)

In [None]:
# display psfs
psfs_min = min(psf.min() for psf in psfs)
pnorm = scarlet.display.AsinhMapping(minimum=psfs_min, stretch=1e-2, Q=1)
prgb = scarlet.display.img_to_rgb(psfs, norm=pnorm)
plt.figure()
plt.imshow(prgb)
plt.show()

In [None]:
# define Frame and Observation:
model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))
bands = ['g', 'r', 'i', 'z', 'y']

frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)

# no weight maps, use flat background noise variance instead
# weights = np.ones_like(images) / (bg_rms[:,None,None]**2)
observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

In [None]:
class SEDConstraint(scarlet.Constraint):
    def __init__(self, sed):
        self._sed = sed

    def __call__(self, X, step):
        S = self._sed
        # closest X that is in the direction of S
        # allows for flux rescaling: only direction is constrained
        if not np.ma.is_masked(S):
            X[:] = np.maximum(np.dot(X, S) / np.dot(S, S) * S, 0)
        else:
            X_ = X[~S.mask]
            S_ = S[~S.mask]
            X[:][~S.mask] =  np.maximum(np.dot(X_, S_) / np.dot(S_, S_) * S_, 0)
        return X
    
class RadialMaskConstraint(scarlet.Constraint):
    def __init__(self, shape, pixel_center, R):
        c, ny, nx = shape
        dy = np.arange(ny) - pixel_center[0]
        dx = np.arange(nx) - pixel_center[1]
        dist2 = dy[:,None]**2 + dx[None,:]**2
        self.mask = dist2 > R**2
        
    def __call__(self, X, step):
        X[self.mask] = 0
        X[:,:] = np.maximum(X, 0)
        return X
        
    
class EELRSource(scarlet.RandomSource):
    """Source to describe EELR
    
    It has a free-form morphology, possible constrained to be within R of the center
    but its SED can be determined up to a constant.
    """
    def __init__(self, frame, sky_coord, sed=None, R=None):
        super().__init__(frame)
        
        center = np.array(frame.get_pixel(sky_coord), dtype="float")
        self.pixel_center = tuple(np.round(center).astype("int"))
        
        if sed is not None:
            self._parameters[0].constraint = SEDConstraint(sed)
            self._parameters[0][:] = self._parameters[0].constraint(self._parameters[0], 0)
        if R is not None:
            self._parameters[1].constraint = RadialMaskConstraint(frame.shape, self.pixel_center, R)
            self._parameters[1][:,:] = self._parameters[1].constraint(self._parameters[1], 0)

In [None]:
def get_center_source(catalog, dim):
    eelr_host_ind = -1
    closest_distsq = dim[0]**2 + dim[1]**2
    center = (dim[0] / 2, dim[1] / 2)
    for k, src in enumerate(catalog):
        distsq = (src['y'] - center[0])**2 + (src['x'] - center[1])**2
        if distsq < closest_distsq:
            eelr_host_ind = k
            closest_distsq = distsq
    return eelr_host_ind

def create_sources(catalog, eelr_host_ind, mags, mask, frame, observation):
    sources = []
    for k, src in enumerate(catalog):
        if k == eelr_host_ind:
            sources.append(scarlet.MultiComponentSource(frame, (src['y'], src['x']), observation, thresh=0.2, shifting=True))

            # set mag for EELR source
            mags = np.ma.masked_array(mags, mask=mask)
            eelr_sed = mag2amplitude(mags)
            sources.append(EELRSource(frame, (src['y'],src['x']), sed=eelr_sed, R=None))
        else:
            sources.append(scarlet.ExtendedSource(frame, (src['y'],src['x']), observation, shifting=True, thresh=0.5))
    return sources

In [None]:
eelr_host_ind = get_center_source(catalog, (images.shape[1], images.shape[2]))
print(eelr_host_ind)
print(catalog[eelr_host_ind]['x'], catalog[eelr_host_ind]['y'])
sources = create_sources(catalog, eelr_host_ind, mags, band_mask, frame, observation)
blend = scarlet.Blend(sources, observation)

In [None]:
# run the fitter
%time blend.fit(200, e_rel=1e-5)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
scarlet.display.show_scene(sources, observation=observation, norm=norm, show_observed=True, show_rendered=True, show_residual=True)

In [None]:
scarlet.display.show_sources(sources, observation, show_observed=True, show_rendered=True, norm=norm)

# Resampling

In [None]:
path_base = f"{source_id:05d}_"

print(path_base)

In [None]:
def plot_asinh_stretch(img):
    stretch = 0.05
    Q = 5
    norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
    plt.imshow(scarlet.display.img_to_rgb(img, norm=norm))

In [None]:
def do_one_eelr_sample(catalog, eelr_host_ind, mags, mask, frame, observation):
    sources = create_sources(catalog, eelr_host_ind, mags, mask, frame, observation)
    blend = scarlet.Blend(sources, observation)
    blend.fit(200, e_rel=1e-5)
#     print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
    
    # EELR host and source and logL
    return sources[eelr_host_ind], sources[eelr_host_ind+1], -blend.loss[-1]

In [None]:
eelr_samples = []
for _ in range(2):
    eelr_samples.append(do_one_eelr_sample(catalog, eelr_host_ind, mags, band_mask, frame, observation)[1])

In [None]:
scarlet.display.show_sources(eelr_samples, observation, show_observed=True, show_rendered=True, norm=norm)

In [None]:
def compute_snr(X):
    sig = (X**2).mean(axis=0)
    noise = X.var(axis=0)
    snr = sig / noise
    snr[sig==0] = 0
    snr = np.nan_to_num(snr, nan=np.nanmax(snr))
    return snr

In [None]:
# Create heatmap of average (since multiple bands) SNR of each EELR pixel

eelr_models = np.stack([sample.get_model() for sample in eelr_samples])

snr_per_band = compute_snr(eelr_models)
snr = snr_per_band.mean(axis=0)

plt.imshow(snr)
plt.colorbar()
plt.savefig(f"{path_base}model_snr.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of average morphology

eelr_morphs = np.stack([sample.morph for sample in eelr_samples])
mean_morph = eelr_morphs.mean(axis=0)

plt.imshow(mean_morph)
plt.colorbar()
plt.savefig(f"{path_base}morph_avg.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of morphology SNR

morph_snr = compute_snr(eelr_morphs)

plt.imshow(morph_snr)
plt.colorbar()
plt.savefig(f"{path_base}morph_snr.png", dpi=200, bbox_inches="tight")
plt.show()

# Unmasked Y band

In [None]:
def get_unmasked_y_samples(catalog, eelr_host_ind, mags, samples_per_y, frame, observation):
    unmasked_y_host_samples = []
    unmasked_y_eelr_samples = []
    unmasked_y_logL = []
    for y in range(22, 27+1):
        print(f"y = {y}:")
        this_mags = mags
        this_mags[4] = y
        for i in range(samples_per_y):
            host_sample, eelr_sample, logL = do_one_eelr_sample(catalog, eelr_host_ind, this_mags, [0, 0, 0, 0, 0], frame, observation)
            unmasked_y_host_samples.append(host_sample)
            unmasked_y_eelr_samples.append(eelr_sample)
            unmasked_y_logL.append(logL)
    return unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_logL

In [None]:
samples_per_y = 5
unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_logL = get_unmasked_y_samples(catalog, eelr_host_ind, mags, samples_per_y, frame, observation)

In [None]:
subsample = [unmasked_y_eelr_samples[i] for i in range(0, len(unmasked_y_eelr_samples), samples_per_y)]
scarlet.display.show_sources(subsample, observation, show_observed=True, show_rendered=True, norm=norm)

## Bayesian Inference

In [None]:
def weighted_mean_and_var(samples, weights):
    # Computes Bayes mean and variance
#     samples = np.stack(samples)
    mean = np.average(samples, axis=0, weights=weights)
    var = (weights * np.moveaxis((samples - mean)**2, 0, 2)).sum(axis=-1) / weights.sum()
    return mean, var

def get_outliers(arr, thresh=5):
    # Returns mask of outliers that are more than thresh below the median
    mask = np.zeros(arr.shape)
    last_size = -1
    cur_size = 0
    while cur_size > last_size:
        last_size = cur_size
        med = np.median(arr)
        mask = arr < med - thresh
        cur_size = np.count_nonzero(mask)
    return mask

def zero_borders(X):
    X[0, :] = 0
    X[-1, :] = 0
    X[:, 0] = 0
    X[:, -1] = 0

In [None]:
host_morphs = np.array([sample.components[1].morph for sample in unmasked_y_host_samples])
# host_morphs = np.array([sample.morph for sample in unmasked_y_host_samples])
eelr_morphs = np.array([sample.morph for sample in unmasked_y_eelr_samples])
unmasked_y_logL = np.array(unmasked_y_logL)

# Repeatedly drop likelihoods that are more than five orders of magnitude less than median
mask = get_outliers(np.array(unmasked_y_logL))
print(f"Dropping {np.count_nonzero(mask)} samples.")
scaled_likelihoods = np.exp(unmasked_y_logL[~mask] - min(unmasked_y_logL[~mask]))
scaled_likelihoods /= scaled_likelihoods.min()
host_morphs = host_morphs[~mask]
eelr_morphs = eelr_morphs[~mask]

host_bayes_mean, host_bayes_var = weighted_mean_and_var(host_morphs, scaled_likelihoods)
eelr_bayes_mean, eelr_bayes_var = weighted_mean_and_var(eelr_morphs, scaled_likelihoods)

# zero out borders to remove artifacts due to PSF
zero_borders(host_bayes_mean)
zero_borders(eelr_bayes_mean)

# TODO: is it expected that the likelihoods vary by tens of orders of magnitude?

# plt.imshow(host_bayes_mean)
plot_asinh_stretch(host_bayes_mean)
plt.colorbar()
plt.show()

# plt.imshow(host_bayes_var)
plot_asinh_stretch(host_bayes_var)
plt.colorbar()
plt.show()

plt.imshow(eelr_bayes_mean)
plt.colorbar()
plt.savefig(f"{path_base}morph_bayes_mean.png", dpi=200, bbox_inches="tight")
plt.show()

plt.imshow(eelr_bayes_var)
plt.colorbar()
plt.savefig(f"{path_base}morph_bayes_var.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
for i, morph in enumerate(eelr_morphs):
    print(i)
    plt.imshow(morph)
    plt.colorbar()
    plt.show()

# EELR Classification

In [None]:
def fit_ellipse(morph):
    bkg = sep.Background(morph)
    objects = sep.extract(morph, 1.2, err=bkg.globalrms)
    if len(objects) != 1:
        print(f"Detected {len(objects)} objects in host!")
    i = max(list(range(len(objects))), key=lambda i: objects["cflux"][i])
    return objects["x"][i], objects["y"][i], objects["a"][i], objects["b"][i], objects["theta"][i]

In [None]:
from matplotlib.patches import Ellipse

host_ellipse = fit_ellipse(host_bayes_mean)
print(host_ellipse)

plot_asinh_stretch(host_bayes_mean)
e = Ellipse(xy=(host_ellipse[0], host_ellipse[1]),
                width=6*host_ellipse[2],
                height=6*host_ellipse[3],
                angle=host_ellipse[4] * 180. / np.pi)
e.set_facecolor('none')
e.set_edgecolor('red')
plt.gca().add_artist(e)
plt.show()

In [None]:
print(eelr_bayes_mean.max())

In [None]:
import math

# TODO: dynamically set this threshold
thresh = 1
center_x = host_bayes_mean.shape[1] / 2 - 0.5
center_y = host_bayes_mean.shape[0] / 2 - 0.5
thetas = []
for r in range(eelr_bayes_mean.shape[0]):
    for c in range(eelr_bayes_mean.shape[1]):
        if eelr_bayes_mean[r, c] > thresh:
            x = c - center_x
            y = r - center_y
            theta = math.atan2(y, x) - host_ellipse[4]
            theta = theta * 180 / math.pi
            if theta < 0:
                theta += 360
            thetas.append(theta)

In [None]:
plt.hist(thetas)
plt.xlabel("Theta (°)")
plt.ylabel("Number of Pixels")
plt.show()

In [None]:
def polar_projection(img, center, Rmax=None, resolution=12):
    """Evaluate img at the location of polar grid coordinates
    This method doesn't resample `img` on the polar grid, it merely
    transforms the coordinates and picks the nearest pixel.
    For resolved features, this is an acceptable approximation
    """
    lims = img.shape
    if Rmax is None:
        Rmax = np.sqrt(lims[0]**2 + lims[1]**2)
    R, P = np.meshgrid(np.linspace(0, Rmax, resolution, dtype=np.float), np.linspace(-np.pi, np.pi, resolution))
    Y = np.round(R * np.sin(P)).astype('int') + center[0]
    X = np.round(R * np.cos(P)).astype('int') + center[1]
    YX = np.dstack((Y,X))
    polar = np.array([[ img[tuple(coord)] for coord in YX[i]] for i in range(len(YX))]).T
    return polar

In [None]:
eelr_polar = polar_projection(eelr_bayes_mean,
                              (host_bayes_mean.shape[0] // 2, host_bayes_mean.shape[1] // 2),
                              min(eelr_bayes_mean.shape[0] // 2 - 1, eelr_bayes_mean.shape[1] // 2 - 1))
angle_intensities = eelr_polar.sum(axis=0)

In [None]:
from scipy.signal import find_peaks

def get_cyclic_peak_inds(arr):
    """Finds peaks in a cyclic array."""
    prom_thresh = 0.3 * (arr.max() - arr.min())  # min prominence of peaks
    dist_thresh = len(arr) // 2 - 1  # min distance between peaks
    flattened_peaks, props = find_peaks(arr, prominence=prom_thresh, distance=dist_thresh)
#     print(flattened_peaks, props)
    if len(flattened_peaks) == 0:
        return np.array([])
    # rotate so that first peak is at front
    roll = -flattened_peaks[0] + 3
    rotated = np.roll(arr, roll)
    rotated_peaks, props = find_peaks(rotated, prominence=prom_thresh, distance=dist_thresh)
    cyclic_peaks = (rotated_peaks - roll) % len(arr)
    return np.sort(cyclic_peaks)

def plot_with_peaks(arr, peak_inds):
    bar_colors = np.array(["blue"] * len(arr))
    bar_colors[peak_inds] = "red"
    plt.bar(np.linspace(-np.pi, np.pi, len(arr)), arr, color=bar_colors, width=2*np.pi/len(arr))

In [None]:
cyclic_peak_inds = get_cyclic_peak_inds(angle_intensities)
print(cyclic_peak_inds)

In [None]:
plot_with_peaks(angle_intensities, cyclic_peak_inds)
plt.savefig(f"{path_base}angle_intensities.png", dpi=200, bbox_inches="tight")
plt.show()

# Full Sample

In [None]:
stretch = 1
Q = 5
norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
bands = ['g', 'r', 'i', 'z', 'y']
model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))

jetlikes = []

for source_i in range(len(source_list)):
    # load in source
    source_id = int(source_list[source_i]["OBJID"])
    print("Source", source_id)
    images, psfs = get_images_and_psfs(source_id)
    catalog, segmap, bg_rms = makeCatalog(images)
    display_img(images, norm, catalog)
    mags = np.nan_to_num(source_list[source_i]['MAG_AB_LINEONLY'], nan=30.0)

    # define scarlet frame and observation
    frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)
    observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

    # sampling
    samples_per_y = 3
    eelr_host_ind = get_center_source(catalog, (images.shape[1], images.shape[2]))
    unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_logL = get_unmasked_y_samples(catalog, eelr_host_ind, mags, samples_per_y, frame, observation)

    # morphology Bayes mean and variance
    host_morphs = np.array([sample.components[1].morph for sample in unmasked_y_host_samples])
    # host_morphs = np.array([sample.morph for sample in unmasked_y_host_samples])
    eelr_morphs = np.array([sample.morph for sample in unmasked_y_eelr_samples])
    unmasked_y_logL = np.array(unmasked_y_logL)
    mask = get_outliers(np.array(unmasked_y_logL))
    scaled_likelihoods = np.exp(unmasked_y_logL[~mask] - min(unmasked_y_logL[~mask]))
    scaled_likelihoods /= scaled_likelihoods.min()
    host_morphs = host_morphs[~mask]
    eelr_morphs = eelr_morphs[~mask]
    host_bayes_mean, host_bayes_var = weighted_mean_and_var(host_morphs, scaled_likelihoods)
    eelr_bayes_mean, eelr_bayes_var = weighted_mean_and_var(eelr_morphs, scaled_likelihoods)
    plt.imshow(eelr_bayes_mean)
    plt.colorbar()
    plt.show()
    
    # zero out borders to remove artifacts due to PSF
    zero_borders(host_bayes_mean)
    zero_borders(eelr_bayes_mean)

    # EELR angle intensities
    eelr_polar = polar_projection(eelr_bayes_mean,
                                  (host_bayes_mean.shape[0] // 2, host_bayes_mean.shape[1] // 2),
                                  min(eelr_bayes_mean.shape[0] // 2 - 1, eelr_bayes_mean.shape[1] // 2 - 1))
    angle_intensities = eelr_polar.sum(axis=0)
    plt.bar(np.linspace(-np.pi, np.pi, eelr_polar.shape[1]), angle_intensities)
    plt.show()

    # peaks in angle space
    cyclic_peak_inds = get_cyclic_peak_inds(angle_intensities)
    print(cyclic_peak_inds)
    if len(cyclic_peak_inds == 2):
        jetlikes.append(source_id)

In [None]:
print(jetlikes)