### Adding support for operations on functions in Python

This notebook looks at how we can support operations on functions. Some operations include

#### Arithmetic operations

$$
(f+g)(x) = f(x) + g(x)
$$

$$
(f - g)(x) = f(x) - g(x)
$$

$$
(f \cdot g)(x) = f(x) \cdot g(x)
$$

$$
\left(\frac{f}{g}\right)(x) = \frac{f(x)}{g(x)}
$$

etc ...

#### Additional operations

Augmentation (concatenation)
$$
(f | g)(x) = \begin{bmatrix} f(x) & g(x) \end{bmatrix}
$$

Composition
$$
(f \circ g)(x) = f(g(x))
$$

In [1]:
def f(x):
    return 4*x

In [2]:
def g(x): 
    return x+5

In [3]:
f(5)

20

In [4]:
g(5)

10

In [5]:
f(5)+g(5)

30

Obviously, we can't do this yet because first-class functions do not support arithmetic operations 

In [6]:
(f+g)(5)

TypeError: unsupported operand type(s) for +: 'function' and 'function'

We write a decorator that augments a function (callable) with arithmetic operations 

In [7]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs))
    
    def __sub__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs))

    def __mul__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs))
    
    def __div__(self, other): 
        return lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs))
    
    # etc ...

In [8]:
@FunctionWrapper
def f(x):
    return 4*x

In [9]:
@FunctionWrapper
def g(x): 
    return x+5

In [10]:
g(5)

10

In [11]:
(f+g)(5)

30

In [12]:
(f*g)(5)

200

In [13]:
f(5)*g(5)+g(5)*g(5)

300

In [14]:
(f*g+g*g)(5)

TypeError: unsupported operand type(s) for +: 'function' and 'function'

This breaks down as functions that result from operations on functions are not augmented. E.g. in the example above, `f*g` and `g*g` are now just first-class functions which of course do not support arithmetic operations.

### Recursive class definitions

This is easily remedied with recursive class definitions.

In [15]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs)))
    
    def __sub__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs)))

    def __mul__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs)))
    
    def __div__(self, other): 
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs)))
    
    # etc ...

In [16]:
@FunctionWrapper
def f(x):
    return 4*x

In [17]:
@FunctionWrapper
def g(x): 
    return x+5

In [18]:
(f*g+g*g)(5)

300

In [19]:
h = f*g+g*g

In [20]:
h(5)

300

We can even create constant functions.

In [21]:
2*f(5)*g(5)+g(5)*g(5)+5

505

In [22]:
h = FunctionWrapper(lambda _: 2)*f*g+g*g+FunctionWrapper(lambda _: 5)

In [23]:
h(5)

505

We can support this directly by implementing an alternative constructor

In [24]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    @classmethod
    def const(cls, c):
        return cls(lambda _: c)
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs)))
    
    def __sub__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs)))

    def __mul__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs)))
    
    def __div__(self, other): 
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs)))
    
    # etc ...

In [25]:
h = FunctionWrapper.const(2)*f*g+g*g+FunctionWrapper.const(5)

In [26]:
h(5)

505

This is a lot of repetitive code. We can greatly simplify this with a metaclass.

#### Dynamically obtaining binary operators 

The metaclasses approach works because we are able to get magic methods dynamically

In [27]:
6 + 8 == getattr(6, '__add__')(8)

True

In [28]:
6 * 8 == getattr(6, '__mul__')(8)

True

### Metaclasses to the rescue

In [29]:
class FunctionWrapperMeta(type):
    
    def __new__(cls, name, bases, dct):
        for method in dct['methods']:
            dct[method] = lambda self, other: lambda *args, **kwargs: getattr(self(*args, **kwargs), method)(other(*args, **kwargs))
        return super(FunctionWrapperMeta, cls).__new__(cls, name, bases, dct)

In [30]:
class FunctionWrapper:
    __metaclass__ = FunctionWrapperMeta
    
    methods = ['__add__', '__mul__', '__sub__', '__div__']
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

In [31]:
@FunctionWrapper
def f(x):
    return 4*x

In [32]:
@FunctionWrapper
def g(x): 
    return x+5

In [33]:
f(5)

20

In [34]:
g(5)

10

In [35]:
(f+g)(5)

2

In [36]:
(f-g)(5)

2

In [37]:
(f*g)(5)

2

In [38]:
(f/g)(5)

2

In [39]:
f(5)/g(5)

2

#### GOTCHA: Python's closures are late binding!

In [40]:
def multipliers(n):
    fns = []
    for m in range(n):
        fns.append(lambda x: m*x)
    return fns

In [41]:
for mul in multipliers(5):
    print mul(3)

12
12
12
12
12


#### Solution

In [45]:
make_fn = lambda m: lambda x: m*x

In [48]:
def multipliers(n):
    fns = []
    for m in range(n):
        fns.append(make_fn(m))
    return fns

In [49]:
for mul in multipliers(5):
    print mul(3)

0
3
6
9
12


In [50]:
make_method = lambda method: lambda self, other: lambda *args, **kwargs: getattr(self(*args, **kwargs), method)(other(*args, **kwargs))

In [51]:
class FunctionWrapperMeta(type):
    
    def __new__(cls, name, bases, dct):
        for method in dct['methods']:
            dct[method] = make_method(method)
        return super(FunctionWrapperMeta, cls).__new__(cls, name, bases, dct)

In [52]:
class FunctionWrapper:
    __metaclass__ = FunctionWrapperMeta
    
    methods = ['__add__', '__mul__', '__sub__', '__div__']
    
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

In [53]:
@FunctionWrapper
def f(x):
    return 4*x

In [54]:
@FunctionWrapper
def g(x): 
    return x+5

In [55]:
f(5)

20

In [56]:
g(5)

10

In [57]:
(f+g)(5)

30

In [58]:
(f-g)(5)

10

In [59]:
(f*g)(5)

200

In [60]:
(f/g)(5)

2

# TODO: Recursive class definition

### Beyond scalar functions: vector/matrix-valued functions

This extends trivially to vector and matrix-valued functions. Specifically, an operation on any arbitrary function will work as long as the operation is well-defined on the return type of the function.

In [62]:
import numpy as np
from scipy.spatial.distance import cdist

In [63]:
@FunctionWrapper
def polynomial_basis(X, deg=2):
    return X**np.arange(deg+1)

In [64]:
X = np.random.randint(10, size=(10, 1)); X

array([[5],
       [5],
       [1],
       [1],
       [6],
       [6],
       [7],
       [9],
       [4],
       [1]])

In [65]:
polynomial_basis(X)

array([[ 1,  5, 25],
       [ 1,  5, 25],
       [ 1,  1,  1],
       [ 1,  1,  1],
       [ 1,  6, 36],
       [ 1,  6, 36],
       [ 1,  7, 49],
       [ 1,  9, 81],
       [ 1,  4, 16],
       [ 1,  1,  1]])

In [66]:
(polynomial_basis+polynomial_basis)(X)

array([[  2,  10,  50],
       [  2,  10,  50],
       [  2,   2,   2],
       [  2,   2,   2],
       [  2,  12,  72],
       [  2,  12,  72],
       [  2,  14,  98],
       [  2,  18, 162],
       [  2,   8,  32],
       [  2,   2,   2]])

#### Example: Covariance functions

In [67]:
@FunctionWrapper
def sqr_exp(x_p, x_q=None, len_scale=1.):

    if x_q is None:
        return np.ones_like(x_p).ravel()
    r_sq = cdist(x_p, x_q, 'sqeuclidean')
    return np.exp(-0.5*r_sq/len_scale**2)

In [68]:
@FunctionWrapper
def matern32(x_p, x_q=None, len_scale=1.):

    if x_q is None:
        return np.ones_like(x_p).ravel()
    r = cdist(x_p, x_q, 'euclidean')
    s = np.sqrt(3)*r/len_scale
    return (1+s)*np.exp(-s)

In [69]:
x_p = np.array([[ 0.25],
                [ 0.75]])

In [70]:
x_q = np.array([[ 0.60],
                [ 0.12],
                [ 0.75]])

In [71]:
sqr_exp(x_p, x_q, 0.5).round(2)

array([[ 0.78,  0.97,  0.61],
       [ 0.96,  0.45,  1.  ]])

In [72]:
matern32(x_p, x_q, 0.5).round(2)

array([[ 0.66,  0.92,  0.48],
       [ 0.9 ,  0.36,  1.  ]])

In [73]:
(sqr_exp(x_p, x_q, 0.5)+matern32(x_p, x_q, 0.5)).round(2)

array([[ 1.44,  1.89,  1.09],
       [ 1.86,  0.81,  2.  ]])

In [74]:
(sqr_exp + matern32)(x_p, x_q, 0.5).round(2)

array([[ 1.44,  1.89,  1.09],
       [ 1.86,  0.81,  2.  ]])

#### Problem: Functions with heterogenous parameters/arguments

In [75]:
class PolynomialBasis:

    def __init__(self, degree=2., include_bias=True):
        self.include_bias = include_bias
        self.degree = degree

    def __call__(self, X):
        powers = np.arange(self.degree+1)
        if not self.include_bias:
            powers = powers[1:]
        return X**powers

In [76]:
class RadialBasis:

    def __init__(self, mu=0., s=1.):
        self.mu = mu
        self.s = s

    def __call__(self, X):
        return np.exp(-cdist(X, self.mu, 'sqeuclidean')/(2*self.s**2))

Note that *instances* of these Basis classes are callable, not the classes themselves. So we must decorate *instances* of these bases.

In [80]:
f = FunctionWrapper(PolynomialBasis())

In [81]:
g = FunctionWrapper(RadialBasis(mu=np.arange(4, 7).reshape(-1, 1)))

In [82]:
f(X)

array([[  1.,   5.,  25.],
       [  1.,   5.,  25.],
       [  1.,   1.,   1.],
       [  1.,   1.,   1.],
       [  1.,   6.,  36.],
       [  1.,   6.,  36.],
       [  1.,   7.,  49.],
       [  1.,   9.,  81.],
       [  1.,   4.,  16.],
       [  1.,   1.,   1.]])

In [83]:
g(X).round(2)

array([[ 0.61,  1.  ,  0.61],
       [ 0.61,  1.  ,  0.61],
       [ 0.01,  0.  ,  0.  ],
       [ 0.01,  0.  ,  0.  ],
       [ 0.14,  0.61,  1.  ],
       [ 0.14,  0.61,  1.  ],
       [ 0.01,  0.14,  0.61],
       [ 0.  ,  0.  ,  0.01],
       [ 1.  ,  0.61,  0.14],
       [ 0.01,  0.  ,  0.  ]])

In [84]:
(f(X)+g(X)).round(2)

array([[  1.61,   6.  ,  25.61],
       [  1.61,   6.  ,  25.61],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.14,   6.61,  37.  ],
       [  1.14,   6.61,  37.  ],
       [  1.01,   7.14,  49.61],
       [  1.  ,   9.  ,  81.01],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ]])

In [85]:
h = f+g
h(X).round(2)

array([[  1.61,   6.  ,  25.61],
       [  1.61,   6.  ,  25.61],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.14,   6.61,  37.  ],
       [  1.14,   6.61,  37.  ],
       [  1.01,   7.14,  49.61],
       [  1.  ,   9.  ,  81.01],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ]])

Since instances of these basis classes implement `__call__`, and are therefore callable, we can actually do away with having to decorate instances of basis classes and directly specify `FunctionWrapperMeta` as the metaclass of the basis class.

In [120]:
class PolynomialBasis:

    __metaclass__ = FunctionWrapperMeta

    methods = ['__add__', '__mul__', '__sub__', '__div__']

    def __init__(self, degree=2., include_bias=True):
        self.include_bias = include_bias
        self.degree = degree

    def __call__(self, X):
        powers = np.arange(self.degree+1)
        if not self.include_bias:
            powers = powers[1:]
        return X**powers

In [121]:
class RadialBasis:

    __metaclass__ = FunctionWrapperMeta

    methods = ['__add__', '__mul__', '__sub__', '__div__']
 
    def __init__(self, mu=0., s=1.):
        self.mu = mu
        self.s = s
        
    def __call__(self, X):
        return np.exp(-cdist(X, self.mu, 'sqeuclidean')/(2*self.s**2))

Now we needn't worry about decorating basis functions to support operations. It is done automatically at class creation (not object creation) time.

In [129]:
f = PolynomialBasis()

In [130]:
g = RadialBasis(mu=np.arange(4, 7).reshape(-1, 1))

In [131]:
f(X)

array([[  1.,   5.,  25.],
       [  1.,   5.,  25.],
       [  1.,   1.,   1.],
       [  1.,   1.,   1.],
       [  1.,   6.,  36.],
       [  1.,   6.,  36.],
       [  1.,   7.,  49.],
       [  1.,   9.,  81.],
       [  1.,   4.,  16.],
       [  1.,   1.,   1.]])

In [132]:
g(X).round(2)

array([[ 0.61,  1.  ,  0.61],
       [ 0.61,  1.  ,  0.61],
       [ 0.01,  0.  ,  0.  ],
       [ 0.01,  0.  ,  0.  ],
       [ 0.14,  0.61,  1.  ],
       [ 0.14,  0.61,  1.  ],
       [ 0.01,  0.14,  0.61],
       [ 0.  ,  0.  ,  0.01],
       [ 1.  ,  0.61,  0.14],
       [ 0.01,  0.  ,  0.  ]])

In [133]:
(f(X)+g(X)).round(2)

array([[  1.61,   6.  ,  25.61],
       [  1.61,   6.  ,  25.61],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.14,   6.61,  37.  ],
       [  1.14,   6.61,  37.  ],
       [  1.01,   7.14,  49.61],
       [  1.  ,   9.  ,  81.01],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ]])

In [134]:
h = f+g
h(X).round(2)

array([[  1.61,   6.  ,  25.61],
       [  1.61,   6.  ,  25.61],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.14,   6.61,  37.  ],
       [  1.14,   6.61,  37.  ],
       [  1.01,   7.14,  49.61],
       [  1.  ,   9.  ,  81.01],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ]])

# TODO: Concatenation, composition

# TODO
### Appendix: Mind-bending experiments

In [102]:
identity = lambda s: s

In [104]:
(FunctionWrapper(identity) * FunctionWrapper(len))('hello ')

'hello hello hello hello hello hello '

In [118]:
def strange_repeat(s):
    return s * len(s) + s.rjust(20)

In [119]:
 strange_repeat('ATCG')

'ATCGATCGATCGATCG                ATCG'