In [25]:
import jax
import jax.numpy as jnp

From question 4 we calculated the gradients analytically which came out to be:

 grad_x = $2xy + y^{3}cos(x)$

 grad_y = $x^{2} + 3y^{2}sin(x)$


In [28]:
def f(x, y):
    return x**2 * y + y**3 * jnp.sin(x)

grad_f = jax.grad(f, argnums=(0, 1))

In [32]:
# List of points to evaluate
points = [(2.0, 3.0), (5.0, 6.0)]

# Evaluate gradients at each point using JAX
for x_, y_ in points:
    grad_x, grad_y = grad_f(x_, y_)

# Analytical gradient computation
analytical_grad_x = 2 * x_ * y_ + y_**3 * jnp.cos(x_)
analytical_grad_y = x_**2 + 3 * y_**2 * jnp.sin(x_)

print(f"Gradient with respect to x (JAX): {grad_x}")
print(f"Gradient with respect to y (JAX): {grad_y}")

print(f"Analytical gradient with respect to x: {analytical_grad_x}")
print(f"Analytical gradient with respect to y: {analytical_grad_y}")



Gradient with respect to x (JAX): 121.27103424072266
Gradient with respect to y (JAX): -78.56382751464844
Analytical gradient with respect to x: 121.27103424072266
Analytical gradient with respect to y: -78.56382751464844
