<a href="https://colab.research.google.com/github/USCbiostats/PM570-Colab/blob/main/Lecture-4.PopStructureGWAS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!rm -r /content/PM570-Colab/
!git clone https://github.com/USCbiostats/PM570-Colab.git
!pip install pandas_plink
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.bed
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.bim
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.fam

rm: cannot remove '/content/PM570-Colab/': No such file or directory
Cloning into 'PM570-Colab'...
remote: Enumerating objects: 191, done.[K
remote: Counting objects: 100% (191/191), done.[K
remote: Compressing objects: 100% (142/142), done.[K
remote: Total 191 (delta 98), reused 124 (delta 44), pack-reused 0[K
Receiving objects: 100% (191/191), 46.80 KiB | 1.80 MiB/s, done.
Resolving deltas: 100% (98/98), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pandas_plink
  Downloading pandas_plink-2.2.9-cp39-cp39-manylinux2010_x86_64.whl (100 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.4/100.4 KB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting zstandard>=0.13.0
  Downloading zstandard-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
Collecting Dep

In [2]:
import sys
sys.path.append('/content/PM570-Colab/')

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.random as rdm
import jax.scipy.linalg as jspla
import jax.scipy.stats as stats

# lets make sure we're using 64bit precision to not lose accuracy
# in our GWAS results
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)

from sim import geno, trait
from util import gwas

N = 5000
P = 10_000
PROP = 0.1
H2G = 0.1

key = rdm.PRNGKey(0)
key, geno_key, trait_key = rdm.split(key, 3)

# simulate genotype w/o LD
X = geno.naive_sim_genotype(N, P, geno_key)

# center and standardized genotype
X = X - jnp.mean(X, axis=0)
X = X / jnp.std(X, axis=0)

# compute GRM
A = X @ X.T / P

# compute eigendecomposition of A = U @ D @ U.T
D, U = jnpla.eigh(A)
D = D.astype(float)
U = U.astype(float)

# simulate phenotype using genotype data
y = trait.naive_trait_sim(X, PROP, H2G, trait_key)
y = y - jnp.mean(y)
y = y / jnp.std(y)


def normal_h2g_likelihood(params: jnp.ndarray, y: jnp.ndarray, A: jnp.ndarray) -> float:
  """ evaluate the likelhood under the linear mixed model of
      y ~ N(0, A s2g + I s2e) =>
      y ~ N(0, V) for V = A s2g + I s2e

  Args:
    params: the variance components [s2g, s2e]
    y: phenotype
    A: GRM
  
  Returns:
    float: the neg log likelihood
  """
  n = len(y)
  V = params[0] * A + params[1] * jnp.eye(n)
  L = jnpla.cholesky(V)

  # rotate y to independent basis
  # inv(L) @ y => N(0, inv(L) @ V @ inv(L).T)
  #            =  N(0, inv(L) @ L @ L.T @ inv(L).T)
  #            =  N(0, I @ I) = N(0, I)
  y_r = jspla.cho_solve((L, True), y)

  return -jnp.sum(stats.norm.logpdf(y_r, loc=0., scale=1.))


  return _convert_element_type(operand, new_dtype, weak_type=False)


In [4]:
# let's use gradient descent to infer h2g
nll_vandg = jax.jit(jax.value_and_grad(normal_h2g_likelihood))
step_size = 1e-1
loss = 10000
max_iter = 10
tol = 1e-3

# init params
params = 0.5 * jnp.ones(2)
for idx in range(max_iter):
  loss_i, nllgrad = nll_vandg(params, y, A)
  print(f"Iter = {idx} | Params = {params} | nLL = {loss_i}")
  params = params - step_size * nllgrad
  if jnp.fabs(loss_i - loss) < tol:
    break
  loss = loss_i

print(f"Var components = {params}")

Iter = 0 | Params = [0.5 0.5] | nLL = 7953.708847219411
Iter = 1 | Params = [451.25347347 893.35299901] | nLL = 4594.69422967721


KeyboardInterrupt: ignored

In [6]:
EPS = 1e-5

def normal_h2g_likelihood_fast(params: jnp.ndarray, Uty: jnp.ndarray, D: jnp.ndarray) -> float:
  """ Implements the Fast-LMM idea from Lippert et al. Nat Meth 2011.
  
  Evaluate the likelhood under the linear mixed model of
      y ~ N(0, A s2g + I s2e) =>
      y ~ N(0, U D Ut s2g + I s2e); recall that inv(U) = Ut; recall U @ Ut = I
      Ut @ y ~ N(0, Ut [U D Ut s2g + I s2e ] U) =>
             ~ N(0, Ut U D Ut U s2g + Ut U s2e) =>
             ~ N(0, I D I s2g + I s2e) =>
             ~ N(0, D s2g + I s2e)


  Args:
    params: the variance components [s2g, s2e]
    Uty: phenotype rotated from eigenvectors of A
    D: Eigenvalues of A
  
  Returns:
    float: the neg log likelihood
  """
  s2g = params[0]
  s2e = params[1]

  v = s2g * D + s2e

  return -jnp.sum(stats.norm.logpdf(Uty, loc=0., scale=jnp.sqrt(v + EPS)))
  

# let's use gradient descent to infer h2g
nll_vandg = jax.jit(jax.value_and_grad(normal_h2g_likelihood_fast))
step_size = 1e-1
loss = 10000
max_iter = 10
tol = 1e-3

# init params
params = 0.5 * jnp.ones(2)
Uty = U.T @ y
for idx in range(max_iter):
  loss_i, nllgrad = nll_vandg(params, Uty, D)
  print(f"Iter = {idx} | Params = {params} | nLL = {loss_i}")
  params = params - step_size * nllgrad
  # keep parameters in valid variance space
  params = jnp.where(params < 0, 0.001, params)
  if jnp.fabs(loss_i - loss) < tol:
    break
  loss = loss_i

print(f"Var components = {params}")
print(f"h2g = {params[0] / sum(params)}")

Iter = 0 | Params = [0.5 0.5] | nLL = 7195.181142440659
Iter = 1 | Params = [1.00000000e-03 5.56020455e+01] | nLL = 14685.249574307632
Iter = 2 | Params = [1.00000000e-03 5.11867513e+01] | nLL = 14482.283585533994
Iter = 3 | Params = [1.00000000e-03 4.63981842e+01] | nLL = 14241.778088869592
Iter = 4 | Params = [1.00000000e-03 4.11262826e+01] | nLL = 13947.159367682083
Iter = 5 | Params = [1.00000000e-03 3.51953956e+01] | nLL = 13568.08246044856
Iter = 6 | Params = [1.00000000e-03 2.82942073e+01] | nLL = 13039.778653982601
Iter = 7 | Params = [1.00000000e-03 1.97710491e+01] | nLL = 12181.807890310214
Iter = 8 | Params = [1.00000000e-03 7.76643433e+00] | nLL = 10041.399937863378
Iter = 9 | Params = [0.001 0.001] | nLL = 1355381.4447830908
Var components = [52836981.01163648 82731541.8832707 ]
h2g = 0.3897437243053518


In [12]:
from functools import partial
from util.optimization import newton_cg

# init params
params = 0.5 * jnp.ones(2)
Uty = U.T @ y

loss_f = jax.jit(partial(normal_h2g_likelihood_fast, Uty=Uty, D=D))
_, loss, num_iter, params = newton_cg(loss_f, params, step_size, max_iter)


print(f"Number of iterations = {num_iter} | loss = {loss}")
print(f"Var components = {params}")
print(f"hat(h2g) = {params[0] / sum(params)}")
print(f"h2g = {H2G}")

Number of iterations = 3 | loss = 7089.511819831592
Var components = [0.09400057 0.90467083]
hat(h2g) = 0.09412562464453532
h2g = 0.1
