# Demonstration of `photutils.psf` with an image-based PSF Model

In [None]:
import os
import sys

import numpy as np

from astropy import units as u
from astropy.table import Table
from astropy.io import fits
from astropy import wcs

%matplotlib inline
from matplotlib import style, pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

plt.rcParams['image.cmap'] = 'viridis'
plt.rcParams['image.origin'] = 'lower'
plt.rcParams['axes.prop_cycle'] = style.library['seaborn-deep']['axes.prop_cycle']
plt.rcParams['figure.figsize'] = (14, 8)
plt.rcParams['axes.titlesize'] =  plt.rcParams['axes.labelsize'] = 16
plt.rcParams['xtick.labelsize'] =  plt.rcParams['ytick.labelsize'] = 14

plt.rcParams['image.interpolation'] = 'nearest'

In [None]:
import photutils
from photutils import psf
from astropy.modeling import models

photutils.__version__

# Build a simulated image with a funky elliptical Moffat-like PSF (but no noise)

In [None]:
psfmodel = ((models.Shift(-5) & models.Shift(2)) | 
            models.Rotation2D(-20) | 
            (models.Identity(1) & models.Scale(1.5)) | 
            models.Moffat2D(1, 0,0, 6, 4.76))

psfmodel.bounding_box = ((-10, 10), (-10, 10))
psfim = psfmodel.render().T
plt.imshow(psfim)
psfmodel.offset_0 = psfmodel.offset_1 = 0
psfimcen = psfmodel.render()
del psfmodel.bounding_box

psfmodel

In [None]:
im = np.zeros((100, 100))

amps = np.random.randn(100)**2
xs = im.shape[0] * np.random.rand(amps.size)
ys = im.shape[1] * np.random.rand(amps.size)

for x, y, amp in zip(xs, ys, amps):
    psfmodel.amplitude_3 = amp
    psfmodel.offset_1 = -x
    psfmodel.offset_0 = -y
    
    psfmodel.render(im)
plt.imshow(im)

## Now we use  `FittableImageModel` on a *rendered* version of the PSF model with no pixel subsampling

In [None]:
plt.imshow(psfimcen)
plt.colorbar()
psf_im_model = psf.FittableImageModel(psfimcen, normalize=1)

plt.figure()

In [None]:
psf_im_model.bounding_box = ((-10, 10), (-10, 10))
psfrendered = psf_im_model.render()
del psf_im_model.bounding_box
plt.imshow(psfrendered)
plt.colorbar()

## Now lets try doing photometry

First we need to find stars.  We'll use the DAOPhot algorithm (which at its core is the same as most other PSF photometry tools).

First we estimate the variance in the image to give us some guess as to what might be a good threshold for star-finding.

In [None]:
from astropy.stats import SigmaClip
bkg_var = photutils.background.BiweightScaleBackgroundRMS(
            sigma_clip=SigmaClip(3))(im)
bkg_var

Then we create a `DAOStarFinder` object and run that on the image

In [None]:
star_finder = photutils.findstars.DAOStarFinder(threshold=bkg_var/2, 
                                                fwhm=5)
found_stars = star_finder(im)

plt.imshow(im)
plt.scatter(found_stars['xcentroid'], found_stars['ycentroid'], color='k')

found_stars

And then we create the object to do the photometry, and run it on the table of stars we found.

In [None]:
ph = psf.BasicPSFPhotometry(psf.DAOGroup(10), None, psf_im_model, 
                            (5, 5), aperture_radius=10)

if 'xcentroid' in found_stars.colnames:
    # there's an if here simply to make sure you can run this cell
    # multiple times without re-running the star finder
    found_stars['xcentroid'].name = 'x_0'
    found_stars['ycentroid'].name = 'y_0'
    found_stars['flux'].name = 'flux_0'
    
res = ph.do_photometry(im, found_stars)


plt.imshow(im)
plt.colorbar()
plt.scatter(res['x_0'], res['y_0'], color='k')
plt.scatter(res['x_fit'], res['y_fit'], color='r', s=3, lw=0)

res

And now we try making a residual image to see how well it did

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6))

vmin, vmax = -.3, 2.

ax1.imshow(im, vmin=vmin, vmax=vmax)
ax2.imshow(ph.get_residual_image(), vmin=vmin, vmax=vmax)
ax2.scatter(res['x_fit'], res['y_fit'], color='r', s=3, lw=0)

Well that looks OK except that it looks ugly because our psf model that we fit was a bit small.  So lets try subtracting the *actual* model

In [None]:
subtracted_image = psf.subtract_psf(im, psf_im_model, res)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,6))

vmin, vmax = -.3, 2.

ax1.imshow(im, vmin=vmin, vmax=vmax)
ax2.imshow(subtracted_image, vmin=vmin, vmax=vmax)
ax2.scatter(res['x_fit'], res['y_fit'], color='r', s=3, lw=0)

# Simulated NIRCam data 

Now lets try something like the above, but with simulated NIRCam data, using an oversampled PSF.  

In principal this is one of two modes one might take with real JWST data.  For many cases using a provided PSF (or generated from `webbpsf`) will be sufficient.  But `photutils` will also support high-level tools to build "empirical" PSF's, e.g. directly built from the image.

You will need to download the simulated NIRCam image:

https://stsci.box.com/s/z2sbv2vuqbtsj75fnjdnalnsrvrdcvgt
(the downloaded file should be called `simulated_nircam_1.fits`)

and PSF image:

https://stsci.box.com/s/5kxh7vsvctc5u10ovvdeyv8n6w5tcds0
(the downloaded file should be called `simulated_nircam_psf_1.fits`)

Place both of these files in the same directory that you ran this notebook.

In [None]:
im1fn = 'simulated_nircam_1.fits'
psf1fn = im1fn.replace('_1.fits','_psf_1.fits')

In [None]:
im1f = fits.open(im1fn)
im1 = im1f[1].data
im1h = im1f[1].header
im1wcs = wcs.WCS(im1h)

psf1f = fits.open(psf1fn)
psf1 = psf1f[0].data
psf1h = psf1f[0].header
psf1wcs = wcs.WCS(psf1h)

In [None]:
# this is a quick-and-easy way to re-scale an image, using the 
# astropy.visualization package
from astropy.visualization import LogStretch, PercentileInterval

viz = LogStretch() + PercentileInterval(99)

In [None]:
plt.imshow(viz(im1))

OK, lets histogram it so we can see roughly where the threshold should be

In [None]:
plt.hist(im1.ravel(), bins=100, histtype='step', range=(-10, 200), log=True)
None

In [None]:
dsf = photutils.DAOStarFinder(100, 5)
found_stars = dsf(im1)
found_stars['xcentroid'].name = 'x_0'
found_stars['ycentroid'].name = 'y_0'
found_stars['flux'].name = 'flux_0'
found_stars

In [None]:
plt.imshow(viz(im1))
plt.scatter(found_stars['x_0'], found_stars['y_0'], lw=0, s=3, c='k')
plt.xlim(500, 1500)
plt.ylim(500, 1500)

Now we build the actual PSF model using the file that the PSF is given in. It is using external knowledge that the PSF is 5x oversampled. The simple oversampling below only works as-is because both are square, but that's true here.

In [None]:
# note that this is *not* the same pixel scale as the image above
plt.imshow(viz(psf1))

psfmodel = psf.FittableImageModel(psf1, oversampling=5)

Lets now zoom in on the image somewhere and see how the model looks compared to the actual image scale

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2)

ax1.imshow(viz(im1))
ax1.set_xlim(1080, 1131)
ax1.set_ylim(1100, 1151)
ax1.set_title('Simulated image')

xg, yg = np.mgrid[-25:25, -25:25]
ax2.imshow(viz(psfmodel(xg, yg)))
ax2.set_title('PSF')

Now we build a PSF photometry runner that is auto-configured to work basically the same as DAOPHOT.  All of the steps in photometry are customizable if you like, but for now we'll just use this because it's a familiar code to many people. 

In [None]:
psfphot = psf.DAOPhotPSFPhotometry(crit_separation=5, 
                                   threshold=100, fwhm=5, 
                                   psf_model=psfmodel, fitshape=(9,9),
                                   niters=1, aperture_radius=5)
results = psfphot(im1, found_stars[:100])

Now lets inspect the residual image as a whole, and zoomed in on a few particular stars.

In [None]:
res_im1 = psfphot.get_residual_image()
plt.imshow(viz(res_im1))
results

In [None]:
# these *should* be one by itself and one near the core.  But you might 
# have to change the (10, 500) to something else depending on what you
# want to inspect
for i in (10, 50): 

    resulti = results[i]
    window_rad = 20

    fig, (ax1, ax2) = plt.subplots(1, 2)

    ax1.imshow(viz(im1))
    ax1.set_xlim(resulti['x_fit']-window_rad, resulti['x_fit']+window_rad)
    ax1.set_ylim(resulti['y_fit']-window_rad, resulti['y_fit']+window_rad)
    ax1.scatter([resulti['x_0']], [resulti['y_0']], color='r',s=7)
    ax1.scatter([resulti['x_fit']], [resulti['y_fit']], color='w',s=7)
    ax1.set_title('Original image (star #{})'.format(i))

    ax2.imshow(viz(res_im1))
    ax2.set_xlim(resulti['x_fit']-window_rad, resulti['x_fit']+window_rad)
    ax2.set_ylim(resulti['y_fit']-window_rad, resulti['y_fit']+window_rad)
    ax2.scatter([resulti['x_0']], [resulti['y_0']], color='r',s=7)
    ax2.scatter([resulti['x_fit']], [resulti['y_fit']], color='w',s=7)
    ax2.set_title('Subtracted image (star #{})'.format(i))

OK, looks like at least some of them worked great, but in the crowded areas more iterations/tweaks to the input parameters are needed.  See if you can tweak the parameters to make it better!

For this simulated data set we only have one band, so there's not much output "science" to show... But below you see how to get out magnitudes, 

In [None]:
# this does not yet exist... but the plan is that it will at launch!
#jwst_calibrated_mags(results['flux_fit'], im1h)

# it would do something like this:

# this zero-point is just a made-up number right now, 
# but it's something the instrument team will provide
inst_mag = -2.5*np.log10(results['flux_fit'])
zero_point = 31.2  
results['cal_mag'] = zero_point + inst_mag

# if you scroll to the right you'll see the new column
results

#### Using DS9 or ginga from python (OPTIONAL)

The next few cells showing calls to `imexam` are not critical, but if you have `imexam` installed you can use it to view the found stars in ds9 (by default) or ginga.  These cells provide a simple example of how to use `imexam` connect to the external interactive image viewer.

The example here uses `ginga`.  For ds9 to work as an external viewer here, you must uncomment the code noted below and have both ds9  (available at http://ds9.si.edu/site/Download.html) and the XPA executables (available at http://ds9.si.edu/site/XPA.html) installed.  **MacOS/Linux Users**: An easy way to get ds9 installed is to use the command:

`conda install ds9 --override-channels -c http://ssb.stsci.edu/astroconda`

In [None]:
# As a timing issue, it is best to handle the import and connection in a separate cell from the 
# imex.load_fits command.  This allows the connection to the viewer to be established BEFORE we attempt to 
# load the fits image.
import imexam
imex = imexam.connect(viewer='ginga')
# Comment out the line above and uncomment the following line to use ds9
# imex=imexam.connect()

In [None]:
imex.load_fits(im1fn)
imex.scale()  # Scale the image to fit the viewer
# Uncomment the following lines if you are using ds9 as a viewer, this feature is NOT implemented for ginga
# tomark = [(row['x_0'],row['y_0'],i) for i, row in enumerate(found_stars)]
# imex.mark_region_from_array(tomark)

In [None]:
# Shut down the connection
imex.close()

## Exercises

As you can see, particularly for brighter stars, the above procedure doesn't do well, primarily because the star finder doesn't find them efficiently along with the faint ones.  Try manually editing the `found_stars` table and by-hand insert a few bright stars.  See if you can get them to subtract well. 

Now try playing around with the various options and see if you can do better *automatically*.