In [1]:
import jax, jax.numpy as jnp
import jax.profiler as profiler

In [43]:
def ps(x, q=0.7) :
    # R^n \to \R^n
    print('Compile')
    return x**q 

def agg(x) : 
    # R^n \to \R
    return (x**2).sum()

In [44]:
psj = jax.checkpoint(jax.jit(ps))

### Gradient initial condition

In [45]:
def loss_fn(x, ps_fun, agg_fun, n_iter=5) : 
    for i in range(n_iter) :
        x = ps_fun(x)
    return agg_fun(x)

In [46]:
d = 3
x = jnp.arange(1, d+1, 1, dtype='float32')

In [47]:
loss_and_grad = jax.value_and_grad(loss_fn, argnums=0)
loss, grad = loss_and_grad(x, psj, agg, n_iter=10)
jax.block_until_ready((loss, grad));

Compile


In [48]:
grad

Array([0.05649504, 0.02937562, 0.02003752], dtype=float32)

### Correction setup 

In [54]:
def corr(x, theta) : 
    return x * theta

In [55]:
def loss_theta_fn(x, theta, ps_fun, agg_fun, corr_fun, n_iter=5) : 
    for i in range(n_iter) :
        x = ps_fun(x)
        x = corr_fun(x, theta)
    return agg_fun(x)

In [56]:
theta = jnp.array([0.2])
loss_theta_fn(x, theta, psj, agg, corr)

Array(0.00049276, dtype=float32)

In [57]:
loss_and_grad = jax.value_and_grad(loss_theta_fn, argnums=1)
loss, grad = loss_and_grad(x, theta, psj, agg, corr, n_iter=10)

In [58]:
grad

Array([0.00297975], dtype=float32)