In [None]:
# J - JIT compilation - Just In Time
# A - Automatic differentiation
# X - XLA (Accelerated linear algebra)

# JAX as NumPy

import jax
import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

print("Array a:", a)
print("Array b:", b)
print("Sum of a and b:", a + b)
print("Dot product of a and b:", jnp.dot(a, b))
print("Element-wise multiplication of a and b:", a * b)
print("Sine of a:", jnp.sin(a))
print("Exponential of b:", jnp.exp(b))
print("Mean of a:", jnp.mean(a))
print("Standard deviation of b:", jnp.std(b))
print("Reshaped a to (3,1):", a.reshape((3, 1)))
print("Transpose of b reshaped to (3,1):", b.reshape((3, 1)).T)
print("Stacked arrays a and b vertically:\n", jnp.vstack((a, b)))
print("Stacked arrays a and b horizontally:\n", jnp.hstack((a, b)))
print("Concatenated arrays a and b:", jnp.concatenate((a, b)))
print("Maximum value in a:", jnp.max(a))
print("Minimum value in b:", jnp.min(b))    
print("Sum of all elements in a:", jnp.sum(a))
print("Cumulative sum of a:", jnp.cumsum(a))
print("Unique elements in b:", jnp.unique(b))
print("Sorted a:", jnp.sort(a))
print("Where a > 2:", jnp.where(a > 2))
import numpy as np
print("Convert JAX array a to NumPy array:", np.array(a))   
print("Convert NumPy array back to JAX array:", jnp.array(np.array(a)))
# This script demonstrates basic usage of JAX as a NumPy replacement.
# It covers array creation, arithmetic operations, mathematical functions,
# statistical functions, reshaping, stacking, concatenation, and conversion
# between JAX arrays and NumPy arrays.
# JAX is designed for high-performance numerical computing and can
# leverage GPU/TPU acceleration, automatic differentiation, and JIT compilation.
# It provides a NumPy-like API for ease of use.
# JAX arrays are immutable, meaning that operations on them return new arrays
# rather than modifying the original arrays in place.
# Example: a[1] = 10.0  (This will raise an error)

Array a: [1. 2. 3.]
Array b: [4. 5. 6.]
Sum of a and b: [5. 7. 9.]
Dot product of a and b: 32.0
Element-wise multiplication of a and b: [ 4. 10. 18.]
Sine of a: [0.84147096 0.9092974  0.14112   ]
Exponential of b: [ 54.598152 148.41316  403.4288  ]
Mean of a: 2.0
Standard deviation of b: 0.8164966
Reshaped a to (3,1): [[1.]
 [2.]
 [3.]]
Transpose of b reshaped to (3,1): [[4. 5. 6.]]
Stacked arrays a and b vertically:
 [[1. 2. 3.]
 [4. 5. 6.]]
Stacked arrays a and b horizontally:
 [1. 2. 3. 4. 5. 6.]
Concatenated arrays a and b: [1. 2. 3. 4. 5. 6.]
Maximum value in a: 3.0
Minimum value in b: 4.0
Sum of all elements in a: 6.0
Cumulative sum of a: [1. 3. 6.]
Unique elements in b: [4. 5. 6.]
Sorted a: [1. 2. 3.]
Where a > 2: (Array([2], dtype=int32),)
Convert JAX array a to NumPy array: [1. 2. 3.]
Convert NumPy array back to JAX array: [1. 2. 3.]


In [8]:
# JIT Compilation

import time

@jax.jit
def collatz(x):
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

arr = jnp.arange(1, 1000001)

#you can warm up the JIT compiler by running the function once before timing    
_ = collatz(arr).block_until_ready()

start = time.time()
result = collatz(arr).block_until_ready()
end = time.time()
print("Time taken for JIT-compiled Collatz computation:", end - start, "seconds")

# JAX has asynchronous execution by default.
# To ensure that all computations are complete before measuring time, jax.block_until_ready() can be used, here: result.block_until_ready()

# jaxpr - JAX's Intermediate Representation. You can inspect the jaxpr of a JIT-compiled function to see the low-level operations that JAX generates.

print("JAXPR for the JIT-compiled Collatz function:")
print(jax.make_jaxpr(collatz)(arr))

Time taken for JIT-compiled Collatz computation: 0.0010349750518798828 seconds
JAXPR for the JIT-compiled Collatz function:
let _where = { [34;1mlambda [39;22m; a[35m:bool[1000000][39m b[35m:i32[1000000][39m c[35m:i32[1000000][39m. [34;1mlet
    [39;22md[35m:i32[1000000][39m = select_n a c b
  [34;1min [39;22m(d,) } in
{ [34;1mlambda [39;22m; e[35m:i32[1000000][39m. [34;1mlet
    [39;22mf[35m:i32[1000000][39m = pjit[
      name=collatz
      jaxpr={ [34;1mlambda [39;22m; e[35m:i32[1000000][39m. [34;1mlet
          [39;22mg[35m:i32[1000000][39m = pjit[
            name=remainder
            jaxpr={ [34;1mlambda [39;22m; e[35m:i32[1000000][39m h[35m:i32[][39m. [34;1mlet
                [39;22mi[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] h
                j[35m:bool[][39m = eq i 0:i32[]
                k[35m:i32[][39m = pjit[
                  name=_where
                  jaxpr={ [34;1mlambda [39;22m; j[35m:bool[]