### Engineering population scale GWAS

Genome-wide association studies (GWAS) correlate genetic variants with a trait or disease of interest.

As cohorts have increased in size to millions, there is a need to robustly engineer GWAS to work at scale.
To that end, we have developed an extensible Spark-native implementation of GWAS using Glow.

This notebook leverages the high performance big-data store [Delta Lake](https://delta.io) and uses [mlflow](https://mlflow.org/) to log parameters, metrics and plots associated with each run.

In [0]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import pyspark.sql.functions as fx
from pyspark.sql.types import StringType
from pyspark.ml.linalg import Vector, Vectors, SparseVector, DenseMatrix
from pyspark.ml.stat import Summarizer
from pyspark.mllib.linalg.distributed import RowMatrix
from pyspark.mllib.util import MLUtils

from dataclasses import dataclass

import mlflow
import glow
glow.register(spark)

#### Parameters

In [0]:
allele_freq_cutoff = 0.05
num_pcs = 5 #number of principal components
mlflow.log_param("minor allele frequency cutoff", allele_freq_cutoff)
mlflow.log_param("principal components", num_pcs)

#### Paths

In [0]:
vcf_path = "/databricks-datasets/genomics/1kg-vcfs/ALL.chr22.phase3_shapeit2_mvncall_integrated_v5a.20130502.genotypes.vcf.gz"
phenotype_path = "/databricks-datasets/genomics/1000G/phenotypes.normalized"
sample_info_path = "/databricks-datasets/genomics/1000G/samples/populations_1000_genomes_samples.csv"

delta_silver_path = "/mnt/gwas_test/snps.delta"
delta_gold_path = "/mnt/gwas_test/snps.qced.delta"
principal_components_path = "/dbfs/mnt/gwas_test/pcs.csv"
gwas_results_path = "/mnt/gwas_test/gwas_results.delta"

#### Helper functions

In [0]:
def plot_layout(plot_title, plot_style, xlabel):
  plt.style.use(plot_style) #e.g. ggplot, seaborn-colorblind, print(plt.style.available)
  plt.title(plot_title)
  plt.xlabel(r'${0}$'.format(xlabel))
  plt.gca().spines['right'].set_visible(False)
  plt.gca().spines['top'].set_visible(False)
  plt.gca().yaxis.set_ticks_position('left')
  plt.gca().xaxis.set_ticks_position('bottom')
  plt.tight_layout()
  
def plot_histogram(df, col, xlabel, xmin, xmax, nbins, plot_title, plot_style, color, vline, out_path):
  plt.close()
  plt.figure()
  bins = np.linspace(xmin, xmax, nbins)
  df = df.toPandas()
  plt.hist(df[col], bins, alpha=1, color=color)
  if vline:
    plt.axvline(x=vline, linestyle='dashed', linewidth=2.0, color='black')
  plot_layout(plot_title, plot_style, xlabel)
  plt.savefig(out_path)
  plt.show()
  
def calculate_pval_bonferroni_cutoff(df, cutoff=0.05):
  bonferroni_p =  cutoff / df.count()
  return bonferroni_p

def get_sample_info(vcf_df, sample_metadata_df):
  """
  get sample IDs from VCF dataframe, index them, then join to sample metadata dataframe
  """
  sample_id_list = vcf_df.limit(1).select("genotypes.sampleId").collect()[0].__getitem__("sampleId")
  sample_id_indexed = spark.createDataFrame(sample_id_list, StringType()). \
                            coalesce(1). \
                            withColumnRenamed("value", "Sample"). \
                            withColumn("index", fx.monotonically_increasing_id())
  sample_id_annotated = sample_id_indexed.join(sample_metadata_df, "Sample")
  return sample_id_annotated

### Ingest 1000 Genomes VCF into Delta Lake

Using Glow's VCF reader, which enables variant call format (VCF) files to be read as a Spark Datasource directly from cloud storage,
write genotype data into Delta Lake, a high performance big data store with ACID semantics.
Delta Lake organizes, indexes and compresses data, allowing for performant and reliable computation on genomics data as it grows over time.

In [0]:
vcf_view_unsplit = spark.read.format("vcf"). \
   option("flattenInfoFields", "false"). \
   load(vcf_path)

Split multiallelics varaints to biallelics

In [0]:
vcf_view = glow.transform("split_multiallelics", vcf_view_unsplit)

In [0]:
display(vcf_view.withColumn("genotypes", fx.col("genotypes")[0]))

contigName,start,end,names,referenceAllele,alternateAlleles,qual,filters,splitFromMultiAllelic,attributes,INFO_OLD_MULTIALLELIC,genotypes
22,16050074,16050075,List(rs587697622),A,List(G),100.0,List(PASS),False,"Map(AC -> 1, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0.001, AA -> .|||, AF -> 0.000199681, EAS_AF -> 0, AMR_AF -> 0, DP -> 8012, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050114,16050115,List(rs587755077),G,List(A),100.0,List(PASS),False,"Map(AC -> 32, NS -> 2504, AFR_AF -> 0.0234, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.00638978, EAS_AF -> 0, AMR_AF -> 0.0014, DP -> 11468, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050212,16050213,List(rs587654921),C,List(T),100.0,List(PASS),False,"Map(AC -> 38, NS -> 2504, AFR_AF -> 0.0272, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.00758786, EAS_AF -> 0, AMR_AF -> 0.0014, DP -> 15092, EUR_AF -> 0.001)",,"List(HG00096, true, List(0, 0))"
22,16050318,16050319,List(rs587712275),C,List(T),100.0,List(PASS),False,"Map(AC -> 1, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.000199681, EAS_AF -> 0, AMR_AF -> 0.0014, DP -> 22609, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050526,16050527,List(rs587769434),C,List(A),100.0,List(PASS),False,"Map(AC -> 1, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.000199681, EAS_AF -> 0, AMR_AF -> 0, DP -> 23591, EUR_AF -> 0.001)",,"List(HG00096, true, List(0, 0))"
22,16050567,16050568,List(rs587638893),C,List(A),100.0,List(PASS),False,"Map(AC -> 2, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.000399361, EAS_AF -> 0.002, AMR_AF -> 0, DP -> 21258, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050606,16050607,List(rs587720402),G,List(A),100.0,List(PASS),False,"Map(AC -> 5, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0, AA -> .|||, AF -> 0.000998403, EAS_AF -> 0, AMR_AF -> 0.0014, DP -> 20274, EUR_AF -> 0.004)",,"List(HG00096, true, List(0, 0))"
22,16050626,16050627,List(rs587593704),G,List(T),100.0,List(PASS),False,"Map(AC -> 2, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0.001, AA -> .|||, AF -> 0.000399361, EAS_AF -> 0, AMR_AF -> 0.0014, DP -> 21022, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050645,16050646,List(rs587670191),G,List(T),100.0,List(PASS),False,"Map(AC -> 1, NS -> 2504, AFR_AF -> 0, VT -> SNP, AN -> 5008, SAS_AF -> 0.001, AA -> .|||, AF -> 0.000199681, EAS_AF -> 0, AMR_AF -> 0, DP -> 22073, EUR_AF -> 0)",,"List(HG00096, true, List(0, 0))"
22,16050653,16063474,"List(esv3647175, esv3647176, esv3647177, esv3647178)",A,List(),100.0,List(PASS),True,"Map(AC -> 9,87,599,20, NS -> 2504, AFR_AF -> 0.0061,0.0363,0.0053,0, VT -> SV, AN -> 5008, SAS_AF -> 0,0.0082,0.1094,0.002, AF -> 0.00179712,0.0173722,0.119609,0.00399361, EAS_AF -> 0.001,0.0169,0.2361,0.0099, AMR_AF -> 0,0.0101,0.219,0.0072, DP -> 22545, CS -> DUP_gs, END -> 16063474, EUR_AF -> 0,0.007,0.0944,0.003, SVTYPE -> CNV)",22:16050654:A////,"List(HG00096, true, List(-1, 0))"


##### Note: we compute variant-wise summary stats and Hardy-Weinberg equilibrium P values using `call_summary_stats` & `hardy_weinberg`, which are built into Glow

In [0]:
vcf_view. \
  select(
    fx.expr("*"),
    glow.expand_struct(glow.call_summary_stats(fx.col("genotypes"))),
    glow.expand_struct(glow.hardy_weinberg(fx.col("genotypes")))
  ). \
  write. \
  mode("overwrite"). \
  format("delta"). \
  save(delta_silver_path)

Since metadata associated with Delta Lake are stored directly in the transaction log, we can quickly calculate the size of the Delta Lake and log it to MLflow

In [0]:
num_variants = spark.read.format("delta").load(delta_silver_path).count()
mlflow.log_metric("Number Variants pre-QC", num_variants)

#### Run Quality Control

Perform variant-wise filtering on Hardy-Weinberg equilibrium P-values and allele frequency

In [0]:
hwe = spark.read.format("delta"). \
                 load(delta_silver_path). \
                 where((fx.col("alleleFrequencies").getItem(0) >= allele_freq_cutoff) & 
                       (fx.col("alleleFrequencies").getItem(0) <= (1.0 - allele_freq_cutoff))). \
                 withColumn("log10pValueHwe", fx.when(fx.col("pValueHwe") == 0, 26).otherwise(-fx.log10(fx.col("pValueHwe"))))

In [0]:
hwe_cutoff = calculate_pval_bonferroni_cutoff(hwe)
mlflow.log_param("Hardy-Weinberg P value cutoff", hwe_cutoff)

In [0]:
display(plot_histogram(df=hwe.select("log10pValueHwe"), 
                       col="log10pValueHwe",
                       xlabel='-log_{10}(P)',
                       xmin=0, 
                       xmax=25, 
                       nbins=50, 
                       plot_title="hardy-weinberg equilibrium", 
                       plot_style="ggplot",
                       color='#e41a1c',
                       vline = -np.log10(hwe_cutoff),
                       out_path = "/databricks/driver/hwe.png"
                      )
       )

In [0]:
mlflow.log_artifact("/databricks/driver/hwe.png")

In [0]:
spark.read.format("delta"). \
   load(delta_silver_path). \
   where((fx.col("alleleFrequencies").getItem(0) >= allele_freq_cutoff) & 
         (fx.col("alleleFrequencies").getItem(0) <= (1.0 - allele_freq_cutoff)) &
         (fx.col("pValueHwe") >= hwe_cutoff)). \
   write. \
   mode("overwrite"). \
   format("delta"). \
   save(delta_gold_path)

In [0]:
num_variants = spark.read.format("delta").load(delta_gold_path).count()
mlflow.log_metric("Number Variants post-QC", num_variants)

#### Run Principal Component Analysis (PCA)

To control for ancestry in the GWAS

Note: `array_to_sparse_vector` is a function built into Glow

In [0]:
vectorized = spark.read.format("delta"). \
                        load(delta_gold_path). \
                        select(glow.array_to_sparse_vector(glow.genotype_states(fx.col("genotypes"))).alias("features")). \
                        cache()

#### Use `pyspark.ml` to calculate principal components on sparse vector

In [0]:
matrix = RowMatrix(MLUtils.convertVectorColumnsFromML(vectorized, "features").rdd.map(lambda x: x.features))
pcs = matrix.computeSVD(num_pcs)

In [0]:
pd.DataFrame(pcs.V.toArray()).to_csv(principal_components_path)

#### Read sample information in and plot out principal components

Here we annotate sample info with principal components by joining both DataFrames on index

Note: indexing is performed using the Spark SQL function `monotonically_increasing_id()`

In [0]:
pcs_df = spark.createDataFrame(pcs.V.toArray().tolist(), ["pc" + str(i) for i in range(num_pcs)])

In [0]:
sample_metadata = spark.read.option("header", True).csv(sample_info_path)
sample_info = get_sample_info(vcf_view, sample_metadata)
sample_count = sample_info.count()
mlflow.log_param("number of samples", sample_count)
pcs_indexed = pcs_df.coalesce(1).withColumn("index", fx.monotonically_increasing_id())
pcs_with_samples = pcs_indexed.join(sample_info, "index")

#### Use the display function to create a scatter plot of pc1 and pc2

Note: because we are only analyzing chromosome 22

the PCA scatter plot does not distinguish populations as well as the whole genome data

In [0]:
display(pcs_with_samples)

#### Prepare data for GWAS

In [0]:
bmiPhenotype = spark.read. \
                     format("parquet"). \
                     load(phenotype_path). \
                     withColumnRenamed("values", "phenotype_values")

In [0]:
display(bmiPhenotype.select(fx.explode(fx.col("phenotype_values")).alias("bmi")))

In [0]:
covariate_df = pd.read_csv(principal_components_path)

In [0]:
phenotype_df = bmiPhenotype.toPandas(). \
  explode('phenotype_values'). \
  reset_index(drop=True). \
  pivot(columns='phenotype', values='phenotype_values')

In [0]:
phenotype = phenotype_df.columns[0]
mlflow.log_param("phenotype", phenotype)

In [0]:
genotypes = spark.read.format("delta").load(delta_gold_path)

#### Run `linear_regression`

Note: `genotype_states` is a utility function in Glow that converts an genotypes array, e.g. `[0,1]` into an integer containing the number of alternate alleles, e.g. `1`

In [0]:
results = glow.gwas.linear_regression(
  genotypes.select('contigName', 'start', 'genotypes'),
  phenotype_df,
  covariate_df,
  values_column=glow.genotype_states(fx.col('genotypes'))
)

results.write. \
  format("delta"). \
  mode("overwrite"). \
  save(gwas_results_path)

#### Show results

In [0]:
display(spark.read.format("delta").load(gwas_results_path).limit(100))

#### Load GWAS results into R and plot using `qqman` library

In [0]:
%r
library(SparkR)
gwas_df <- read.df("/mnt/gwas_test/gwas_results.delta", source="delta")
gwas_results <- select(gwas_df, c(cast(alias(gwas_df$contigName, "CHR"), "double"), alias(gwas_df$start, "BP"), alias(gwas_df$pValue, "P")))
gwas_results_rdf <- as.data.frame(gwas_results)

In [0]:
%r
install.packages("qqman", repos="http://cran.us.r-project.org")
library(qqman)

In [0]:
%r
png('/databricks/driver/manhattan.png')
manhattan(gwas_results_rdf, 
          col = c("#228b22", "#6441A5"), 
          chrlabs = NULL,
          suggestiveline = -log10(1e-05), 
          genomewideline = -log10(5e-08),
          highlight = NULL, 
          logp = TRUE, 
          annotatePval = NULL, 
          ylim=c(0,17))
dev.off()

In [0]:
%r
manhattan(gwas_results_rdf, col = c("#228b22", "#6441A5"), chrlabs = NULL,
suggestiveline = -log10(1e-05), genomewideline = -log10(5e-08),
highlight = NULL, logp = TRUE, annotatePval = NULL, ylim=c(0,17))

In [0]:
mlflow.log_artifact('/databricks/driver/manhattan.png')

In [0]:
%r
png('/databricks/driver/qqplot.png')
qq(gwas_results_rdf$P)
dev.off()

In [0]:
%r
qq(gwas_results_rdf$P)

In [0]:
mlflow.log_artifact('/databricks/driver/qqplot.png')

#### Clean up

In [0]:
dbutils.fs.rm("dbfs:/mnt/gwas_test", True)