<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_1_Introduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PM520, Lab 1: Introduction to Python
For an excellent review of Python, please check out the excellent resource [Learn Python in Y Minutes](https://learnxinyminutes.com/docs/python/).

## 1. JAX and JAX.Numpy
[JAX](https://github.com/google/jax) is a Google-backed library to enable automatic differentiation of Python code, while supporting ultra-fast runtime due to "Just-In-Time" (i.e. JIT) compilation from their custom bytecode (i.e. XLA). Hence JAX = JIT + AutoDiff + XLA.

However, before we can use JAX we'll need to install it. To do that, we'll use the Python package management tool `pip`. To call `pip` in colab (or any terminal function) we need to prepend the command with a `!`.


In [1]:
!pip install jax



Let's practice importing JAX and using the numpy implementation backed by JAX. numpy is a Python library for n-dimensional arrays. Here we are using JAX's implementation, which will enable us to take advantage of all of JAX's features.

In [3]:
import jax
import jax.numpy as jnp
import jax.random as rdm

# let's practice some numpy tricks
x = jnp.arange(9)
y = jnp.ones(9)
print(f"x = {x} | y = {y}")

z = x + y
print(f"z = {z} | x + 1 = {x + 1}")


P = 4
i = jnp.eye(P)
a = 2 * jnp.ones(P)
print(f"i = {i} | a = {a}")

# is this mat/vec mult?
b = i * a
print(f"b = {b}")

A = jnp.array([[5., 1], [1, 5]])
a = jnp.array([3, 4])
print(f"A = {A}")
b = A * a
print(f"b = {b}")
b = a * A
print(f"b = {b}")

# nope! b is matrix; mat/vec mult => vec
b = A @ a
print(f"b = {b}")
b = jnp.dot(A, a)
print(f"b = {b}")

## 2. Indexing, broadcasting rules, and dot products

Let's practice how to index, slice, and broadcast jax arrays.

In [25]:
# indexing arrays
shape = (3,3)
X = jnp.arange(9).reshape(shape)

# what is the shape of x?
print(f"shape(x) = {X.shape}")

# indexing and 'slicing'
print(f"X = {X}")
print(f"1st row of x {X[0]}")
print(f"1st row of x {X[0,:]}")
print(f"1st col of x {X[:,0]}")

# arrays can be n-dimensional and not just vectors/matrices!
X = jnp.arange(27).reshape((3,3,3))
print(f"shape(x) = {X.shape}")
print(f"X = {X}")
print(f"1st matrix of x {X[0]}")
print(f"1st matrix of x {X[0,:]}")
print(f"1st row of each matrix of x {X[:,0,:]}")

1st row of each matrix of x [[ 0  1  2]
 [ 9 10 11]
 [18 19 20]]


## 3. Just-in-time compilation
*Just-in-time* or *JIT* is a compilation technique that allows for code to be analyzed and compiled at runtime.

In [35]:
# JIT warm up
def my_func(x):
  return jnp.sum(x ** 2)

# `jax.jit` takes as input a function and returns the JIT-compiled function
my_func_jit = jax.jit(my_func)

# results should be the same
D = 50_000
orig_result = my_func(jnp.ones(4))
jit_result = my_func_jit(jnp.ones(4))
is_same = jnp.allclose(orig_result, jit_result)
print(f"Results are same? {is_same}")

%timeit my_func(jnp.ones(D)) # let's measure time
%timeit my_func_jit(jnp.ones(D)).block_until_ready() # measure using JIT; need to block until result is returned

# results computed faster in the JIT compiled function! We did no extra work
# except wrap our function using a JAX command! Now let's see how to
# use the decorator sytax to handle that automatically for us

@jax.jit
def my_new_func(x):
  return jnp.sum(x ** 2)

# the @jax.jit above the function definition informs the Python interpreter
# to "decorate" `my_new_func` with the `jax.jit` function, which will automatically
# wrap my_new_func in the JIT compiled version. That is, anytime we call `my_new_func`
# we're actually calling the same thing as `jax.jit(my_new_func)`
%timeit my_new_func(jnp.ones(4)).block_until_ready()

# the average time is similar to the above `my_func_jit` which shows that we're
# calling the JIT'd version.

@jax.jit
def other_func(x):
  return jnp.mean(x ** 2)

with jax.log_compiles():
  print("STARTING")
  four_res = other_func(jnp.ones(4))
  print("NOW 5")
  five_res = other_func(jnp.ones(5))
  print("NOW (3,3)")
  mat_res = other_func(jnp.ones((3,3)))

print(f"{four_res}")
print(f"{five_res}")
print(f"{mat_res}")



STARTING
NOW 5
NOW (3,3)
1.0
1.0
1.0


## 4. JAX control primitives

In [14]:
@jax.jit
def slow_ssq(a):
  res = 0.
  for x in a:
    res = res + x ** 2
  return res

a = jnp.arange(50_000).astype(float)

print(f"res = {a @ a}")

import jax.lax as lax

def my_body_func(i, val):
  x, cur_res = val
  return x, cur_res + x[i] ** 2

with jax.log_compiles():
  _, res = lax.fori_loop(0, len(a), my_body_func, (a, 0.))
print(f"res = {res}")



res = 41665456766976.0
res = 41665456766976.0
