 Copyright 2023 Google LLC.
   Licensed under the Apache License, Version 2.0 (the "License");
   you may not use this file except in compliance with the License.
   You may obtain a copy of the License at
       http://www.apache.org/licenses/LICENSE-2.0
   Unless required by applicable law or agreed to in writing, software
   distributed under the License is distributed on an "AS IS" BASIS,
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
   See the License for the specific language governing permissions and
   limitations under the License.

#bbVIPRS Implementation Example

### Author: nickfurlotte@google.com
### Date: <INSERT DATE>

In this colab, we provide an example implementation of black box variational
inference for PRS (bbviPRS).

## Define constants

In [None]:
# Whether to use TPUs in the colab. If False, then will default to CPU.
USE_TPU = True
# The p-value threshold bbviprs uses to select SNPs from the sumstats results.
BBVIPRS_PVALUE_THRESH = 1e-2
# How many optimization steps to take.
BBVIPRS_OPTIMIZATION_STEPS = 200
# Learning rate for the bbviprs optimization.
BBVIPRS_LEARNING_RATE = 0.001

## Installs

In [None]:
!pip install bed-reader

## Imports

In [None]:
import contextlib 
from typing import Callable, Optional, Tuple
import warnings

import numpy as np
import pandas as pd
import bed_reader
import tensorflow as tf
import tensorflow_probability as tfp
import matplotlib.pyplot as plt

tfd = tfp.distributions
tfpl = tfp.layers

warnings.simplefilter('ignore')

## Initialize TPU

In [None]:
if USE_TPU:
  resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')
  tf.config.experimental_connect_to_cluster(resolver)
  # This is the TPU initialization code that has to be at the beginning.
  tf.tpu.experimental.initialize_tpu_system(resolver)
  print('All devices: ', tf.config.list_logical_devices('TPU'))

## Download input files used in LDPred2 tutorial

Find the [tutorial here](https://privefl.github.io/bigsnpr/articles/LDpred2.html)

In [None]:
!wget https://github.com/privefl/bigsnpr/raw/master/data-raw/public-data3.zip
!unzip public-data3.zip

## Process data for analysis

We follow processing steps similar to the LDPred2 tutorial referenced above.
This essentially means that we harmonize the data between the sumstats
and the genotype data by matching the SNP sets and making sure that
SNPs and effect sizes match in direction. In addition, we perform a simple
Z-score normalization on the SNP matrix.

### Read bed and grab SNP matrix and phenotype

In [None]:
bed_file = bed_reader.open_bed('tmp-data/public-data3.bed')
snp_matrix = np.array(bed_file.read())
pheno = bed_file.pheno.astype(float)
print(snp_matrix.shape)

### Read sumstat file and reorder data so that SNPs match

In [None]:
sumstats = pd.read_csv('tmp-data/public-data3-sumstats.txt')
sumstats.head()

In [None]:
common_rsid = np.intersect1d(bed_file.sid, sumstats.rsid)
sumstat_common = sumstats.set_index('rsid').loc[common_rsid].reset_index()
snp_matrix_common = pd.DataFrame(snp_matrix, columns=bed_file.sid)[
    common_rsid
].values

assert sumstat_common.shape[0] == snp_matrix_common.shape[1]

### Check and fix SNP direction

In [None]:
# Note: In BED files the het for the first allele is encoded as 1.
bed_snp_order = (
    pd.DataFrame(
        [bed_file.allele_1, bed_file.allele_2],
        columns=bed_file.sid,
        index=['a1', 'a0'],
    ).T.loc[common_rsid]
)[['a0', 'a1']]
sumstat_snp_order = sumstat_common[['rsid', 'a0', 'a1']].set_index('rsid')

# Sanity check that the rows are in the same order.
assert np.all(bed_snp_order.index == sumstat_snp_order.index)

### Check that we see the same numbers as reported by the LDPred2 tutorial.

"45,337 variants have been matched; 22,758 were flipped and 15,092 were reversed."

In [None]:
_FLIP_MAP = {'A': 'T', 'T': 'A', 'C': 'G', 'G': 'C'}


def check_snps():
  """Evaluate the two SNP encodings to decide if they need to be adjusted."""
  flip_count = 0
  reverse_count = 0
  reverse_rsids = []
  for (rsid, sumstat), (_, bed) in zip(
      sumstat_snp_order.iterrows(), bed_snp_order.iterrows()
  ):
    if not len(np.intersect1d(sumstat.values, bed.values)):
      flip_count += 1
      bed.a0 = _FLIP_MAP[bed.a0]
      bed.a1 = _FLIP_MAP[bed.a1]
    if (sumstat.values == bed.values).sum() != 2:
      reverse_count += 1
      reverse_rsids.append(rsid)
  return flip_count, reverse_count, reverse_rsids


flip_count, reverse_count, reverse_rsids = check_snps()

print(f'Flip count: {flip_count}, Reverse count: {reverse_count}')

Cool those numbers match. Flipping doesn't change the direction of effect
only reversing does, so we kept up with the SNPs that need to be reversed.
Then we simply change the direction of their effects in the sumstats.

In [None]:
reverse_mask = np.in1d(sumstat_common.rsid, reverse_rsids)
sumstat_common = sumstat_common.assign(
    beta=lambda d: np.where(reverse_mask, -d.beta, d.beta)
)

### Perform simple normalization on SNP matrix

In [None]:
snp_mean, snp_sd = (
    snp_matrix_common.mean(axis=0),
    np.std(snp_matrix_common, axis=0),
)
snps = (snp_matrix_common - snp_mean) / snp_sd

# Should look like 0,1 or close.
snps.mean(axis=0).mean(), snps.var(axis=1).mean()

Usually you would also do QC but not required here.

## Prep data for running bbviPRS

### For bbviPRS-select we will set a p-value threshold and only fit the model over those SNPs.

In [None]:
snp_mask = sumstat_common.p <= BBVIPRS_PVALUE_THRESH
num_snps = snp_mask.sum()
print(f'Total number of SNPs selected for VI: {num_snps}')

In [None]:
snp_betas = sumstat_common.beta[snp_mask]
snp_betas_se = sumstat_common.beta_se[snp_mask]

In [None]:
snp_matrix = snps[:, snp_mask]
print(snp_matrix.shape)

### Compute LD matrix

In [None]:
ld_matrix = np.corrcoef(snp_matrix.T)
print(ld_matrix.shape)

## bbviPRS

### A little setup

In [None]:
betas_normed = snp_betas / snp_betas_se
gwas_betas = tf.convert_to_tensor(betas_normed, dtype=tf.float32)
ld_matrix_tensor = tf.convert_to_tensor(ld_matrix, dtype=tf.float32)

### Define our prior and posterior functions

In [None]:
def get_vi_mixture_prior(
    size: int, scale1: float, scale2: float, mixture_prob: float
) -> tf.keras.Model:
  """Create a mixture of normals prior with two components.

  Args:
    size: The number of SNP effect sizes.
    scale1: The standard deviation of the first normal component.
    scale2: The standard deviation of the second normal component.
    mixture_prob: The probability that an effect comes from distribution one.

  Returns:
    Returns a keras model.
  """
  num_components = 2
  probs = np.array(
      [[mixture_prob, 1.0 - mixture_prob] for i in range(size)]
  ).astype(np.float32)
  locs = np.array([np.zeros((num_components,)) for i in range(size)]).astype(
      np.float32
  )
  scales = np.array([[scale1, scale2] for i in range(size)]).astype(np.float32)

  def build_distribution(_):
    return tfd.Independent(
        tfd.MixtureSameFamily(
            mixture_distribution=tfd.Categorical(probs=probs),
            components_distribution=tfd.Normal(loc=locs, scale=scales),
        ),
        reinterpreted_batch_ndims=1,
    )

  return tf.keras.Sequential(
      [tfp.layers.DistributionLambda(build_distribution)]
  )

The function above returns a `keras` model, which when called returns a distribution object. The distribution in this case is a mixture where
`mixture_prob` of the effects come from a Normal with mean zero and variance, `scale_1`, while the `1-mixture_prob` variants come from a Normal with
mean zero and variance `scale_2`. We can use this to mimic the behavior of
the standard LDPred mixture prior. Here is an example.

It is a little strange to return a model that returns a distribution, but we do this to make the larger model building easier.

In [None]:
prior_model = get_vi_mixture_prior(num_snps, 0.5, 1e-7, 0.20)
sample_beta = prior_model(0).sample()

_ = plt.figure(figsize=(10, 8))
_ = plt.hist(sample_beta.numpy(), 100)
_ = plt.title(
    'Sample from a mixture prior where 80% of the effects are '
    'clustered around zero'
)

In the above plot, you can see that the majority of effects are effectively zero,
while the minority is spread far away from zero. As a result of this
induced distribution, the prior will enforce sparsity of non-zero effects.

In [None]:
def get_vi_posterior(
    size: int, beta_init: Optional[tf.Tensor] = None
) -> tf.keras.Model:
  """Create the $N(m, K)$ posterior used for VI inference.

  Args:
    size: The number of SNP effect sizes.

  Returns:
    Returns a keras model.
  """
  if beta_init is None:
    beta_init = tf.zeros(size)

  def build_distribution(t):
    return tfd.Independent(
        tfd.Normal(loc=t[..., 0], scale=tf.math.softplus(t[..., 1])),
        reinterpreted_batch_ndims=1,
    )

  return tf.keras.Sequential([
      tfp.layers.VariableLayer(
          shape=[size, 2],
          dtype=tf.float32,
          initializer=tfp.layers.BlockwiseInitializer(
              [
                  tf.keras.initializers.Constant(beta_init),
                  tf.keras.initializers.Constant(np.log(np.expm1(1.0))),
              ],
              sizes=[1, 1],
          ),
      ),
      tfp.layers.DistributionLambda(build_distribution),
  ])

In [None]:
posterior_model = get_vi_posterior(num_snps)
sample_beta = posterior_model(0).sample()

_ = plt.figure(figsize=(10, 8))
_ = plt.hist(sample_beta.numpy(), 100)
_ = plt.title('Sample from an initial surrogate posterior that is Normal(m, V)')

### Create the model for optimization.

We first create a `keras.Layer` to compute the loss function and then we wrap
that inside of a larger `keras.Model` object. Again, this seems a little odd,
but this setup makes it easier to expand the model later to incorporate more
complexities (such as additional data sources).

In [None]:
class ViLayer(tf.keras.layers.Layer):

  def __init__(
      self,
      size: int,
      prior_mixture: float,
      prior_scale_1: float,
      prior_scale_2: float,
      beta_init: Optional[tf.Tensor] = None,
  ):
    super(ViLayer, self).__init__()
    self.prior_model = get_vi_mixture_prior(
        size, prior_scale_1, prior_scale_2, prior_mixture
    )
    self.posterior_model = get_vi_posterior(size, beta_init)
    self.log_prob_norm = tf.convert_to_tensor(1.0 / size, dtype=tf.float32)
    self.size = size

  def loss_fn(self):
    """Implements a mean squared error loss with a VI objective.

    We use a generative model approach to optimize the posterior distribution.
    Assume that $\beta$ is sampled from Prior(...) and then $\beta_tilde$ is
    computed
    as $\beta_tilde = tf.dot(LD_Matrix, \beta)$. To optimize the posterior,
    we generate a $\beta_tilde$ and then compute the loss as
    $MSE(\beta_tilde, \beta_from_gwas) + $ Variational Inference Objective,
    where the VI objective is the KL divergence between the prior and posterior.
    """
    prior = self.prior_model(0)
    posterior = self.posterior_model(0)
    beta_sample = posterior.sample()
    beta_transform = tf.tensordot(ld_matrix_tensor, beta_sample, 1)
    kl_samples = posterior.sample(_NUM_KL_APPROX_SAMPLES)
    log_prob_norm = tf.convert_to_tensor(1.0 / self.size, dtype=tf.float32)
    loss = tf.divide(
        tf.reduce_sum(tf.square(tf.subtract(gwas_betas, beta_transform))),
        self.size,
    )
    return loss + tf.multiply(
        _KL_WEIGHT,
        tf.multiply(
            self.log_prob_norm,
            tf.reduce_mean(
                posterior.log_prob(kl_samples) - prior.log_prob(kl_samples)
            ),
        ),
    )

  def call(self, inputs):
    self.add_loss(self.loss_fn())
    return self.posterior_model(0).mean()

We found it useful to initialize the posterior mean with a reasonable value.
One obvious choice is the infinitesimal model solution.

In [None]:
def compute_inf_model_posterior(
    ld_matrix: tf.Tensor, beta: tf.Tensor, scale: float, diag: bool = False
) -> Tuple[tf.Tensor, tf.Tensor]:
  r"""Compute mean and variance for the infinitesimal model posterior.

  Given a GWAS effect size vector beta and assuming a normal prior on the
  true underlying effect size ($N(0, scale**2)$), we can compute the posterior
  P(SNP effect sizes | GWAS effect sizes) as $N(\mu, \Sigma)$, where
  $\mu = \beta$ and $\Sigma = (LDMatrix + I*1/scale**2)^{-1}$.

  Args:
    ld_matrix: The NxN matrix of SNP correlations.
    beta: A vector of length N representing the GWAS effects.
    scale: The stddev of the prior distribution - N(0, scale**2).
    diag: Whether the LDMatrix is a diagonal matrix.

  Returns:
    The mean and variance of the analytical posterior.  The mean is a tensor
    of shape [N], and the variance is of shape [N, N].
  """
  size = ld_matrix.shape[0]
  # Sometimes beta is of shape [N], sometimes it is of shape [N, 1].
  beta = tf.squeeze(beta)
  if diag:
    post_var = tf.convert_to_tensor(
        1.0 / (1.0 + 1 / scale**2), dtype=tf.float32
    )
    post_mean = post_var * beta
  else:
    post_var = tf.linalg.inv(
        (ld_matrix + tf.linalg.diag(tf.ones(size) * 1 / scale**2))
    )
    post_mean = tf.linalg.matvec(post_var, beta)
  return post_mean, post_var

Since our dataset is small we can easily look at a range of shrinkage parameters.

In [None]:
sumstat_samples_size = sumstat_common['N'][0]

for h in [1e-2, 1e-1, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 1.0]:
  scale = np.sqrt(num_snps / (sumstat_samples_size * h))

  inf_post_mean, inf_post_var = compute_inf_model_posterior(
      ld_matrix_tensor, gwas_betas, scale=scale.astype(np.float32)
  )
  beta_init = inf_post_mean

  snp_effects = (beta_init * snp_betas_se)[:, None]
  prs = np.dot(snp_matrix, snp_effects).flatten()
  pearson_corr = np.corrcoef(prs, pheno)[0, 1]

  print(f'Pearson correlation for inf model (h={h}): {pearson_corr: 0.4f}.')

### Optimization params

In [None]:
_KL_WEIGHT = tf.convert_to_tensor(1.0 / num_snps, dtype=tf.float32)
_NUM_KL_APPROX_SAMPLES = 50

scales_1 = np.arange(1e-5, 1.0, step=0.05)
scale_2 = 1e-7
mixture_probs = np.logspace(-5, -2, 5)

print(f'Total number of models to fit: {len(scales_1) * len(mixture_probs)}.')

### Training Loop

In this case, we don't have any input to the model, so we create a fake input to
bypass. Additionally, our loss function is completely encapsulated in the
`ViLayer`, so we provide a loss that returns zero. Again this setup is useful
if you wanted to expand functionality.

In [None]:
mock_input = tf.keras.layers.Input(shape=(1,))
features = [[1.0]]
labels = [[1.0]]
mock_training_data = (
    tf.data.Dataset.from_tensor_slices((features, labels)).cache().repeat()
)
results = []


def _create_model(num_snps, scale_1, scale_2, mixture_prob, beta_init):
  layer = ViLayer(num_snps, scale_1, scale_2, mixture_prob, beta_init)
  return layer, tf.keras.Model(inputs=mock_input, outputs=layer(mock_input))


@contextlib.contextmanager
def empty_context():
  yield


strategy = tf.distribute.TPUStrategy(resolver) if USE_TPU else None
# We need a context if running on TPU, otherwise just use a no-op context.
context = strategy.scope if strategy else empty_context

for scale_1 in scales_1:
  for mixture_prob in mixture_probs:
    print(
        f'Fitting model scale_1={scale_1}, scale_2={scale_2}, '
        f'mixture_prob={mixture_prob}.'
    )
    with context() as _:
      tf.random.set_seed(1234)
      layer, mdl = _create_model(
          num_snps, scale_1, scale_2, mixture_prob, beta_init
      )
      opt = tf.keras.optimizers.Adam(learning_rate=BBVIPRS_LEARNING_RATE)
      mdl.compile(
          optimizer=opt, loss=lambda _, __: 0.00, steps_per_execution=50
      )
    mdl.fit(
        mock_training_data, epochs=1, steps_per_epoch=BBVIPRS_OPTIMIZATION_STEPS
    )

    snp_effects = (layer.posterior_model(0).mean() * snp_betas_se)[:, None]
    prs = np.dot(snp_matrix, snp_effects).flatten()
    pearson_corr = np.corrcoef(prs, pheno)[0, 1]
    results.append((scale_1, scale_2, mixture_prob, pearson_corr))
    print(f'Pearson correlation : {pearson_corr: 0.4f}.')

#### Compile results

The result dataframe has one row for each of the models (100 or so) fit above.
It shows the hyperparams `scale_1`, `scale_2` and `mixture_prob` used to 
define the prior and the Pearson's correlation estimate that the model
achieved. Note that we are computing the Pearson's correlation in the training
sample, so this could be an overestimate. 

In the LDPred2 tutorial, we see that they achieved a Pearson's correlation
of about 0.49. Given that bbviPRS tends to underperform relative to LDPred2
in some cases, this discrepancy probably makes sense. But it is hard to say
since we don't know much about the input data. As a result, this result is 
really just a proof of principle to illustrate how to implement and run
bbviPRS in a way that is similar to LDPred.

In [None]:
full_result_df = pd.DataFrame(
    results, columns=['scale_1', 'scale_2', 'mixture_prob', 'pearson']
).sort_values(by='pearson', ascending=False)
full_result_df.head()

In [None]:
'''
	scale_1	scale_2	mixture_prob	pearson
47	0.45001	1.000000e-07	0.000316	0.445777
57	0.55001	1.000000e-07	0.000316	0.443240
26	0.25001	1.000000e-07	0.000056	0.441404
62	0.60001	1.000000e-07	0.000316	0.441263
77	0.75001	1.000000e-07	0.000316	0.440966
'''