## Section 1: Computation using JAX 

1.1 JAX is designed to be compatible with NumPy, a widely used library for numerical operations:

In [None]:
import numpy as np
import jax
import jax.numpy as jnp

# Create a NumPy array
numpy_array = np.array([1, 2, 3])

# Convert it to a JAX array
jax_array = jnp.array(numpy_array)

# Perform operations on the JAX array
result = jax_array * 2

print(result)

1.2 As can be seen above, we can use JAX in a similar way to NumPy, making it easy with NumPy to get started with JAX.

1.3 One of the key features of JAX is its efficient automatic differentiation capabilities. This allows us to compute gradients effortlessly:

In [None]:
from jax import grad

# Define a simple function
def f(x):
    return x**2 + 3*x + 1

# Compute the derivative of f with respect to x
df_dx = grad(f)

# Evaluate the derivative at x = 2
result = df_dx(2)

print(result)

1.4 JAX also allows us to compute higher-order gradients with ease, which is crucial for machine learning algorithms.

1.5 JAX enables efficient vectorized computations, which can greatly accelerate numerical operations.

In [None]:
# Define a function that operates element-wise on an array
def elementwise_func(x):
    return x**2 + 3*x + 1

# Apply the function to an array using JAX
input_array = jnp.array([1, 2, 3, 4])
result = elementwise_func(input_array)

print(result)
