## Section 0: JAX Readme - https://github.com/google/jax#installation

Installation for Nvidia GPU on Linux X86_64 arch.

Please follow the readme/documentation for information on alternative installation strategies. 

Installing the jax module and jaxlib with cuda support

In [2]:
!pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Looking in links: https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Collecting jax[cuda12_pip]
  Using cached jax-0.4.19-py3-none-any.whl (1.7 MB)
INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while.
  Using cached jax-0.4.18-py3-none-any.whl (1.7 MB)
  Using cached jax-0.4.17-py3-none-any.whl (1.7 MB)
  Using cached jax-0.4.16-py3-none-any.whl (1.6 MB)
INFO: pip is looking at multiple versions of jax[cuda12-pip] to determine which version is compatible with other requirements. This could take a while.
  Using cached jax-0.4.14.tar.gz (1.3 MB)
  Installing build dependencies ... [?25l|^C
[?25canceled
[31mERROR: Operation cancelled by user[0m[31m
[0m

## Section 1: Understanding JAX Basics

1.1 Importing JAX:

In [None]:
import jax
import jax.numpy as jnp

1.2 Checking if JAX is utilizing the GPU:

In [None]:
from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

In [None]:
devices = jax.devices()
for device in devices:
    print(device)

1.3 Number of CUDA devices:

In [None]:
print(jax.devices())

## Section 2: Computation using JAX 

2.1 JAX is designed to be compatible with NumPy, a widely used library for numerical operations:

In [None]:
import numpy as np

# Create a NumPy array
numpy_array = np.array([1, 2, 3])

# Convert it to a JAX array
jax_array = jnp.array(numpy_array)

# Perform operations on the JAX array
result = jax_array * 2

print(result)

2.2 As can be seen above, we can use JAX in a similar way to NumPy, making it easy with NumPy to get started with JAX.

2.3 One of the key features of JAX is its efficient automatic differentiation capabilities. This allows us to compute gradients effortlessly:

In [None]:
from jax import grad

# Define a simple function
def f(x):
    return x**2 + 3*x + 1

# Compute the derivative of f with respect to x
df_dx = grad(f)

# Evaluate the derivative at x = 2
result = df_dx(2)

print(result)

2.4 JAX also allows us to compute higher-order gradients with ease, which is crucial for machine learning algorithms.

2.5 JAX enables efficient vectorized computations, which can greatly accelerate numerical operations.

In [None]:
# Define a function that operates element-wise on an array
def elementwise_func(x):
    return x**2 + 3*x + 1

# Apply the function to an array using JAX
input_array = jnp.array([1, 2, 3, 4])
result = elementwise_func(input_array)

print(result)
