<a href="https://colab.research.google.com/github/anhquan-truong/PM520/blob/main/Lab_1_Introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PM520, Lab 1: Introduction to Python
For an excellent review of Python, please check out the excellent resource [Learn Python in Y Minutes](https://learnxinyminutes.com/docs/python/).

## 1. JAX and JAX.Numpy
[JAX](https://github.com/google/jax) is a Google-backed library to enable automatic differentiation of Python code, while supporting ultra-fast runtime due to "Just-In-Time" (i.e. JIT) compilation from their custom bytecode (i.e. XLA). Hence JAX = JIT + AutoDiff + XLA.

However, before we can use JAX we'll need to install it. To do that, we'll use the Python package management tool `pip`. To call `pip` in colab (or any terminal function) we need to prepend the command with a `!`.


In [None]:
!pip install jax



Let's practice importing JAX and using the numpy implementation backed by JAX. numpy is a Python library for n-dimensional arrays. Here we are using JAX's implementation, which will enable us to take advantage of all of JAX's features.

In [11]:
import jax
import jax.numpy as jnp # submodule - numpy using by jax
import jax.random as rdm #pseudorandom number submodule in jax

# let's practice some numpy tricks
x = jnp.arange(9) #assignment is '=', only way to assign things. Intergers - can be represented exactly

y = jnp.ones(9) # there is period after the number. Floating point (1.0, 1.1, etc) - cannot represent number exactly

print(f"x = {x} | y = {y}")
#`f` is the f-string. Formatted string. Put a f before any string, the value inside the {} -curly brackets - get the value of the variables or get evaluated

# x and y are arrays, 9 elements
z = x + y
print(f"z = {z} | x + 1 = {x + 1}")
# if you add a floating to an interger --> floating (z in this case)
# Broadcasting in python - like adding a scalar to a vector


P = 4
i = jnp.eye(P) # identity matrix
a = 2 * jnp.ones(P)
print(f"i = {i} | a = {a}")

# is this mat/vec mult?
b = i * a # this isnot mat/vec mult because this is broadcasting a to a 4x4 vector
print(f"b = {b}")

# We can manually contruct an array with `array`
A = jnp.array([[5., 1], [1, 5]]) # this is a list, 2 elements
a = jnp.array([3, 4]) # list of 3 and 4
print(f"A = {A}")
b = A * a # 3*5 = 15, 4*1 = 4 - first row. 3*1 = 3, 4*5 = 20 - 2nd row. Row wise --> broadcasting multiplication
print(f"b = {b}")
b = a * A
print(f"b = {b}")

# nope! b is matrix; mat/vec mult => vec
b = A @ a # `@` symbol is mat/mat or mat/vec mult
print(f"b = {b}")
b = jnp.dot(A, a) # dot product
print(f"b = {b}")

x = [0 1 2 3 4 5 6 7 8] | y = [1. 1. 1. 1. 1. 1. 1. 1. 1.]
z = [1. 2. 3. 4. 5. 6. 7. 8. 9.] | x + 1 = [1 2 3 4 5 6 7 8 9]
i = [[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]] | a = [2. 2. 2. 2.]
b = [[2. 0. 0. 0.]
 [0. 2. 0. 0.]
 [0. 0. 2. 0.]
 [0. 0. 0. 2.]]
A = [[5. 1.]
 [1. 5.]]
b = [[15.  4.]
 [ 3. 20.]]
b = [[15.  4.]
 [ 3. 20.]]
b = [19. 23.]
b = [19. 23.]


## 2. Indexing, broadcasting rules, and dot products

Let's practice how to index, slice, and broadcast jax arrays.

In [37]:
# indexing arrays
shape = (3,3) # tuple - like a list but you cannot add anything to it. Immutable

# list is [] | Tuple ()
X = jnp.arange(9).reshape(shape) # function reshape bound to jnp. arrange (instance)
#reshape to change the shape of X
#fast allocation

# what is the shape of x?
print(f"shape(x) = {X.shape}")

# indexing and 'slicing'
print(f"X = {X}")
print(f"1st row of x {X[0]}") # index the 1st thing starts with 0 (zero)
print(f"1st row of x {X[0,:]}") # a comma, additional axis and the : means grab everything --> give the 1st row and give me everything in the 1st row
print(f"1st col of x {X[:,0]}") # for the 1st dimension, give me everything (all rows) and 2nd dimension gives 1st index (1st col)

# arrays can be n-dimensional and not just vectors/matrices!
X = jnp.arange(27).reshape((3,3,3)) #'tensor' like - multidimensional array
# same shape of arrays (nxn) --> batch dimension. You can see this is 3 dimensions arrays. (3 brackets [[[]]])
print(f"shape(x) = {X.shape}")
print(f"X = {X}")
print(f"1st matrix of x {X[0]}") # this is the first matrix
print(f"1st matrix of x {X[0,:]}") # this is the first matrix and everything in the first matrix (array)

print(f"1st row of each matrix of x {X[:,0,:]}") # take all matrix, then take the first rows, then take the everything in the first rows
#1st colon - 1st dimension, 0 - 2nd dimension and 2rd colon - 3rd dimension

# arrays can be n-dimensional and not just vectors/matrices!
X = jnp.arange(81).reshape((3,3,3,3)) #'tensor' like - multidimensional array
# same shape of arrays (nxn) --> batch dimension. You can see this is 3 dimensions arrays. (3 brackets [[[]]])
print(f"shape(x) = {X.shape}")
print(f"X = {X}")
print(f"1st matrix of x {X[0]}") # this is the first matrix
print(f"1st matrix of x {X[0,:]}") # this is the first matrix and everything in the first matrix (array)

print(f"1st row of each matrix of x {X[:,:,0,:]}") # take all matrix, then take the first rows, then take the everything in the first rows
#1st colon - 1st dimension, 0 - 2nd dimension and 2rd colon - 3rd dimension

shape(x) = (3, 3, 3, 3)
X = [[[[ 0  1  2]
   [ 3  4  5]
   [ 6  7  8]]

  [[ 9 10 11]
   [12 13 14]
   [15 16 17]]

  [[18 19 20]
   [21 22 23]
   [24 25 26]]]


 [[[27 28 29]
   [30 31 32]
   [33 34 35]]

  [[36 37 38]
   [39 40 41]
   [42 43 44]]

  [[45 46 47]
   [48 49 50]
   [51 52 53]]]


 [[[54 55 56]
   [57 58 59]
   [60 61 62]]

  [[63 64 65]
   [66 67 68]
   [69 70 71]]

  [[72 73 74]
   [75 76 77]
   [78 79 80]]]]
1st matrix of x [[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]
1st matrix of x [[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]]

 [[ 9 10 11]
  [12 13 14]
  [15 16 17]]

 [[18 19 20]
  [21 22 23]
  [24 25 26]]]
1st row of each matrix of x [[[ 0  1  2]
  [ 9 10 11]
  [18 19 20]]

 [[27 28 29]
  [36 37 38]
  [45 46 47]]

 [[54 55 56]
  [63 64 65]
  [72 73 74]]]


## 3. Just-in-time compilation
*Just-in-time* or *JIT* is a compilation technique that allows for code to be analyzed and compiled at runtime.

In [44]:
# JIT warm up
def my_func(x): # use all lowercase and underscore between words
  return jnp.sum(x ** 2) # indent to define the block - consistant about the spaces, all are inside the function
# return a 0 dimension object - a single number

print(my_func(2*jnp.ones(3)))
# `jax.jit` takes as input a function and returns the JIT-compiled function
my_func_jit = jax.jit(my_func) #comes down to the CPU architecture by default - geting faster for free

# results should be the same
D = 50_000 #puting underscore is also the number, just split the number for visualization
orig_result = my_func(jnp.ones(4))
jit_result = my_func_jit(jnp.ones(4))
is_same = jnp.allclose(orig_result, jit_result) # check element wise, are close? but not equal. Because there is subtle difference, and where they differ is so far down from the decimal--> so are these numbers close enough?
print(f"Results are same? {is_same}")

%timeit my_func(jnp.ones(D)) # let's measure time #rerun the operation a bunch of time and give us the statistics
%timeit my_func_jit(jnp.ones(D)).block_until_ready() # measure using JIT; need to block until result is returned
#blcok until ready, specific for jax thing. Blocking until the res is ready.

# results computed faster in the JIT compiled function! We did no extra work
# except wrap our function using a JAX command! Now let's see how to
# use the decorator sytax to handle that automatically for us

@jax.jit
def my_new_func(x):
  return jnp.sum(x ** 2)

# the @jax.jit above the function definition informs the Python interpreter
# to "decorate" `my_new_func` with the `jax.jit` function, which will automatically
# wrap my_new_func in the JIT compiled version. That is, anytime we call `my_new_func`
# we're actually calling the same thing as `jax.jit(my_new_func)`
%timeit my_new_func(jnp.ones(4)).block_until_ready()

# the average time is similar to the above `my_func_jit` which shows that we're
# calling the JIT'd version.

@jax.jit
def other_func(x):
  return jnp.mean(x ** 2)

with jax.log_compiles():
  print("STARTING")
  four_res = other_func(jnp.ones(4))
  print("NOW 5")
  five_res = other_func(jnp.ones(5))
  print("NOW (3,3)")
  mat_res = other_func(jnp.ones((3,3)))

print(f"{four_res}")
print(f"{five_res}")
print(f"{mat_res}")

Results are same? True


## 4. JAX control primitives

In [None]:
@jax.jit
def slow_ssq(a):
  res = 0.
  for x in a:
    res = res + x ** 2
  return res

a = jnp.arange(50_000).astype(float)

print(f"res = {a @ a}")

import jax.lax as lax

def my_body_func(i, val):
  x, cur_res = val
  return x, cur_res + x[i] ** 2

with jax.log_compiles():
  _, res = lax.fori_loop(0, len(a), my_body_func, (a, 0.))
print(f"res = {res}")



res = 41665456766976.0
res = 41665456766976.0
