In [None]:
%matplotlib inline

from functools import partial
import pickle

# for parallel computing
from joblib import Parallel, delayed

import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
# use a better colormap and don't interpolate the pixels
matplotlib.rc('image', cmap='inferno', interpolation='none', origin='lower')

from scipy.signal import find_peaks
from scipy.ndimage import gaussian_filter

# our code for source separation
import scarlet
import scarlet.display
# code to detect sources
import sep

# to open fits files
from astropy.io import fits
from astropy.table import Table

In [None]:
# set data path on local machine
data_dir = '/home/czhao/Synced/Documents/PrincetonStuff/2019-20/spring/IW/hsc_images'

In [None]:
# source_list = Table.read('parent_sample/source_list.fits')
# source_list = Table.read('good_ir_merged.fits')
source_list = Table.read('parent_sample/EELR_HSCmag_from_SDSSspec.fits')

In [None]:
# find directory for source at RA/DEC

def get_source_i(ra, dec):
    d = (source_list['RA'] - ra)**2 + (source_list['DEC'] - dec)**2
    i = np.argmin(d)
    if d[i] * 3600 > 1:
        print("Closest match more than 1 arcsec away: Proceed with caution!")
    return i

# most prominent candidate
# ra, dec=37.77, -3.75
# source_i = get_source_i(ra, dec)

source_i = 2

source_id = int(source_list[source_i]["OBJID"])

In [None]:
import glob

def get_images_and_psfs(source_id):
    # open files of the source
    bands = ['G', 'R', 'I', 'Z', 'Y']
    images = []
    psfs = []
    for b in bands:
        file = glob.glob("{}/{:05d}/cutout_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
        hdulist = fits.open(file)
        images.append(hdulist[1].data)
        hdulist.close()

        file = glob.glob("{}/{:05d}/psf_HSC-{}_*_src_*.fits".format(data_dir, source_id, b))[0]
        hdulist = fits.open(file)
        psfs.append(hdulist[0].data)
        hdulist.close()
    images = (np.array(images)[:,40:-40,40:-40]).copy()

    # pad PSFs to the same shape
    psf_height = max(psf.shape[0] for psf in psfs)
    psf_width = max(psf.shape[1] for psf in psfs)
    psfs = np.stack([np.pad(psf, (((psf_height - psf.shape[0]) // 2,), ((psf_width - psf.shape[1]) // 2,)))
                     for psf in psfs])
    
    return images, psfs

In [None]:
images, psfs = get_images_and_psfs(source_id)

In [None]:
# get magnitude of line emitters
def mag2amplitude(mags):
    dlambda = np.array([.14, .14, .16, .13, .11])
    fnu_Jy = 10**((48.6+mags)/-2.5)
    photons_1Jy = 1.51e7 / dlambda
    return photons_1Jy * fnu_Jy

In [None]:
# from Ai-Lei spectral decomposition tables
# this will require a read-out method to work with other sources
# mags = np.array([26.3652646789765, 23.73486734207137, 21.898052328748093, 28, 21.05645664757267])

# mags = np.array([source_list[source_i][f"speclineMag_{band.lower()}"] for band in bands])

mags = np.nan_to_num(source_list[source_i]['MAG_AB_LINEONLY'], nan=30.0)

print(mags)

# last element (Y band) appears untrustworthy
band_mask = [0,0,0,0,1]

In [None]:
# detect all sources in the image
def makeCatalog(img):
    detect = img.mean(axis=0)
    bkg = sep.Background(detect)
    #catalog = sep.extract(detect, 1.5, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4)
    catalog, segmap = sep.extract(detect, 1.2, err=bkg.globalrms,deblend_nthresh=64,deblend_cont=3e-4, segmentation_map=True)
    bg_rms = np.array([sep.Background(band).globalrms for band in img])
    return catalog, segmap, bg_rms

def display_img(images, norm, catalog, ax=None):
    # make and image and label all sources
    img_rgb = scarlet.display.img_to_rgb(images, norm=norm)

    if ax is None:
        plt.figure(figsize=(6,6))
        ax = plt.gca()
    ax.imshow(img_rgb)
    # Mark all of the sources from the detection cataog
    for k, src in enumerate(catalog):
        ax.text(src["x"], src["y"], str(k), color="w")

In [None]:
catalog, segmap, bg_rms = makeCatalog(images)
# first define color stretch and convert 5 bands to RGB channels
stretch = 1
Q = 5
norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
display_img(images, norm, catalog)
plt.show()

In [None]:
# display psfs
psfs_min = min(psf.min() for psf in psfs)
pnorm = scarlet.display.AsinhMapping(minimum=psfs_min, stretch=1e-2, Q=1)
prgb = scarlet.display.img_to_rgb(psfs, norm=pnorm)
plt.figure()
plt.imshow(prgb)
plt.show()

In [None]:
# define Frame and Observation:
model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))
bands = ['g', 'r', 'i', 'z', 'y']

frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)

# no weight maps, use flat background noise variance instead
# weights = np.ones_like(images) / (bg_rms[:,None,None]**2)
observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

In [None]:
class SEDConstraint(scarlet.Constraint):
    def __init__(self, sed):
        self._sed = sed

    def __call__(self, X, step):
        S = self._sed
        # closest X that is in the direction of S
        # allows for flux rescaling: only direction is constrained
        if not np.ma.is_masked(S):
            X[:] = np.maximum(np.dot(X, S) / np.dot(S, S) * S, 0)
        else:
            X_ = X[~S.mask]
            S_ = S[~S.mask]
            X[:][~S.mask] =  np.maximum(np.dot(X_, S_) / np.dot(S_, S_) * S_, 0)
        return X
    
class RadialMaskConstraint(scarlet.Constraint):
    def __init__(self, shape, pixel_center, R):
        c, ny, nx = shape
        dy = np.arange(ny) - pixel_center[0]
        dx = np.arange(nx) - pixel_center[1]
        dist2 = dy[:,None]**2 + dx[None,:]**2
        self.mask = dist2 > R**2
        
    def __call__(self, X, step):
        X[self.mask] = 0
        X[:,:] = np.maximum(X, 0)
        return X
        
    
class EELRSource(scarlet.RandomSource):
    """Source to describe EELR
    
    It has a free-form morphology, possible constrained to be within R of the center
    but its SED can be determined up to a constant.
    """
    def __init__(self, frame, sky_coord, sed=None, R=None):
        super().__init__(frame)
        
        center = np.array(frame.get_pixel(sky_coord), dtype="float")
        self.pixel_center = tuple(np.round(center).astype("int"))
        
        if sed is not None:
            self._parameters[0].constraint = SEDConstraint(sed)
            self._parameters[0][:] = self._parameters[0].constraint(self._parameters[0], 0)
        if R is not None:
            self._parameters[1].constraint = RadialMaskConstraint(frame.shape, self.pixel_center, R)
            self._parameters[1][:,:] = self._parameters[1].constraint(self._parameters[1], 0)

In [None]:
def get_center_source(catalog, dim):
    eelr_host_ind = -1
    closest_distsq = dim[0]**2 + dim[1]**2
    center = (dim[0] / 2, dim[1] / 2)
    for k, src in enumerate(catalog):
        distsq = (src['y'] - center[0])**2 + (src['x'] - center[1])**2
        if distsq < closest_distsq:
            eelr_host_ind = k
            closest_distsq = distsq
    return eelr_host_ind

def create_sources(catalog, eelr_host_ind, mags, mask, frame, observation):
    sources = []
    for k, src in enumerate(catalog):
        if k == eelr_host_ind:
            sources.append(scarlet.MultiComponentSource(frame, (src['y'], src['x']), observation, thresh=0.2, shifting=True))

            # set mag for EELR source
            mags = np.ma.masked_array(mags, mask=mask)
            eelr_sed = mag2amplitude(mags)
            sources.append(EELRSource(frame, (src['y'],src['x']), sed=eelr_sed, R=None))
        else:
            sources.append(scarlet.ExtendedSource(frame, (src['y'],src['x']), observation, shifting=True, thresh=0.5))
    return sources

In [None]:
eelr_host_ind = get_center_source(catalog, (images.shape[1], images.shape[2]))
print(eelr_host_ind)
print(catalog[eelr_host_ind]['x'], catalog[eelr_host_ind]['y'])
sources = create_sources(catalog, eelr_host_ind, mags, band_mask, frame, observation)
blend = scarlet.Blend(sources, observation)

In [None]:
# run the fitter
%time blend.fit(200, e_rel=1e-5)
print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
plt.plot(-np.array(blend.loss))
plt.xlabel('Iteration')
plt.ylabel('log-Likelihood')

In [None]:
fig = scarlet.display.show_scene(sources, observation=observation, norm=norm, show_observed=True, show_rendered=True, show_residual=True)
fig.savefig(f"{source_id}_scarlet_scene.png", dpi=200, bbox_inches="tight")
fig

In [None]:
fig = scarlet.display.show_sources(sources, observation, show_observed=True, show_rendered=True, norm=norm)
fig.savefig(f"{source_id}_scarlet_sources.png", dpi=200, bbox_inches="tight")
fig

# Resampling

In [None]:
prefix = f"{source_id:05d}_"
print(prefix)

In [None]:
def plot_asinh_stretch(img):
    stretch = 0.05
    Q = 5
    norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
    plt.imshow(scarlet.display.img_to_rgb(img, norm=norm))

In [None]:
def do_one_eelr_sample(catalog, eelr_host_ind, mags, mask, frame, observation):
    sources = create_sources(catalog, eelr_host_ind, mags, mask, frame, observation)
    blend = scarlet.Blend(sources, observation)
    blend.fit(200, e_rel=1e-5)
#     print("scarlet ran for {0} iterations to logL = {1}".format(len(blend.loss), -blend.loss[-1]))
    
    # EELR host and source, full model, and logL
    return sources[eelr_host_ind], sources[eelr_host_ind+1], blend.get_model(), -blend.loss[-1]

In [None]:
eelr_samples = []
for _ in range(2):
    eelr_samples.append(do_one_eelr_sample(catalog, eelr_host_ind, mags, band_mask, frame, observation)[1])

In [None]:
scarlet.display.show_sources(eelr_samples, observation, show_observed=True, show_rendered=True, norm=norm)

In [None]:
def compute_snr(X):
    sig = (X**2).mean(axis=0)
    noise = X.var(axis=0)
    snr = sig / noise
    snr[sig==0] = 0
    snr = np.nan_to_num(snr, nan=np.nanmax(snr))
    return snr

In [None]:
# Create heatmap of average (since multiple bands) SNR of each EELR pixel

eelr_models = np.stack([sample.get_model() for sample in eelr_samples])

snr_per_band = compute_snr(eelr_models)
snr = snr_per_band.mean(axis=0)

plt.imshow(snr)
plt.colorbar()
plt.savefig(f"{prefix}model_snr.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of average morphology

eelr_morphs = np.stack([sample.morph for sample in eelr_samples])
mean_morph = eelr_morphs.mean(axis=0)

plt.imshow(mean_morph)
plt.colorbar()
plt.savefig(f"{prefix}morph_avg.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
# Create heatmap of morphology SNR

morph_snr = compute_snr(eelr_morphs)

plt.imshow(morph_snr)
plt.colorbar()
plt.savefig(f"{prefix}morph_snr.png", dpi=200, bbox_inches="tight")
plt.show()

# Unmasked Y band

In [None]:
def get_unmasked_y_samples(catalog, eelr_host_ind, mags, num_samples, frame, observation):
    unmasked_y_host_samples = []
    unmasked_y_eelr_samples = []
    unmasked_y_model_samples = []
    unmasked_y_logL = []
    for _ in range(num_samples):
        y = np.random.uniform(22, 27)
#         print(f"y = {y}:")
        this_mags = mags
        this_mags[4] = y
        host_sample, eelr_sample, model_sample, logL = do_one_eelr_sample(catalog, eelr_host_ind, this_mags, [0, 0, 0, 0, 0], frame, observation)
        unmasked_y_host_samples.append(host_sample)
        unmasked_y_eelr_samples.append(eelr_sample)
        unmasked_y_model_samples.append(model_sample)
        unmasked_y_logL.append(logL)
    return unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_model_samples, unmasked_y_logL

In [None]:
num_samples = 10
unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_model_samples, unmasked_y_logL = get_unmasked_y_samples(catalog, eelr_host_ind, mags, num_samples, frame, observation)

In [None]:
scarlet.display.show_sources(unmasked_y_eelr_samples[:5], observation, show_observed=True, show_rendered=True, norm=norm)

## Model Averaging

In [None]:
def weighted_mean_and_var(samples, weights):
    # Computes likelihood-weighted mean and variance
#     samples = np.stack(samples)
    mean = np.average(samples, axis=0, weights=weights)
    w_normed = weights / weights.sum()
    # variance with reliability weights
    var_nonadj = (np.square(w_normed) * np.moveaxis((samples - mean)**2, 0, -1)).sum(axis=-1)
    v2 = np.square(w_normed).sum(axis=-1)
    var = var_nonadj / (1 - v2)
    return mean, var

def get_outliers(arr, thresh=5):
    # Returns mask of outliers that are more than thresh below the median
    mask = np.zeros(arr.shape)
    last_size = -1
    cur_size = 0
    while cur_size > last_size:
        last_size = cur_size
        med = np.median(arr)
        mask = arr < med - thresh
        cur_size = np.count_nonzero(mask)
    return mask

def zero_borders(X):
    X[0, :] = 0
    X[-1, :] = 0
    X[:, 0] = 0
    X[:, -1] = 0

In [None]:
host_morphs = np.array([sample.components[1].morph for sample in unmasked_y_host_samples])
# host_morphs = np.array([sample.morph for sample in unmasked_y_host_samples])
eelr_morphs = np.array([sample.morph for sample in unmasked_y_eelr_samples])
models = np.array(unmasked_y_model_samples)
unmasked_y_logL = np.array(unmasked_y_logL)

# Drop likelihoods that are more than 7 orders of magnitude (in units of e) less than the max
mask = unmasked_y_logL > (unmasked_y_logL.max() - 7)
print(f"Using {np.count_nonzero(mask)} samples.")
scaled_likelihoods = np.exp(unmasked_y_logL[mask] - max(unmasked_y_logL[mask]))
host_morphs = host_morphs[mask]
eelr_morphs = eelr_morphs[mask]
models = models[mask]

host_avg, host_var = weighted_mean_and_var(host_morphs, scaled_likelihoods)
eelr_avg, eelr_var = weighted_mean_and_var(eelr_morphs, scaled_likelihoods)
model_avg, model_var = weighted_mean_and_var(models, scaled_likelihoods)

# zero out borders to remove artifacts due to PSF
zero_borders(host_avg)
zero_borders(eelr_avg)

# low pass filter
eelr_avg_filtered = gaussian_filter(eelr_avg, sigma=2)

# plt.imshow(host_avg)
plot_asinh_stretch(host_avg)
plt.colorbar()
plt.show()

# plt.imshow(host_var)
plot_asinh_stretch(host_var)
plt.colorbar()
plt.show()

display_img(model_avg, norm, catalog)
plt.show()

plt.imshow(eelr_avg)
plt.colorbar()
plt.savefig(f"{prefix}morph_avg.png", dpi=200, bbox_inches="tight")
plt.show()

plt.imshow(eelr_avg_filtered)
plt.colorbar()
plt.savefig(f"{prefix}morph_avg_filtered.png", dpi=200, bbox_inches="tight")
plt.show()

plt.imshow(eelr_var)
plt.colorbar()
plt.savefig(f"{prefix}morph_var.png", dpi=200, bbox_inches="tight")
plt.show()

In [None]:
for i, morph in enumerate(eelr_morphs):
    print(i)
    plt.imshow(morph)
    plt.colorbar()
    plt.show()

# EELR Classification

In [None]:
def fit_ellipse(morph):
    bkg = sep.Background(morph)
    objects = sep.extract(morph, 1.2, err=bkg.globalrms)
    if len(objects) != 1:
        print(f"Detected {len(objects)} objects in host!")
    i = max(list(range(len(objects))), key=lambda i: objects["cflux"][i])
    return objects["x"][i], objects["y"][i], objects["a"][i], objects["b"][i], objects["theta"][i]

In [None]:
from matplotlib.patches import Ellipse

host_ellipse = fit_ellipse(host_avg)
print(host_ellipse)

plot_asinh_stretch(host_avg)
e = Ellipse(xy=(host_ellipse[0], host_ellipse[1]),
                width=6*host_ellipse[2],
                height=6*host_ellipse[3],
                angle=host_ellipse[4] * 180. / np.pi)
e.set_facecolor('none')
e.set_edgecolor('red')
plt.gca().add_artist(e)
plt.show()

In [None]:
print(eelr_avg.max())

In [None]:
import math

# TODO: dynamically set this threshold
thresh = 1
center_x = host_avg.shape[1] / 2 - 0.5
center_y = host_avg.shape[0] / 2 - 0.5
thetas = []
for r in range(eelr_avg.shape[0]):
    for c in range(eelr_avg.shape[1]):
        if eelr_avg[r, c] > thresh:
            x = c - center_x
            y = r - center_y
            theta = math.atan2(y, x) - host_ellipse[4]
            theta = theta * 180 / math.pi
            if theta < 0:
                theta += 360
            thetas.append(theta)

In [None]:
plt.hist(thetas)
plt.xlabel("Theta (°)")
plt.ylabel("Number of Pixels")
plt.show()

In [None]:
def polar_projection(img, center, Rmax=None, resolution=100):
    """Evaluate img at the location of polar grid coordinates
    This method doesn't resample `img` on the polar grid, it merely
    transforms the coordinates and picks the nearest pixel.
    For resolved features, this is an acceptable approximation
    """
    lims = img.shape
    if Rmax is None:
        Rmax = np.sqrt(lims[0]**2 + lims[1]**2)
    R, P = np.meshgrid(np.linspace(0, Rmax, resolution, dtype=np.float), np.linspace(-np.pi, np.pi, resolution))
    Y = np.round(R * np.sin(P)).astype('int') + center[0]
    X = np.round(R * np.cos(P)).astype('int') + center[1]
    YX = np.dstack((Y,X))
    polar = np.array([[ img[tuple(coord)] for coord in YX[i]] for i in range(len(YX))]).T
    return polar

In [None]:
num_bins = 12  # number of angle bins
angles_per_bin = 5  # number of angles to sample per bin

eelr_polar = polar_projection(eelr_avg,
                              (host_avg.shape[0] // 2, host_avg.shape[1] // 2),
                              min(eelr_avg.shape[0] // 2 - 1, eelr_avg.shape[1] // 2 - 1),
                              resolution=num_bins * angles_per_bin)

# only consider pixels r_thresh or further from the center (units depend on polar resolution)
r_thresh = 2
angle_intensities = eelr_polar[r_thresh:].sum(axis=0)
angle_intensities = np.array([sum(angle_intensities[i:i+angles_per_bin]) for i in range(0, num_bins*angles_per_bin, angles_per_bin)])

In [None]:
def get_cyclic_peak_inds(arr):
    """Finds peaks in a cyclic array."""
    
    prom_thresh = 0.2 * (arr.max() - arr.min())  # min prominence of peaks
#     dist_thresh = len(arr) // 2 - 1  # min distance between peaks
    peaks, _ = find_peaks(arr, prominence=prom_thresh)

    # find peaks that were at the edges
    roll = len(arr) // 2
    rolled = np.roll(arr, roll)
    rolled_peaks, _ = find_peaks(rolled, prominence=prom_thresh)
    unrolled_peaks = (rolled_peaks - roll) % len(arr)
    
    return np.sort(np.union1d(peaks, unrolled_peaks))

def plot_with_peaks(arr, peak_inds, ax=None):
    bar_colors = np.array(["blue"] * len(arr))
    if len(peak_inds) > 0:
        bar_colors[peak_inds] = "red"
    if ax is None:
        ax = plt.gca()
    ax.bar(np.linspace(-np.pi, np.pi, len(arr)), arr, color=bar_colors, width=2*np.pi/len(arr))
    ax.set_xlabel("Theta")

In [None]:
cyclic_peak_inds = get_cyclic_peak_inds(angle_intensities)
print(cyclic_peak_inds)

In [None]:
plot_with_peaks(angle_intensities, cyclic_peak_inds)
plt.savefig(f"{prefix}angle_intensities.png", dpi=200, bbox_inches="tight")
plt.show()

# griz Sampling

In [None]:
with open("gp_models.pickle", "rb") as infile:
    z_to_griz_models = pickle.load(infile)

print(z_to_griz_models.keys())

In [None]:
def do_one_griz_sample(z_):
    z_ = np.array([[z_]])
    g = 25
    r = g - z_to_griz_models["g-r"].sample_y(z_)[0, 0]
    i = r - z_to_griz_models["r-i"].sample_y(z_)[0, 0]
    z = i - z_to_griz_models["i-z"].sample_y(z_)[0, 0]
    return [g, r, i, z]

# Large Sample

In [None]:
first_df = pd.read_csv("parent_sample/FIRST.csv", comment="#")
first_ids = first_df["objid"].to_numpy()

FIG_DIR = "eelr_outputs_sdss_fixedvar/"

In [None]:
def gen_eelr_fig(result):    
    if result is None: return

    stretch = 1
    Q = 5
    norm = scarlet.display.AsinhMapping(minimum=0, stretch=stretch, Q=Q)
    bands = ['g', 'r', 'i', 'z', 'y']
    model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))
    
    # load in source
    images, psfs = get_images_and_psfs(result["id"])
    catalog, segmap, bg_rms = makeCatalog(images)

    # define scarlet frame and observation
    frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)
    observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)
    
    fig, axs = plt.subplots(1, 3, figsize=(15, 5), gridspec_kw={'width_ratios': [1, 1, 1.1]})
#     fig.suptitle(f"Source {result['id']}")
    axs[0].set_title("Observation")
    axs[1].set_title("Rendered Model")
    axs[2].set_title("EELR Morphology")
    display_img(images, norm, catalog, axs[0])
    display_img(observation.render(result["model_avg"]), norm, catalog, axs[1])
    im = axs[2].imshow(result["eelr_avg"])
    fig.colorbar(im, ax=axs[2], fraction=0.04)
#         plot_with_peaks(angle_intensities, cyclic_peak_inds, axs[2])

    # footnote
#     if result['id'] in first_ids:
#         footnote = "In FIRST catalog."
#     else:
#         footnote = "Not in FIRST catalog."
#     plt.figtext(0.11, 0.07, footnote, fontsize="small", fontstyle="italic", ha="left", va="bottom")

    prefix = f"{result['id']:05d}_"
    plt.savefig(f"{FIG_DIR}{prefix}eelr.png", dpi=200, bbox_inches="tight")
#     plt.show()

In [None]:
import functools
import traceback

def safe_func(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except Exception:
            traceback.print_exc()
            return None
    return wrapper

In [None]:
def model_eelr(objid, num_samples=10, gen_fig=True, use_sdss=False, z_noise=0):

    # FIXME: scarlet breaks on 178
#     if objid == 178: return None
    
    source_info = source_list[source_list["OBJID"].astype(int) == objid]
    source_info = {k: source_info[k][0] for k in source_info.columns}
    
    bands = ['g', 'r', 'i', 'z', 'y']
    model_psf = scarlet.PSF(partial(scarlet.psf.gaussian, sigma=.8), shape=(None, 8, 8))
    
    # load in source
    images, psfs = get_images_and_psfs(objid)
    catalog, segmap, bg_rms = makeCatalog(images)
    if use_sdss:
        mags = np.nan_to_num(source_info['MAG_AB_LINEONLY'], nan=30.0)
    else:
        z = source_info["Z"]
        z += np.random.uniform(-z_noise, z_noise)

    # define scarlet frame and observation
    frame = scarlet.Frame(images.shape, psfs=model_psf, channels=bands)
    observation = scarlet.Observation(images, psfs=psfs, channels=bands).match(frame)

    # sampling
    eelr_host_ind = get_center_source(catalog, (images.shape[1], images.shape[2]))
    if use_sdss:
        unmasked_y_host_samples, unmasked_y_eelr_samples, unmasked_y_model_samples, unmasked_y_logL = get_unmasked_y_samples(catalog, eelr_host_ind, mags, num_samples, frame, observation)
    else:
        unmasked_y_host_samples = []
        unmasked_y_eelr_samples = []
        unmasked_y_model_samples = []
        mags_samples = []
        unmasked_y_logL = []
        for _ in range(num_samples):
            y = np.random.uniform(18, 26)
            mags = do_one_griz_sample(z)
            mags.append(y)
            mags_samples.append(mags)
            host_sample, eelr_sample, model_sample, logL = do_one_eelr_sample(catalog, eelr_host_ind, mags, [0, 0, 0, 0, 0], frame, observation)
            unmasked_y_host_samples.append(host_sample)
            unmasked_y_eelr_samples.append(eelr_sample)
            unmasked_y_model_samples.append(model_sample)
            unmasked_y_logL.append(logL)
    
    # morphology likelihood-weighted mean and variance
    host_morphs = np.array([sample.components[1].morph for sample in unmasked_y_host_samples])
    # host_morphs = np.array([sample.morph for sample in unmasked_y_host_samples])
    eelr_morphs = np.array([sample.morph for sample in unmasked_y_eelr_samples])
    models = np.array(unmasked_y_model_samples)
    unmasked_y_logL = np.array(unmasked_y_logL)
    # Drop likelihoods that are more than 7 orders of magnitude (in units of e) less than the max
    mask = unmasked_y_logL > (unmasked_y_logL.max() - 7)
    used_samples = np.count_nonzero(mask)
    print(f"Using {used_samples} samples.")
    scaled_likelihoods = np.exp(unmasked_y_logL[mask] - max(unmasked_y_logL[mask]))
    host_morphs = host_morphs[mask]
    eelr_morphs = eelr_morphs[mask]
    models = models[mask]
    host_avg, host_var = weighted_mean_and_var(host_morphs, scaled_likelihoods)
    eelr_avg, eelr_var = weighted_mean_and_var(eelr_morphs, scaled_likelihoods)
    model_avg, model_var = weighted_mean_and_var(models, scaled_likelihoods)

#     print(mags_avg - mags_avg[0])
#     sdss_mags = np.nan_to_num(source_info['MAG_AB_LINEONLY'], nan=30.0)
#     print(sdss_mags - sdss_mags[0])
    
#     # zero out borders to remove artifacts due to PSF
#     zero_borders(host_avg)
#     zero_borders(eelr_avg)
    
#     # low pass filter
#     eelr_avg_filtered = gaussian_filter(eelr_avg, sigma=2)

#     # EELR angle intensities
#     num_bins = 12  # number of angle bins
#     angles_per_bin = 5  # number of angles to sample per bin
#     eelr_polar = polar_projection(eelr_avg_filtered,
#                                   (host_avg.shape[0] // 2, host_avg.shape[1] // 2),
#                                   min(eelr_avg.shape[0] // 2 - 1, eelr_avg.shape[1] // 2 - 1),
#                                   resolution=num_bins * angles_per_bin)
#     # only consider pixels r_thresh or further from the center (units depend on polar resolution)
#     r_thresh = 2
#     angle_intensities = eelr_polar[r_thresh:].sum(axis=0)
#     angle_intensities = np.array([sum(angle_intensities[i:i+angles_per_bin]) for i in range(0, num_bins*angles_per_bin, angles_per_bin)])

#     # peaks in angle space
#     cyclic_peak_inds = get_cyclic_peak_inds(angle_intensities)
    
    result = {"id": objid,
              "host_avg": host_avg,
              "host_var": host_var,
              "eelr_avg": eelr_avg,
              "eelr_var": eelr_var,
              "model_avg": model_avg,
              "model_var": model_var,
              "num_samples": used_samples,
#               "peaks": cyclic_peak_inds,
           }
    if not use_sdss:
        result["z"] = z
        mags_samples = np.array(mags_samples)[mask]
        result["mags_avg"], result["mags_var"] = weighted_mean_and_var(mags_samples, scaled_likelihoods)
    
    if gen_fig:
        gen_eelr_fig(result)
    
    return result

In [None]:
sample = "full"

if sample == "first":
    source_inds = first_ids
elif sample == "full":
    source_inds = source_list["OBJID"].astype(int)

results = Parallel(n_jobs=3, verbose=11)(delayed(safe_func(model_eelr))(i, num_samples=50, use_sdss=True, z_noise=0) for i in source_inds)

## Save/Load Results

In [None]:
with open("results_sdss_fixedvar.pickle", "wb") as outfile:
    pickle.dump(results, outfile)

In [None]:
with open("results_sdss.pickle", "rb") as infile:
    results = pickle.load(infile)

## Regenerate EELR Figures

In [None]:
sample = "full"

if sample == "first":
    source_inds = first_ids
elif sample == "full":
    source_inds = source_list["OBJID"].astype(int)

Parallel(n_jobs=6, verbose=11)(delayed(gen_eelr_fig)(res) for res in results)

## Differential Evaluation

In [None]:
with open("results_sdss.pickle", "rb") as infile:
    sdss_results = pickle.load(infile)

with open("results_gp_noisyz.pickle", "rb") as infile:
    gp_results = pickle.load(infile)

In [None]:
from scipy.stats import pearsonr

sdss_mse, sdss_totinten, sdss_meanintenvar = [], [], []
gp_mse, gp_totinten, gp_meanintenvar = [], [], []
correlations = []

for i in range(len(sdss_results)):
    if sdss_results[i] is None: continue
    assert(sdss_results[i]["id"] == gp_results[i]["id"])
    
    images, psfs = get_images_and_psfs(sdss_results[i]["id"])
    sdss_mse.append(np.square(sdss_results[i]["model_bayes_mean"] - images).sum() / images.size)
    gp_mse.append(np.square(gp_results[i]["model_bayes_mean"] - images).sum() / images.size)
    
    mid_r = sdss_results[i]["eelr_bayes_mean"].shape[0] // 2
    mid_c = sdss_results[i]["eelr_bayes_mean"].shape[1] // 2
    win_half = 15
    cropped_sdss_eelr_bayes_mean = sdss_results[i]["eelr_bayes_mean"][mid_r-win_half:mid_r+win_half, mid_c-win_half:mid_c+win_half]
    cropped_gp_eelr_bayes_mean = gp_results[i]["eelr_bayes_mean"][mid_r-win_half:mid_r+win_half, mid_c-win_half:mid_c+win_half]
    
    sdss_totinten.append(cropped_sdss_eelr_bayes_mean.sum())
    gp_totinten.append(cropped_gp_eelr_bayes_mean.sum())
    
    sdss_meanintenvar.append(sdss_results[i]["eelr_bayes_var"].mean())
    gp_meanintenvar.append(gp_results[i]["eelr_bayes_var"].mean())
    
    correlations.append(pearsonr(cropped_sdss_eelr_bayes_mean.ravel(), cropped_gp_eelr_bayes_mean.ravel())[0])
#     correlations.append(pearsonr(sdss_results[i]["eelr_bayes_mean"].ravel(), gp_results[i]["eelr_bayes_mean"].ravel())[0]

sdss_mse, sdss_totinten, sdss_meanintenvar = np.array(sdss_mse), np.array(sdss_totinten), np.array(sdss_meanintenvar)
gp_mse, gp_totinten, gp_meanintenvar = np.array(gp_mse), np.array(gp_totinten), np.array(gp_meanintenvar)
correlations = np.array(correlations)

In [None]:
plt.scatter(sdss_mse, gp_mse)
plt.loglog([sdss_mse.min(), gp_mse.max()], [sdss_mse.min(), gp_mse.max()], color="red")
plt.xlabel("SDSS MSE")
plt.ylabel("GP MSE")
# plt.savefig("sdss_vs_gp_mse.png", dpi=200, bbox_inches="tight")
plt.show()

print("SDSS Median MSE:", np.median(sdss_mse))
print("GP Median MSE:", np.median(gp_mse))

print("Proportion where SDSS MSE < GP MSE:", np.count_nonzero(sdss_mse < gp_mse) / len(sdss_mse))

In [None]:
plt.scatter(sdss_totinten, gp_totinten)
plt.loglog([gp_totinten.min(), sdss_totinten.max()], [gp_totinten.min(), sdss_totinten.max()], color="red")
plt.xlabel("SDSS Total Intensity")
plt.ylabel("GP Total Intensity")
# plt.savefig("sdss_vs_gp_intensity.png", dpi=200, bbox_inches="tight")
plt.show()

print("SDSS Median Total Intensity:", np.median(sdss_totinten))
print("GP Median Total Intensity:", np.median(gp_totinten))

print("Proportion where SDSS intensity > GP intensity:", np.count_nonzero(sdss_totinten > gp_totinten) / len(sdss_totinten))

In [None]:
plt.scatter(sdss_meanintenvar, gp_meanintenvar)
plt.plot([sdss_meanintenvar.min(), gp_meanintenvar.max()], [sdss_meanintenvar.min(), gp_meanintenvar.max()], color="red")
plt.xlabel("SDSS MPIV")
plt.ylabel("GP MPIV")
# plt.savefig("sdss_vs_gp_intenvar.png", dpi=200, bbox_inches="tight")
plt.show()

print("SDSS Median MPIV:", np.median(sdss_meanintenvar))
print("GP Median MPIV:", np.median(gp_meanintenvar))

print("Proportion where SDSS intensity var < GP intensity var:", np.count_nonzero(sdss_meanintenvar < gp_meanintenvar) / len(sdss_meanintenvar))

In [None]:
plt.hist(correlations)
plt.xlabel("Correlation")
plt.ylabel("Frequency")
plt.savefig("sdss_vs_gp_noisyz_correlation.png", dpi=200, bbox_inches="tight")
plt.show()

## Comparison with FIRST

In [None]:
# # TODO: consider distance between peaks?

# jetlikes = [source_inds[i] for i in range(len(source_inds)) if results[i] and len(results[i]["peaks"]) == 2]
# print(len(jetlikes))

In [None]:
# first_detected, _, first_detected_inds = np.intersect1d(jetlikes, first_ids, return_indices=True)
# print(f"Detected {len(first_detected)} of {len(first_ids)} FIRST AGNs")
# print(first_detected)

In [None]:
# first_missed = np.delete(first_ids, first_detected_inds)
# print(first_missed)