In [1]:
# Import the usual libraries
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

# Enable inline plotting
%matplotlib inline

# Progress bar
from tqdm.auto import trange, tqdm

In [2]:
import webbpsf_ext

In [3]:
webbpsf_ext.setup_logging('WARN')

webbpsf_ext log messages of level WARN and above will be shown.
webbpsf_ext log outputs will be directed to the screen.


In [28]:
from scipy.interpolate import interp1d
from webbpsf_ext.webbpsf_ext_core import _transmission_map, _calc_psf_from_coeff

class NIRCam_PSF():
    
    from webbpsf_ext import NIRCam_ext
    
    def __init__(self, filter, image_mask, fov_pix, oversample, sp=None, use_coeff=True, **kwargs):
        """
        """

        # Choose Lyot stop based on coronagraphic mask input
        if image_mask is None:
            pupil_mask = None
        else:
            pupil_mask = 'CIRCLYOT' if image_mask[-1]=='R' else 'WEDGELYOT'
            
        nrc_on = self.NIRCam_ext(filter=filter, image_mask=image_mask, pupil_mask=pupil_mask,
                                 fov_pix=fov_pix, oversample=oversample, **kwargs)

        nrc_off = self.NIRCam_ext(filter=filter, image_mask=None, pupil_mask=pupil_mask,
                                  fov_pix=fov_pix, oversample=oversample, **kwargs)
        
        nrc_on.options['jitter_sigma'] = 0
        nrc_off.options['jitter_sigma'] = 0
        
        # Generating initial PSFs...
        print('Generating initial PSFs...')
        if use_coeff:
            nrc_on.gen_psf_coeff()
            nrc_off.gen_psf_coeff()
            func_on = nrc_on.calc_psf_from_coeff
            func_off = nrc_off.calc_psf_from_coeff
        else:
            func_on = nrc_on.calc_psf
            func_off = nrc_off.calc_psf
            
        # On axis PSF
        if image_mask[-1]=='R':
            self.psf_on = func_on(sp=sp, return_oversample=False, return_hdul=False)
        elif image_mask[-1]=='B':
            # Need an array of PSFs along bar center
            xvals = np.linspace(-8,8,9)
            self.psf_bar_xvals = xvals
            
            if use_coeff:
                nrc_on.gen_wfemask_coeff(large_grid=True)

            psf_bar_arr = []
            for xv in tqdm(xvals, desc='Bar PSFs', leave=False):
                psf = func_on(sp=sp, return_oversample=False, return_hdul=False, 
                              coord_vals=(xv,0), coord_frame='idl')
                psf_bar_arr.append(psf)
            self.psf_on = np.array(psf_bar_arr)
        else:
            raise 
            
        # Off axis PSF
        self.psf_off = func_off(sp=sp, return_oversample=False, return_hdul=False)

        # Store NIRCam classes
        self.nrc_on  = nrc_on
        self.nrc_off = nrc_off
        
        # PSF generation functions for later use
        self._use_coeff = use_coeff
        self._func_on  = func_on
        self._func_off = func_off

        self.sp = sp
        
    @property
    def fov_pix(self):
        return self.nrc_on.fov_pix
    @property
    def oversample(self):
        return self.nrc_on.oversample

    @property
    def filter(self):
        return self.nrc_on.filter
    @property
    def image_mask(self):
        return self.nrc_on.image_mask
    @property
    def pupil_mask(self):
        return self.nrc_on.pupil_mask
    
    @property
    def use_coeff(self):
        return self._use_coeff
    
    def rth_to_xy(self, r, th, PA_V3=0):
        """ Convert (r,th) location to (x,y) in idl coords

        Assume (r,th) in coordinate system with North up East to the left.
        Then convert to NIRCam detector orientation (idl coord frame).
        Units assumed to be in arcsec.

        th : float
            Position angle (positive angles East of North) in degrees.
            Can also be an array; must match size of `r`.
        """

        # Convert to aperture PA
        PA_ap = PA_V3 + self.nrc_on.siaf_ap.V3IdlYAngle
        # Get theta relative to detector orientation (idl frame)
        th_fin = th - PA_ap
        # Return (x,y) in idl frame
        return webbpsf_ext.coords.rtheta_to_xy(r, th_fin)

    
    def gen_psf_xy(self, coord_vals, coord_frame, quick=True):
        """ Generate offset PSF
        
        Parameters
        ==========
        coord_vals : tuple or None
            Coordinates (in arcsec or pixels) to calculate field-dependent PSF.
            If multiple values, then this should be an array ([xvals], [yvals]).
        coord_frame : str
            Type of input coordinates. 

                * 'tel': arcsecs V2,V3
                * 'sci': pixels, in DMS axes orientation; aperture-dependent
                * 'det': pixels, in raw detector read out axes orientation
                * 'idl': arcsecs relative to aperture reference location.
        quick : bool
            Use linear combination of on-axis and off-axis PSFs to generate
            PSF as a function of corongraphic mask throughput. This is much
            faster than the standard
        """
        
        if quick:
            t_temp, cx_idl, _ = _transmission_map(self.nrc_on, coord_vals, coord_frame)
            trans = t_temp**2

            # Linear combination of min/max to determine PSF
            # Get a and b values for each position
            avals = trans
            bvals = 1 - avals

            if self.image_mask[-1]=='B':
                # Interpolation function
                xvals = self.psf_bar_xvals
                psf_arr = self.psf_on
                finterp = interp1d(xvals, psf_arr, kind='linear', fill_value='extrapolate', axis=0)
                psf_on = finterp(cx_idl)
            else:
                psf_on = self.psf_on
            psf_off = self.psf_off

            ny = nx = self.fov_pix
            res = avals.reshape([-1,1,1]) * psf_off.reshape([1,ny,nx]) \
                + bvals.reshape([-1,1,1]) * psf_on.reshape([1,ny,nx])
            
            return res.squeeze()
        
        else:
            calc_psf = self._func_on
            return calc_psf(sp=self.sp, coord_vals=coord_vals, coord_frame=coord_frame,
                            return_oversample=False, return_hdul=False)

In [29]:
sp = webbpsf_ext.stellar_spectrum('G2V')

filt = 'F300M' 
image_mask = 'MASK335R'

fov_pix = 257
oversample = 2

kwargs = {}
inst = NIRCam_PSF(filt, image_mask, fov_pix, oversample, **kwargs)

Generating initial PSFs...


In [47]:
xv = np.linspace(-10,10,101)
yv = np.linspace(-10,10,101)

coords = (xv, yv)
%time test = inst.gen_psf_xy(coords, coord_frame='idl')

CPU times: user 34.2 ms, sys: 30 ms, total: 64.2 ms
Wall time: 61.5 ms
