# DIY AD

Minimalistic implementation of AD in python.
The implementation is supposed to be minimal both in terms of linssssssssssssses-of-code and the concepts required for its implementation.
Namely, we will use exactly one core concept: linearizations.

In [None]:
import numpy as np
import jax

A linearization can be though of as the implicit Jacobian matrix of a function (yes, the name "linearization" is misleading as it might imply having an offset that it does not have here).
We need our implicit Jacobian matrix to have four attributes: a forward pass (Jacobian-Vector-Product), a backwards pass (Vector-Jacobian-Product), a reference to the parameters with respect to which we are going to differentiate to, and the value of the linearized function.
The last attribute is required in order to amend the Jacobian.

The forward pass of the Jacobian reapplies all linearized operators of our function in the order in which they were originally called.
A simple recursion into the individual Jacobians of the operators is sufficient if each of them applies the chain rule correctly.
The backwards pass is slightly trickier as we need to watch out for accumulating gradients.
For this purpose of tracking accumulating gradients, our backwards pass will share a `tape` to which the gradient for each input to a function is accumulated on.

In [None]:
class Linearization():
    def __init__(self, p, fwd=None, bwd=None, _wrt=None):
        self.p = (p, ) if not (isinstance(p, tuple) or p is None) else p
        if fwd is None:
            def fwd(*t, tape): return t[0] if len(self.p) == 1 else t
        if bwd is None:
            def bwd(*t, tape): return t
        self._fwd, self._bwd = fwd, bwd
        # Reference initial parameters w.r.t. which we want to differentiate
        self._wrt = self if _wrt is None else _wrt

    def __call__(self, *t, tape=None):
        if tape is not None:
            return self._fwd(*t, tape=tape)
        # Outermost call of Linearization (i.e. without a tape yet)
        # `tape` serves as shared dict for gradient accumulation
        tape = [0.] * len(t)
        o = self._fwd(*t, tape=tape)
        return o if self.p is not None else tuple(tape)

    @property
    def T(self):
        return self.__class__(None, self._bwd, self._fwd, _wrt=self._wrt)

    @classmethod
    def chain(clss, lins, p, fwd, bwd):
        lins = (lins, ) if isinstance(lins, Linearization) else lins

        def chained_fwd(*t, tape=None):
            # NOTE, `t` is the full input of our overall model. All
            # linearizations that come after must by definition work on all of
            # `t`. Thus, we can skip bookkeeping which input gets where.
            return fwd(*(l(*t, tape=tape) for l in lins), tape=tape)

        def chained_bwd(*t, tape=None):
            bs = bwd(*t, tape=tape)
            bs = (bs, ) if len(lins) == 1 else bs
            assert isinstance(bs, tuple) and len(bs) == len(lins)
            for l, b in zip(lins, bs):
                o = l.T(*b, tape=tape)
                assert isinstance(o, tuple) or o is None
                if l is lins[0]._wrt:
                    assert len(o) == len(tape)
                    for i, ne in enumerate(o):
                        tape[i] += ne
            return None  # all information is stored on tape

        assert all(l._wrt is lins[0]._wrt for l in lins)
        return clss(p, chained_fwd, chained_bwd, _wrt=lins[0]._wrt)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.p}, {self._fwd}, {self._bwd})"

Let's start by defining two simple functions (and their linearizations): `exp` and `add`.
First, let's have a look at the unary function `exp`.
Its Jacobian is a diagonal matrix as it is a point-wise operator.
In terms of code, the forward and reverse of the linearization are thus simple multiplications.
The diagonal operator, i.e. the multiplication, is amended once to the left of the Jacobian for the forward pass and once to the right for the backwards pass.

Notice how in the implementation below the computation of the diagonal is hoisted out of the forward and reverse mode using a closure.
This specific choice of rematerialization strategy can be easily configured by writing custom rules for functions.
A custom rule is nothing more than a simple `isinstance` check for a `Linearization` in a function.

The binary operator `add` is slightly more elaborate in that it takes two arguments and has to handle all permutations of linearizations versus no linearizations as input.
The `isinstance` checks thus have four cases but conceptually works the same as `exp`.

In [None]:
def exp(pl):
    if isinstance(pl, Linearization):
        y = exp(*pl.p)
        def fwd(t, tape):
            return y * t
        def bwd(*t, tape):
            assert len(t) == 1
            return (y * t[0], )
        return Linearization.chain(pl, y, fwd, bwd)
    return np.exp(pl)


def add(pl_l, pl_r):
    if isinstance(pl_l, Linearization) and isinstance(pl_r, Linearization):
        def fwd(t0, t1, tape):
            return t0 + t1
        def bwd(*t, tape):
            assert len(t) == 1
            return (t, t)
        assert len(pl_l.p) == len(pl_r.p) == 1
        return Linearization.chain((pl_l, pl_r), pl_l.p[0] + pl_r.p[0], fwd, bwd)
    elif any(isinstance(pl, Linearization) for pl in (pl_l, pl_r)):
        pll, p = (pl_l, pl_r) if isinstance(pl_l, Linearization) else (pl_r, pl_l)

        def fwd(*t, tape):
            assert len(t) == 1
            return t[0]
        def bwd(*t, tape):
            assert len(t) == 1
            return t
        assert len(pll.p) == 1
        return Linearization.chain(pll, pll.p[0] + p, fwd, bwd)
    else:
        return pl_l + pl_r


def cost(p):
    p = exp(p)
    y = exp(add(p, -25))
    return add(y, p)


p0, t = 3.14, 2.
p0lin = Linearization(p0)
j = cost(p0lin)
j(t), j.T(t), jax.jvp(lambda p: jax.numpy.exp(jax.numpy.exp(p) - 25) + jax.numpy.exp(p), (p0, ), (t, ))[1].item()

In [None]:
p0 = 1e-2 * np.arange(0, 9)
ones = np.ones((9, ))
y = exp(p0)
j = exp(exp(Linearization(p0)))
j(p0), j.T(ones)

In [None]:
_, jj_T = jax.vjp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), p0.astype(float))
_, jj_at_p0 = jax.jvp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), (p0.astype(float), ), (p0, ))
jj_at_p0, jj_T(ones)

How best to rematerialize in the Jacobian is a difficult question.
The approach taken here is a flexible albeit lazy one: We let the user do it explicitly when defining the Jacobian.
However, the design of building forward and backwards pass in general incentivizes sharing memory between both passes.
Furthermore, it is trivial to hoist out constants via closures.

In [None]:
def weighted_reduction(pl, n=32, n_cols=3):
    if isinstance(pl, Linearization):
        def fwd(*t, tape):
            return weighted_reduction(*t)
        def bwd(*t, tape):
            assert len(t) == 1 and len(pl.p) == 1
            p_shp = pl.p[0].reshape(3, -1).shape
            t_T, indices = np.zeros(p_shp), np.arange(n) % p_shp[0]
            for i, idx in enumerate(indices):
                super_expensive_weights = np.ones(p_shp[1:])
                t_T[idx] += t[0][i] * super_expensive_weights
            return (t_T.reshape(pl.p[0].shape), )
        return Linearization.chain(pl, weighted_reduction(*pl.p), fwd, bwd)
    p = pl.reshape(n_cols, -1)
    y, indices = np.zeros((n, )), np.arange(n) % p.shape[0]
    for i, idx in enumerate(indices):
        super_expensive_weights = np.ones(p.shape[1:])
        y[i] = np.sum(p[idx] * super_expensive_weights)
    return y


p0lin = Linearization(np.arange(12, dtype=float))
y2 = weighted_reduction(*p0lin.p)
j = weighted_reduction(p0lin)
print(y2)
j(*p0lin.p), j.T(y2, )

In [None]:
def f(p): return weighted_reduction(exp(p))

p0 = 1e-2 * np.arange(0, 9)
f0 = f(p0)
f_ones = np.ones(f0.shape)

j = f(Linearization(p0))
j(p0), j.T(f_ones)

In [None]:
def sum(pl):
    if isinstance(pl, Linearization):
        def fwd(*t, tape):
            return sum(*t)
        def bwd(*t, tape):
            assert len(t) == 1 and len(pl.p) == 1
            return (t[0] * np.ones(pl.p[0].shape), )
        return Linearization.chain(pl, sum(*pl.p), fwd, bwd)
    return np.sum(pl)

p0 = Linearization(np.arange(0, 9, dtype=float))
j = sum(p0)
j(*p0.p), j.T(1.)

In [None]:
def pow(pl, exponent):
    if isinstance(pl, Linearization) and not isinstance(exponent, Linearization):
        yl = exponent * pow(*pl.p, exponent - 1)
        def fwd(*t, tape):
            assert len(t) == 1
            return yl * t[0]
        def bwd(*t, tape):
            assert len(t) == 1
            return (yl * t[0], )
        return Linearization.chain(pl, pow(*pl.p, exponent), fwd, bwd)
    elif not isinstance(pl, Linearization) and not isinstance(exponent, Linearization):
        return pl**exponent
    else:
        raise NotImplementedError()


y = pow(3., 3)
j = pow(Linearization(3.), 3)
j.T(4.), jax.vjp(lambda x: x**3, 3.)[1](4.)

In [None]:
data = 0.

def h(p):
    return add(sum(pow(add(-data, weighted_reduction(exp(p))), 2)), sum(pow(p, 2)))

p0 = Linearization(np.arange(0, 9, dtype=float))
j = h(p0)
j(*p0.p), j.T(1.)