### Adding support for operations on functions in Python

This notebook looks at how we can support operations on functions. Some operations include

#### Arithmetic operations

$$
(f+g)(x) = f(x) + g(x)
$$

$$
(f - g)(x) = f(x) - g(x)
$$

$$
(f \cdot g)(x) = f(x) \cdot g(x)
$$

$$
\left(\frac{f}{g}\right)(x) = \frac{f(x)}{g(x)}
$$

etc ...

#### Additional operations

Augmentation (concatenation)
$$
(f | g)(x) = \begin{bmatrix} f(x) & g(x) \end{bmatrix}
$$

Composition
$$
(f \circ g)(x) = f(g(x))
$$

In [206]:
def f(x):
    return 4*x

In [207]:
def g(x): 
    return x+5

In [208]:
f(5)

20

In [209]:
g(5)

10

In [210]:
f(5)+g(5)

30

Obviously, we can't do this yet because first-class functions do not support arithmetic operations 

In [211]:
(f+g)(5)

TypeError: unsupported operand type(s) for +: 'function' and 'function'

We write a decorator that augments a function (callable) with arithmetic operations 

In [220]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs))
    
    def __sub__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs))

    def __mul__(self, other):
        return lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs))
    
    def __div__(self, other): 
        return lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs))
    
    # etc ...

In [221]:
@FunctionWrapper
def f(x):
    return 4*x

In [222]:
@FunctionWrapper
def g(x): 
    return x+5

In [223]:
g(5)

10

In [224]:
(f+g)(5)

30

In [225]:
(f*g)(5)

200

In [227]:
f(5)*g(5)+g(5)*g(5)

300

In [412]:
(f*g+g*g)(5)

TypeError: unsupported operand type(s) for +: 'function' and 'function'

This breaks down as functions that result from operations on functions are not augmented. E.g. in the example above, `f*g` and `g*g` are now just first-class functions which of course do not support arithmetic operations.

### Recursive class definitions

This is easily remedied with recursive class definitions.

In [230]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs)))
    
    def __sub__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs)))

    def __mul__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs)))
    
    def __div__(self, other): 
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs)))
    
    # etc ...

In [231]:
@FunctionWrapper
def f(x):
    return 4*x

In [232]:
@FunctionWrapper
def g(x): 
    return x+5

In [234]:
(f*g+g*g)(5)

300

In [235]:
h = f*g+g*g

In [236]:
h(5)

300

We can even create constant functions.

In [237]:
2*f(5)*g(5)+g(5)*g(5)+5

505

In [238]:
h = FunctionWrapper(lambda _: 2)*f*g+g*g+FunctionWrapper(lambda _: 5)

In [239]:
h(5)

505

We can support this directly by implementing an alternative constructor

In [243]:
class FunctionWrapper:
    
    def __init__(self, fn):
        self.fn = fn
        
    @classmethod
    def const(cls, c):
        return cls(lambda _: c)
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

    def __add__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__add__(other(*args, **kwargs)))
    
    def __sub__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__sub__(other(*args, **kwargs)))

    def __mul__(self, other):
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__mul__(other(*args, **kwargs)))
    
    def __div__(self, other): 
        return FunctionWrapper(lambda *args, **kwargs: self(*args, **kwargs).__div__(other(*args, **kwargs)))
    
    # etc ...

In [244]:
h = FunctionWrapper.const(2)*f*g+g*g+FunctionWrapper.const(5)

In [245]:
h(5)

505

This is a lot of repetitive code. We can greatly simplify this with metaclasses.

#### Dynamically obtaining binary operators 

The metaclasses approach works because we are able to obtain magic methods dynamically

In [246]:
6 + 8 == getattr(6, '__add__')(8)

True

In [247]:
6 * 8 == getattr(6, '__mul__')(8)

True

### Metaclasses to the rescue

In [248]:
class FunctionWrapperMeta(type):
    
    def __new__(cls, name, bases, dct):
        for method in dct['methods']:
            dct[method] = lambda self, other: lambda *args, **kwargs: getattr(self(*args, **kwargs), method)(other(*args, **kwargs))
        return super(FunctionWrapperMeta, cls).__new__(cls, name, bases, dct)

In [249]:
class FunctionWrapper:
    __metaclass__ = FunctionWrapperMeta
    
    methods = ['__add__', '__mul__', '__sub__', '__div__']
    
    def __init__(self, fn):
        self.fn = fn
        
    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

In [251]:
@FunctionWrapper
def f(x):
    return 4*x

In [252]:
@FunctionWrapper
def g(x): 
    return x+5

In [253]:
f(5)

20

In [254]:
g(5)

10

In [255]:
(f+g)(5)

2

In [256]:
(f-g)(5)

2

In [257]:
(f*g)(5)

2

In [258]:
(f/g)(5)

2

In [259]:
f(5)/g(5)

2

#### GOTCHA: Python's closures are late binding!

In [262]:
def multipliers(n):
    fns = []
    for m in range(n):
        fns.append(lambda x: m*x)
    return fns

In [263]:
for mul in multipliers(5):
    print mul(3)

12
12
12
12
12


In [264]:
multipliers = lambda n: [lambda x: m*x for m in range(n)]

In [265]:
for mul in multipliers(5):
    print mul(3)

12
12
12
12
12


In [268]:
def multipliers(n):
    fns = []
    for m in range(n):
        make_fn = lambda m: lambda x: m*x
        fns.append(make_fn(m))
    return fns

In [269]:
for mul in multipliers(5):
    print mul(3)

0
3
6
9
12


In [270]:
multipliers = lambda n: [(lambda m: lambda x: m*x)(m) for m in range(n)]

In [271]:
for mul in multipliers(5):
    print mul(3)

0
3
6
9
12


In [272]:
def multipliers(n):
    for m in range(n):
        yield lambda x: m*x

In [273]:
for mul in multipliers(5):
    print mul(3)

0
3
6
9
12


In [274]:
multipliers = lambda n: (lambda x: m*x for m in range(n))

In [275]:
for mul in multipliers(5):
    print mul(3)

0
3
6
9
12


In [423]:
make_method = lambda method: lambda self, other: lambda *args, **kwargs: getattr(self(*args, **kwargs), method)(other(*args, **kwargs))

In [424]:
class FunctionWrapperMeta(type):
    
    def __new__(cls, name, bases, dct):
        for method in dct['methods']:
            dct[method] = make_method(method)
        return super(FunctionWrapperMeta, cls).__new__(cls, name, bases, dct)

In [425]:
class FunctionWrapper:
    __metaclass__ = FunctionWrapperMeta
    
    methods = ['__add__', '__mul__', '__sub__', '__div__']
    
    def __init__(self, fn):
        self.fn = fn

    def __call__(self, *args, **kwargs):
        return self.fn(*args, **kwargs)

In [426]:
@FunctionWrapper
def f(x):
    return 4*x

In [427]:
@FunctionWrapper
def g(x): 
    return x+5

In [428]:
f(5)

20

In [429]:
g(5)

10

In [430]:
(f+g)(5)

30

In [431]:
(f-g)(5)

10

In [432]:
(f*g)(5)

200

In [433]:
(f/g)(5)

2

### Beyond scalar functions: vector/matrix-valued functions

In [434]:
import numpy as np
from scipy.spatial.distance import cdist

In [435]:
@FunctionWrapper
def polynomial_basis(X, deg=2):
    return X**np.arange(deg+1)

In [436]:
X = np.random.randint(10, size=(10, 1)); X

array([[1],
       [9],
       [2],
       [6],
       [8],
       [5],
       [4],
       [1],
       [1],
       [8]])

In [437]:
polynomial_basis(X)

array([[ 1,  1,  1],
       [ 1,  9, 81],
       [ 1,  2,  4],
       [ 1,  6, 36],
       [ 1,  8, 64],
       [ 1,  5, 25],
       [ 1,  4, 16],
       [ 1,  1,  1],
       [ 1,  1,  1],
       [ 1,  8, 64]])

In [438]:
(polynomial_basis+polynomial_basis)(X)

array([[  2,   2,   2],
       [  2,  18, 162],
       [  2,   4,   8],
       [  2,  12,  72],
       [  2,  16, 128],
       [  2,  10,  50],
       [  2,   8,  32],
       [  2,   2,   2],
       [  2,   2,   2],
       [  2,  16, 128]])

#### Example: Covariance functions

In [439]:
@FunctionWrapper
def sqr_exp(x_p, x_q=None, len_scale=1.):

    if x_q is None:
        return np.ones_like(x_p).ravel()
    r_sq = cdist(x_p, x_q, 'sqeuclidean')
    return np.exp(-0.5*r_sq/len_scale**2)

In [440]:
@FunctionWrapper
def matern32(x_p, x_q=None, len_scale=1.):

    if x_q is None:
        return np.ones_like(x_p).ravel()
    r = cdist(x_p, x_q, 'euclidean')
    s = np.sqrt(3)*r/len_scale
    return (1+s)*np.exp(-s)

In [441]:
x_p = np.array([[ 0.25],
                [ 0.75]])

In [442]:
x_q = np.array([[ 0.60],
                [ 0.12],
                [ 0.75]])

In [443]:
sqr_exp(x_p, x_q, 0.5).round(2)

array([[ 0.78,  0.97,  0.61],
       [ 0.96,  0.45,  1.  ]])

In [444]:
matern32(x_p, x_q, 0.5).round(2)

array([[ 0.66,  0.92,  0.48],
       [ 0.9 ,  0.36,  1.  ]])

In [445]:
(sqr_exp(x_p, x_q, 0.5)+matern32(x_p, x_q, 0.5)).round(2)

array([[ 1.44,  1.89,  1.09],
       [ 1.86,  0.81,  2.  ]])

In [446]:
(sqr_exp + matern32)(x_p, x_q, 0.5).round(2)

array([[ 1.44,  1.89,  1.09],
       [ 1.86,  0.81,  2.  ]])

#### Problem: Functions with heterogenous parameters/arguments

In [398]:
class PolynomialBasis:

    def __init__(self, degree=2., include_bias=True):
        self.include_bias = include_bias
        self.degree = degree

    def __call__(self, X):
        powers = np.arange(self.degree+1)
        if not self.include_bias:
            powers = powers[1:]
        return X**powers

In [399]:
class RadialBasis:

    def __init__(self, mu=0., s=1.):
        self.mu = mu
        self.s = s

    def __call__(self, X):
        return np.exp(-cdist(X, self.mu, 'sqeuclidean')/(2*self.s**2))

In [447]:
f = FunctionWrapper(PolynomialBasis())

In [448]:
g = FunctionWrapper(RadialBasis(mu=np.arange(4, 7).reshape(-1, 1)))

In [449]:
f(X)

array([[  1.,   1.,   1.],
       [  1.,   9.,  81.],
       [  1.,   2.,   4.],
       [  1.,   6.,  36.],
       [  1.,   8.,  64.],
       [  1.,   5.,  25.],
       [  1.,   4.,  16.],
       [  1.,   1.,   1.],
       [  1.,   1.,   1.],
       [  1.,   8.,  64.]])

In [450]:
g(X).round(2)

array([[ 0.01,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.01],
       [ 0.14,  0.01,  0.  ],
       [ 0.14,  0.61,  1.  ],
       [ 0.  ,  0.01,  0.14],
       [ 0.61,  1.  ,  0.61],
       [ 1.  ,  0.61,  0.14],
       [ 0.01,  0.  ,  0.  ],
       [ 0.01,  0.  ,  0.  ],
       [ 0.  ,  0.01,  0.14]])

In [451]:
(f(X)+g(X)).round(2)

array([[  1.01,   1.  ,   1.  ],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   2.01,   4.  ],
       [  1.14,   6.61,  37.  ],
       [  1.  ,   8.01,  64.14],
       [  1.61,   6.  ,  25.61],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.  ,   8.01,  64.14]])

In [452]:
h = f+g
h(X).round(2)

array([[  1.01,   1.  ,   1.  ],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   2.01,   4.  ],
       [  1.14,   6.61,  37.  ],
       [  1.  ,   8.01,  64.14],
       [  1.61,   6.  ,  25.61],
       [  2.  ,   4.61,  16.14],
       [  1.01,   1.  ,   1.  ],
       [  1.01,   1.  ,   1.  ],
       [  1.  ,   8.01,  64.14]])

In [410]:
class PolynomialBasis:

    __metaclass__ = FunctionWrapperMeta

    methods = ['__add__', '__mul__', '__sub__', '__div__']

    def __init__(self, degree=2., include_bias=True):
        self.include_bias = include_bias
        self.degree = degree

    def __call__(self, X):
        powers = np.arange(self.degree+1)
        if not self.include_bias:
            powers = powers[1:]
        return X**powers

In [411]:
class RadialBasis:

    __metaclass__ = FunctionWrapperMeta

    methods = ['__add__', '__mul__', '__sub__', '__div__']
 
    def __init__(self, mu=0., s=1.):
        self.mu = mu
        self.s = s
        
    def __call__(self, X):
        return np.exp(-cdist(X, self.mu, 'sqeuclidean')/(2*self.s**2))

In [392]:
f = PolynomialBasis()

In [393]:
g = RadialBasis(mu=np.arange(4, 7).reshape(-1, 1))

In [394]:
f(X)

array([[  1.,   3.,   9.],
       [  1.,   5.,  25.],
       [  1.,   9.,  81.],
       [  1.,   2.,   4.],
       [  1.,   7.,  49.],
       [  1.,   4.,  16.],
       [  1.,   2.,   4.],
       [  1.,   9.,  81.],
       [  1.,   6.,  36.],
       [  1.,   9.,  81.]])

In [395]:
g(X).round(2)

array([[ 0.61,  0.14,  0.01],
       [ 0.61,  1.  ,  0.61],
       [ 0.  ,  0.  ,  0.01],
       [ 0.14,  0.01,  0.  ],
       [ 0.01,  0.14,  0.61],
       [ 1.  ,  0.61,  0.14],
       [ 0.14,  0.01,  0.  ],
       [ 0.  ,  0.  ,  0.01],
       [ 0.14,  0.61,  1.  ],
       [ 0.  ,  0.  ,  0.01]])

In [396]:
(f(X)+g(X)).round(2)

array([[  1.61,   3.14,   9.01],
       [  1.61,   6.  ,  25.61],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   2.01,   4.  ],
       [  1.01,   7.14,  49.61],
       [  2.  ,   4.61,  16.14],
       [  1.14,   2.01,   4.  ],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   6.61,  37.  ],
       [  1.  ,   9.  ,  81.01]])

In [397]:
h = f+g
h(X).round(2)

array([[  1.61,   3.14,   9.01],
       [  1.61,   6.  ,  25.61],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   2.01,   4.  ],
       [  1.01,   7.14,  49.61],
       [  2.  ,   4.61,  16.14],
       [  1.14,   2.01,   4.  ],
       [  1.  ,   9.  ,  81.01],
       [  1.14,   6.61,  37.  ],
       [  1.  ,   9.  ,  81.01]])