In [244]:
import re
import threading
import functools
import types
from typing import Optional, Union


class Scope(object):
    _leaf = object()
    context = threading.local()

    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __enter__(self):
        type(self).get_contexts().append(self)
        return self

    def __exit__(self, typ, value, traceback):
        type(self).get_contexts().pop()

    def __getattr__(self, item):
        return self.__dict__.get(item)

    @classmethod
    def get_contexts(cls):
        # no race-condition here, cls.contexts is a thread-local object
        # be sure not to override contexts in a subclass however!
        if not hasattr(cls.context, "stack"):
            cls.context.stack = []
        return cls.context.stack

    @classmethod
    def chain(cls, attr, *, leaf=_leaf, predicate=lambda _: True, drop_none=False):
        for c in cls.get_contexts():
            if predicate(c):
                val = getattr(c, attr)
                if drop_none and val is None:
                    continue
                else:
                    yield val
        if leaf is not cls._leaf:
            if not (drop_none and leaf is None):
                yield leaf

    @classmethod
    def variable_name(cls, name: str) -> Optional[str]:
        """
        Generate PyMC4 variable name based on name scope we are currently in.

        Parameters
        ----------
        name : str|None
            The desired target name for a variable, can be any, including None

        Returns
        -------
        str : scoped name

        Examples
        --------
        >>> with Scope(name="inner"):
        ...     print(Scope.variable_name("leaf"))
        inner/leaf
        >>> with Scope(name="inner"):
        ...     with Scope():
        ...         print(Scope.variable_name("leaf1"))
        inner/leaf1

        empty name results in None name
        >>> assert Scope.variable_name(None) is None
        >>> assert Scope.variable_name("") is None
        """
        value = "/".join(map(str, cls.chain("name", leaf=name, drop_none=True)))
        if not value:
            return None
        else:
            return value

    @classmethod
    def transformed_variable_name(cls, transform_name: str, name: str) -> Optional[str]:
        return cls.variable_name("__{}_{}".format(transform_name, name))

    def __repr__(self):
        return "Scope({})".format(self.__dict__)


def name_scope(name):
    return Scope(name=name)


class NameParts:
    NAME_RE = re.compile(r"^(?:__(?P<transform>[^_]+)_)?(?P<name>[^_].*)$")
    NAME_ERROR_MESSAGE = (
        "Invalid name: `{}`, the correct one should look like: `__transform_name` or `name`, "
        "note only one underscore between the transform and actual name"
    )
    UNTRANSFORMED_NAME_ERROR_MESSAGE = (
        "Invalid name: `{}`, the correct one should look like: " "`name` without leading underscore"
    )
    __slots__ = ("path", "transform_name", "untransformed_name")

    @classmethod
    def is_valid_untransformed_name(cls, name):
        print("Name received is:", name)
        match = cls.NAME_RE.match(name)
        return match is not None and match["transform"] is None

    @classmethod
    def is_valid_name(cls, name):
        match = cls.NAME_RE.match(name)
        return match is not None

    def __init__(self, path, transform_name, untransformed_name):
        self.path = tuple(path)
        self.untransformed_name = untransformed_name
        self.transform_name = transform_name

    @classmethod
    def from_name(cls, name):
        split = name.split("/")
        path, original_name = split[:-1], split[-1]
        match = cls.NAME_RE.match(original_name)
        if not cls.is_valid_name(name):
            raise ValueError(cls.NAME_ERROR_MESSAGE.format(name))
        return cls(path, match["transform"], match["name"])

    @property
    def original_name(self):
        if self.is_transformed:
            return "__{}_{}".format(self.transform_name, self.untransformed_name)
        else:
            return self.untransformed_name

    @property
    def full_original_name(self):
        return "/".join(self.path + (self.original_name,))

    @property
    def full_untransformed_name(self):
        return "/".join(self.path + (self.untransformed_name,))

    @property
    def is_transformed(self):
        return self.transform_name is not None

    def __repr__(self):
        return "<NameParts of {}>".format(self.full_original_name)

    def replace_transform(self, transform_name):
        return self.__class__(self.path, transform_name, self.untransformed_name)

In [290]:
# Biwrap function
# Biwrap absorbs all the positional and keyword arguments.

def biwrap(wrapper):
    """Allow for optional keyword arguments in lower level decoratrors.

    Notes
    -----
    Currently this is only used to wrap pm.Model to capture model runtime flags such as
    keep_auxiliary and keep_return. See pm.Model for all possible keyword parameters

    """

    @functools.wraps(wrapper)
    def enhanced(*args, **kwargs):
        print("Args in enhanced:", args)
        print("Kwargs in enhanced:", kwargs)

        # Check if decorated method is bound to a class
        print("Wrapper name is:", wrapper.__name__)
        print("Length of args: ", len(args))
        is_bound_method = hasattr(args[0], wrapper.__name__) if args else False
        if is_bound_method:
            # If bound to a class, `self` will be an argument
            print("Is bound")
            count = 1
        else:
            count = 0
        if len(args) > count:
            # If lower level decorator is not called user model will be an argument
            # fill in parameters and call pm.Model
            newfn = wrapper(*args, **kwargs)
            print("Returning called function")
            return newfn
        else:
            # If lower level decorator is called user model will not be passed in as an argument
            # prefill args and kwargs but do not call pm.Model
            newwrapper = functools.partial(wrapper, *args, **kwargs)
            print("Returning functools.partial")
            return newwrapper

    return enhanced

## Behaviour 1
If at least one positional argument is supplied at the function call, then actual function is called with same arguments. Can result in error if mismatch in positional arguments

In [291]:
@biwrap
def model(name, times):
    print(f"Hi, {name}\n"*times)

In [292]:
model("Sayam", 10)

Args in enhanced: ('Sayam', 10)
Kwargs in enhanced: {}
Wrapper name is: model
Length of args:  2
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam
Hi, Sayam

Returning called function


In [293]:
model('tiger')  # Error because times argument is not supplied

Args in enhanced: ('tiger',)
Kwargs in enhanced: {}
Wrapper name is: model
Length of args:  1


TypeError: model() missing 1 required positional argument: 'times'

## Behaviour 2
If no positional arguments are supplied, then calling the function returns a functools.partial func that again needs to be called.

In [294]:
@biwrap
def model(name, times):
    print(f"Hi, {name}\n"*times)

In [295]:
partial_object = model()
partial_object('sayam', 10)

Args in enhanced: ()
Kwargs in enhanced: {}
Wrapper name is: model
Length of args:  0
Returning functools.partial
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam
Hi, sayam



## Behaviour 3 
I can hack this line by giving a wrapping name to any function in dir(args[0])
```python
is_bound_method = hasattr(args[0], wrapper.__name__) if args else False
```

In [296]:
@biwrap
def upper(name):
    print("Uppercased name is:", name.upper())

In [297]:
upper("Hello")()

Args in enhanced: ('Hello',)
Kwargs in enhanced: {}
Wrapper name is: upper
Length of args:  1
Is bound
Returning functools.partial
Uppercased name is: HELLO


In [298]:
class Model:
    """Base coroutine object.

    Supports iteration over random variables via `.control_flow`.
    """

    # this is gonna be used for generator-like objects,
    # prohibit modification of this dict wrapping it into a MappingProxy
    default_model_info = types.MappingProxyType(
        dict(keep_auxiliary=True, keep_return=False, scope=name_scope(None), name=None)
    )

    @staticmethod
    def validate_name(name: Optional[Union[int, str]]) -> Optional[str]:
        """Validate the type of the name argument."""
        if name is not None and not isinstance(name, (int, str)):
            raise ValueError("name should be either `str` or `int`, got type {}".format(type(name)))
        elif name is None:
            return None
        else:
            return str(name)

    def __init__(self, genfn, *, name=None, keep_auxiliary=True, keep_return=True):
        self.genfn = genfn
        self.name = self.validate_name(name)
        self.model_info = dict(
            keep_auxiliary=keep_auxiliary,
            keep_return=keep_return,
            scope=name_scope(self.name),
            name=self.name,
        )

    def control_flow(self):
        """Iterate over the random variables in the model."""
        return (yield from self.genfn())


In [310]:
_no_name_provided = object()

def get_name(default, base_fn, name) -> Optional[str]:
    """Parse the name of an rv from arguments.

    Parameters
    ----------
    default : _no_name_provided, str, or None
        Default to fall back to if it is not _no_name_provided
    base_fn : callable
        In case the random variable has a name attribute
        and defualt is _no_name_provided, use that
    name : _no_name_provided, str, or None
        Provided argument

    Returns
    -------
    str or None
    """
    if name is _no_name_provided:
        if default is not _no_name_provided:
            name = default
        elif hasattr(base_fn, "name"):
            name = getattr(base_fn, "name")
        elif hasattr(base_fn, "__name__"):
            name = base_fn.__name__
    return name


class ModelTemplate:
    """Model Template -- generative model with metadata.

    ModelTemplate is a callable object that represents a generative process. A generative process samples
    from prior distributions and allows them to interact in arbitrarily-complex, user-defined ways.

    Parameters
    ----------
    template : callable
        Generative process, that accepts any arguments as conditioners and returns realizations if any.
    keep_auxiliary : bool
        Generative process may require some auxiliary variables to be created, but they are probably will not be used
        anywhere else. In that case it is useful to tell PyMC4 engine that we can get rid of auxiliary variables
        as long as they are not needed any more.
    keep_return : bool
        The return value of the model will be recorded
    """

    def __init__(self, template, *, name=None, keep_auxiliary=True, keep_return=True):
        self.template = template
        print("Inside ModelTemplate Constructor self.name:", name)
        self.name = name
        self.keep_auxiliary = keep_auxiliary
        self.keep_return = keep_return

    def __call__(
        self, *args, name=_no_name_provided, keep_auxiliary=None, keep_return=None, **kwargs
    ):
        genfn = functools.partial(self.template, *args, **kwargs)
        name = get_name(self.name, self.template, name)
        if name is not None and not NameParts.is_valid_untransformed_name(name):
            # throw an informative message to fix a name
            raise ValueError(NameParts.UNTRANSFORMED_NAME_ERROR_MESSAGE)
        if keep_auxiliary is None:
            keep_auxiliary = self.keep_auxiliary
        if keep_return is None:
            keep_return = self.keep_return

        return Model(genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return)


In [313]:
_no_name_provided = object()

@biwrap
def model(genfn, *, name=_no_name_provided, keep_auxiliary=True, keep_return=True, method=False):
    """Flexibly wrap a generator function into a Model template."""
    
    print("Inside model, before if-else:, name:", name, "genfn_name:", genfn.__name__)
    if method:
        # What is this block for?
        template = ModelTemplate(
            genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return
        )

        @functools.wraps(genfn)
        def wrapped(*args, **kwargs):
            return template(*args, **kwargs)

        return wrapped
    else:
        template = ModelTemplate(
            genfn, name=name, keep_auxiliary=keep_auxiliary, keep_return=keep_return
        )
        print("genfn:", genfn)
        print("name:", name)
        print("keep_auxiliary:", keep_auxiliary)
        print("keep_return:", keep_return)
        return template

As the model function accepts no positional arguments, implies len(args) = 0, implies biwrap will return functools.partial object

In [318]:
@model()
def superman():
    print("Superman")

superman()

Args in enhanced: ()
Kwargs in enhanced: {}
Wrapper name is: model
Length of args:  0
Returning functools.partial
Inside model, before if-else:, name: <object object at 0x1102b8c90> genfn_name: superman
Inside ModelTemplate Constructor self.name: <object object at 0x1102b8c90>
genfn: <function superman at 0x110fd34d0>
name: <object object at 0x1102b8c90>
keep_auxiliary: True
keep_return: True
Name received is: <object object at 0x1102b8440>


TypeError: expected string or bytes-like object

In [309]:
def superman():
    print("Superman")

decorated_superman = model(name="tiger")(superman)

Args in enhanced: ()
Kwargs in enhanced: {'name': 'tiger'}
Wrapper name is: model
Length of args:  0
Returning functools.partial
Inside model, before if-else:, name: tiger genfn_name: superman
Inside ModelTemplate Constructor self.name: tiger
genfn: <function superman at 0x110ff9b90>
name: tiger
keep_auxiliary: True
keep_return: True
