<a href="https://colab.research.google.com/github/USCbiostats/PM570-Colab/blob/main/Lecture-2.JaxPandasCrashCourse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## JAX and JAX.Numpy
`jax` is a Google-backed library to enable automatic differentiation of Python code, while supporting ultra-fast runtime due to "Just-In-Time" (i.e. JIT) compilation from their custome bytecode (i.e. XLA). Hence JAX = **J**IT + **A**utoDiff + **X**LA. 

Let's practice importing JAX and using the `numpy` implementation backed by JAX. `numpy` is a Python library for n-dimensional arrays. Here we are using JAX's implementation, which will enable us to take advantage of all of JAX's features.

In [1]:
import jax
import jax.numpy as jnp
import jax.random as rdm

# readr::read_tsv
# let's practice some numpy tricks
x = jnp.arange(9)
y = jnp.ones(9)
print(f"x = {x} | y = {y}")

z = x + y
print(f"z = {z} | x + 1 = {x + 1}")

P = 4
i = jnp.eye(P)
a = 2 * jnp.ones(P)
print(f"i = {i} | a = {a}")

# is this mat/vec mult?
b = i * a
print(f"b = {b}")

A = jnp.array([[5., 1], [1, 5]])
a = 2 * jnp.ones(2)
print(f"A = {A}")
b = A * a
print(f"b = {b}")

# nope! b is matrix; mat/vec mult => vec
b = A @ a
print(f"b = {b}")
b = jnp.dot(A, a)
print(f"b = {b}")



x = [0 1 2 3 4 5 6 7 8] | y = [1. 1. 1. 1. 1. 1. 1. 1. 1.]
z = [1. 2. 3. 4. 5. 6. 7. 8. 9.] | x + 1 = [1 2 3 4 5 6 7 8 9]
i = [[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]] | a = [2. 2. 2. 2.]
b = [[2. 0. 0. 0.]
 [0. 2. 0. 0.]
 [0. 0. 2. 0.]
 [0. 0. 0. 2.]]
A = [[5. 1.]
 [1. 5.]]
b = [[10.  2.]
 [ 2. 10.]]
b = [12. 12.]
b = [12. 12.]


In [2]:
# random variables in JAX
key = rdm.PRNGKey(0)

key, y_key = rdm.split(key)
N = 5
mu_y = 50
std_y = 100
y = mu_y + std_y * rdm.normal(y_key, shape=(N,)) # y_i ~ N(mu_y, std_y)
print(f"y = {y}")



y = [ -95.8194  -154.7044   254.73392  166.84094  -47.58364]


In [3]:
import jax.scipy.stats as stats

N = 500_000

key = rdm.PRNGKey(0)

# simulate genotype
freq = 0.1
key, h1_key, h2_key = rdm.split(key, 3)
h1 = rdm.bernoulli(h1_key, freq, shape=(N,)).astype(int) 
h2 = rdm.bernoulli(h2_key, freq, shape=(N,)).astype(int) 
x = h1 + h2
x = x - 2 * freq

# simulate phenotype as a function of genotype
h2g = 1e-4
key, b_key = rdm.split(key)
beta = rdm.normal(b_key)
g = x * beta
s2g = jnp.var(g)
s2e = ((1 / h2g) - 1) * s2g # h2g = s2g / (s2g + s2e) => 

# phenotype = genetic component + env noise
key, y_key = rdm.split(key)
y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
y = y - jnp.mean(y)

#print(f"y = {y}")

beta_hat =  (x.T @ y) / (x.T @ x) # (x.T x)^-1 x.T y
s2e_hat = jnp.mean((y - x * beta_hat) ** 2)
se_beta = jnp.sqrt(s2e_hat / (x.T @ x))

print(f" beta = {beta} | hat(beta) = {beta_hat} | se(hat(beta)) = {se_beta}")
z = beta_hat / se_beta
print(f"z = {z} | p-value = {2*stats.norm.cdf(-jnp.fabs(z))}")

 beta = -0.1399441361427307 | hat(beta) = -0.1588948369026184 | se(hat(beta)) = 0.019815918058156967
z = -8.018545150756836 | p-value = 1.0700497342445428e-15


## LAX and functional design patterns

`jax.lax.scan` provides a means to scan along an axis of `ndarray` while carrying state along with it. The psuedocode for `scan` looks like, 
```python
def scan(func, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = func(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
  ```

  Let's see how this can be useful for GWAS...

In [7]:
import jax.lax as lax

# simulate geno + pheno
N = 10_000
P = 10000

key = rdm.PRNGKey(0)

# simulate genotype
freq = 0.1
key, h1_key, h2_key = rdm.split(key, 3)
h1 = rdm.bernoulli(h1_key, freq, shape=(N,P)).astype(int) 
h2 = rdm.bernoulli(h2_key, freq, shape=(N,P)).astype(int) 
X = h1 + h2
X = X - 2 * freq

# simulate phenotype as a function of genotype
h2g = 0.3
key, b_key = rdm.split(key)
beta = rdm.normal(b_key, shape=(P,))
G = X @ beta
s2g = jnp.var(G)
s2e = ((1 / h2g) - 1) * s2g # h2g = s2g / (s2g + s2e) => 

# phenotype = genetic component + env noise
key, y_key = rdm.split(key)
y = G + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
y = y - jnp.mean(y)

# lets write a function that performs OLS between a single variant and y, and
# reports the effect size estimate, its SE, and pvalue

# scan results a 'carry' that keeps state going forward
# here we dont require changing state. all we need to
# do is keep passing phenotype along
def ols_scan(y, x):
  xtx = x.T @ x
  beta_hat =  (x.T @ y) / (xtx) # (x.T x)^-1 x.T y
  s2e_hat = jnp.mean((y - x * beta_hat) ** 2)
  se_beta = jnp.sqrt(s2e_hat / (xtx))
  p_val = 2*stats.norm.cdf(-jnp.fabs(beta_hat / se_beta))

  # scan requires we return updated state (i.e. same y)
  # along with the result
  return y, jnp.array([beta_hat, se_beta, p_val])

_, gwas_res = lax.scan(ols_scan, y, X.T)

# print first 5 results...
print("BETA\tSE\tPVal")
print(gwas_res[:5])

BETA	SE	PVal
[[-0.89536387  1.8326901   0.6251591 ]
 [ 0.8770723   1.8177588   0.62944937]
 [ 2.8826997   1.8519931   0.11957997]
 [ 1.3135461   1.8424423   0.47588444]
 [ 0.20249538  1.8436264   0.91253996]]


## Just-in-time compilation and function decorators
Let's see if we can use JIT to speed up our GWAS scan. To do that we'll need to review adding "decorators" to python functions that modify them in some way.

In [5]:
from jax._src.api import block_until_ready
# JIT warm up

def my_func(x):
  return jnp.sum(x ** 2)

# `jax.jit` takes as input a function and returns the JIT-compiled function
my_func_jit = jax.jit(my_func)

# results should be the same
is_same = jnp.allclose(my_func(jnp.ones(4)), my_func_jit(jnp.ones(4)))
print(f"Results are same? {is_same}")

%timeit my_func(jnp.ones(4)) # let's measure time
%timeit my_func_jit(jnp.ones(4)).block_until_ready() # measure using JIT; need to block until result is returned

# results computed faster in the JIT compiled function! We did no extra work
# except wrap our function using a JAX command! Now let's see how to 
# use the decorator sytax to handle that automatically for us

@jax.jit
def my_new_func(x):
  return jnp.sum(x ** 2)

# the @jax.jit above the function definition informs the Python interpreter
# to "decorate" `my_new_func` with the `jax.jit` function, which will automatically
# wrap my_new_func in the JIT compiled version. That is, anytime we call `my_new_func`
# we're actually calling the same thing as `jax.jit(my_new_func)`
%timeit my_new_func(jnp.ones(4)).block_until_ready()

# the average time is similar to the above `my_func_jit` which shows that we're
# calling the JIT'd version. 

Results are same? True
620 µs ± 81.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
309 µs ± 9.42 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
312 µs ± 6.77 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [8]:
# Now let's apply this to our GWAS scan. We'll skip
# the decorator syntax for now to illustrate the speedup for our scan.

def gwas_scan_slow(X, y):
  _, gwas_res = lax.scan(ols_scan, y, X.T)
  return gwas_res

gwas_scan_fast = jax.jit(gwas_scan_slow)
%timeit gwas_scan_slow(X, y).block_until_ready()
%timeit gwas_scan_fast(X, y).block_until_ready()

# the speedup here seems marginal, but this only will improve as sample
# sizes get bigger and our functions are more complex!

2.51 s ± 75.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
2.48 s ± 74.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


#Pandas
Pandas is a Python library for datatable/dataframe like data structures.
Let's take our GWAS results and convert to dataframe for easier manipulation

In [None]:
import pandas as pd

df_gwas = pd.DataFrame(gwas_res, columns=["BETA", "SE", "PVal"])
print(df_gwas.head())

         BETA         SE        PVal
0  -11.807892   31.55952   0.7082951
1    8.070707  31.307985   0.7965734
2   24.326462   31.32242  0.43736708
3   -0.281418   32.09645   0.9930043
4   14.883338  31.868835  0.64048654


# VMAP, or how I learned to stop worrying and love vectorization
Say we have a function defined only for a particular shape of data, and we would like to write a similarly but for multiple "batches" of data. How could we do that?

In comes `jax.vmap`.

In [13]:
def ols_single(y: jnp.ndarray, x: jnp.ndarray) -> jnp.ndarray:
  xtx = x.T @ x
  beta_hat =  (x.T @ y) / (xtx) # (x.T x)^-1 x.T y
  s2e_hat = jnp.mean((y - x * beta_hat) ** 2)
  se_beta = jnp.sqrt(s2e_hat / (xtx))
  p_val = 2*stats.norm.cdf(-jnp.fabs(beta_hat / se_beta))

  # scan requires we return updated state (i.e. same y)
  # along with the result
  return jnp.array([beta_hat, se_beta, p_val])

results = []
for i in range(10):
  results.append(ols_single(y, X.T[i]))
results = jnp.array(results)
print(f"Results = {results}")

Results = [[-0.89536387  1.8326901   0.6251591 ]
 [ 0.8770723   1.8177588   0.62944937]
 [ 2.8826997   1.8519931   0.11957997]
 [ 1.3135461   1.8424423   0.47588444]
 [ 0.20249538  1.8436264   0.91253996]
 [-2.5835547   1.8322238   0.15852053]
 [-2.4758806   1.8603319   0.18322818]
 [ 0.33134592  1.8274202   0.85611725]
 [-0.9843171   1.8259833   0.58984447]
 [-0.7290425   1.8216625   0.68900394]]


In [14]:
#subset first 10 SNPs
X_sub = X[:,:10]

# call vmap on our ols function, but vectorize only along the genotype
# hence "(None, 0)", and vectorize output along genotype axis "0"
ols_multi = jax.vmap(ols_single, (None, 0), 0)
results = ols_multi(y, X_sub.T)
print(f"Results = {results}")

Results = [[-0.8953643   1.8326901   0.62515897]
 [ 0.8770692   1.8177588   0.62945056]
 [ 2.882697    1.8519931   0.11958028]
 [ 1.3135438   1.8424423   0.47588527]
 [ 0.20249653  1.8436264   0.9125395 ]
 [-2.5835583   1.8322238   0.15851992]
 [-2.4758773   1.8603319   0.18322876]
 [ 0.33134153  1.8274201   0.8561191 ]
 [-0.9843188   1.8259833   0.58984387]
 [-0.72904265  1.8216625   0.6890038 ]]


In [15]:
# we can combine this with JIT!
ols_multi_fast = jax.jit(ols_multi)

%timeit ols_multi(y, X_sub.T).block_until_ready()
%timeit ols_multi_fast(y, X_sub.T).block_until_ready()

26.7 ms ± 545 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
449 µs ± 9.12 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


# Autograd
Automatic differentiation is a field of study that focuses on how to compute derivatives of computer code algorithmically

In [20]:
def sos(x: jnp.ndarray) -> float:
  return jnp.sum(x ** 2)

x = 2 * jnp.ones(5)
sos(x)

def sos_handcoded_deriv(x: jnp.ndarray) -> jnp.ndarray:
  return 2 * x

# we can just `jax.grad` to get the gradient function automatically
sos_prime = jax.grad(sos)
print(f"Are derivatives the same? {jnp.allclose(sos_prime(x), sos_handcoded_deriv(x))}")

sos_hess = jax.hessian(sos)
print(f"Hessian = {sos_hess(x)}")

Are derivatives the same? True
Hessian = [[2. 0. 0. 0. 0.]
 [0. 2. 0. 0. 0.]
 [0. 0. 2. 0. 0.]
 [0. 0. 0. 2. 0.]
 [0. 0. 0. 0. 2.]]


In [None]:
# inv(A) @ v; we are solving for direction d in Ad = v
# the solution d = inv(A) @ v is the same solution to following optimization problem
# min_d 0.5*d.T @ A @ d + d.T @ v

# TODO: later