Skip to content

Commit

Permalink
Merge pull request #10980 from rrjbca/broadcast_redshift
Browse files Browse the repository at this point in the history
Fix specialized comoving distance functions for iterable and broadcastable redshift arguments
  • Loading branch information
mhvk committed Dec 13, 2020
2 parents 2388de9 + b513cfe commit 0d9a916
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 26 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,9 @@ astropy.coordinates
astropy.cosmology
^^^^^^^^^^^^^^^^^

- Fixed an issue where specializations of the comoving distance calculation
for certain cosmologies could not handle redshift arrays. [#10980]

astropy.extern
^^^^^^^^^^^^^^

Expand Down
42 changes: 16 additions & 26 deletions astropy/cosmology/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1758,14 +1758,10 @@ def _elliptic_comoving_distance_z1z2(self, z1, z2):
Comoving distance in Mpc between each input redshift.
"""
from scipy.special import ellipkinc
if isiterable(z1):
z1 = np.asarray(z1)
if isiterable(z2):
z2 = np.asarray(z2)
if isiterable(z1) and isiterable(z2):
if z1.shape != z2.shape:
msg = "z1 and z2 have different shapes"
raise ValueError(msg)
try:
z1, z2 = np.broadcast_arrays(z1, z2)
except ValueError as e:
raise ValueError("z1 and z2 have different shapes") from e

# The analytic solution is not valid for any of Om0, Ode0, Ok0 == 0.
# Use the explicit integral solution for these cases.
Expand Down Expand Up @@ -1829,12 +1825,10 @@ def _dS_comoving_distance_z1z2(self, z1, z2):
d : `~astropy.units.Quantity`
Comoving distance in Mpc between each input redshift.
"""
if isiterable(z1):
z1 = np.asarray(z1)
z2 = np.asarray(z2)
if z1.shape != z2.shape:
msg = "z1 and z2 have different shapes"
raise ValueError(msg)
try:
z1, z2 = np.broadcast_arrays(z1, z2)
except ValueError as e:
raise ValueError("z1 and z2 have different shapes") from e

return self._hubble_distance * (z2 - z1)

Expand All @@ -1858,12 +1852,10 @@ def _EdS_comoving_distance_z1z2(self, z1, z2):
d : `~astropy.units.Quantity`
Comoving distance in Mpc between each input redshift.
"""
if isiterable(z1):
z1 = np.asarray(z1)
z2 = np.asarray(z2)
if z1.shape != z2.shape:
msg = "z1 and z2 have different shapes"
raise ValueError(msg)
try:
z1, z2 = np.broadcast_arrays(z1, z2)
except ValueError as e:
raise ValueError("z1 and z2 have different shapes") from e

prefactor = 2 * self._hubble_distance
return prefactor * ((1+z1)**(-1./2) - (1+z2)**(-1./2))
Expand Down Expand Up @@ -1891,12 +1883,10 @@ def _hypergeometric_comoving_distance_z1z2(self, z1, z2):
d : `~astropy.units.Quantity`
Comoving distance in Mpc between each input redshift.
"""
if isiterable(z1):
z1 = np.asarray(z1)
z2 = np.asarray(z2)
if z1.shape != z2.shape:
msg = "z1 and z2 have different shapes"
raise ValueError(msg)
try:
z1, z2 = np.broadcast_arrays(z1, z2)
except ValueError as e:
raise ValueError("z1 and z2 have different shapes") from e

s = ((1 - self._Om0) / self._Om0) ** (1./3)
# Use np.sqrt here to handle negative s (Om0>1).
Expand Down
49 changes: 49 additions & 0 deletions astropy/cosmology/tests/test_cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -1637,3 +1637,52 @@ def test_elliptic_comoving_distance_z1z2():
cosmo._integral_comoving_distance_z1z2(0., z))
assert allclose(cosmo._elliptic_comoving_distance_z1z2(0., z),
cosmo._integral_comoving_distance_z1z2(0., z))


SPECIALIZED_COMOVING_DISTANCE_COSMOLOGIES = [
core.FlatLambdaCDM(H0=70, Om0=0.0, Tcmb0=0.0), # de Sitter
core.FlatLambdaCDM(H0=70, Om0=1.0, Tcmb0=0.0), # Einstein - de Sitter
core.FlatLambdaCDM(H0=70, Om0=0.3, Tcmb0=0.0), # Hypergeometric
core.LambdaCDM(H0=70, Om0=0.3, Ode0=0.6, Tcmb0=0.0), # Elliptic
]


ITERABLE_REDSHIFTS = [
(0, 1, 2, 3, 4), # tuple
[0, 1, 2, 3, 4], # list
np.array([0, 1, 2, 3, 4]), # array
]


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize('cosmo', SPECIALIZED_COMOVING_DISTANCE_COSMOLOGIES)
@pytest.mark.parametrize('z', ITERABLE_REDSHIFTS)
def test_comoving_distance_iterable_argument(cosmo, z):
"""
Regression test for #10980
Test that specialized comoving distance methods handle iterable arguments.
"""

assert allclose(cosmo.comoving_distance(z),
cosmo._integral_comoving_distance_z1z2(0., z))


@pytest.mark.skipif('not HAS_SCIPY')
@pytest.mark.parametrize('cosmo', SPECIALIZED_COMOVING_DISTANCE_COSMOLOGIES)
def test_comoving_distance_broadcast(cosmo):
"""
Regression test for #10980
Test that specialized comoving distance methods broadcast array arguments.
"""

z1 = np.zeros((2, 5))
z2 = np.ones((3, 1, 5))
z3 = np.ones((7, 5))
output_shape = np.broadcast(z1, z2).shape

# Check compatible array arguments return an array with the correct shape
assert cosmo._comoving_distance_z1z2(z1, z2).shape == output_shape

# Check incompatible array arguments raise an error
with pytest.raises(ValueError, match='z1 and z2 have different shapes'):
cosmo._comoving_distance_z1z2(z1, z3)

0 comments on commit 0d9a916

Please sign in to comment.