<a href="https://colab.research.google.com/github/USCbiostats/PM570-Colab/blob/main/Lecture-2.JaxPandasCrashCourse.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## JAX and JAX.Numpy
`jax` is a Google-backed library to enable automatic differentiation of Python code, while supporting ultra-fast runtime due to "Just-In-Time" (i.e. JIT) compilation from their custome bytecode (i.e. XLA). Hence JAX = **J**IT + **A**utoDiff + **X**LA. 

Let's practice importing JAX and using the `numpy` implementation backed by JAX. `numpy` is a Python library for n-dimensional arrays. Here we are using JAX's implementation, which will enable us to take advantage of all of JAX's features.

In [2]:
import jax
import jax.numpy as jnp
import jax.random as rdm

# readr::read_tsv
# let's practice some numpy tricks
x = jnp.arange(9)
y = jnp.ones(9)
print(f"x = {x} | y = {y}")

z = x + y
print(f"z = {z} | x + 1 = {x + 1}")

P = 4
i = jnp.eye(P)
a = 2 * jnp.ones(P)
print(f"i = {i} | a = {a}")

# is this mat/vec mult?
b = i * a
print(f"b = {b}")

A = jnp.array([[5., 1], [1, 5]])
a = 2 * jnp.ones(2)
print(f"A = {A}")
b = A * a
print(f"b = {b}")

# nope! b is matrix; mat/vec mult => vec
b = A @ a
print(f"b = {b}")
b = jnp.dot(A, a)
print(f"b = {b}")



x = [0 1 2 3 4 5 6 7 8] | y = [1. 1. 1. 1. 1. 1. 1. 1. 1.]
z = [1. 2. 3. 4. 5. 6. 7. 8. 9.] | x + 1 = [1 2 3 4 5 6 7 8 9]
i = [[1. 0. 0. 0.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 0. 0. 1.]] | a = [2. 2. 2. 2.]
b = [[2. 0. 0. 0.]
 [0. 2. 0. 0.]
 [0. 0. 2. 0.]
 [0. 0. 0. 2.]]
A = [[5. 1.]
 [1. 5.]]
b = [[10.  2.]
 [ 2. 10.]]
b = [12. 12.]
b = [12. 12.]


In [3]:
# random variables in JAX
key = rdm.PRNGKey(0)

key, y_key = rdm.split(key)
N = 5
mu_y = 50
std_y = 100
y = mu_y + std_y * rdm.normal(y_key, shape=(N,)) # y_i ~ N(mu_y, std_y)
print(f"y = {y}")



y = [ -95.8194  -154.7044   254.73392  166.84094  -47.58364]


In [4]:
import jax.scipy.stats as stats

N = 500_000

key = rdm.PRNGKey(0)

# simulate genotype
freq = 0.1
key, h1_key, h2_key = rdm.split(key, 3)
h1 = rdm.bernoulli(h1_key, freq, shape=(N,)).astype(int) 
h2 = rdm.bernoulli(h2_key, freq, shape=(N,)).astype(int) 
x = h1 + h2
x = x - 2 * freq

# simulate phenotype as a function of genotype
h2g = 1e-4
key, b_key = rdm.split(key)
beta = rdm.normal(b_key)
g = x * beta
s2g = jnp.var(g)
s2e = ((1 / h2g) - 1) * s2g # h2g = s2g / (s2g + s2e) => 

# phenotype = genetic component + env noise
key, y_key = rdm.split(key)
y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
y = y - jnp.mean(y)

#print(f"y = {y}")

beta_hat =  (x.T @ y) / (x.T @ x) # (x.T x)^-1 x.T y
s2e_hat = jnp.mean((y - x * beta_hat) ** 2)
se_beta = jnp.sqrt(s2e_hat / (x.T @ x))

print(f" beta = {beta} | hat(beta) = {beta_hat} | se(hat(beta)) = {se_beta}")
z = beta_hat / se_beta
print(f"z = {z} | p-value = {2*stats.norm.cdf(-jnp.fabs(z))}")

 beta = -0.1399441361427307 | hat(beta) = -0.1588948369026184 | se(hat(beta)) = 0.019815918058156967
z = -8.018545150756836 | p-value = 1.0700497342445428e-15


## LAX and functional design patterns

`jax.lax.scan` provides a means to scan along an axis of `ndarray` while carrying state along with it. The psuedocode for `scan` looks like, 
```python
def scan(func, init, xs, length=None):
  if xs is None:
    xs = [None] * length
  carry = init
  ys = []
  for x in xs:
    carry, y = func(carry, x)
    ys.append(y)
  return carry, np.stack(ys)
  ```

  Let's see how this can be useful for GWAS...

In [9]:
import jax.lax as lax

# simulate geno + pheno
N = 10_000
P = 1000

key = rdm.PRNGKey(0)

# simulate genotype
freq = 0.1
key, h1_key, h2_key = rdm.split(key, 3)
h1 = rdm.bernoulli(h1_key, freq, shape=(N,P)).astype(int) 
h2 = rdm.bernoulli(h2_key, freq, shape=(N,P)).astype(int) 
X = h1 + h2
X = X - 2 * freq

# simulate phenotype as a function of genotype
h2g = 1e-4
key, b_key = rdm.split(key)
beta = rdm.normal(b_key, shape=(P,))
G = X @ beta
s2g = jnp.var(G)
s2e = ((1 / h2g) - 1) * s2g # h2g = s2g / (s2g + s2e) => 

# phenotype = genetic component + env noise
key, y_key = rdm.split(key)
y = G + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))
y = y - jnp.mean(y)

# lets write a function that performs OLS between a single variant and y, and
# reports the effect size estimate, its SE, and pvalue

# scan results a 'carry' that keeps state going forward
# here we dont require changing state. all we need to
# do is keep passing phenotype along
def ols_scan(y, x):
  xtx = x.T @ x
  beta_hat =  (x.T @ y) / (xtx) # (x.T x)^-1 x.T y
  s2e_hat = jnp.mean((y - x * beta_hat) ** 2)
  se_beta = jnp.sqrt(s2e_hat / (xtx))
  p_val = 2*stats.norm.cdf(-jnp.fabs(beta_hat / se_beta))

  # scan requires we return updated state (i.e. same y)
  # along with the result
  return y, jnp.array([beta_hat, se_beta, p_val])

_, gwas_res = lax.scan(ols_scan, y, X.T)

# print first 5 results...
print("BETA\tSE\tPVal")
print(gwas_res[:5])

BETA	SE	PVal
[[-11.807892    31.55952      0.7082951 ]
 [  8.070707    31.307985     0.7965734 ]
 [ 24.326462    31.32242      0.43736708]
 [ -0.281418    32.09645      0.9930043 ]
 [ 14.883338    31.868835     0.64048654]]


Pandas is a Python library for datatable/dataframe like data structures.
Let's take our GWAS results and convert to dataframe for easier manipulation

In [10]:
import pandas as pd

df_gwas = pd.DataFrame(gwas_res, columns=["BETA", "SE", "PVal"])
print(df_gwas.head())

         BETA         SE        PVal
0  -11.807892   31.55952   0.7082951
1    8.070707  31.307985   0.7965734
2   24.326462   31.32242  0.43736708
3   -0.281418   32.09645   0.9930043
4   14.883338  31.868835  0.64048654
