In [None]:
import numpy as np
import jax

In [None]:
class Linearization():
    def __init__(self, p, fwd=lambda t: t, bwd=lambda t: t):
        self.p = p
        self._fwd, self._bwd = fwd, bwd

    def __call__(self, t):
        return self._fwd(t)

    @property
    def T(self):
        return self.__class__(None, self._bwd, self._fwd)

    def __repr__(self):
        return f"{self.__class__.__name__}({self.p}, {self._fwd}, {self._bwd})"


def exp(pl):
    if isinstance(pl, Linearization):
        y = exp(pl.p)
        def fwd(t): return y * pl(t)
        def bwd(t): return pl.T(y * t)
        return Linearization(y, fwd, bwd)
    return np.exp(pl)

In [None]:
x0 = 1e-2 * np.arange(0, 9)
ones = np.ones((9, ))
y = exp(x0)
j = exp(exp(Linearization(x0)))
print(y)
print(j.T(ones))

In [None]:
_, jj_T = jax.vjp(lambda x: jax.numpy.exp(jax.numpy.exp(x)), x0.astype(float))
jj_T(ones)

In [None]:
def weighted_reduction(pl, n=32, n_cols=3):
    if isinstance(pl, Linearization):
        def fwd(t): return weighted_reduction(pl(t))
        def bwd(t):
            p_shp = pl.p.reshape(3, -1).shape
            x_T, indices = np.zeros(p_shp), np.arange(n) % p_shp[0]
            for i, idx in enumerate(indices):
                super_expensive_weights = np.ones(p_shp[1:])
                x_T[idx] += t[i] * super_expensive_weights
            return pl.T(x_T.reshape(pl.p.shape))
        return Linearization(weighted_reduction(pl.p), fwd, bwd)
    p = pl.reshape(n_cols, -1)
    y, indices = np.zeros((n, )), np.arange(n) % p.shape[0]
    for i, idx in enumerate(indices):
        super_expensive_weights = np.ones(p.shape[1:])
        y[i] = np.sum(p[idx] * super_expensive_weights)
    return y

In [None]:
y2 = weighted_reduction(y)
j = weighted_reduction(Linearization(y))
j(y)

In [None]:
def f(x): return weighted_reduction(exp(x))

x0 = 1e-2 * np.arange(0, 9)
f0 = f(x0)
f_ones = np.ones(f0.shape)

j = f(Linearization(x0))
j(x0), j.T(f_ones)

In [None]:
def sum(pl):
    if isinstance(pl, Linearization):
        y = sum(pl.p)
        def fwd(t): return sum(pl(t))
        def bwd(t): return pl.T(t * np.ones(pl.p.shape))
        return Linearization(y, fwd, bwd)
    return np.sum(pl)


sum(j).T(1.)