In [None]:
import numpy as np
from collections import defaultdict, namedtuple
import jax

In [None]:
def identiy(*t, tape): return t


class Linearization():
    def __init__(self, p, fwd=identiy, bwd=identiy, _wrt=None):
        self.p = p
        self._fwd, self._bwd = fwd, bwd
        # Reference initial parameters w.r.t. which we want to differentiate
        self._wrt = self if _wrt is None else _wrt

    def __call__(self, *t, tape=None):
        if tape is not None:
            return self._fwd(*t, tape=tape)
        # Outermost call of Linearization (i.e. without a tape yet)
        # `tape` serves as shared dict for gradient accumulation
        tape = defaultdict(lambda: 0.)
        o = self._fwd(*t, tape=tape)
        assert len(o) == 1 if len(tape) == 0 else True
        return o[0] if len(tape) == 0 else (tape[self._wrt], )

    @property
    def T(self):
        return self.__class__(None, self._bwd, self._fwd, _wrt=self._wrt)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.p}, {self._fwd}, {self._bwd})"

In [None]:
def exp(pl):
    if isinstance(pl, Linearization):
        y = exp(pl.p)
        def fwd(*t, tape):
            t = pl(*t, tape=tape)
            assert len(t) == 1
            return (y * t[0], )
        def bwd(*t, tape):
            assert len(t) == 1
            t = pl.T(y * t[0], tape=tape)
            assert len(t) == 1
            tape[pl] += t[0]
            return t
        return Linearization(y, fwd, bwd, _wrt=pl._wrt)
    return np.exp(pl)


def add(pl_l, pl_r):
    if isinstance(pl_l, Linearization) and not isinstance(pl_r, Linearization):
        def bwd(*t, tape):
            t = pl_l.T(*t, tape=tape)
            assert len(t) == 1
            tape[pl_l] += t[0]
            return t
        return Linearization(pl_l.p + pl_r, pl_l, bwd, _wrt=pl_l._wrt)
    elif not isinstance(pl_l, Linearization) and isinstance(pl_r, Linearization):
        def bwd(*t, tape):
            t = pl_r.T(*t, tape=tape)
            assert len(t) == 1
            tape[pl_r] += t[0]
            return t
        return Linearization(pl_l + pl_r.p, pl_r, bwd, _wrt=pl_r._wrt)
    elif isinstance(pl_l, Linearization) and isinstance(pl_r, Linearization):
        def fwd(*t, tape):
            t = pl_l(*t, tape=tape)
            assert len(t) == 1
            c = t[0]
            t = pl_r(*t, tape=tape)
            assert len(t) == 1
            c += t[0]
            return (c, )
        def bwd(*t, tape):
            assert len(t) == 1 and pl_l._wrt is pl_r._wrt
            tape[pl_l] += t[0]
            tape[pl_r] += t[0]
            return (pl_l.T(*t, tape=tape), pl_r.T(*t, tape=tape))
        return Linearization(pl_l.p + pl_r.p, fwd, bwd, _wrt=pl_l._wrt)
    else:
        return pl_l + pl_r


def cost(p): return add(exp(p), p)

x = 3.14
p0lin = Linearization(np.array(x))
j = cost(p0lin)
j(2.), j.T(2.)

In [None]:
jax.vjp(lambda p: jax.numpy.exp(p) + p, x)[1](2.)

In [None]:
p0 = 1e-2 * np.arange(0, 9)
ones = np.ones((9, ))
y = exp(p0)
j = exp(exp(Linearization(p0)))
j(p0), j.T(ones)

In [None]:
_, jj_T = jax.vjp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), p0.astype(float))
_, jj_at_p0 = jax.jvp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), (p0.astype(float), ), (p0, ))
jj_at_p0, jj_T(ones)

In [None]:
def weighted_reduction(pl, n=32, n_cols=3):
    if isinstance(pl, Linearization):
        def fwd(*t, tape):
            return (weighted_reduction(*pl(*t, tape=tape)), )
        def bwd(*t, tape):
            assert len(t) == 1
            p_shp = pl.p.reshape(3, -1).shape
            t_T, indices = np.zeros(p_shp), np.arange(n) % p_shp[0]
            for i, idx in enumerate(indices):
                super_expensive_weights = np.ones(p_shp[1:])
                t_T[idx] += t[0][i] * super_expensive_weights
            tape[pl] += t_T
            return pl.T(t_T.reshape(pl.p.shape), tape=tape)
        return Linearization(weighted_reduction(pl.p), fwd, bwd, _wrt=pl._wrt)
    p = pl.reshape(n_cols, -1)
    y, indices = np.zeros((n, )), np.arange(n) % p.shape[0]
    for i, idx in enumerate(indices):
        super_expensive_weights = np.ones(p.shape[1:])
        y[i] = np.sum(p[idx] * super_expensive_weights)
    return y


p0lin = Linearization(np.arange(12, dtype=float))
y2 = weighted_reduction(p0lin.p)
j = weighted_reduction(p0lin)
j(p0lin.p), j.T(y2, )

In [None]:
def f(p): return weighted_reduction(exp(p))

p0 = 1e-2 * np.arange(0, 9)
f0 = f(p0)
f_ones = np.ones(f0.shape)

j = f(Linearization(p0))
j(p0), j.T(f_ones)

In [None]:
def sum(pl):
    if isinstance(pl, Linearization):
        def fwd(*t, tape):
            return (sum(*pl(*t, tape=tape)), )
        def bwd(*t, tape):
            assert len(t) == 1
            tape[pl] = t = t[0] * np.ones(pl.p.shape)
            return pl.T(t, tape=tape)
        return Linearization(sum(pl.p), fwd, bwd, _wrt=pl._wrt)
    return np.sum(pl)

p0 = np.arange(0, 9, dtype=float)
j = sum(Linearization(p0))
j(p0), j.T(1.)

In [None]:
def pow(pl, exponent):
    if isinstance(pl, Linearization) and not isinstance(exponent, Linearization):
        yl = exponent * pow(pl.p, exponent - 1)
        def fwd(*t, tape):
            t = pl(*t, tape=tape)
            assert len(t) == 1
            return (yl * t[0], )
        def bwd(*t, tape):
            assert len(t) == 1
            tape[pl] = t = yl * t[0]
            return pl.T(t, tape=tape)
        return Linearization(pow(pl.p, exponent), fwd, bwd, _wrt=pl._wrt)
    elif not isinstance(pl, Linearization) and not isinstance(exponent, Linearization):
        return pl**exponent
    else:
        raise NotImplementedError()


y = pow(3., 3)
j = pow(Linearization(3.), 3)
j.T(4.), jax.vjp(lambda x: x**3, 3.)[1](4.)

In [None]:
data = 0.

def h(p):
    return add(sum(add(-data, weighted_reduction(exp(p)))), sum(pow(p, 2)))

j = h(Linearization(p0))
j(p0)
j.T(1.)

In [None]:
print(jax.vjp(lambda x: (x + x).sum(), p0)[1](1.))
print((lambda x: sum(add(x, x)))(Linearization(p0)).T(1.))