In [None]:
# Third-party
import astropy.units as u
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
%matplotlib inline

In [None]:
import inspect
from astropy.utils.decorators import wraps
from astropy.units.decorators import _validate_arg_value
from astropy.units import add_enabled_units, add_enabled_equivalencies
from astropy.tests.helper import quantity_allclose
import pytest

In [None]:
class StripInputUnits:

    @classmethod
    def as_decorator(cls, func=None, **kwargs):
        self = cls(**kwargs)
        if func is not None and not kwargs:
            return self(func)
        else:
            return self

    def __init__(self, func=None, **kwargs):
        self.equivalencies = kwargs.pop('equivalencies', [])
        self.decorator_kwargs = kwargs

    def __call__(self, wrapped_function):

        # Extract the function signature for the function we are wrapping.
        wrapped_signature = inspect.signature(wrapped_function)

        # Define a new function to return in place of the wrapped one
        @wraps(wrapped_function)
        def wrapper(*func_args, **func_kwargs):
            # Bind the arguments to our new function to the signature of the original.
            bound_args = wrapped_signature.bind(*func_args, **func_kwargs)
            
            new_kwargs = func_kwargs.copy()
            
            # Iterate through the parameters of the original signature
            for param in wrapped_signature.parameters.values():
                # We do not support variable arguments (*args, **kwargs)
                if param.kind in (inspect.Parameter.VAR_KEYWORD,
                                  inspect.Parameter.VAR_POSITIONAL):
                    continue

                # Catch the (never triggered) case where bind relied on a default value.
                if param.name not in bound_args.arguments and param.default is not param.empty:
                    bound_args.arguments[param.name] = param.default

                # Get the value of this parameter (argument to new function)
                arg = bound_args.arguments[param.name]
                new_kwargs[param.name] = arg

                # Get target unit or physical type, either from decorator kwargs
                #   or annotations
                if param.name in self.decorator_kwargs:
                    target_unit = self.decorator_kwargs[param.name]
                else:
                    target_unit = param.annotation

                # If the target_unit is empty, then no target units or physical
                #   types were specified so we can continue to the next arg
                if target_unit is inspect.Parameter.empty:
                    continue

                # If the argument value is None, and the default value is None,
                #   pass through the None even if there is a target unit
                if arg is None and param.default is None:
                    continue

                target_unit = u.Unit(target_unit)

                # Now we loop over the allowed units/physical types and validate
                #   the value of the argument:
                _validate_arg_value(param.name, wrapped_function.__name__,
                                    arg, [target_unit], self.equivalencies)
                
                new_kwargs[param.name] = arg.to(target_unit).value

            # Call the original function with any equivalencies in force.
            with add_enabled_equivalencies(self.equivalencies):                    
                return_ = wrapped_function(**new_kwargs)
            
            if wrapped_signature.return_annotation is not inspect.Signature.empty:
                return return_ * wrapped_signature.return_annotation
            else:
                return return_

        return wrapper


strip_input_units = StripInputUnits.as_decorator

In [None]:
@strip_input_units(rv=u.km/u.s, time=u.day)
def my_func1(rv, name, time=None, **kwargs):
    if time is None:
        time = 1.
        
    # py 2/3 compatible - have to multiply the output by expected output unit
    return rv * time * u.km/u.s*u.day 

@strip_input_units(rv=u.km/u.s, time=u.day)
def my_func2(rv, name, *args, time=None, **kwargs):
    if time is None:
        time = 1.
    return rv * time * u.km/u.s*u.day 

@strip_input_units
def my_func3(rv: u.km/u.s, name, time: u.day=None, **kwargs) -> u.km/u.s*u.day:
    if time is None:
        time = 1.
    
    # py3 only: output unit gets added by return annotation
    return rv * time

In [None]:
rv_arr = np.array([12, 3.])
t_arr = np.array([1, 2.])

for func in [my_func1, my_func2, my_func3]:
    assert quantity_allclose(func(rv_arr*u.km/u.s, 'bob', derp='asdf'), rv_arr*u.km/u.s*u.day)
    assert quantity_allclose(func(rv_arr*u.km/u.s, name='bob', derp='asdf'), rv_arr*u.km/u.s*u.day)
    assert quantity_allclose(func(rv=rv_arr*u.km/u.s, name='bob', time=t_arr*u.day), rv_arr*t_arr*u.km/u.s*u.day)
    assert quantity_allclose(func(rv=rv_arr*u.kpc/u.Myr, name='bob', time=t_arr*u.year),
                             (rv_arr*u.kpc/u.Myr * t_arr*u.year).to(u.km/u.s*u.day))

In [None]:
# expected failures:
with pytest.raises(u.UnitsError):
    my_func1(rv=[12,3]*u.kpc, name='bob', time=[1,2]*u.year)

with pytest.raises(TypeError):
    my_func1(rv=[12,3]*u.kpc/u.yr, name='bob', time=[1,2])