In [16]:
## Section 1: Understanding JAX Basics

# 1.1 Importing JAX:

In [1]:
import jax
import jax.numpy as jnp

In [None]:
# JAX is designed to be compatible with NumPy, a widely used library for numerical operations.

In [2]:
import numpy as np

# Create a NumPy array
numpy_array = np.array([1, 2, 3])

# Convert it to a JAX array
jax_array = jnp.array(numpy_array)

# Perform operations on the JAX array
result = jax_array * 2

print(result)

[2 4 6]


In [None]:
# As can be seen above, we can use JAX in a similar way to NumPy, making it easy with NumPy to get started with JAX.

In [4]:
# Lets check whether JAX is using the CPU or the GPU:
devices = jax.devices()
for device in devices:
    print(device)

TFRT_CPU_0


In [5]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

cpu


In [None]:
# One of the key features of JAX is its efficient automatic differentiation capabilities. This allows us to compute gradients effortlessly.

In [None]:
from jax import grad

# Define a simple function
def f(x):
    return x**2 + 3*x + 1

# Compute the derivative of f with respect to x
df_dx = grad(f)

# Evaluate the derivative at x = 2
result = df_dx(2)

print(result)

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int32. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

In [None]:
# JAX also allows us to compute higher-order gradients with ease, which is crucial for machine learning algorithms.

In [None]:
# JAX enables efficient vectorized computations, which can greatly accelerate numerical operations.

In [None]:
# Define a function that operates element-wise on an array
def elementwise_func(x):
    return x**2 + 3*x + 1

# Apply the function to an array using JAX
input_array = jnp.array([1, 2, 3, 4])
result = elementwise_func(input_array)

print(result)


[ 5 11 19 29]
