Skip to content

Commit

Permalink
Merge pull request #1672 from larrybradley/fix-compound-psf
Browse files Browse the repository at this point in the history
Fix compound PSF model bug
  • Loading branch information
larrybradley committed Nov 27, 2023
2 parents a955e84 + 8c7cde1 commit e7690b5
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 5 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ New Features
Bug Fixes
^^^^^^^^^

- ``photutils.psf``

- Fixed an issue where PSF models produced by ``make_psf_model`` would
raise an error with ``PSFPhotometry`` if the fit did not converge.
[#1672]

API Changes
^^^^^^^^^^^

Expand Down
10 changes: 5 additions & 5 deletions photutils/psf/photometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,26 +602,26 @@ def _make_fit_results(self, models, infos):
nfitparam = len(self._fitted_psf_param_names)
for model, info in zip(models, infos):
model_nsub = model.n_submodels
npsf_models = model_nsub // psf_nsub

param_cov = info.get('param_cov', None)
if param_cov is None:
if nfitparam == 0: # model params are all fixed
nfitparam = 3
param_err = np.array([np.nan] * nfitparam * model_nsub)
nfitparam = 3 # x_err, y_err, and flux_err are np.nan
param_err = np.array([np.nan] * nfitparam * npsf_models)
else:
param_err = np.sqrt(np.diag(param_cov))

# model is for a single source (which may be compound)
if model_nsub == psf_nsub:
if npsf_models == 1:
fit_models.append(model)
fit_infos.append(info)
fit_param_errs.append(param_err)
continue

# model is a grouped model for multiple sources
fit_models.extend(self._split_compound_model(model, psf_nsub))
nsources = model_nsub // psf_nsub
fit_infos.extend([info] * nsources) # views
fit_infos.extend([info] * npsf_models) # views
fit_param_errs.extend(self._split_param_errs(param_err, nfitparam))

if len(fit_models) != len(fit_infos):
Expand Down
50 changes: 50 additions & 0 deletions photutils/psf/tests/test_photometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,56 @@ def test_psf_photometry(test_data):
assert resid_data3.unit == unit


@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required')
def test_psf_photometry_compound(test_data):
"""
Test compound models output from ``make_psf_model``.
"""
data, error, sources = test_data
x_stddev = y_stddev = 1.8
psf_func = Gaussian2D(amplitude=1, x_mean=0, y_mean=0, x_stddev=x_stddev,
y_stddev=y_stddev)
psf_model = make_psf_model(psf_func, x_name='x_mean', y_name='y_mean')

fit_shape = (5, 5)
finder = DAOStarFinder(5.0, 3.0)

psfphot = PSFPhotometry(psf_model, fit_shape, finder=finder,
aperture_radius=4)
phot = psfphot(data, error=error)
assert isinstance(phot, QTable)
assert len(phot) == len(sources)

# test results when fit does not converge (fitter_maxiters=10)
match = r'One or more fit\(s\) may not have converged.'
with pytest.warns(AstropyUserWarning, match=match):
psfphot = PSFPhotometry(psf_model, fit_shape, finder=finder,
aperture_radius=4, fitter_maxiters=10)
phot = psfphot(data, error=error)
columns1 = ['x_err', 'y_err', 'flux_err']
for column in columns1:
assert np.all(np.isnan(phot[column]))

# allow other parameters to vary
psf_model.x_stddev_2.fixed = False
psf_model.y_stddev_2.fixed = False
psfphot = PSFPhotometry(psf_model, fit_shape, finder=finder,
aperture_radius=4, fitter_maxiters=400)
phot = psfphot(data, error=error)
columns2 = ['x_stddev_2_init', 'y_stddev_2_init', 'x_stddev_2_fit',
'y_stddev_2_fit', 'x_stddev_2_err', 'y_stddev_2_err']
for column in columns2:
assert column in phot.colnames

# test results when fit does not converge (fitter_maxiters=10)
with pytest.warns(AstropyUserWarning, match=match):
psfphot = PSFPhotometry(psf_model, fit_shape, finder=finder,
aperture_radius=4, fitter_maxiters=10)
phot = psfphot(data, error=error)
for column in columns1 + ['x_stddev_2_err', 'y_stddev_2_err']:
assert np.all(np.isnan(phot[column]))


@pytest.mark.skipif(not HAS_SCIPY, reason='scipy is required')
def test_psf_photometry_mask(test_data):
data, error, sources = test_data
Expand Down

0 comments on commit e7690b5

Please sign in to comment.