In [None]:
import jax
print(jax.__version__)

0.4.30


What is gradient:   
https://www.youtube.com/watch?v=6zgBUZuC-p8&list=PLg5nrpKdkk2DpW_a-kuHU_FsVPPaU447J

In [None]:
# Automatic Differentiation
# Location 329
# First code example
import jax

"""
Jax's 'grad' function calculates the gradient of a function.
The function takes two arguments:
- the target function to calculate the gradient of
- the index of the argument to calculate the gradient with respect to

The result of this function ('jax.grad') is a new function ('simple_function_grad') that calculates the gradient of the target function with respect to the specified argument.
The result of this new function is the derivative of the target function at the specified point.
"""

# define the target function to verify its gradient
def simple_function(x):
  return 2*x + 9
  # return jax.numpy.sin(x)

# calculate the gradient of the target function
simple_function_grad = jax.grad(simple_function)

# Evaluate the gradient at a specific point - the derivative of the function at that point
result = simple_function_grad(2.0)  # 3.0 is the point at which the gradient is evaluated
print(f"Gradient at x=2.0: {result}") # The derivative of the function at the specified point.

Gradient at x=2.0: 2.0


In [None]:
"""
Matrix Multiplication using XLA
"""
import jax.numpy as jnp
import jax

def mathmul(A, B):
  return (A @ B)

@jax.jit
def mathmul_jit(A, B): 
  return (A @ B)

A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])

# Unoptimized matrix multiplication
C = mathmul(A, B)
print(f"Unoptimized matrix multiplication: {C}")

# XLA optimized matrix multiplication
D = mathmul_jit(A, B)
print(f"XLA optimized matrix multiplication: {D}")


Unoptimized matrix multiplication: [[19 22]
 [43 50]]
XLA optimized matrix multiplication: [[19 22]
 [43 50]]


In [None]:
"""
Timing and comparing functions
"""

from time import perf_counter

def compare(func, n, A, M):

  def time_func(func, n, A, M):
    for i in range(n):
      M = func(A, M)
      #print(f"M: {M}")
    return(func)

  # def time_func(func, n, A, M):
  #   if n > 0:
  #     new_M = func(A, M)
  #     time_func(func, n-1, A, new_M)
  #   return(func)

  t1_start = perf_counter()
  used_func = time_func(func, n, A, B)
  print(f"used_func: {used_func}")
  # Stop the stopwatch / counter
  t1_stop = perf_counter()

  return (t1_stop-t1_start)

In [None]:
A = jnp.array([[1, 2], [3, 4]])
B = jnp.array([[5, 6], [7, 8]])
n = 1000000

unoptimized_function_time = compare(mathmul, n, A, B)
optimized_function_time = compare(mathmul_jit, n, A, B)

print(f"unoptimized_function_time: {unoptimized_function_time} ")
print(f"optimized_function_time: {optimized_function_time} ")

assert(optimized_function_time < unoptimized_function_time)
print(f"Optimization value: {unoptimized_function_time - optimized_function_time}")

used_func: <function mathmul at 0x00000203866D5580>
used_func: <PjitFunction of <function mathmul_jit at 0x00000203866D5620>>
unoptimized_function_time: 4.521084500000143 
optimized_function_time: 3.287428900002851 
Optimization value: 1.2336555999972916


## eye (one-hot vector)

In [None]:
import numpy as np
"""
one-hot vector
Each number in the target vector is converted into 1 that is placed in its value placement, 
while all the rest of the row are zeros.
np.eye(max-value+1)[vector]

"""
v = np.array([1, 4, 2, 1, 0, 1, 3, 2])
np.eye(np.max(v)+1)[v]

array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0.]])

#### one-hot without eye

https://wandb.ai/mostafaibrahim17/ml-articles/reports/One-Hot-Encoding-Creating-a-NumPy-Array-Using-Weights-Biases--Vmlldzo2MzQzNTQ5#:~:text=Array%20in%20NumPy-,To%20generate%20one-hot%20encodings%20for%20an%20array%20in%20NumPy,to%20its%20category%20to%201.




In [None]:
v = np.array([1, 4, 2, 1, 0, 1, 3, 2])
m = np.zeros((v.size, v.max() + 1))
m

array([[0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0.]])

In [None]:
np.arange(v.size)

array([0, 1, 2, 3, 4, 5, 6, 7])

In [None]:
m[np.arange(v.size), v] = 1
m

array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 1., 0., 0.]])

m[np.arange(v.size), v] = 1      
m[ [0, 1, 2, 3, 4, 5, 6, 7], [1, 4, 2, 1, 0, 1, 3, 2] ] = 1      
m[0,1] = 1      
m[1,4] = 1      
m[2,2] = 1

## CH3 Coding Challenge  
Matrix Power and XLA optimization    
(location 419)

In [1]:
import jax
import jax.numpy as jnp
from time import perf_counter

matrix = jnp.array([[2,3], [1,4]])
power = 10

# Calculate matrix power
def matrix_power(_matrix, _power):
  result = jnp.eye(_matrix.shape[0])
  for _ in range(_power):
    result = result @ _matrix
  return result

# Unoptimized matrix power
t1_start = perf_counter()
power_result_unoptimized = matrix_power(matrix, power)
t1_stop = perf_counter()

print(f"power_result_unoptimized: {power_result_unoptimized}")

"""
XLA optimized matrix power using @jax.jit

Due to the following:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
https://jax.readthedocs.io/en/latest/faq.html#faq-different-kinds-of-jax-values
https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError
https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array

If I do the following:
@jax.jit
def matrix_power_jit(matrix, power):
  result = jnp.eye(matrix.shape[0])
  for _ in range(power):
    result = result @ matrix
  return result

power_result_optimized = matrix_power_jit(matrix, 10)

I get the error:
"This concrete value was not available in Python because it depends on the value of the argument power."

So instead of using @jax.jit, I call the same function with the following:
matrix_power_jit = jax.jit(matrix_power, static_argnums=(1,))
power_result_optimized = matrix_power_jit(matrix, power)

This is from:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
"""

# Optimized matrix power

# Define the jit function

# matrix_power_jit = jax.jit(matrix_power, static_argnums=(1,)) # static_argnums=(1,) means that the second argument is static.
#                                                               # trace- and compile-time constant.
#                                                               # https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html

matrix_power_jit = jax.jit(matrix_power, static_argnames=('_power',)) # static_argnames=('_power',) means that the '_power' argument is static.
                                                                      #                 (the name is the target's function arg name).
                                                                      # trace- and compile-time constant.
                                                                      # https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html
# Call the jit function
t2_start = perf_counter()
power_result_optimized = matrix_power_jit(matrix, power)
t2_stop = perf_counter()

print(f"power_result_optimized: {power_result_optimized}")

print(f"Unoptimized matrix power time: {t1_stop - t1_start}")
print(f"Optimized matrix power time: {t2_stop - t2_start}")


# On large power value the complilation of the jit function takes longer than the execution of the unoptimized function.
if power > 100:
  assert(t2_stop - t2_start > t1_stop - t1_start)
  print(f"Optimization value: {t1_stop - t1_start - (t2_stop - t2_start)}")
else:
  assert(t1_stop - t1_start > t2_stop - t2_start)
  print(f"Optimization value: {t1_stop - t1_start - (t2_stop - t2_start)}")


# Call the jit function again
t3_start = perf_counter()
power_result_optimized_2 = matrix_power_jit(matrix, power)
t3_stop = perf_counter()

print(f"power_result_optimized_2: {power_result_optimized_2}")
print(f"Optimized matrix power time: {t3_stop - t3_start}")

assert(t1_stop - t1_start > t3_stop - t3_start)

print(f"Optimization second time: {t1_stop - t1_start - (t3_stop - t3_start)}")

power_result_unoptimized: [[2441407. 7324218.]
 [2441406. 7324219.]]
power_result_optimized: [[2441407. 7324218.]
 [2441406. 7324219.]]
Unoptimized matrix power time: 0.08302760000060516
Optimized matrix power time: 0.03413060000002588
Optimization value: 0.04889700000057928
power_result_optimized_2: [[2441407. 7324218.]
 [2441406. 7324219.]]
Optimized matrix power time: 5.2800000048591755e-05
Optimization second time: 0.08297480000055657
