In [1]:
import jax

print(jax.__version__)

0.4.30


In [2]:
# create jax array/vector
x = jax.numpy.array([1, 2, 3])

# apply sine function to the array
y = jax.numpy.sin(x) + 2
print(y)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[2.841471  2.9092975 2.14112  ]


In [3]:
# Create jax matrix
M = jax.numpy.array([[1, 2], [3, 4]])
M_sqrt = jax.numpy.sqrt(M)
print(M)
print(M_sqrt)

[[1 2]
 [3 4]]
[[1.        1.4142135]
 [1.7320508 2.       ]]


In [4]:
# Code challenge, location: 283
def array_manipulation_challenge(A):
  A_sqr = jax.numpy.square(A)
  A_sqr_cumulative_sum = jax.numpy.cumsum(A_sqr)
  A_sqr_cumulative_sum_mean = jax.numpy.mean(A_sqr_cumulative_sum)
  return A_sqr, A_sqr_cumulative_sum, A_sqr_cumulative_sum_mean

A = jax.numpy.array([1, 2, 3, 4, 5])
A_sqr, A_sqr_cumulative_sum, A_sqr_cumulative_sum_mean = array_manipulation_challenge(A)

print(f"A: {A}")
print(f"A_sqr: {A_sqr}")
print(f"A_sqr_cumulative_sum: {A_sqr_cumulative_sum}")
print(f"A_sqr_cumulative_sum_mean: {A_sqr_cumulative_sum_mean}")

A: [1 2 3 4 5]
A_sqr: [ 1  4  9 16 25]
A_sqr_cumulative_sum: [ 1  5 14 30 55]
A_sqr_cumulative_sum_mean: 21.0


What is gradient:   
https://www.youtube.com/watch?v=6zgBUZuC-p8&list=PLg5nrpKdkk2DpW_a-kuHU_FsVPPaU447J

In [5]:
# Automatic Differentiation
# Location 329
# First code example
import jax

"""
Jax's 'grad' function calculates the gradient of a function.
The function takes two arguments:
- the target function to calculate the gradient of
- the index of the argument to calculate the gradient with respect to

The result of this function ('jax.grad') is a new function ('simple_function_grad') that calculates the gradient of the target function with respect to the specified argument.
The result of this new function is the derivative of the target function at the specified point.
"""

# define the target function to verify its gradient
def simple_function(x):
  return 2*x + 9
  # return jax.numpy.sin(x)

# calculate the gradient of the target function
simple_function_grad = jax.grad(simple_function)

# Evaluate the gradient at a specific point - the derivative of the function at that point
result = simple_function_grad(2.0)  # 3.0 is the point at which the gradient is evaluated
print(f"Gradient at x=2.0: {result}") # The derivative of the function at the specified point.

Gradient at x=2.0: 2.0


In [6]:
"""
Matrix Multiplication using XLA
"""
import jax.numpy as jnp
import jax

def mathmul(A, B):
  return (A @ B)

@jax.jit
def mathmul_jit(A, B): 
  return (A @ B)

A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

# Unoptimized matrix multiplication
C = mathmul(A, B)
print(f"Unoptimized matrix multiplication: {C}")

# XLA optimized matrix multiplication
D = mathmul_jit(A, B)
print(f"XLA optimized matrix multiplication: {D}")


Unoptimized matrix multiplication: [[19 22]
 [43 50]]
XLA optimized matrix multiplication: [[19 22]
 [43 50]]


In [23]:
"""
Timing and comparing functions
"""

from time import perf_counter

def compare(func, n, A, M):

  def time_func(func, n, A, M):
    for i in range(n):
      M = func(A, M)
      #print(f"M: {M}")
    return(func)

  # def time_func(func, n, A, M):
  #   if n > 0:
  #     new_M = func(A, M)
  #     time_func(func, n-1, A, new_M)
  #   return(func)

  t1_start = perf_counter()
  used_func = time_func(func, n, A, B)
  print(f"used_func: {used_func}")
  # Stop the stopwatch / counter
  t1_stop = perf_counter()

  return (t1_stop-t1_start)

In [25]:
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
n = 1000000

unoptimized_function_time = compare(mathmul, n, A, B)
optimized_function_time = compare(mathmul_jit, n, A, B)

print(f"unoptimized_function_time: {unoptimized_function_time} ")
print(f"optimized_function_time: {optimized_function_time} ")

assert(optimized_function_time < unoptimized_function_time)
print(f"Optimization value: {unoptimized_function_time - optimized_function_time}")

used_func: <function mathmul at 0x7faa746222a0>
used_func: <PjitFunction of <function mathmul_jit at 0x7faa74622520>>
unoptimized_function_time: 3.3267515279999316 
optimized_function_time: 2.424925336999877 
Optimization value: 0.9018261910000547
