In [1]:
# import ceviche_challenges
# from ceviche_challenges import units as u
# from ceviche_challenges.model_base import _wavelengths_nm_to_omegas

from ceviche import viz, fdfd_ez
from ceviche import jacobian

import autograd
import autograd.numpy as npa
import jax
import jax.numpy as jnp

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

# Autograd version

In [22]:
def f_ad(x):
    return npa.cos(x), npa.sin(x)

def g_ad(x1, x2):
    return x1 + x2

def g2_ad(x1, x2):
    return npa.square(x1) + npa.square(x2)

# def g2_ad(x):
#     # return npa.sqrt(npa.sum(npa.square(npa.abs(x))))
#     return npa.sum(npa.square(x))

def h_ad(x):
    return g_ad(*f_ad(x))

# Jax version

In [23]:
def f_jax(x):
    return jnp.cos(x), jnp.sin(x)

def g_jax(x1, x2):
    return x1 + x2

def g2_jax(x1, x2):
    return jnp.square(x1) + jnp.square(x2)

def h_jax(x):
    return g_jax(*f_jax(x))

In [24]:
x = np.random.rand(3,3)
xbis = np.random.rand(3,3)

x_ad = npa.array(x)
xbis_ad = npa.array(xbis)

x_jax = jnp.array(x)
xbis_jax = jnp.array(xbis)

# g Comparison

In [25]:
y_jax = h_jax(x_jax)

Jac_G = jax.jacfwd(g_jax)(x_jax, xbis_jax)
Jac_G2 = jax.jacfwd(g2_jax)(x_jax, xbis_jax)

In [26]:
y_ad = h_ad(x_ad)

Jac_ad_G = autograd.jacobian(g_ad)(x_ad, xbis_ad)
Jac_ad_G2 = autograd.jacobian(g2_ad)(x_ad, xbis_ad)

In [27]:
assert (Jac_G == Jac_ad_G).all()
assert (Jac_G2 == Jac_ad_G2).all()

# h Comparison

In [32]:
Jac_H = jax.jacfwd(h_jax)(x_jax)
Jac_ad_H = autograd.jacobian(h_ad)(x_ad)

In [41]:
assert ((Jac_H - Jac_ad_H) / np.linalg.norm(Jac_H)).max() < 1e-7