New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a new decorator for functions accepting quantities. #3072
Changes from 22 commits
9325fbb
07c4549
27e1f7e
c71a17c
e77b066
cbfb0c8
58b8170
457ef6a
e83cf96
d959ae4
af99f46
d61fb42
df4d4a1
a3a995e
dc263d2
a7996d6
9182a44
b21a9b6
0368c73
2a0ebaf
9ca243e
50a89ca
df4968f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
import cmath | ||
|
||
import inspect | ||
import collections | ||
import textwrap | ||
import warnings | ||
import numpy as np | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# -*- coding: utf-8 -*- | ||
# Licensed under a 3-clause BSD style license - see LICENSE.rst | ||
|
||
__all__ = ['quantity_input'] | ||
|
||
from ..utils.decorators import wraps | ||
from ..utils.compat import funcsigs | ||
|
||
from .core import UnitsError, add_enabled_equivalencies | ||
|
||
class QuantityInput(object): | ||
|
||
@classmethod | ||
def as_decorator(cls, func=None, **kwargs): | ||
""" | ||
A decorator for validating the units of arguments to functions. | ||
|
||
Unit specifications can be provided as keyword arguments to the decorator, | ||
or by using Python 3's function annotation syntax. Arguments to the decorator | ||
take precidence over any function annotations present. | ||
|
||
A `~astropy.units.UnitsError` will be raised if the unit attribute of | ||
the argument is not equivalent to the unit specified to the decorator | ||
or in the annotation. | ||
If the argument has no unit attribute, i.e. it is not a Quantity object, a | ||
`~exceptions.ValueError` will be raised. | ||
|
||
Where an equivalency is specified in the decorator, the function will be | ||
executed with that equivalency in force. | ||
|
||
Examples | ||
-------- | ||
|
||
Python 2 and 3:: | ||
|
||
import astropy.units as u | ||
@u.quantity_input(myangle=u.arcsec) | ||
def myfunction(myangle): | ||
return myangle**2 | ||
|
||
Python 3 only:: | ||
|
||
import astropy.units as u | ||
@u.quantity_input | ||
def myfunction(myangle: u.arcsec): | ||
return myangle**2 | ||
|
||
Using equivalencies:: | ||
|
||
import astropy.units as u | ||
@u.quantity_input(myenergy=u.eV, equivalencies=u.mass_energy()) | ||
def myfunction(myenergy): | ||
return myenergy**2 | ||
|
||
""" | ||
self = cls(**kwargs) | ||
if func is not None and not kwargs: | ||
return self(func) | ||
else: | ||
return self | ||
|
||
def __init__(self, func=None, **kwargs): | ||
self.equivalencies = kwargs.pop('equivalencies', []) | ||
self.decorator_kwargs = kwargs | ||
|
||
def __call__(self, wrapped_function): | ||
|
||
# Extract the function signature for the function we are wrapping. | ||
wrapped_signature = funcsigs.signature(wrapped_function) | ||
|
||
# Define a new function to return in place of the wrapped one | ||
@wraps(wrapped_function) | ||
def wrapper(*func_args, **func_kwargs): | ||
# Bind the arguments to our new function to the signature of the original. | ||
bound_args = wrapped_signature.bind(*func_args, **func_kwargs) | ||
|
||
# Iterate through the parameters of the original signature | ||
for param in wrapped_signature.parameters.values(): | ||
# Catch the (never triggered) case where bind relied on a default value. | ||
if param.name not in bound_args.arguments and param.default is not param.empty: | ||
bound_args.arguments[param.name] = param.default | ||
|
||
# Get the value of this parameter (argument to new function) | ||
arg = bound_args.arguments[param.name] | ||
|
||
# Get target unit, either from decotrator kwargs or annotations | ||
if param.name in self.decorator_kwargs: | ||
target_unit = self.decorator_kwargs[param.name] | ||
else: | ||
target_unit = param.annotation | ||
|
||
# If the target unit is empty, then no unit was specified so we | ||
# move past it | ||
if target_unit is not funcsigs.Parameter.empty: | ||
try: | ||
equivalent = arg.unit.is_equivalent(target_unit, | ||
equivalencies=self.equivalencies) | ||
|
||
if not equivalent: | ||
raise UnitsError("Argument '{0}' to function '{1}'" | ||
" must be in units convertable to" | ||
" '{2}'.".format(param.name, | ||
wrapped_function.__name__, | ||
target_unit.to_string())) | ||
|
||
# Either there is no .unit or no .is_equivalent | ||
except AttributeError: | ||
if hasattr(arg, "unit"): | ||
error_msg = "a 'unit' attribute without an 'is_equivalent' method" | ||
else: | ||
error_msg = "no 'unit' attribute" | ||
raise TypeError("Argument '{0}' to function has '{1}' {2}. " | ||
"You may want to pass in an astropy Quantity instead." | ||
.format(param.name, wrapped_function.__name__, error_msg)) | ||
|
||
# Call the original function with any equivalencies in force. | ||
with add_enabled_equivalencies(self.equivalencies): | ||
return wrapped_function(*func_args, **func_kwargs) | ||
|
||
return wrapper | ||
|
||
quantity_input = QuantityInput.as_decorator |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
# -*- coding: utf-8 -*- | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Include license line at the top of the file too (not sure which needs to go first) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
# Licensed under a 3-clause BSD style license - see LICENSE.rst | ||
|
||
from functools import wraps | ||
from textwrap import dedent | ||
|
||
from ... import units as u | ||
from ...extern import six | ||
from ...tests.helper import pytest | ||
|
||
|
||
def py3only(func): | ||
if not six.PY3: | ||
return pytest.mark.skipif('not six.PY3')(func) | ||
else: | ||
@wraps(func) | ||
def wrapper(*args, **kwargs): | ||
code = compile(dedent(func.__doc__), __file__, 'exec') | ||
# This uses an unqualified exec statement illegally in Python 2, | ||
# but perfectly allowed in Python 3 so in fact we eval the exec | ||
# call :) | ||
eval('exec(code)') | ||
|
||
return wrapper | ||
|
||
|
||
@py3only | ||
def test_args3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.arcsec): | ||
return solarx, solary | ||
|
||
solarx, solary = myfunc_args(1*u.arcsec, 1*u.arcsec) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, u.Quantity) | ||
|
||
assert solarx.unit == u.arcsec | ||
assert solary.unit == u.arcsec | ||
""" | ||
|
||
|
||
@py3only | ||
def test_args_noconvert3(): | ||
""" | ||
@u.quantity_input() | ||
def myfunc_args(solarx: u.arcsec, solary: u.arcsec): | ||
return solarx, solary | ||
|
||
solarx, solary = myfunc_args(1*u.deg, 1*u.arcmin) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, u.Quantity) | ||
|
||
assert solarx.unit == u.deg | ||
assert solary.unit == u.arcmin | ||
""" | ||
|
||
|
||
@py3only | ||
def test_args_nonquantity3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary): | ||
return solarx, solary | ||
|
||
solarx, solary = myfunc_args(1*u.arcsec, 100) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, int) | ||
|
||
assert solarx.unit == u.arcsec | ||
""" | ||
|
||
|
||
@py3only | ||
def test_arg_equivalencies3(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this function, you should add a test that the function can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed. |
||
""" | ||
@u.quantity_input(equivalencies=u.mass_energy()) | ||
def myfunc_args(solarx: u.arcsec, solary: u.eV): | ||
return solarx, solary+(10*u.J) # Add an energy to check equiv is working | ||
|
||
solarx, solary = myfunc_args(1*u.arcsec, 100*u.gram) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, u.Quantity) | ||
|
||
assert solarx.unit == u.arcsec | ||
assert solary.unit == u.gram | ||
""" | ||
|
||
|
||
@py3only | ||
def test_wrong_unit3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.deg): | ||
return solarx, solary | ||
|
||
with pytest.raises(u.UnitsError) as e: | ||
solarx, solary = myfunc_args(1*u.arcsec, 100*u.km) | ||
assert str(e.value) == "Argument 'solary' to function 'myfunc_args' must be in units convertable to 'deg'." | ||
""" | ||
|
||
|
||
@py3only | ||
def test_not_quantity3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.deg): | ||
return solarx, solary | ||
|
||
with pytest.raises(TypeError) as e: | ||
solarx, solary = myfunc_args(1*u.arcsec, 100) | ||
assert str(e.value) == "Argument 'solary' to function has 'myfunc_args' no 'unit' attribute. You may want to pass in an astropy Quantity instead." | ||
""" | ||
|
||
|
||
@py3only | ||
def test_decorator_override(): | ||
""" | ||
@u.quantity_input(solarx=u.arcsec) | ||
def myfunc_args(solarx: u.km, solary: u.arcsec): | ||
return solarx, solary | ||
|
||
solarx, solary = myfunc_args(1*u.arcsec, 1*u.arcsec) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, u.Quantity) | ||
|
||
assert solarx.unit == u.arcsec | ||
assert solary.unit == u.arcsec | ||
""" | ||
|
||
|
||
@py3only | ||
def test_kwargs3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary, myk: u.arcsec=1*u.arcsec): | ||
return solarx, solary, myk | ||
|
||
solarx, solary, myk = myfunc_args(1*u.arcsec, 100, myk=100*u.deg) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, int) | ||
assert isinstance(myk, u.Quantity) | ||
|
||
assert myk.unit == u.deg | ||
""" | ||
|
||
|
||
@py3only | ||
def test_unused_kwargs3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary, myk: u.arcsec=1*u.arcsec, myk2=1000): | ||
return solarx, solary, myk, myk2 | ||
|
||
solarx, solary, myk, myk2 = myfunc_args(1*u.arcsec, 100, myk=100*u.deg, myk2=10) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(solary, int) | ||
assert isinstance(myk, u.Quantity) | ||
assert isinstance(myk2, int) | ||
|
||
assert myk.unit == u.deg | ||
assert myk2 == 10 | ||
""" | ||
|
||
|
||
@py3only | ||
def test_kwarg_equivalencies3(): | ||
""" | ||
@u.quantity_input(equivalencies=u.mass_energy()) | ||
def myfunc_args(solarx: u.arcsec, energy: u.eV=10*u.eV): | ||
return solarx, energy+(10*u.J) # Add an energy to check equiv is working | ||
|
||
solarx, energy = myfunc_args(1*u.arcsec, 100*u.gram) | ||
|
||
assert isinstance(solarx, u.Quantity) | ||
assert isinstance(energy, u.Quantity) | ||
|
||
assert solarx.unit == u.arcsec | ||
assert energy.unit == u.gram | ||
""" | ||
|
||
|
||
@py3only | ||
def test_kwarg_wrong_unit3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.deg=10*u.deg): | ||
return solarx, solary | ||
|
||
with pytest.raises(u.UnitsError) as e: | ||
solarx, solary = myfunc_args(1*u.arcsec, solary=100*u.km) | ||
assert str(e.value) == "Argument 'solary' to function 'myfunc_args' must be in units convertable to 'deg'." | ||
""" | ||
|
||
|
||
@py3only | ||
def test_kwarg_not_quantity3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.deg=10*u.deg): | ||
return solarx, solary | ||
|
||
with pytest.raises(TypeError) as e: | ||
solarx, solary = myfunc_args(1*u.arcsec, solary=100) | ||
assert str(e.value) == "Argument 'solary' to function has 'myfunc_args' no 'unit' attribute. You may want to pass in an astropy Quantity instead." | ||
""" | ||
|
||
|
||
@py3only | ||
def test_kwarg_default3(): | ||
""" | ||
@u.quantity_input | ||
def myfunc_args(solarx: u.arcsec, solary: u.deg=10*u.deg): | ||
return solarx, solary | ||
|
||
solarx, solary = myfunc_args(1*u.arcsec) | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
D'oh--just realized something went awry here. This should be moved above, to the section for
astropy.units
(each changelog section has subsections for each Astropy subpackage).