we want to debug some info about function


In [9]:
def trace(f):
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    return inner


def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

max = trace(max)

def foo():
    max(-10, -1, -3)
foo()


max(-10, -1, -3) = ...
max(-10, -1, -3) = 0


Same using python syntax

In [10]:
DEBUG = True

def trace(f):
    
    if not DEBUG:
        return f
    
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

def foo():
    max(-10, -1, -3)
foo()

max(-10, -1, -3) = ...
max(-10, -1, -3) = 0


But there is a problem with all function attributes

In [11]:
help(max) 

Help on function inner in module __main__:

inner(*args, **kwargs)



In [12]:

DEBUG = True

def trace(f):
    
    if not DEBUG:
        return f
    
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    inner.__name__ = f.__name__
    inner.__doc__ = f.__doc__
    inner.__module__ = f.__module__
    
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

help(max)

Help on function max in module __main__:

max(*args, **kwargs)
    Finds the largest argument.



We can extract this wrapper update functionality

In [13]:
def update_wrapper(wrapped, wrapper):
    for attr in ["__name__", "__doc__", "__module__"]:
        setattr(wrapper, attr, getattr(wrapped, attr))
        wrapper.__wrapped__ = wrapped

DEBUG = True

def trace(f):
    
    if not DEBUG:
        return f
    
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    update_wrapper(f, inner)
    
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

help(max)

Help on function max in module __main__:

max(*args)
    Finds the largest argument.



This also can be decorated

In [15]:
import functools

def update_wrapper(wrapped, wrapper):
    for attr in ["__name__", "__doc__", "__module__"]:
        setattr(wrapper, attr, getattr(wrapped, attr))
        wrapper.__wrapped__ = wrapped
    return wrapper

DEBUG = True

def trace(f):
    if not DEBUG:
        return f
    
    wraps = functools.partial(update_wrapper, f)
    
    @wraps
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

help(max)

Help on function max in module __main__:

max(*args)
    Finds the largest argument.



Same without functools


In [1]:

def update_wrapper(wrapped, wrapper):
    for attr in ["__name__", "__doc__", "__module__"]:
        setattr(wrapper, attr, getattr(wrapped, attr))
    return wrapper

DEBUG = True

def wraps(f):
    def deco(g):
        update_wrapper(f, g)
        return  g
    return  deco

def trace(f):
    if not DEBUG:
        return f
        
    @wraps(f)
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

help(max)

Help on function max in module __main__:

max(*args)
    Finds the largest argument.



This functionality is already in functools


In [5]:
import functools

DEBUG = True

def trace(f):
    if not DEBUG:
        return f
        
    @functools.wraps(f)
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...")
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}")
        return  ret
    
    return inner

@trace
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

help(max)

Help on function max in module __main__:

max(*args)
    Finds the largest argument.



Now we want to output the trace result to stderr

In [8]:
import functools
import sys

DEBUG = True

def trace(stream=sys.stdout):
    def decorator(f):
        if not DEBUG:
            return f
        
        @functools.wraps(f)
        def inner(*args, **kwargs):
            call = ", ".join(
                [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
            )
            print(f"{f.__name__}({call}) = ...", file=stream)
            ret = f(*args, **kwargs)
            print(f"{f.__name__}({call}) = {ret}", file=stream)
            return  ret
        
        return inner
    return decorator

@trace(sys.stderr)
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

max(1, 2)

max(1, 2) = ...
max(1, 2) = 2


2

But if we don't want to put arguments every time @trace(sys.stderr). Sometimes we want just @trace

In [15]:
import functools
import sys

DEBUG = True

def trace(f=None, *, stream=sys.stdout):
    if f is None:
        return functools.partial(trace, stream=stream)
    
    if not DEBUG:
        return f
    
    @functools.wraps(f)
    def inner(*args, **kwargs):
        call = ", ".join(
            [str(a) for a in args] + [f"{k}={v}" for k, v in kwargs]
        )
        print(f"{f.__name__}({call}) = ...", file=stream)
        ret = f(*args, **kwargs)
        print(f"{f.__name__}({call}) = {ret}", file=stream)
        return  ret
        
    return inner


@trace(stream=sys.stderr)
def max(*args):
    """Finds the largest argument."""
    ret = 0
    for x in args:
        ret = ret if x < ret else x
    return ret

max(1, 2)

max(1, 2) = ...
max(1, 2) = 2


2

## Examples
### Logger initialization
we want to make init logger callable only once

In [19]:
import warnings

def once(f):
    called = False
    def inner(*args, **kwargs):
        nonlocal called
        if not called:
            called = True
            res = f(*args, **kwargs)
            assert res is None
    
    return inner

    
    
    
def deprecated(f):
    called = False
    def inner(*args, **kwargs):
        warnings.warn(f"{f.__name__} is deprecated", category=DeprecationWarning)
        print(f"Don't use {f.__name__}, use ... instead", file=sys.stderr)
        f(*args, **kwargs)
    
    return inner


@once
@deprecated
def init_logger():
    print("initializing logger")
    


def foo():
    init_logger()

foo()
foo()
    

initializing logger


Don't use init_logger, use ... instead


### Time and cache

In [27]:
import time
import functools


def profile(f):
    @functools.wraps(f)
    def inner(*args, **kwargs):
        start = time.perf_counter()
        res = f(*args, **kwargs)
        elapsed = time.perf_counter() - start
        inner.__n_calls__ += 1
        inner.__total_time__ += elapsed
        return res
    
    inner.__n_calls__ = 0
    inner.__total_time__ = 0
    return inner


def memoize(f):
    cache = {}
    
    @functools.wraps(f)
    def inner(*args, **kwargs):
        key = (args,   frozenset(kwargs))
        if key not in cache:
            cache[key] = f(*args, **kwargs)
        return cache[key]
    inner.__cache__ = cache
    return inner
    
@profile
@memoize
def fib(n):
    return 1 if n <= 1 else fib(n-1) + fib(n-2)


print(fib(22))
print(fib.__n_calls__)
print(fib.__total_time__)
print(fib.__cache__)

## another cache
@profile
@functools.lru_cache(maxsize=None)
def fib(n):
    return 1 if n <= 1 else fib(n-1) + fib(n-2)

28657
43
0.001493199998549244
{((1,), frozenset()): 1, ((0,), frozenset()): 1, ((2,), frozenset()): 2, ((3,), frozenset()): 3, ((4,), frozenset()): 5, ((5,), frozenset()): 8, ((6,), frozenset()): 13, ((7,), frozenset()): 21, ((8,), frozenset()): 34, ((9,), frozenset()): 55, ((10,), frozenset()): 89, ((11,), frozenset()): 144, ((12,), frozenset()): 233, ((13,), frozenset()): 377, ((14,), frozenset()): 610, ((15,), frozenset()): 987, ((16,), frozenset()): 1597, ((17,), frozenset()): 2584, ((18,), frozenset()): 4181, ((19,), frozenset()): 6765, ((20,), frozenset()): 10946, ((21,), frozenset()): 17711, ((22,), frozenset()): 28657}


## Singledispatch

example we can call len from many objects


len("foo")
len([1, 2, 3])
len({})

"foo".__len__()

if we want to define such function:

In [29]:
@functools.singledispatch
def json(x):
    assert False, f"json not supported for {type(x)}"
    
@json.register(type(None))
def _(x):
    return "null"

@json.register(int)
def _(x):
    return str(x)

@json.register(list)
def _(xs):
    contents = ", ".join(json(x) for x in xs)
    return f"[{contents}]"

print(json(None))
print(json(92))
print(json([92, None]))

null
92
[92, null]


In [31]:
## Functools.reduce

In [32]:
res = functools.reduce(lambda x, y: x + ", " + y , ["a", "b", "c"], "initial")

print(res)

initial, a, b, c


In [34]:
def max(xs):
    return functools.reduce(lambda x, y : x if x >= y else y, xs, float("-inf"))
print (max([1, 2, 3, 4, 5, 3]))


5
