In [23]:

from typing import NamedTuple, Dict, List, Tuple
import doctest

class ConcreteMap(NamedTuple):
    name: str
    target: str
    options: list
    order: int

Parameter = lambda name: ConcreteMap(name,
                                     'params',
                                     ['policy', 'state_update'],
                                     0)

Timestep = lambda name: ConcreteMap(name,
                                    'timestep',
                                    ['policy', 'state_update'],
                                    1)

StateHistory = lambda name: ConcreteMap(name,
                                        'state_history',
                                        ['policy', 'state_update'],
                                        2)

State = lambda name: ConcreteMap(name,
                                 'prev_state',
                                 ['policy', 'state_update'],
                                 3)

Signal = lambda name: ConcreteMap(name,
                                  'policy_input',
                                  ['state_update'],
                                  4)


def check_args(*args: list, option : str=None, target: str=None):
    """
    Checks if args are ConcreteMap.
    
    >>> check_args(Signal('a'), State('a'))
    True
    >>> check_args(Signal('a'), 2.0)
    False
    >>> check_args(Signal('a'), State('a'), option='policy')
    False
    >>> check_args(Signal('a'), State('a'), target='policy_input')
    False
    >>> check_args(Signal('a'), target='policy_input')
    True
    >>> check_args(Signal('a'), option='state_update')
    True
    >>> check_args(Signal('a'), target='policy_input', option='state_update')
    True
    """
    arg_types = [type(arg) is ConcreteMap
                 and (option is None or option in arg.options)
                 and (target is None or target == arg.target)
                 for arg in args]
    arg_options = []
    return not (False in arg_types)

def make_policy(function: callable, 
                *signals: Tuple[ConcreteMap],
                pos_args: Tuple[ConcreteMap]=(), 
                kw_args: Dict[str, ConcreteMap]={},
                **kwargs: Dict[str, ConcreteMap]):

    # Check if arguments are valid.
    args_are_valid = check_args(*pos_args, *kwargs.values(),
                          *kwargs.values(), option='policy')
    args_are_valid &= check_args(*signals, target='policy_input')
    if not args_are_valid:
        raise TypeError('Arguments are not valid.') 
        
    # Prepare mappings
    
    

    def wrapped_function(*args):
        arg_map = [args[arg.order][arg.name] for arg in pos_args]
        kwarg_map = [args[arg.order][arg.name] for arg in kw_args]
        signal_map = None
        output = function(*arg_map, **kwarg_map)
        return {}
    return wrapped_function

In [19]:
doctest.testmod()

TestResults(failed=0, attempted=7)

In [8]:
f = lambda a, x: x

make_policy(f,
            Signal('variable'),
            pos_args=(Parameter('a'), State('x'))
            )

<function __main__.make_policy.<locals>.wrapped_function(*args)>

In [15]:
check_args(Signal('a'), target='policy_input')

True