# Autodiff of an implicit function

In [25]:
import jax.numpy as jnp
from jax import grad

def f(E, e, M):
    return E - e*jnp.sin(E) - M

dfdE = grad(f, argnums=0)
dfde = grad(f, argnums=1)
dfdM = grad(f, argnums=2)


In [None]:
from functools import partial
from jax import custom_vjp
from jax.lax import while_loop


@custom_vjp
def eccentric_anomaly_newton_raphson(e, M, Eini):
    def cond_fun(carry):
        E_prev, E = carry
        return jnp.abs(E - E_prev) > 1e-10

    def body_fun(carry):
        _, E = carry
        E_new = E - f(E, e, M) / dfdE(E, e, M)
        return E, E_new

    _, E_star = while_loop(cond_fun, body_fun, (Eini, 0.0))
    return E_star


def eccentric_anomaly_newton_raphson_fwd(e, M, Eini):
    E_star = eccentric_anomaly_newton_raphson(e, M, Eini)
    return E_star, (jnp.sin(E_star), 1.0 - e * jnp.cos(E_star))


def eccentric_anomaly_newton_raphson_bwd(residuals, u):
    sin_E_star, one_minus_cos_E_star = residuals
    return (sin_E_star * u / one_minus_cos_E_star, u / one_minus_cos_E_star, 0.0)


eccentric_anomaly_newton_raphson.defvjp(eccentric_anomaly_newton_raphson_fwd, eccentric_anomaly_newton_raphson_bwd)

In [41]:
etest = 0.1
Mtest = 0.2
Eini = 0.3
Estar = eccentric_anomaly_newton_raphson(etest, Mtest, Eini)
print("E_star =", Estar)
dE_de = grad(eccentric_anomaly_newton_raphson, argnums=0)
dE_dM = grad(eccentric_anomaly_newton_raphson, argnums=1)


print(dE_de(etest, Mtest, Eini))
print(dE_dM(etest, Mtest, Eini))

E_star = 0.22202006
0.24400184
1.1080891


In [26]:
f(Estar, etest, Mtest)

Array(0., dtype=float32, weak_type=True)