In [None]:
class Linearization():
    def __init__(self, primals, fwd=lambda x: x, bwd=lambda y: y):
        self.p = primals
        self._fwd, self._bwd = fwd, bwd

    def __call__(self, primals):
        return self._fwd(primals)

    @property
    def T(self):
        return self.__class__(None, self._bwd, self._fwd)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.p}, {self._fwd}, {self._bwd})"


def exp(pl):
    if isinstance(pl, Linearization):
        y = exp(pl.p)
        def fwd(x): return y * pl(x)
        def bwd(x): return pl.T(x) * y
        return Linearization(y, fwd, bwd)
    return np.exp(pl)

In [None]:
import numpy as np

x0 = 1e-2 * np.arange(0, 9).reshape(3, 3)
ones = np.ones((3, 3))
y = exp(x0)
j = exp(exp(Linearization(x0)))
print(y)
print(j.T(ones))

In [None]:
import jax

_, jj_T = jax.vjp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), x0.astype(float))
jj_T(ones)

In [None]:
def weighted_reduction(pl):
    n = 32
    if isinstance(pl, Linearization):
        y = weighted_reduction(pl.p)
        def fwd(x): return weighted_reduction(pl(x))
        def bwd(x):
            indices = np.arange(n) % pl.p.shape[0]
            x_T = np.zeros(pl.p.shape)
            for i, idx in enumerate(indices):
                super_expensive_weights = np.ones(pl.p.shape[1:])
                x_T[idx] += x[i] * super_expensive_weights
            return pl.T(x_T)
        return Linearization(y, fwd, bwd)
    indices = np.arange(n) % pl.shape[0]
    y = np.zeros((n, ))
    for i, idx in enumerate(indices):
        super_expensive_weights = np.ones(pl.shape[1:])
        y[i] = np.sum(pl[idx] * super_expensive_weights)
    return y

In [None]:
y2 = weighted_reduction(y)
j = weighted_reduction(Linearization(y))
j.T(y2)

In [None]:
def f(x): return weighted_reduction(exp(x))

x0 = 1e-2 * np.arange(0, 9).reshape(3, 3)
f0 = f(x0)
f_ones = np.ones(f0.shape)

j = f(Linearization(x0))
j(x0), j.T(f_ones)

In [None]:
def sum(pl):
    if isinstance(pl, Linearization):
        y = sum(pl.p)
        def fwd(x): return sum(pl(x))
        def bwd(x): return pl.T(x * np.ones(pl.p.shape))
        return Linearization(y, fwd, bwd)
    return np.sum(pl)


sum(j).T(1.)