<a href="https://colab.research.google.com/github/USCbiostats/PM570-Colab/blob/main/Lecture-4.PopStructureGWAS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!rm -r /content/PM570-Colab/
!git clone https://github.com/USCbiostats/PM570-Colab.git
!pip install pandas_plink
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.bed
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.bim
!wget https://github.com/mancusolab/sushie/raw/main/data/plink/EUR.fam

rm: cannot remove '/content/PM570-Colab/': No such file or directory
Cloning into 'PM570-Colab'...
remote: Enumerating objects: 169, done.[K
remote: Counting objects: 100% (169/169), done.[K
remote: Compressing objects: 100% (127/127), done.[K
remote: Total 169 (delta 83), reused 111 (delta 37), pack-reused 0[K
Receiving objects: 100% (169/169), 35.19 KiB | 11.73 MiB/s, done.
Resolving deltas: 100% (83/83), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pandas_plink
  Downloading pandas_plink-2.2.9-cp38-cp38-manylinux2010_x86_64.whl (100 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.8/100.8 KB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
Collecting zstandard>=0.13.0
  Downloading zstandard-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (2.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.6/2.6 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
Collecting De

In [12]:
# let's start simple with no LD
import sys
sys.path.append('/content/PM570-Colab/')

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jnpla
import jax.random as rdm
import jax.scipy.linalg as jspla
import jax.scipy.stats as stats

# lets make sure we're using 64bit precision to not lose accuracy
# in our GWAS results
# again, this only works on startup!
from jax.config import config
config.update("jax_enable_x64", True)

from sim import geno, trait
from util import gwas

N = 5000
P = 10_000
PROP = 0.1
H2G = 0.1

key = rdm.PRNGKey(0)
key, geno_key, trait_key = rdm.split(key, 3)

# simulate genotype w/o LD
X = geno.naive_sim_genotype(N, P, geno_key)

# center and standardized genotype
X = X - jnp.mean(X, axis=0)
X = X / jnp.std(X, axis=0)

# compute GRM
A = X @ X.T / P

# compute eigendecomposition of A = U @ D @ U.T
D, U = jnpla.eig(A)
D = D.astype(float)
U = U.astype(float)

# simulate phenotype using genotype data
y = trait.naive_trait_sim(X, PROP, H2G, trait_key)
y = y - jnp.mean(y)
y = y / jnp.std(y)


def normal_h2g_likelihood(params: jnp.ndarray, y: jnp.ndarray, A: jnp.ndarray) -> float:
  """ evaluate the likelhood under the linear mixed model of
      y ~ N(0, A s2g + I s2e) =>
      y ~ N(0, V) for V = A s2g + I s2e

  Args:
    params: the variance components [s2g, s2e]
    y: phenotype
    A: GRM
  
  Returns:
    float: the neg log likelihood
  """
  n = len(y)
  V = params[0] * A + params[1] * jnp.eye(n)
  L = jnpla.cholesky(V)

  # rotate y to independent basis
  # inv(L) @ y => N(0, inv(L) @ V @ inv(L).T)
  #            =  N(0, inv(L) @ L @ L.T @ inv(L).T)
  #            =  N(0, I @ I) = N(0, I)
  y_r = jspla.cho_solve((L, True), y)

  return -jnp.sum(stats.norm.logpdf(y_r, loc=0., scale=1.))


  return _convert_element_type(operand, new_dtype, weak_type=False)


In [8]:
# let's use gradient descent to infer h2g
nll_vandg = jax.jit(jax.value_and_grad(normal_h2g_likelihood))
step_size = 1e-1
loss = 10000
max_iter = 10
tol = 1e-3

# init params
params = 0.5 * jnp.ones(2)
for idx in range(max_iter):
  loss_i, nllgrad = nll_vandg(params, y, A)
  print(f"Iter = {idx} | Params = {params} | nLL = {loss_i}")
  params = params - step_size * nllgrad
  if jnp.fabs(loss_i - loss) < tol:
    break
  loss = loss_i

print(f"Var components = {params}")

Iter = 0 | Params = [0.5 0.5] | nLL = 7953.708847219411
Iter = 1 | Params = [451.25347347 893.35299901] | nLL = 4594.69422967721
Iter = 2 | Params = [451.25347364 893.35299928] | nLL = 4594.694229677209
Iter = 3 | Params = [451.25347381 893.35299954] | nLL = 4594.694229677208
Iter = 4 | Params = [451.25347398 893.3529998 ] | nLL = 4594.694229677207


KeyboardInterrupt: ignored

In [21]:

def normal_h2g_likelihood_fast(params: jnp.ndarray, Uty: jnp.ndarray, D: jnp.ndarray) -> float:
  """ evaluate the likelhood under the linear mixed model of
      y ~ N(0, A s2g + I s2e) =>
      y ~ N(0, U D Ut s2g + I s2e); recall that inv(U) = Ut; recall U @ Ut = I
      Ut @ y ~ N(0, Ut [U D Ut s2g + I s2e ] U) =>
             ~ N(0, Ut U D Ut U s2g + Ut U s2e) =>
             ~ N(0, I D I s2g + I s2e) =>
             ~ N(0, D s2g + I s2e)


  Args:
    params: the variance components [s2g, s2e]
    Uty: phenotype rotated from eigenvectors of A
    D: Eigenvalues of A
  
  Returns:
    float: the neg log likelihood
  """
  V = params[0] * D + params[1]

  return -jnp.sum(stats.norm.logpdf(Uty, loc=0., scale=V))

# let's use gradient descent to infer h2g
nll_vandg = jax.jit(jax.value_and_grad(normal_h2g_likelihood_fast))
step_size = 1e-1
loss = 10000
max_iter = 100
tol = 1e-3

# init params
params = 0.5 * jnp.ones(2)
Uty = U.T @ y
for idx in range(max_iter):
  loss_i, nllgrad = nll_vandg(params, Uty, D)
  print(f"Iter = {idx} | Params = {params} | nLL = {loss_i}")
  params = params - step_size * nllgrad
  # keep parameters in valid variance space
  params = jnp.where(params < 0, 0.001, params)
  if jnp.fabs(loss_i - loss) < tol:
    break
  loss = loss_i

print(f"Var components = {params}")
print(f"h2g = {params[0] / sum(params)}")

Iter = 0 | Params = [0.5 0.5] | nLL = 7655.494771578404
Iter = 1 | Params = [ 12.84174073 331.76473174] | nLL = 33805.01263084504
Iter = 2 | Params = [ 11.41642052 330.31282274] | nLL = 33763.42653529282
Iter = 3 | Params = [  9.97655877 328.84888424] | nLL = 33721.06550475145
Iter = 4 | Params = [  8.52169326 327.3725797 ] | nLL = 33677.898562023955
Iter = 5 | Params = [  7.05133635 325.88355566] | nLL = 33633.89277097222
Iter = 6 | Params = [  5.56497294 324.38144044] | nLL = 33589.013063800645
Iter = 7 | Params = [  4.06205817 322.86584285] | nLL = 33543.22204853149
Iter = 8 | Params = [  2.54201501 321.33635054] | nLL = 33496.4797938436
Iter = 9 | Params = [  1.00423143 319.79252841] | nLL = 33448.743587957724
Iter = 10 | Params = [1.00000000e-03 3.18233917e+02] | nLL = 33408.66656057784
Iter = 11 | Params = [1.00000000e-03 3.16662766e+02] | nLL = 33383.92027398248
Iter = 12 | Params = [1.0000000e-03 3.1508382e+02] | nLL = 33358.92720577229
Iter = 13 | Params = [1.00000000e-03 3.13

In [36]:
beta = (X.T @ y) / len(y)
g = X @ beta

Array([-0.0181527 , -0.01034132,  0.00923672, ..., -0.02430825,
       -0.01662116,  0.00837149], dtype=float64)