Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cosmo: validating and setting Parameters #12190

Merged
merged 11 commits into from Nov 3, 2021
5 changes: 3 additions & 2 deletions astropy/cosmology/__init__.py
Expand Up @@ -8,16 +8,17 @@

"""

from . import core, flrw, funcs, units, utils
from . import core, flrw, funcs, parameter, units, utils

from . import io # needed before 'realizations' # isort: split
from . import realizations
from .core import *
from .flrw import *
from .funcs import *
from .parameter import *
from .realizations import *
from .utils import *

__all__ = (core.__all__ + flrw.__all__ # cosmology classes
+ realizations.__all__ # instances thereof
+ funcs.__all__ + utils.__all__) # utils
+ funcs.__all__ + parameter.__all__ + utils.__all__) # utils
mhvk marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 5 additions & 1 deletion astropy/cosmology/connect.py
Expand Up @@ -3,7 +3,9 @@
import copy
import warnings

from astropy.cosmology import units as cu
from astropy.io import registry as io_registry
from astropy.units import add_enabled_units
from astropy.utils.exceptions import AstropyUserWarning

__all__ = ["CosmologyRead", "CosmologyWrite",
Expand Down Expand Up @@ -87,7 +89,9 @@ def __call__(self, *args, **kwargs):
"keyword argument `cosmology` must be either the class "
f"{valid[0]} or its qualified name '{valid[1]}'")

cosmo = self.registry.read(self._cls, *args, **kwargs)
with add_enabled_units(cu):
cosmo = self.registry.read(self._cls, *args, **kwargs)

return cosmo


Expand Down
187 changes: 18 additions & 169 deletions astropy/cosmology/core.py
@@ -1,7 +1,6 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

import abc
import copy
import functools
import inspect
from types import FunctionType, MappingProxyType
Expand All @@ -14,6 +13,7 @@
from astropy.utils.metadata import MetaData

from .connect import CosmologyFromFormat, CosmologyRead, CosmologyToFormat, CosmologyWrite
from .parameter import Parameter

# Originally authored by Andrew Becker (becker@astro.washington.edu),
# and modified by Neil Crighton (neilcrighton@gmail.com), Roban Kramer
Expand All @@ -22,7 +22,7 @@
# Many of these adapted from Hogg 1999, astro-ph/9905116
# and Linder 2003, PRL 90, 91301

__all__ = ["Cosmology", "CosmologyError", "FlatCosmologyMixin", "Parameter"]
__all__ = ["Cosmology", "CosmologyError", "FlatCosmologyMixin"]

__doctest_requires__ = {} # needed until __getattr__ removed

Expand All @@ -34,166 +34,6 @@ class CosmologyError(Exception):
pass


class Parameter:
"""Cosmological parameter (descriptor).

Should only be used with a :class:`~astropy.cosmology.Cosmology` subclass.
For automatic default value and unit inference make sure the Parameter
attribute has a corresponding initialization argument (see Examples below).

Parameters
----------
fget : callable or None, optional
Function to get the value from instances of the cosmology class.
If None (default) returns the corresponding private attribute.
Often not set here, but as a decorator with ``getter``.
doc : str or None, optional
Parameter description. If 'doc' is None and 'fget' is not, then 'doc'
is taken from ``fget.__doc__``.
unit : unit-like or None (optional, keyword-only)
The `~astropy.units.Unit` for the Parameter. If None (default) no
unit as assumed.
equivalencies : `~astropy.units.Equivalency` or sequence thereof
Unit equivalencies for this Parameter.
fmt : str (optional, keyword-only)
`format` specification, used when making string representation
of the containing Cosmology.
See https://docs.python.org/3/library/string.html#formatspec

Examples
--------
The most common use case of ``Parameter`` is to access the corresponding
private attribute.

>>> from astropy.cosmology import LambdaCDM
>>> from astropy.cosmology.core import Parameter
>>> class Example1(LambdaCDM):
... param = Parameter(doc="example parameter", unit=u.m)
... def __init__(self, param=15 * u.m):
... super().__init__(70, 0.3, 0.7)
... self._param = param << self.__class__.param.unit
>>> Example1.param
<Parameter 'param' at ...
>>> Example1.param.unit
Unit("m")

>>> ex = Example1(param=12357)
>>> ex.param
<Quantity 12357. m>

``Parameter`` also supports custom ``getter`` methods.
:attr:`~astropy.cosmology.FLRW.m_nu` is a good example.

>>> import astropy.units as u
>>> class Example2(LambdaCDM):
... param = Parameter(doc="example parameter", unit="m")
... def __init__(self, param=15):
... super().__init__(70, 0.3, 0.7)
... self._param = param << self.__class__.param.unit
... @param.getter
... def param(self):
... return self._param << u.km

>>> ex2 = Example2(param=12357)
>>> ex2.param
<Quantity 12.357 km>

.. doctest::
:hide:

>>> from astropy.cosmology.core import _COSMOLOGY_CLASSES
>>> _ = _COSMOLOGY_CLASSES.pop(Example1.__qualname__)
>>> _ = _COSMOLOGY_CLASSES.pop(Example2.__qualname__)
"""

def __init__(self, fget=None, doc=None, *, unit=None, equivalencies=[], fmt=".3g"):
# modeled after https://docs.python.org/3/howto/descriptor.html#properties
self.__doc__ = fget.__doc__ if (doc is None and fget is not None) else doc
self.fget = fget if not hasattr(fget, "fget") else fget.__get__
# TODO! better detection if descriptor.

# units stuff
self._unit = u.Unit(unit) if unit is not None else None
self._equivalencies = equivalencies

# misc
self._fmt = str(fmt)
self.__wrapped__ = fget # so always have access to `fget`
self.__name__ = getattr(fget, "__name__", None) # compat with other descriptors

def __set_name__(self, objcls, name):
# attribute name
self._attr_name = name
self._attr_name_private = "_" + name

# update __name__, if not already set
self.__name__ = self.__name__ or name

@property
def name(self):
"""Parameter name."""
return self._attr_name

@property
def unit(self):
"""Parameter unit."""
return self._unit

@property
def equivalencies(self):
"""Equivalencies used when initializing Parameter."""
return self._equivalencies

@property
def format_spec(self):
"""String format specification."""
return self._fmt

# -------------------------------------------
# descriptor

def __get__(self, obj, objcls=None):
# get from class
if obj is None:
return self
# get from obj, allowing for custom ``getter``
if self.fget is None: # default to private attr (diff from `property`)
return getattr(obj, self._attr_name_private)
return self.fget(obj)

def __set__(self, obj, value):
raise AttributeError("can't set attribute")

def __delete__(self, obj):
raise AttributeError("can't delete attribute")

# -------------------------------------------
# from 'property'

def getter(self, fget):
"""Make new Parameter with custom ``fget``.

Note: ``Parameter.getter`` must be the top-most descriptor decorator.

Parameters
----------
fget : callable

Returns
-------
`~astropy.cosmology.Parameter`
Copy of this Parameter but with custom ``fget``.
"""
return type(self)(fget=fget, doc=self.__doc__,
unit=self.unit, equivalencies=self.equivalencies,
fmt=self.format_spec)

# -------------------------------------------

def __repr__(self):
return f"<Parameter {self._attr_name!r} at {hex(id(self))}>"


class Cosmology(metaclass=abc.ABCMeta):
"""Base-class for all Cosmologies.

Expand Down Expand Up @@ -230,6 +70,7 @@ class Cosmology(metaclass=abc.ABCMeta):

# Parameters
__parameters__ = ()
__all_parameters__ = ()
mhvk marked this conversation as resolved.
Show resolved Hide resolved

# ---------------------------------------------------------------

Expand All @@ -253,18 +94,26 @@ def __init_subclass__(cls):
# Parameters

# Get parameters that are still Parameters, either in this class or above.
parameters = [n for n in cls.__parameters__ if isinstance(getattr(cls, n), Parameter)]
parameters = []
derived_parameters = []
for n in cls.__parameters__:
p = getattr(cls, n)
if isinstance(p, Parameter):
derived_parameters.append(n) if p.derived else parameters.append(n)

# Add new parameter definitions
parameters += [n for n, v in cls.__dict__.items()
if (n not in parameters
and not n.startswith("_")
and isinstance(v, Parameter))]
for n, v in cls.__dict__.items():
if n in parameters or n.startswith("_") or not isinstance(v, Parameter):
continue
derived_parameters.append(n) if v.derived else parameters.append(n)

# reorder to match signature
ordered = [parameters.pop(parameters.index(n))
for n in cls._init_signature.parameters.keys()
if n in parameters]
parameters = ordered + parameters # place "unordered" at the end
cls.__parameters__ = tuple(parameters)
cls.__all_parameters__ = cls.__parameters__ + tuple(derived_parameters)

# -------------------
# register as a Cosmology subclass
Expand Down Expand Up @@ -402,9 +251,9 @@ def __equiv__(self, other):

# check all parameters in 'other' match those in 'self' and 'other' has
# no extra parameters (latter part should never happen b/c same class)
params_eq = (set(self.__parameters__) == set(other.__parameters__)
params_eq = (set(self.__all_parameters__) == set(other.__all_parameters__)
and all(np.all(getattr(self, k) == getattr(other, k))
for k in self.__parameters__))
for k in self.__all_parameters__))
return params_eq

def __eq__(self, other):
Expand Down