## Day 001: JAX Installation and Basic Array Operations

In [3]:
import jax
import numpy as np
import jax.numpy as jnp

In [4]:
# 1. Verify JAX installation and backend
print(f"JAX version: {jax.__version__}")
print(f"JAX backend: {jax.default_backend()}")

JAX version: 0.6.1
JAX backend: cpu


In [5]:
# 2. Create your first JAX array
# Similar to NumPy's array creation
my_jax_array = jnp.array([1.0, 2.0, 3.0, 4.0])
print("\nMy first JAX array:")
print(my_jax_array)
print(f"Type: {type(my_jax_array)}")
print(f"Shape: {my_jax_array.shape}")
print(f"Dtype: {my_jax_array.dtype}")


My first JAX array:
[1. 2. 3. 4.]
Type: <class 'jaxlib._jax.ArrayImpl'>
Shape: (4,)
Dtype: float32


In [6]:
# 3. Perform basic arithmetic operations
# JAX operations are just-in-time compiled on first execution
print("\nBasic arithmetic operations:")
sum_result = my_jax_array + 10.0
print(f"Array + 10: {sum_result}")

product_result = my_jax_array * 2.5
print(f"Array * 2.5: {product_result}")

division_result = my_jax_array / 2.0
print(f"Array / 2.0: {division_result}")


Basic arithmetic operations:
Array + 10: [11. 12. 13. 14.]
Array * 2.5: [ 2.5  5.   7.5 10. ]
Array / 2.0: [0.5 1.  1.5 2. ]


In [7]:
# Dot product with another array (similar to jnp.dot or @)
another_array = jnp.array([5.0, 6.0, 7.0, 8.0])
dot_product = jnp.dot(my_jax_array, another_array)
print(f"Dot product: {dot_product}")

# Element-wise multiplication
elementwise_mult = my_jax_array * another_array
print(f"Element-wise multiplication: {elementwise_mult}")

# More complex operation: sum of squares
sum_of_squares = jnp.sum(my_jax_array**2)
print(f"Sum of squares: {sum_of_squares}")

Dot product: 70.0
Element-wise multiplication: [ 5. 12. 21. 32.]
Sum of squares: 30.0


In [8]:
# Types of Matrices in JAX and NumPy
# JAX supports various matrix types similar to NumPy
identity_matrix = jnp.eye(3)  # Identity matrix
print("\nIdentity matrix:")
print(identity_matrix)

# Diagonal matrix
diagonal_matrix = jnp.diag(jnp.array([1, 2, 3]))
print("\nDiagonal matrix:")
print(diagonal_matrix)

# Random matrix generation
random_matrix = jax.random.normal(jax.random.PRNGKey(0), (3, 3))  # 3x3 matrix with normal distribution
print("\nRandom matrix (3x3):")
print(random_matrix)



Identity matrix:
[[1. 0. 0.]
 [0. 1. 0.]
 [0. 0. 1.]]

Diagonal matrix:
[[1 0 0]
 [0 2 0]
 [0 0 3]]

Random matrix (3x3):
[[ 1.6226422   2.0252647  -0.43359444]
 [-0.07861735  0.1760909  -0.97208923]
 [-0.49529874  0.4943786   0.6643493 ]]


In [9]:
# Row Matrix
row_np = np.array([[1, 2, 3]])
row_jax = jnp.array([[1, 2, 3]])

In [10]:
row_np

array([[1, 2, 3]])

In [11]:
row_jax

Array([[1, 2, 3]], dtype=int32)

In [12]:
# Column Matrix
col_np = np.array([[1],[2],[3]])
col_jax = jnp.array([[1],[2],[3]])

col_np

array([[1],
       [2],
       [3]])

In [13]:
col_jax

Array([[1],
       [2],
       [3]], dtype=int32)

In [14]:
# Square Matrix
square_np = np.array([[1, 2], [3, 4]])
square_jax = jnp.array([[1, 2], [3, 4]])

square_np

array([[1, 2],
       [3, 4]])

In [15]:
square_jax

Array([[1, 2],
       [3, 4]], dtype=int32)

In [16]:
# Identity Matrix
identity_np = np.eye(3)
identity_jax = jnp.eye(3)

identity_np

array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]])

In [17]:
identity_jax

Array([[1., 0., 0.],
       [0., 1., 0.],
       [0., 0., 1.]], dtype=float32)

In [18]:
# Zero Matrix
zero_np = np.zeros((3, 3))
zero_jax = jnp.zeros((3, 3))

zero_np

array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]])

In [19]:
zero_jax

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [20]:
# Diagonal Matrix
diagonal_np = np.diag([1, 2, 3])    
diagonal_jax = jnp.diag(jnp.array([1, 2, 3]))

diagonal_np

array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]])

In [21]:
diagonal_jax

Array([[1, 0, 0],
       [0, 2, 0],
       [0, 0, 3]], dtype=int32)

In [22]:
# Upper Triangular Matrix
upper_triangular_np = np.triu(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
upper_triangular_jax = jnp.triu(jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

upper_triangular_np

array([[1, 2, 3],
       [0, 5, 6],
       [0, 0, 9]])

In [23]:
upper_triangular_jax

Array([[1, 2, 3],
       [0, 5, 6],
       [0, 0, 9]], dtype=int32)

In [24]:
# Lower Triangular Matrix
lower_triangular_np = np.tril(np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))
lower_triangular_jax = jnp.tril(jnp.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]))

lower_triangular_np

array([[1, 0, 0],
       [4, 5, 0],
       [7, 8, 9]])

In [25]:
lower_triangular_jax

Array([[1, 0, 0],
       [4, 5, 0],
       [7, 8, 9]], dtype=int32)

In [26]:
# Symmetric Matrix
symmetric_np = np.array([[1, 2, 3], [2, 4, 5], [3, 5, 6]])
symmetric_jax = jnp.array([[1, 2, 3], [2, 4, 5], [3, 5, 6]])

symmetric_np

array([[1, 2, 3],
       [2, 4, 5],
       [3, 5, 6]])

In [27]:
symmetric_jax

Array([[1, 2, 3],
       [2, 4, 5],
       [3, 5, 6]], dtype=int32)

In [28]:
# Skew-Symmetric Matrix
skew_symmetric_np = np.array([[0, -2, 3], [2, 0, -1], [-3, 1, 0]])
skew_symmetric_jax = jnp.array([[0, -2, 3], [2, 0, -1], [-3, 1, 0]])

skew_symmetric_np

array([[ 0, -2,  3],
       [ 2,  0, -1],
       [-3,  1,  0]])

In [29]:
skew_symmetric_jax

Array([[ 0, -2,  3],
       [ 2,  0, -1],
       [-3,  1,  0]], dtype=int32)

In [30]:
a = jnp.zeros((3, 3),dtype=jnp.float32)  # Example of creating a zero matrix in JAX
a

Array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)

In [31]:
b = jnp.arange(6)
print(b)

[0 1 2 3 4 5]


In [32]:
b.__class__
# Automatic Differentiation
# JAX supports automatic differentiation, which is useful for optimization and machine learning tasks.
def mse_loss(predictions, targets):
    """Mean Squared Error Loss Function"""
    return jnp.mean((predictions - targets) ** 2)

# Example usage of the loss function
predictions = jnp.array([0.5, 0.2, 0.8])
targets = jnp.array([0.0, 0.0, 1.0])
loss = mse_loss(predictions, targets)
print("\nMean Squared Error Loss:")
print(loss)


Mean Squared Error Loss:
0.11


In [33]:
# JAX's automatic differentiation
def mse_loss(preds, labels):
    return ((preds - labels) ** 2).mean()

print("\nMean Squared Error Loss Function:",
      mse_loss(jnp.array([1.0, 2.0]),
                jnp.array([0.0, 1.5])))


Mean Squared Error Loss Function: 0.625


In [34]:
# Loss gradient
mse_grad_fn = jax.grad(mse_loss)
print("\nGradient of the MSE loss function with respect to predictions:",
      mse_grad_fn(jnp.array([1.0, 2.0]),
                  jnp.array([0.0, 1.5])))


Gradient of the MSE loss function with respect to predictions: [1.  0.5]


In [35]:
# Function Transformation
exmp_pred = jnp.array([1.0, 2.0])
exmp_labels = jnp.array([0.0, 1.5])

# Show the JAX computation trace (JAXPR)
jax.make_jaxpr(mse_loss)(exmp_pred, exmp_labels)


{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[2][39m b[35m:f32[2][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[2][39m = sub a b
    d[35m:f32[2][39m = integer_pow[y=2] c
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 2.0:f32[]
  [34m[22m[1min [39m[22m[22m(f,) }

In [None]:
# Just-In-Time Compilation (JIT)

# Sample input batch
x_batch = jnp.array([1.0, 2.0, 3.0, 4.0])
y_batch = jnp.array([0.5, 2.5, 2.5, 5.0])

# JIT-compiled MSE Loss function
@jax.jit
def mse_loss(x, y):
    return jnp.mean((x - y) ** 2)

# Run JIT once to trigger compilation
mse_loss(x_batch, y_batch).block_until_ready()

# Benchmark using Jupyter magic (works only in notebooks)
%timeit mse_loss(x_batch, y_batch).block_until_ready()


5.79 μs ± 107 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [41]:
%timeit mse_jit(x_batch, y_batch).block_until_ready()  # Benchmarking the JIT-compiled function

5.9 μs ± 96.6 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [31]:
# Automatic Differentiation
# JAX supports automatic differentiation, which is useful for optimization and machine learning tasks.
def mse_loss(predictions, targets):
    """Mean Squared Error Loss Function"""
    return jnp.mean((predictions - targets) ** 2)

# Example usage of the loss function
predictions = jnp.array([0.5, 0.2, 0.8])
targets = jnp.array([0.0, 0.0, 1.0])
loss = mse_loss(predictions, targets)
print("\nMean Squared Error Loss:")
print(loss)


Mean Squared Error Loss:
0.11


In [34]:
# The loss function can be differentiated using JAX's grad function
from jax import grad
loss_grad = grad(mse_loss)
gradients = loss_grad(predictions, targets)
print("\nGradients of the loss function:")
print(gradients)


Gradients of the loss function:
[ 0.33333334  0.13333334 -0.13333333]
