<div style="background-color:lightgreen; padding: 10px; font-size: 24px;">
    
__PRS Calculator:__ Hail MatrixTable
</div>

<div style="background-color:lightgrey; padding: 10px;; font-size: 18px;">  
    
__Author:__ Ahmed Khattab  
        __Scripps Research__
    
</div>

<div style="background-color:lightblue; padding: 10px; font-size: 16px;"> 
    
__Introduction__

In this notebook, we will demonstrate how to calculate Polygenic Risk Scores (PRS) using the Hail MatrixTable (MT) data structure.


__Using Hail MT__ 

Hail is a scalable genomic data analysis framework that provides powerful tools for genetic analysis, including PRS calculation.

__Resources used?__   


Cost when running: $72.91 per hour  

Main node: 4CPUs, 15GB RAM, 150 GB Disk  
Workers (300): 4CPUs, 15GB RAM, 150GB Disk   

Time and Cost:  __$49.8/  38min__  

In [1]:
import time
import datetime

# Get the current date and time
start_time = datetime.datetime.now()

# Record the start time
current_date = start_time.date()
current_time = start_time.time()

# Format the current date
formatted_start_date = current_date.strftime("%Y-%m-%d")

# Format the current time
formatted_start_time = current_time.strftime("%H:%M:%S")

# Print the formatted date and time separately
print("Start date:", formatted_start_date)
print("Start time:", formatted_start_time)

Start date: 2024-06-21
Start time: 16:01:31


# Import Hail

In [2]:
import os
import pandas as pd
import numpy as np
from datetime import datetime
import gcsfs
import multiprocessing
import ast
import concurrent.futures

In [3]:
import hail as hl
hl.init(tmp_dir='hail_temp/', default_reference='GRCh38')


Using hl.init with a default_reference argument is deprecated. To set a default reference genome after initializing hail, call `hl.default_reference` with an argument to set the default reference genome.


Reading spark-defaults.conf to determine GCS requester pays configuration. This is deprecated. Please use `hailctl config set gcs_requester_pays/project` and `hailctl config set gcs_requester_pays/buckets`.

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
Running on Apache Spark version 3.3.0
SparkUI available at http://all-of-us-7093-m.c.terra-vpc-sc-e098d676.internal:35153
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.130-bea04d9c79b5
LOGGING: writing to /home/jupyter/workspaces/duplicateoftype2diabetesriskprediction/hail-20240621-1601-0.2.130-bea04d9c79b5.log


# Define Bucket

In [4]:
bucket = os.getenv("WORKSPACE_BUCKET")

# Read Hail MT

In [5]:
mt_wgs_path = os.getenv("WGS_ACAF_THRESHOLD_MULTI_HAIL_PATH")

In [6]:
mt = hl.read_matrix_table(mt_wgs_path)

In [7]:
mt.count()

(48314438, 245394)

In [8]:
# To reduce the MT size, keep only the GT field (The only field we need for PRS calculation)
mt = mt.select_entries("GT") 

# Drop Flagged srWGS samples

AoU provides a table listing samples that are flagged as part of the sample outlier QC for the srWGS SNP and Indel joint callset. 

__Read more:__ https://support.researchallofus.org/hc/en-us/articles/4614687617556-How-the-All-of-Us-Genomic-data-are-organized#h_01GY7QZR2QYFDKGK89TCHSJSA7

In [9]:
# Read flagged samples
flagged_samples_path = "gs://fc-aou-datasets-controlled/v7/wgs/short_read/snpindel/aux/relatedness/relatedness_flagged_samples.tsv"

In [10]:
!gsutil -u $$GOOGLE_PROJECT cat $flagged_samples_path > flagged_samples.cvs

In [11]:
# Import flagged samples into a hail table
flagged_samples = hl.import_table(flagged_samples_path, key='sample_id')

2024-06-21 16:03:27.022 Hail: INFO: Reading table without type imputation1) / 1]
  Loading field 'sample_id' as type str (not specified)


In [12]:
!gsutil -u $$GOOGLE_PROJECT cat $flagged_samples | head -n 3

/bin/bash: -c: line 0: syntax error near unexpected token `|'
/bin/bash: -c: line 0: `gsutil -u $GOOGLE_PROJECT cat <hail.table.Table object at 0x7f4d00a0d510> | head -n 3'


In [13]:
# Drop flagged sample from main Hail MT
mt = mt.anti_join_cols(flagged_samples)

In [14]:
mt.count()

(48314438, 230019)

# Define The Sample Intended for PRS Calculation

This is a pre-selected sample for all people with WGS and EHR data availabe.

In [15]:
import pandas
import os

# This query represents dataset "participants_with_WGS_EHR_phenotypes_020524" for domain "person" and was generated for All of Us Controlled Tier Dataset v7
dataset_16967016_person_sql = """
    SELECT
        person.person_id,
        person.gender_concept_id,
        p_gender_concept.concept_name as gender,
        person.birth_datetime as date_of_birth,
        person.race_concept_id,
        p_race_concept.concept_name as race,
        person.ethnicity_concept_id,
        p_ethnicity_concept.concept_name as ethnicity,
        person.sex_at_birth_concept_id,
        p_sex_at_birth_concept.concept_name as sex_at_birth 
    FROM
        `""" + os.environ["WORKSPACE_CDR"] + """.person` person 
    LEFT JOIN
        `""" + os.environ["WORKSPACE_CDR"] + """.concept` p_gender_concept 
            ON person.gender_concept_id = p_gender_concept.concept_id 
    LEFT JOIN
        `""" + os.environ["WORKSPACE_CDR"] + """.concept` p_race_concept 
            ON person.race_concept_id = p_race_concept.concept_id 
    LEFT JOIN
        `""" + os.environ["WORKSPACE_CDR"] + """.concept` p_ethnicity_concept 
            ON person.ethnicity_concept_id = p_ethnicity_concept.concept_id 
    LEFT JOIN
        `""" + os.environ["WORKSPACE_CDR"] + """.concept` p_sex_at_birth_concept 
            ON person.sex_at_birth_concept_id = p_sex_at_birth_concept.concept_id  
    WHERE
        person.PERSON_ID IN (
            SELECT
                distinct person_id  
            FROM
                `""" + os.environ["WORKSPACE_CDR"] + """.cb_search_person` cb_search_person  
            WHERE
                cb_search_person.person_id IN (
                    SELECT
                        person_id 
                    FROM
                        `""" + os.environ["WORKSPACE_CDR"] + """.cb_search_person` p 
                    WHERE
                        has_ehr_data = 1 
                ) 
                AND cb_search_person.person_id IN (
                    SELECT
                        person_id 
                    FROM
                        `""" + os.environ["WORKSPACE_CDR"] + """.cb_search_person` p 
                    WHERE
                        has_whole_genome_variant = 1 
                ) 
            )"""

dataset_16967016_person_df = pandas.read_gbq(
    dataset_16967016_person_sql,
    dialect="standard",
    use_bqstorage_api=("BIGQUERY_STORAGE_API_ENABLED" in os.environ),
    progress_bar_type="tqdm_notebook")

Downloading:   0%|          | 0/206173 [00:00<?, ?rows/s]

In [16]:
dataset_16967016_person_df['person_id'].nunique()

206173

In [17]:
# drop flagged sample
flag_s = pd.read_csv('flagged_samples.cvs')

In [18]:
flag_s.shape

(15375, 1)

In [19]:
flag_s.rename(columns={'sample_id': 'person_id'}, inplace=True)

In [20]:
# Merge the two DataFrames on person_id column
aou_ids = pd.merge(dataset_16967016_person_df[['person_id']], flag_s, how='left', indicator=True)

# Filter to keep only rows where the merge indicator is left-only (i.e., rows present only in dataset_16967016_person_df[['person_id']])
aou_ids = aou_ids[aou_ids['_merge'] == 'left_only']

# Drop the merge indicator column
aou_ids.drop(columns=['_merge'], inplace=True)

In [21]:
aou_ids.shape

(193835, 1)

In [22]:
# Convert the subset_sample_ids to a Python set
subset_sample_ids_set = set(map(str, aou_ids['person_id'].tolist()))

In [23]:
len(subset_sample_ids_set)

193835

In [24]:
# Filter samples: keep only the 100
mt = mt.filter_cols(hl.literal(subset_sample_ids_set).contains(mt.s))

In [25]:
mt.count()

(48314438, 193835)

# Prepare PRS Weight Table

We are using PGS004859 (www.pgscatalog.org/score/PGS004859/)

In [26]:
# read table
prs_df = pd.read_csv('prs_scores/PGS004859_Deutsch_AJ_PRS_1108235_Type_2_diabetes_T2D_Diabetes_Care_2023.GRCh37_to_GRCh38.csv')

# change columns names to fit Hail
prs_df['contig'] = 'chr' + prs_df['chr'].astype(str)
prs_df['position'] = prs_df['bp']
prs_df['variant_id'] = prs_df.apply(lambda row: f'{row["contig"]}:{row["position"]}', axis=1)

hail_df_fp = f"{bucket}/prs_calculator_tutorial/PGS004859_example/PGS004859_weights_tabel.csv"
prs_df.to_csv(hail_df_fp, index=False)

In [27]:
with gcsfs.GCSFileSystem().open(f'{bucket}/prs_calculator_tutorial/PGS004859_example/PGS004859_weights_tabel.csv', 'rb') as gcs_file:
    PGS004859_weights_tabel = pd.read_csv(gcs_file)

In [28]:
PGS004859_weights_tabel.shape

(1107156, 12)

In [29]:
PGS004859_weights_tabel.head()

Unnamed: 0,chr,bp,rs_number,effect_allele,noneffect_allele,weight,additive,recessive,dominant,contig,position,variant_id
0,1,818802,1:754182:G:A:rs3131969,G,A,6e-06,1,0,0,chr1,818802,chr1:818802
1,1,833068,1:768448:A:G:rs12562034,A,G,1.2e-05,1,0,0,chr1,833068,chr1:833068
2,1,1104646,1:1040026:C:T:rs6671356,C,T,8.6e-05,1,0,0,chr1,1104646,chr1:1104646
3,1,1106320,1:1041700:G:A:rs6604968,G,A,8.8e-05,1,0,0,chr1,1106320,chr1:1106320
4,1,1113575,1:1048955:G:A:rs4970405,G,A,1.9e-05,1,0,0,chr1,1113575,chr1:1113575


# PRS Calculator

In [30]:
import pandas as pd
import os
import hail as hl
import gcsfs

def calculate_effect_allele_count(hail_mt):
    effect_allele = hail_mt.prs_info['effect_allele']
    non_effect_allele = hail_mt.prs_info['noneffect_allele']
        
    ref_allele = hail_mt.alleles[0]

    # Create a set of alternate alleles using hl.set
    alt_alleles_set = hl.set(hail_mt.alleles[1:].map(lambda allele: allele))

    is_effect_allele_ref = ref_allele == effect_allele
    is_effect_allele_alt = alt_alleles_set.contains(effect_allele)
    is_non_effect_allele_ref = ref_allele == non_effect_allele
    is_non_effect_allele_alt = alt_alleles_set.contains(non_effect_allele)

    return hl.case() \
        .when(hail_mt.GT.is_hom_ref() & is_effect_allele_ref, 2) \
        .when(hail_mt.GT.is_hom_var() & is_effect_allele_alt, 2) \
        .when(hail_mt.GT.is_het() & is_effect_allele_ref, 1) \
        .when(hail_mt.GT.is_het() & is_effect_allele_alt, 1) \
        .default(0)

def calculate_final_prs(mt, prs_identifier, pgs_weight_path, output_path):
    
    print("")
    print("#####################")
    print(f"      {prs_identifier}")
    print("#####################")
    
    
    # Construct paths
    bucket = os.getenv("WORKSPACE_BUCKET")
    
    PGS_path             = f'{bucket}/{pgs_weight_path}'
    interval_fp          = f"{bucket}/{output_path}/interval/{prs_identifier}_interval.tsv"
    hail_fp              = f'{bucket}/{output_path}/hail/' 
    gc_csv_fp            = f'{bucket}/{output_path}/score/{prs_identifier}_scores.csv' 
    gc_found_csv_fp      = f'{bucket}/{output_path}/score/{prs_identifier}_found_in_aou.csv' 
        
    # Step 1: Read PRS weight table from a file in GCS and import as Hail Table
    prs_table = hl.import_table(PGS_path,
                                types={"variant_id":"tstr",
                                        "rsid":"tstr",
                                        "weight":"tfloat",
                                        "contig":"tstr",
                                        "position":"tint32",
                                        "effect_allele":"tstr",
                                        "noneffect_allele":"tstr",
                                        "additive":"tint32",
                                        "recessive":"tint32",
                                        "dominant":"tint32"},
                                delimiter=',')
    prs_table = prs_table.annotate(locus=hl.locus(prs_table.contig, prs_table.position))
    prs_table = prs_table.key_by('locus')

    # Step 2: Semi-join to keep only variants in the weights DataFrame
    mt_prs = mt.semi_join_rows(prs_table)
    
    # Step 3: Annotate the MatrixTable with the PRS information and calculate effect allele count
    mt_prs = mt_prs.annotate_rows(prs_info=prs_table[mt_prs.locus])

    # Step 4: Calculate effect allele count and multiply by variant weight in a single step
    effect_allele_count_expr = calculate_effect_allele_count(mt_prs)
    mt_prs = mt_prs.annotate_entries(
        effect_allele_count=effect_allele_count_expr,
        weighted_count=effect_allele_count_expr * mt_prs.prs_info['weight'])

    # Step 5: Sum the weighted counts per sample and count the number of variants with weights per sample
    mt_prs = mt_prs.annotate_cols(
        sum_weights=hl.agg.sum(mt_prs.weighted_count),
        N_variants=hl.agg.count_where(hl.is_defined(mt_prs.weighted_count)))
    
    # Step 6: Extract found variants
    found_variants_table = mt_prs.filter_rows(hl.is_defined(mt_prs.prs_info)).rows()
    found_prs_info_df = found_variants_table.select(found_variants_table.prs_info).to_pandas()
    print(f"save as: {gc_found_csv_fp}")
    found_prs_info_df.to_csv(gc_found_csv_fp, header=True, index=False, sep=',')

    # Step 7: Write the PRS scores to a Hail Table
    mt_prs.key_cols_by().cols().write(hail_fp)
    
    # Step 12: Export the Hail Table to a CSV file
    mt_prs = hl.read_table(hail_fp)
    mt_prs.export(gc_csv_fp, header=True, delimiter=',')
    print(f"save as: {gc_csv_fp}")
    
    return 

In [31]:
%time calculate_final_prs(mt, 'PGS004859', 'prs_calculator_tutorial/PGS004859_example/PGS004859_weights_tabel.csv', 'prs_calculator_tutorial/prs_calculator_hail_mt/PGS004859_aou')


#####################
      PGS004859
#####################


2024-06-21 16:04:30.189 Hail: INFO: Reading table without type imputation
  Loading field 'chr' as type str (not specified)
  Loading field 'bp' as type str (not specified)
  Loading field 'rs_number' as type str (not specified)
  Loading field 'effect_allele' as type str (user-supplied)
  Loading field 'noneffect_allele' as type str (user-supplied)
  Loading field 'weight' as type float64 (user-supplied)
  Loading field 'additive' as type int32 (user-supplied)
  Loading field 'recessive' as type int32 (user-supplied)
  Loading field 'dominant' as type int32 (user-supplied)
  Loading field 'contig' as type str (user-supplied)
  Loading field 'position' as type int32 (user-supplied)
  Loading field 'variant_id' as type str (user-supplied)
2024-06-21 16:04:44.996 Hail: INFO: Coerced sorted dataset          (0 + 1) / 1]
2024-06-21 16:04:50.402 Hail: INFO: Coerced sorted dataset          (0 + 1) / 1]
2024-06-21 16:05:07.780 Hail: INFO: wrote table with 1107156 rows in 1 partition to hail_t

save as: gs://fc-secure-e5684327-e720-41ed-979a-b9ae6477b844/prs_calculator_tutorial/prs_calculator_hail_mt/PGS004859_aou/score/PGS004859_found_in_aou.csv


2024-06-21 16:23:01.505 Hail: INFO: Coerced sorted dataset          (0 + 1) / 1]
2024-06-21 16:23:06.313 Hail: INFO: Coerced sorted dataset          (0 + 1) / 1]
2024-06-21 16:23:16.733 Hail: INFO: wrote table with 1107156 rows in 1 partition to hail_temp//__iruid_11738-13mWxU99f6l8rpswfZv7nh
2024-06-21 16:23:25.020 Hail: INFO: wrote table with 1107156 rows in 1 partition to hail_temp//__iruid_12154-xk9Y41yXWfZhosJKaTx0NY
2024-06-21 16:41:56.262 Hail: INFO: wrote table with 193835 rows in 1200 partitions to gs://fc-secure-e5684327-e720-41ed-979a-b9ae6477b844/prs_calculator_tutorial/prs_calculator_hail_mt/PGS004859_aou/hail/
2024-06-21 16:42:05.346 Hail: INFO: merging 1201 files totalling 5.0M... / 1200]
2024-06-21 16:42:10.992 Hail: INFO: while writing:
    gs://fc-secure-e5684327-e720-41ed-979a-b9ae6477b844/prs_calculator_tutorial/prs_calculator_hail_mt/PGS004859_aou/score/PGS004859_scores.csv
  merge time: 5.637s


save as: gs://fc-secure-e5684327-e720-41ed-979a-b9ae6477b844/prs_calculator_tutorial/prs_calculator_hail_mt/PGS004859_aou/score/PGS004859_scores.csv
CPU times: user 1min 17s, sys: 3.89 s, total: 1min 21s
Wall time: 37min 45s


In [32]:
import datetime

# Get the current date and time again
end_time = datetime.datetime.now()

# Record the end time
current_date = end_time.date()
current_time = end_time.time()

# Format the current date
formatted_end_date = current_date.strftime("%Y-%m-%d")

# Format the current time
formatted_end_time = current_time.strftime("%H:%M:%S")

# Print the formatted end date and time separately
print("End date:", formatted_end_date)
print("End time:", formatted_end_time)

End date: 2024-06-21
End time: 16:42:14


Traceback (most recent call last):
  File "/opt/conda/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/opt/conda/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/opt/conda/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/opt/conda/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/opt/conda/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/opt/conda/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/opt/conda/lib/python3.10/asyncio/base_events.py", line 1894, in _run_once
    handle = self._ready.popleft()
IndexError: pop from 