<a href="https://colab.research.google.com/github/USCbiostats/PM520/blob/main/Lab_2_Sampling.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import jax
import jax.numpy as jnp
import jax.random as rdm

In [None]:
# random variables in JAX
key = rdm.PRNGKey(0)

key, y_key = rdm.split(key)
N = 500
mu_y = 50
std_y = 100
y = mu_y + std_y * rdm.normal(y_key, shape=(N,)) # y_i ~ N(mu_y, std_y)
print(f"mean(y) = {jnp.mean(y)}")

P = 100
key, x_key = rdm.split(key)
X = rdm.normal(x_key, shape=(N,P))
print(f"shape(x) = {X.shape}")

mean(y) = 55.9890251159668
shape(x) = (500, 100)


# Simulations under a linear model
Goal: given $n=500$, $p=100$, simulate a random normal matrix of size $n \times p$, simulate $p$ effect sizes $\beta$ under a standard normal distribution, compute $g = X \beta$, and $s^2_g :=$ var($g$) and lastly compute
$y = g + \epsilon$ where $\epsilon_i \sim N(0, \sigma^2e)$ where $\sigma^2_e := (1 - r^2)/r^2 * s^2_g$. where $r^2 = 0.5$.

In [None]:
# our PRNG seed to initialize state for sampling and replication
seed = 0

# defined quantities
N = 500
P = 100
r2 = 0.5

# create initial key
key = rdm.PRNGKey(seed)

key, x_key = rdm.split(key)
X = rdm.normal(x_key, shape=(N, P))

key, b_key = rdm.split(key)
beta = rdm.normal(b_key, shape=(P,))

# g = jnp.dot(X, beta)
g = X @ beta
s2g = jnp.var(g)

# back out what s2e is, such that s2g / (s2g + s2e) == h2
s2e = (1 - r2) / r2 * s2g
key, y_key = rdm.split(key)

# add env noise to g, but scale such that var(e) == s2e
y = g + jnp.sqrt(s2e) * rdm.normal(y_key, shape=(N,))

# compute basic stats and sanity check against specified h2
print(f"mean(y) = {jnp.mean(y)} | var(y) = {jnp.var(y)}")
print(f"hat(h2) = {s2g / jnp.var(y)} | true h2 = {h2}")

mean(y) = -0.3141368627548218 | var(y) = 185.59474182128906
hat(h2) = 0.4755840599536896 | true h2 = 0.5


Let's wrap the above functionality into a function called `sim_data` that takes as its arguments,
  1. an PRNGKey
  2. N
  3. P
  4. h2
and returns `y`, `X`, `beta`

In [None]:
def sim_data(key, N, P, h2):
  """
  Simulate a trait under a linear model;
  """
  pass

mean(y) = -0.3141368627548218 | var(y) = 185.59474182128906


In [None]:
def compute_stat(y, x):
  """
  Compute OLS between y and x;
  Return hat{beta}_x and se(hat{beta}_x)
  """
  pass

# Compute stats!
Let's loop over each column in `X` and compute its _marginal_ statistics using the above function.

In [None]:
# import our stats library
import jax.scipy.stats as stats

beta_hats = []
ses = []

for i in range(P):
  beta_i, se_i = compute_stat(y, X[:, i])
  beta_hats.append(beta_i)
  ses.append(se_i)

beta_hats = jnp.array(beta_hats)
ses = jnp.array(ses)
zscores = beta_hats / ses
pvalues = 2 * stats.norm.sf(jnp.abs(zscores))

# permutation tests wow!
let's write a function `perm` that takes as input,
  1. PRNGKey,
  2. y,
  3. X,
  4. beta_hat
  5. B, the number of permutations to perform

and returns empirical/permutation pvalues for each effect

In [1]:
def compute_marginals(y, X):
  """
  Compute the marginal statistics for reach column in X against y;
  Return tuple of (beta_hat, ses) where each are jax arrays
  """
  pass

def perm(key, y, X, zscores, B):
  """
  Compute a permutation test for each marginal effect over X;
  Returns the permutation/empirical p-value for each marginal effect
  as a single jax array.
  """
  pass

# split key for fun
key, p_key = rdm.split(key)

# compute marginals under observed
beta_hat, ses = compute_marginals(y, X)

# compute zscores and pvalues under normality
zscores = beta_hats / ses
pvalues = 2 * stats.norm.sf(jnp.abs(zscores))

# compute empirical pvalues
B = 100
perm(p_key, y, X, zscores, B)

NameError: name 'rdm' is not defined