In [5]:
import jax
import jax.numpy as jnp
import sympy as sp

# Analytical Gradient
def analytical_gradient(x, y):
    grad_x = 2*x*y + y**3 * jnp.cos(x)
    grad_y = x**2 + 3*y**2 * jnp.sin(x)
    return (grad_x, grad_y)

# Using JAX to calculate the gradient
def f(x, y):
    return x**2 * y + y**3 * jnp.sin(x)

grad_f = jax.grad(f, argnums=(0, 1))

# List of points to evaluate
points = [(8, 0), (3, 5), (7, 4), (5, 1), (3, 4), (1, 6)]

# Evaluate gradients at each point
for x_val, y_val in points:
    grad_values_jax = grad_f(float(x_val), float(y_val))
    grad_values_analytical = analytical_gradient(x_val, y_val)

    # Using SymPy to confirm the gradient analytically
    x, y = sp.symbols('x y')
    f_sympy = x**2 * y + y**3 * sp.sin(x)
    grad_f_sympy = [sp.diff(f_sympy, var) for var in (x, y)]
    grad_values_sympy = [g.evalf(subs={x: x_val, y: y_val}) for g in grad_f_sympy]

    print(f"Point (x={x_val}, y={y_val}):")
    print(f"Gradient using JAX wrt x: {grad_values_jax[0]}")
    print(f"Gradient using JAX wrt y: {grad_values_jax[1]}")
    print(f"Analytical Gradient wrt x: {grad_values_analytical[0]}")
    print(f"Analytical Gradient wrt y: {grad_values_analytical[1]}")
    print(f"Gradient using SymPy wrt x: {grad_values_sympy[0]}\n")
    print(f"Gradient using SymPy wrt y: {grad_values_sympy[1]}\n")


Point (x=8, y=0):
Gradient using JAX wrt x: 0.0
Gradient using JAX wrt y: 64.0
Analytical Gradient wrt x: 0.0
Analytical Gradient wrt y: 64.0
Gradient using SymPy wrt x: 0

Gradient using SymPy wrt y: 64.0000000000000

Point (x=3, y=5):
Gradient using JAX wrt x: -93.74906158447266
Gradient using JAX wrt y: 19.583999633789062
Analytical Gradient wrt x: -93.74906158447266
Analytical Gradient wrt y: 19.583999633789062
Gradient using SymPy wrt x: -93.7490620750557

Gradient using SymPy wrt y: 19.5840006044900

Point (x=7, y=4):
Gradient using JAX wrt x: 104.24974060058594
Gradient using JAX wrt y: 80.53535461425781
Analytical Gradient wrt x: 104.24974060058594
Analytical Gradient wrt y: 80.53535461425781
Gradient using SymPy wrt x: 104.249744277972

Gradient using SymPy wrt y: 80.5353567385019

Point (x=5, y=1):
Gradient using JAX wrt x: 10.283661842346191
Gradient using JAX wrt y: 22.123226165771484
Analytical Gradient wrt x: 10.283661842346191
Analytical Gradient wrt y: 22.12322616577148