This notebook evaluates the log-likelihood of model parameters obtained by wake-sleep. Produces Table 2 in the text. 

In [1]:
import numpy as np
import pathlib 

import matplotlib.pyplot as plt

import torch

from torch.utils.data import Dataset

import deblending_runjingdev.sdss_dataset_lib as sdss_dataset_lib
import deblending_runjingdev.psf_transform_lib as psf_transform_lib
import deblending_runjingdev.wake_lib as wake_lib
from deblending_runjingdev.which_device import device

from astropy.io import fits
from astropy.wcs import WCS

import os

# Load M2 data

In [2]:
sdss_image, sdss_background, hubble_locs, hubble_fluxes, sdss_data, wcs = \
    sdss_dataset_lib.load_m2_data()

loading sdss image from ../../sdss_stage_dir/2583/2/136/frame-r-002583-2-0136.fits
loading sdss image from ../../sdss_stage_dir/2583/2/136/frame-i-002583-2-0136.fits
loading hubble data from  ../hubble_data/hlsp_acsggct_hst_acs-wfc_ngc7089_r.rdviq.cal.adj.zpt
getting sdss coordinates from:  ../../sdss_stage_dir/2583/2/136/frame-r-002583-2-0136.fits

 aligning images. 
 Getting sdss coordinates from:  ../../sdss_stage_dir/2583/2/136/frame-i-002583-2-0136.fits


the RADECSYS keyword is deprecated, use RADESYSa. [astropy.wcs.wcs]



 returning image at x0 = 630, x1 = 310


In [3]:
hubble_n_stars = torch.Tensor([hubble_locs.shape[0]]).long().to(device)

# Load initial PSF and background

In [4]:
bands = [2, 3]
psfield_file = '../../sdss_stage_dir/2583/2/136/psField-002583-2-0136.fit'
init_psf_params = psf_transform_lib.get_psf_params(
                                    psfield_file,
                                    bands = bands).to(device)

In [5]:
init_background_params = torch.zeros(len(bands), 3).to(device)
init_background_params[:, 0] = sdss_background.mean(-1).mean(-1)

In [6]:
model_params = wake_lib.ModelParams(sdss_image.unsqueeze(0), 
                                    init_psf_params,
                                    init_background_params)

# Function to evaluate log-likelihood conditional in true catalog

In [7]:
def get_chi2_loss(powerlaw_psf_params, planar_back_params):
    
    # construct model from psf and background
    model_params = wake_lib.ModelParams(sdss_image.unsqueeze(0),
                                        powerlaw_psf_params,
                                        planar_back_params)
    
    # evaluate loss at true hubble parameters
    recon_mean, _ = model_params.get_loss(use_cached_stars = False, 
                                          locs = hubble_locs.unsqueeze(0),
                                          fluxes = hubble_fluxes.unsqueeze(0), 
                                          n_stars = hubble_n_stars)
        
    recon_mean = recon_mean.to(device)
    
    # only evalute loss at r-band 
    band = 0
    out = ((recon_mean - sdss_image.unsqueeze(0))**2 / recon_mean)[:, band, 5:95, 5:95].sum()
    
    return out.detach().cpu().numpy()

In [11]:
# estimated psf parameters
est_psf_params = torch.Tensor(np.load(
                '../fits/starnet-iter1-powerlaw_psf_params.npy')).to(device)

# estimated background parameters
est_back_params = torch.Tensor(np.load(
    '../fits/starnet-iter1-planarback_params.npy')).to(device)


In [12]:
losses_vec = [get_chi2_loss(init_psf_params, init_background_params), # initial psf + initial background
              get_chi2_loss(est_psf_params, init_background_params), # estimated psf + intial background
              get_chi2_loss(init_psf_params, est_back_params), # initial psf + estimated background
              get_chi2_loss(est_psf_params, est_back_params)] # estimated psf estimated background

In [13]:
import pandas as pd

In [14]:
chi_sq_stats_df = \
    pd.DataFrame({'Model Estimate': ['PHOTO', 'StarNet PSF', 'StarNet background', 'StarNet background + PSF'],
                   'Neg. Loglik': losses_vec})

# Table 2 in our paper



In [16]:
chi_sq_stats_df

Unnamed: 0,Model Estimate,Neg. Loglik
0,PHOTO,867079.6
1,StarNet PSF,866502.9
2,StarNet background,365066.4
3,StarNet background + PSF,339508.75
