Skip to content

Commit

Permalink
Merge pull request #1664 from larrybradley/gridpsf-oversamp
Browse files Browse the repository at this point in the history
GriddedPSF oversampling in x and y
  • Loading branch information
larrybradley committed Nov 21, 2023
2 parents 345a914 + 9fd444b commit 21f97e2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 60 deletions.
7 changes: 7 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ New Features
- Added ``make_psf_model`` function for making a PSF model from a
2D Astropy model. Compound models are also supported. [#1658]

- The ``GriddedPSFModel`` oversampling can now be different in the x
and y directions. The ``oversampling`` attribute is now stored as
a 1D ``numpy.ndarray`` with two elements. [#1664]

- ``photutils.segmentation``

- The ``SegmentationImage`` ``make_source_mask`` method now uses a
Expand Down Expand Up @@ -99,6 +103,9 @@ API Changes
first sorted by y then by x. As a result, the order of the ``data``
and ``xygrid`` attributes may be different. [#1661]

- The ``oversampling`` attribute is now stored as a 1D
``numpy.ndarray`` with two elements. [#1664]

- A ``ValueError`` is raised if ``GriddedPSFModel`` is called with x
and y arrays that have more than 2 dimensions. [#1662]

Expand Down
36 changes: 20 additions & 16 deletions photutils/psf/griddedpsfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from astropy.nddata import NDData, reshape_as_blocks
from astropy.visualization import simple_norm

from photutils.utils._parameters import as_pair

__all__ = ['GriddedPSFModel', 'ModelGridPlotMixin', 'stdpsf_reader',
'webbpsf_reader', 'STDPSFGrid']
__doctest_skip__ = ['GriddedPSFModelRead', 'STDPSFGrid']
Expand Down Expand Up @@ -310,8 +312,10 @@ class GriddedPSFModel(ModelGridPlotMixin, Fittable2DModel):
words, ``grid_xypos[i]`` should be the (x, y) position of
the reference ePSF defined in ``data[i]``.
* ``'oversampling'``: The integer oversampling factor of the
ePSF.
* ``'oversampling'``: The integer oversampling factor(s) of
the ePSF. If ``oversampling`` is a scalar then it will be
used for both axes. If ``oversampling`` has two elements,
they must be in ``(y, x)`` order.
The meta attribute may contain other properties such as the
telescope, instrument, detector, and filter of the ePSF.
Expand Down Expand Up @@ -341,10 +345,13 @@ class GriddedPSFModel(ModelGridPlotMixin, Fittable2DModel):
def __init__(self, nddata, *, flux=flux.default, x_0=x_0.default,
y_0=y_0.default, fill_value=0.0):

self._data_input = self._validate_data(nddata)
self._nddata = self._validate_data(nddata)
self.data, self.grid_xypos = self._define_grid(nddata)
self._meta = nddata.meta # use _meta to avoid the meta descriptor
self.oversampling = nddata.meta['oversampling']
self.oversampling = as_pair('oversampling',
nddata.meta['oversampling'],
lower_bound=(0, 1))

self.fill_value = fill_value

self._grid_xpos, self._grid_ypos = np.transpose(self.grid_xypos)
Expand Down Expand Up @@ -385,8 +392,6 @@ def _validate_data(data):
if 'oversampling' not in data.meta:
raise ValueError('"oversampling" must be in the nddata meta '
'dictionary.')
if not np.isscalar(data.meta['oversampling']):
raise ValueError('oversampling must be a scalar value')

return data

Expand Down Expand Up @@ -429,8 +434,7 @@ def __str__(self):
cls_info.extend([('Number of ePSFs', len(self.grid_xypos)),
('ePSF shape (oversampled pixels)',
self.data.shape[1:]),
('Oversampling', self.oversampling),
])
('Oversampling', tuple(self.oversampling))])

with np.printoptions(threshold=25, edgeitems=5):
fmt = [f'{key}: {val}' for key, val in cls_info]
Expand All @@ -447,7 +451,7 @@ def copy(self):
Note that the ePSF grid data is not copied. Use the `deepcopy`
method if you want to copy the ePSF grid data.
"""
return self.__class__(self._data_input, flux=self.flux.value,
return self.__class__(self._nddata, flux=self.flux.value,
x_0=self.x_0.value, y_0=self.y_0.value,
fill_value=self.fill_value)

Expand Down Expand Up @@ -633,8 +637,8 @@ def evaluate(self, x, y, flux, x_0, y_0):

# now evaluate the ePSF at the (x_0, y_0) subpixel position on
# the input (x, y) values
xi = self.oversampling * (np.asarray(x, dtype=float) - x_0)
yi = self.oversampling * (np.asarray(y, dtype=float) - y_0)
xi = self.oversampling[1] * (np.asarray(x, dtype=float) - x_0)
yi = self.oversampling[0] * (np.asarray(y, dtype=float) - y_0)

# define origin at the ePSF image center
ny, nx = self.data.shape[1:]
Expand Down Expand Up @@ -957,7 +961,7 @@ def stdpsf_reader(filename, detector_id=None):
# itertools.product iterates over the last input first
xy_grid = [yx[::-1] for yx in itertools.product(ygrid, xgrid)]

oversampling = 4
oversampling = 4 # assumption for STDPSF files
nxpsfs = xgrid.shape[0]
nypsfs = ygrid.shape[0]
meta = {'grid_xypos': xy_grid,
Expand Down Expand Up @@ -1124,9 +1128,10 @@ def __init__(self, filename):
self._ygrid = grid_data['ygrid']
xy_grid = [yx[::-1] for yx in itertools.product(self._ygrid,
self._xgrid)]
oversampling = 4
oversampling = 4 # assumption for STDPSF files
self.grid_xypos = xy_grid
self.oversampling = oversampling
self.oversampling = as_pair('oversampling', oversampling,
lower_bound=(0, 1))
meta = {'grid_shape': (len(self._ygrid), len(self._xgrid)),
'grid_xypos': xy_grid,
'oversampling': oversampling}
Expand All @@ -1152,8 +1157,7 @@ def __str__(self):
cls_info.extend([('Number of ePSFs', len(self.grid_xypos)),
('ePSF shape (oversampled pixels)',
self.data.shape[1:]),
('Oversampling', self.oversampling),
])
('Oversampling', self.oversampling)])

with np.printoptions(threshold=25, edgeitems=5):
fmt = [f'{key}: {val}' for key, val in cls_info]
Expand Down
42 changes: 20 additions & 22 deletions photutils/psf/tests/test_griddedpsfmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from astropy.modeling.models import Gaussian2D
from astropy.nddata import NDData
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

from photutils.psf import GriddedPSFModel, STDPSFGrid
from photutils.segmentation import SourceCatalog, detect_sources
Expand Down Expand Up @@ -64,8 +64,8 @@ def test_gridded_psf_model(self, psfmodel):
assert key in psfmodel.meta
grid_xypos = psfmodel.grid_xypos
assert len(grid_xypos) == 16
assert psfmodel.oversampling == 4
assert psfmodel.meta['oversampling'] == psfmodel.oversampling
assert_equal(psfmodel.oversampling, [4, 4])
assert_equal(psfmodel.meta['oversampling'], psfmodel.oversampling)
assert psfmodel.data.shape == (16, 101, 101)

idx = np.lexsort((grid_xypos[:, 0], grid_xypos[:, 1]))
Expand Down Expand Up @@ -143,8 +143,7 @@ def test_gridded_psf_model_invalid_inputs(self):
GriddedPSFModel(nddata)

# check grid_xypos length
meta = {'grid_xypos': [[0, 0], [1, 0], [1, 0]],
'oversampling': 4}
meta = {'grid_xypos': [[0, 0], [1, 0], [1, 0]], 'oversampling': 4}
nddata = NDData(data, meta=meta)
with pytest.raises(ValueError):
GriddedPSFModel(nddata)
Expand All @@ -162,13 +161,6 @@ def test_gridded_psf_model_invalid_inputs(self):
with pytest.raises(ValueError):
GriddedPSFModel(nddata)

# check oversampling is a scalar
meta = {'grid_xypos': [[0, 0], [0, 1], [1, 0], [1, 1]],
'oversampling': [4, 4]}
nddata = NDData(data, meta=meta)
with pytest.raises(ValueError):
GriddedPSFModel(nddata)

@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required')
def test_gridded_psf_model_eval(self, psfmodel):
"""
Expand All @@ -178,9 +170,9 @@ def test_gridded_psf_model_eval(self, psfmodel):
shape = (200, 200)
data = np.zeros(shape)
eval_xshape = (np.ceil(psfmodel.data.shape[2]
/ psfmodel.oversampling)).astype(int)
/ psfmodel.oversampling[1])).astype(int)
eval_yshape = (np.ceil(psfmodel.data.shape[1]
/ psfmodel.oversampling)).astype(int)
/ psfmodel.oversampling[0])).astype(int)

xx = [40, 50, 160, 160]
yy = [60, 150, 50, 140]
Expand Down Expand Up @@ -213,7 +205,7 @@ def test_copy(self, psfmodel):
new_model = psfmodel.copy()
new_model.flux = 100
assert new_model.flux.value != flux
assert new_model._data_input is psfmodel._data_input
assert new_model._nddata is psfmodel._nddata

def test_deepcopy(self, psfmodel):
flux = psfmodel.flux.value
Expand Down Expand Up @@ -244,6 +236,12 @@ def test_repr_str(self, psfmodel):
for key in keys:
assert key in repr(psfmodel)

def test_gridded_psf_oversampling(self, psfmodel):
nddata = psfmodel._nddata
nddata.meta['oversampling'] = [4, 4]
psfmodel2 = GriddedPSFModel(nddata)
assert_equal(psfmodel2.oversampling, psfmodel.oversampling)

def test_read_stdpsf(self):
"""
Test STDPSF read for a single detector.
Expand All @@ -252,8 +250,8 @@ def test_read_stdpsf(self):
filename = op.join(op.dirname(op.abspath(__file__)), 'data', filename)
psfmodel = GriddedPSFModel.read(filename)
assert psfmodel.data.shape[0] == len(psfmodel.meta['grid_xypos'])
assert psfmodel.oversampling == 4
assert psfmodel.meta['oversampling'] == psfmodel.oversampling
assert_equal(psfmodel.oversampling, [4, 4])
assert_equal(psfmodel.meta['oversampling'], psfmodel.oversampling)

@pytest.mark.parametrize(('filename', 'detector_id'),
list(product(STDPSF_FILENAMES[1:], (1, 2))))
Expand All @@ -265,8 +263,8 @@ def test_read_stdpsf_multi_detector(self, filename, detector_id):
psfmodel = GriddedPSFModel.read(filename, detector_id=detector_id,
format='stdpsf')
assert psfmodel.data.shape[0] == len(psfmodel.meta['grid_xypos'])
assert psfmodel.oversampling == 4
assert psfmodel.meta['oversampling'] == psfmodel.oversampling
assert_equal(psfmodel.oversampling, [4, 4])
assert_equal(psfmodel.meta['oversampling'], psfmodel.oversampling)

# test format auto-detect
filename = op.join(op.dirname(op.abspath(__file__)), 'data', filename)
Expand All @@ -287,8 +285,8 @@ def test_read_webbpsf(self, filename):
filename = op.join(op.dirname(op.abspath(__file__)), 'data', filename)
psfmodel = GriddedPSFModel.read(filename, format='webbpsf')
assert psfmodel.data.shape[0] == len(psfmodel.meta['grid_xypos'])
assert psfmodel.oversampling == 4
assert psfmodel.meta['oversampling'] == psfmodel.oversampling
assert_equal(psfmodel.oversampling, [4, 4])
assert_equal(psfmodel.meta['oversampling'], psfmodel.oversampling)
psfmodel.plot_grid()

# test format auto-detect
Expand Down Expand Up @@ -316,7 +314,7 @@ def test_stdpsfgrid(filename):
psfgrid = STDPSFGrid(filename)
assert 'grid_xypos' in psfgrid.meta
assert 'oversampling' in psfgrid.meta
assert psfgrid.oversampling == 4
assert_equal(psfgrid.oversampling, [4, 4])
assert psfgrid.data.shape[0] == len(psfgrid.meta['grid_xypos'])

psfgrid.plot_grid()
Expand Down
8 changes: 4 additions & 4 deletions photutils/psf/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from astropy.nddata import NDData
from astropy.table import Table
from astropy.utils.exceptions import AstropyDeprecationWarning
from numpy.testing import assert_allclose
from numpy.testing import assert_allclose, assert_equal

from photutils import datasets
from photutils.detection import find_peaks
Expand Down Expand Up @@ -475,8 +475,8 @@ def test_basic_test_grid_from_epsfs(self):
psf_grid = grid_from_epsfs(self.epsfs)

assert np.all(psf_grid.oversampling == self.epsfs[0].oversampling)
assert psf_grid.data.shape == (4, psf_grid.oversampling * 25 + 1,
psf_grid.oversampling * 25 + 1)
assert psf_grid.data.shape == (4, psf_grid.oversampling[0] * 25 + 1,
psf_grid.oversampling[1] * 25 + 1)

def test_grid_xypos(self):
"""
Expand Down Expand Up @@ -514,5 +514,5 @@ def test_meta(self):
for key in keys + ['extra_key']:
assert key in psf_grid.meta
assert psf_grid.meta['grid_xypos'].sort() == self.grid_xypos.sort()
assert psf_grid.meta['oversampling'] == 4
assert_equal(psf_grid.meta['oversampling'], [4, 4])
assert psf_grid.meta['fill_value'] == 0.0
22 changes: 4 additions & 18 deletions photutils/psf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -692,16 +692,7 @@ def grid_from_epsfs(epsfs, grid_xypos=None, meta=None):
data_arrs.append(epsf.data)

if i == 0:
# EPSFModel allows a tuple for oversampling factor in x, y,
# but GriddedPSFModel requires it to be a single scalar value.
# Keep this condition for now by checking that x and y match
if np.isscalar(epsf.oversampling):
oversampling = epsf.oversampling
else:
if epsf.oversampling[0] != epsf.oversampling[1]:
raise ValueError('Oversampling must be the same in x and '
'y.')
oversampling = epsf.oversampling[0]
oversampling = epsf.oversampling

# same for fill value and flux, grid will have a single value
# so it should be the same for all input, and error if not.
Expand All @@ -720,14 +711,9 @@ def grid_from_epsfs(epsfs, grid_xypos=None, meta=None):
pass # just keep as None

else:
if np.isscalar(epsf.oversampling):
if epsf.oversampling != oversampling:
raise ValueError('All input EPSFModels must have the same '
'value for ``oversampling``.')
if (epsf.oversampling[0] != epsf.oversampling[1]
!= oversampling):
raise ValueError('All input EPSFModels must have the '
'same value for ``oversampling``.')
if np.any(epsf.oversampling != oversampling):
raise ValueError('All input EPSFModels must have the same '
'value for ``oversampling``.')

if epsf.fill_value != fill_value:
raise ValueError('All input EPSFModels must have the same '
Expand Down

0 comments on commit 21f97e2

Please sign in to comment.