### This Notebook applies a correction to saturated pixels in JWST images using matched JWST PSFs.

Run the Notebook "analyze_ngc1365_miri_psf_webbpsf.ipynb" first! That will produce the correction factors stored in json files. This Notebook will read the correction factors and saturated spot center coordinates then apply that to the real data.

Daizhong Liu (dzliu@mpe.mpg.de; astro.dzliu@gmail.com)


In [1]:
import os, sys, re, copy, glob, shutil, json
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
from astropy.visualization import simple_norm
from astropy.wcs import WCS, FITSFixedWarning
from astropy.wcs.utils import proj_plane_pixel_area, proj_plane_pixel_scales
from astropy.wcs.utils import wcs_to_celestial_frame
from astropy.nddata import Cutout2D
from astropy.coordinates import SkyCoord, FK5, ICRS
from astropy.units import Quantity
from collections import OrderedDict
from matplotlib import ticker
from mpl_toolkits.axes_grid1.axes_divider import make_axes_locatable
from reproject import reproject_interp
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
warnings.filterwarnings("ignore", category=FITSFixedWarning)
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['xtick.top'] = True
plt.rcParams['ytick.right'] = True


In [2]:
# Load functions from 

%run analyze_ngc1365_miri_psf_webbpsf.ipynb

assert 'add_mock_wcs_to_psf' in globals()
assert 'get_position_angle' in globals()


NIRCam F200W position angle: 261.1078648908976
NIRCam F300M position angle: 261.1090399455477
NIRCam F335M position angle: 261.1089846394891
NIRCam F360M position angle: 261.0893218163628
MIRI F770W position angle: 265.5578648908976
MIRI F1000W position angle: 265.5578648908976
MIRI F1130W position angle: 265.5578648908976
MIRI F2100W position angle: 265.5578648908976
NIRCam F300M, applying astrometry correction wcs shift 0.0 arcsec -0.015 arcsec
NIRCam F335M, applying astrometry correction wcs shift 0.0 arcsec -0.015 arcsec
NIRCam F360M, applying astrometry correction wcs shift 0.0 arcsec -0.015 arcsec
MIRI F770W, applying astrometry correction wcs shift 0.16 arcsec -0.11 arcsec
MIRI F1000W, applying astrometry correction wcs shift 0.16 arcsec -0.11 arcsec
MIRI F1130W, applying astrometry correction wcs shift 0.05 arcsec -0.06 arcsec
MIRI F2100W, applying astrometry correction wcs shift 0.05 arcsec -0.06 arcsec
NIRCam F200W pixelscale: 0.00311
NIRCam F300M pixelscale: 0.0063
NIRCam F3

In [3]:
# Here are the results from running "analyze_ngc1365_miri_psf_webbpsf.ipynb" in advance

center_skycoord_lists = OrderedDict()
cutout_size_lists = OrderedDict()
rescaling_base_lists = OrderedDict()
rescaling_factor_lists = OrderedDict()
saturated_core_lists = OrderedDict()

for json_file in [
        'out_analyzing_ngc1365_miri_psf_53.40154167_-36.14041694_fk5.json', 
        'out_analyzing_ngc1365_miri_psf_53.40254792_-36.13764111_fk5.json', 
        'out_analyzing_ngc1365_miri_psf_53.40321167_-36.13849722_fk5.json', 
        'out_analyzing_ngc1365_miri_psf_53.40169417_-36.13840611_fk5.json', 
    ]:
    with open(json_file, 'r') as fp:
        info_list = json.load(fp)
        for key in info_list:
            info_dict = info_list[key]
            if info_dict['saturated_core'][0] > 0.0:
                if key not in center_skycoord_lists:
                    center_skycoord_lists[key] = []
                    cutout_size_lists[key] = []
                    rescaling_base_lists[key] = []
                    rescaling_factor_lists[key] = []
                    saturated_core_lists[key] = []
                # 
                ra = Quantity(*info_dict['star_skycoord_ra'])
                dec = Quantity(*info_dict['star_skycoord_dec'])
                frame = info_dict['star_skycoord_frame']
                center_skycoord_lists[key].append(SkyCoord(ra, dec, frame=frame))
                cutout_size_lists[key].append(Quantity(*info_dict['cutout_size']))
                rescaling_base_lists[key].append(info_dict['rescaling_base'])
                rescaling_factor_lists[key].append(info_dict['rescaling_factor'])
                saturated_core_lists[key].append(Quantity(*info_dict['saturated_core']))

print('center_skycoord_lists.keys()', center_skycoord_lists.keys())


center_skycoord_lists.keys() odict_keys(['NIRCam F200W', 'NIRCam F300M', 'NIRCam F335M', 'NIRCam F360M', 'MIRI F770W', 'MIRI F1000W', 'MIRI F1130W', 'MIRI F2100W'])


In [4]:
def fix_saturated_pixels(
        sci_image = None, 
        sci_header = None, 
        sci_posangle = None, 
        rms_image = None, 
        rms_header = None, 
        psf_image = None, 
        psf_header = None, 
        center_skycoord_list = None, # list of SkyCoord
        cutout_size_list = None, # list of Quantity (arcsec)
        rescaling_factor_list = None, 
        rescaling_base_list = None, 
        saturated_core_list = None, # list of Quantity (arcsec)
        out_sci_file = None, 
        out_rms_file = None, 
    ):
    
    # global sci_files
    # global sci_images
    # global sci_headers
    # global sci_posangles
    # global rms_files
    # global rms_images
    # global rms_headers
    # global psf_files
    # global psf_images
    # global psf_headers
    # global cutout_sizes
    
    print('Process fix_saturated_pixels started..')
    
    sci_wcs = WCS(sci_header, naxis=2)

    pixel_scale = get_pixel_scale(sci_wcs)

    psf_oversample = get_oversampling_factor(psf_header)

    sci_position_angle = sci_posangle

    for center_skycoord, cutout_size, rescaling_factor, rescaling_base, saturated_core \
        in zip(center_skycoord_list, cutout_size_list, rescaling_factor_list, rescaling_base_list, saturated_core_list):
        
        # make cutout for the sci data
        sci_image_cutout = make_cutout(
            image = sci_image, 
            wcs = sci_wcs, 
            scoord = center_skycoord,
            cutsize = cutout_size, 
        )

        # make cutout for the rms data
        (ymin, ymax), (xmin, xmax) = sci_image_cutout.bbox_original
        rms_image_cutout_data = rms_image[ymin:ymax+1, xmin:xmax+1]

        # add wcs to the psf data and cutout to the specified size
        psf_image_cutout = add_mock_wcs_to_psf(
            psf_image, 
            psf_header, 
            center_ra_dec = center_skycoord, 
            pixel_scale = pixel_scale / psf_oversample,
            cutout_size = cutout_size * 1.25, # make it slightly larger for later reproject/cutout
            position_angle = sci_position_angle,
        )

        # rescale psf data
        #print('np.nanmax(psf_image)', np.nanmax(psf_image))
        #print('np.nanmax(psf_image_cutout.data)', np.nanmax(psf_image_cutout.data))
        psf_image_cutout.data = psf_image_cutout.data * rescaling_factor + rescaling_base
        #print('np.nanmax(psf_image_cutout.data)', np.nanmax(psf_image_cutout.data))

        # resample psf cutout to the sci original-pixel-scale cutout image dimension 
        psf_image_resampled_data = reproject_interp(
            (psf_image_cutout.data, psf_image_cutout.wcs), 
            sci_image_cutout.wcs, 
            sci_image_cutout.data.shape,
            return_footprint = False,
        )

        # prepare pixel grid arrays
        ny, nx = psf_image_resampled_data.shape
        gy, gx = np.mgrid[0:ny,0:nx]
        sc = center_skycoord.transform_to(wcs_to_celestial_frame(sci_image_cutout.wcs))
        cx, cy = sci_image_cutout.wcs.all_world2pix([sc.ra.deg], [sc.dec.deg], 0)
        cx, cy = cx[0], cy[0]
        gradii = np.sqrt((gx-cx)**2 + (gy-cy)**2) * pixel_scale.to(u.arcsec)

        # fix sci image data within the cutout
        sci_image_cutout_data = copy.copy(sci_image_cutout.data)
        saturated_mask = (gradii <= saturated_core)
        #sci_image_cutout_data[np.logical_and(saturated_mask, np.isnan(sci_image_cutout_data))] = 0.0
        sci_image_cutout_data[np.logical_and(saturated_mask, sci_image_cutout_data==0.0)] = np.nan # saturated pixels NaN
        
        psf_weight = np.full(psf_image_resampled_data.shape, fill_value=0.0)
        blo, bhi = 0.95, 1.25
        psf_weight[gradii/saturated_core<=blo] = 1.0
        buffer_area = np.logical_and(gradii/saturated_core>blo, gradii/saturated_core<bhi) # buffer 
        psf_weight[buffer_area] = 1.0 - (gradii[buffer_area]/saturated_core-blo)/(bhi-blo)
        # saturated core radius is where psf surf. dens. equals that of sci data.
        sci_image_fixed_data = psf_image_resampled_data * psf_weight + \
                               sci_image_cutout_data * (1.0 - psf_weight)
        mask = np.isnan(sci_image_cutout_data)
        sci_image_fixed_data[mask] = psf_image_resampled_data[mask]
        mask = (sci_image_fixed_data < sci_image_cutout_data)
        sci_image_fixed_data[mask] = sci_image_cutout_data[mask]
        #print('np.nanmax(sci_image_fixed_data)', np.nanmax(sci_image_fixed_data))
        #print('np.nanmax(sci_image_cutout.data)', np.nanmax(sci_image_cutout.data))

        # combine rms too
        psf_image_resampled_error = psf_image_resampled_data * 0.20 # 20% error
        rms_image_cutout_data[np.logical_and(saturated_mask, np.isnan(rms_image_cutout_data))] = 0.0
        #rms_image_fixed_data = psf_image_resampled_error
        rms_image_fixed_data = psf_image_resampled_error * psf_weight + \
                               rms_image_cutout_data * (1.0 - psf_weight)
        mask = (rms_image_fixed_data < rms_image_cutout_data)
        rms_image_fixed_data[mask] = rms_image_cutout_data[mask]

        # fill-in back to the sci original data array
        (ymin, ymax), (xmin, xmax) = sci_image_cutout.bbox_original
        sci_image_fixed = copy.copy(sci_image)
        sci_image_fixed[ymin:ymax+1, xmin:xmax+1] = sci_image_fixed_data[:, :]

        rms_image_fixed = copy.copy(rms_image)
        rms_image_fixed[ymin:ymax+1, xmin:xmax+1] = rms_image_fixed_data[:, :]

        fixing_history_str = 'Fixed saturated pixels within Cutout2D ' + \
                str(sci_image_cutout.bbox_original) + \
                ' at RA Dec {} {}'.format(center_skycoord.to_string('hmsdms', sep=':', precision=6), \
                                          center_skycoord.frame.name.upper()) + \
                ' with PSF scaling factor {}'.format(rescaling_factor) + \
                ' and saturated core radius {} arcsec by dzliu.'.format(saturated_core)
        print(fixing_history_str)
        
        sci_header_fixed = copy.deepcopy(sci_header)
        sci_header_fixed['HISTORY'] = fixing_history_str

        rms_header_fixed = copy.deepcopy(rms_header)
        rms_header_fixed['HISTORY'] = fixing_history_str

        try:
            WCS(rms_header_fixed, naxis=2)
        except:
            for k in sci_wcs.to_header():
                rms_header_fixed[k] = sci_header_fixed[k]
        
        # output
        if os.path.isfile(out_sci_file):
            shutil.move(out_sci_file, out_sci_file+'.backup')
        fits.PrimaryHDU(data=sci_image_fixed, header=sci_header_fixed).writeto(out_sci_file)
        print('Output to {!r}'.format(out_sci_file))

        if os.path.isfile(out_rms_file):
            shutil.move(out_rms_file, out_rms_file+'.backup')
        fits.PrimaryHDU(data=rms_image_fixed, header=rms_header_fixed).writeto(out_rms_file)
        print('Output to {!r}'.format(out_rms_file))
        
        # store back to sci_image etc.
        sci_image = sci_image_fixed
        #sci_header = sci_header_fixed
        rms_image = rms_image_fixed
        #rms_header = rms_header_fixed

    print('Process fix_saturated_pixels finished!')


In [5]:
if __name__ == '__main__' and '__file__' not in globals():
    
    for key in sci_images:

        sci_file = sci_files[key]
        out_sci_file = re.sub(r'\.fits$', r'', sci_file) + '_fixedsatur.sci.dzliu.fits'
        out_rms_file = re.sub(r'\.fits$', r'', sci_file) + '_fixedsatur.rms.dzliu.fits'
        
        fix_saturated_pixels(
            sci_image = sci_images[key],
            sci_header = sci_headers[key],
            sci_posangle = sci_posangles[key],
            rms_image = rms_images[key],
            rms_header = rms_headers[key],
            psf_image = psf_images[key],
            psf_header = psf_headers[key],
            center_skycoord_list = center_skycoord_lists[key], 
            cutout_size_list = cutout_size_lists[key], 
            rescaling_factor_list = rescaling_factor_lists[key], 
            rescaling_base_list = rescaling_base_lists[key], 
            saturated_core_list = saturated_core_lists[key], 
            out_sci_file = out_sci_file, 
            out_rms_file = out_rms_file, 
        )
    
    
    
    

Process fix_saturated_pixels started..
Fixed saturated pixels within Cutout2D ((3264, 3320), (3672, 3728)) at RA Dec 03:33:36.370000 -36:08:25.501000 FK5 with PSF scaling factor 433792532.98767394 and saturated core radius 0.2573049928559211 arcsec arcsec by dzliu.
Output to '/Users/dzliu/Data/PHANGS-JWST/jwst_reprocessed/NIRCam_v0p4p2/ngc1365_nircam_lv3_f200w_i2d_align_fixedsatur.sci.dzliu.fits'
Output to '/Users/dzliu/Data/PHANGS-JWST/jwst_reprocessed/NIRCam_v0p4p2/ngc1365_nircam_lv3_f200w_i2d_align_fixedsatur.rms.dzliu.fits'
Fixed saturated pixels within Cutout2D ((3588, 3644), (3578, 3634)) at RA Dec 03:33:36.611500 -36:08:15.508000 FK5 with PSF scaling factor 20089047.750874165 and saturated core radius 0.19407031497565794 arcsec arcsec by dzliu.
Output to '/Users/dzliu/Data/PHANGS-JWST/jwst_reprocessed/NIRCam_v0p4p2/ngc1365_nircam_lv3_f200w_i2d_align_fixedsatur.sci.dzliu.fits'
Output to '/Users/dzliu/Data/PHANGS-JWST/jwst_reprocessed/NIRCam_v0p4p2/ngc1365_nircam_lv3_f200w_i2d_ali