In [1]:
from folx.api import *

In [23]:
def sum_fwd_lap(x: FwdLaplArray, dependencies, n_el: int) -> FwdLaplArray:
    out_jac = jnp.zeros((n_el, 3, *x.shape[1:]), x.x.dtype)
    jac = x.jacobian.data
    jac = jac.reshape((n_el, 3, *jac.shape[1:]))
    jac = jnp.swapaxes(jac, 1, 2)
    out_jac = out_jac.at[dependencies.T].add(jac)
    out_jac = out_jac.reshape(n_el * 3, *x.shape[1:])
    y = x.x.sum(0)
    y_lapl = x.laplacian.sum(0)
    return FwdLaplArray(x=y, jacobian=FwdJacobian(out_jac), laplacian=y_lapl)


In [28]:
dependencies = jnp.array([
    [0, 1, 2, 3],
    [0, 1, 2, 3],
    [0, 1, 2, 3],
])
x = FwdLaplArray(
    x=jnp.array([1., 2., 3.]),
    jacobian=FwdJacobian(np.random.normal(size=(12, 3))),
    laplacian=jnp.array([1, 1, 1])
)

In [32]:
x_hat = FwdLaplArray(
    x=x.x,
    jacobian=FwdJacobian(
        x.jacobian.data,
        x0_idx=np.array([
            [0, 0, 0, 0],
            [1, 1, 1, 1],
            [2, 2, 2, 2],
            [3, 3, 3, 3],
            [4, 4, 4, 4],
            [5, 5, 5, 5],
            [6, 6, 6, 6],
            [7, 7, 7, 7],
            [8, 8, 8, 8],
            [9, 9, 9, 9],
            [10, 10, 10, 10],
            [11, 11, 11, 11],
        ])[:, :3]
    ),
    laplacian=x.laplacian
)

In [33]:
sum_fwd_lap(x, dependencies, 4).jacobian.data

Array([-3.8336732 , -1.2847297 , -0.3524232 ,  0.8147084 ,  0.9420111 ,
       -2.8378847 , -1.7232125 , -1.0635946 , -2.4166512 , -0.16242611,
       -0.38889828,  1.3375041 ], dtype=float32)

In [34]:
from folx import forward_laplacian

In [36]:
forward_laplacian(lambda x: x.sum(0), disable_jit=True)(x_hat).jacobian.data

Array([-3.8336732 , -1.2847297 , -0.3524232 ,  0.8147084 ,  0.9420111 ,
       -2.8378847 , -1.7232125 , -1.0635946 , -2.4166515 , -0.16242611,
       -0.38889828,  1.3375041 ], dtype=float32)