In [29]:
import jax
import jax.numpy as jnp
import math as mt
from jax import grad



def f(a):
    x = a[0]
    y = a[1]
    p = (x**2)*y + (y**3)*(jnp.sin(x))
    return p


# Our explicit gradient function
def g_f(x,y):
    p = jnp.array([2*x*y + (y**3)*mt.cos(x), x**2 + 3*(y**2)*mt.sin(x)])
    return p

# JAX's grad operator
grad_f = grad(f)

# three different input
input = jnp.array([[0.2, 0.3], [2.4, 3.6], [4.4, 2.1]])

for x in input:
    print("Explicit Gradient Function: ", g_f(x[0],x[1]))
    print("JAX Gradient Function: ", grad_f(x))
    print("")


Explicit Gradient Function:  [0.1464618  0.09364072]
JAX Gradient Function:  [0.1464618  0.09364072]

Explicit Gradient Function:  [-17.123838  32.022003]
JAX Gradient Function:  [-17.123838  32.022003]

Explicit Gradient Function:  [15.633791   6.7703066]
JAX Gradient Function:  [15.633791   6.7703066]

