# SCRIPT TO PERFORM QUALITY CONTROL ON WHOLE-GENOME SEQUENCING DATA

In order to run, there has to be several files in the project folder:
- GENCODE GTF: Run Scripts/WGS/01_get_gencode_annotation.sh. Obtain from: https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/gencode.v46.annotation.gtf.gz (Check for newer versions).

Once completed, a new Jupyter Notebook should be initialized so we can access this file. Or unmount and mount again the project (https://community.ukbiobank.ac.uk/hc/en-gb/community/posts/16019592366365-It-seems-that-the-recently-dx-uploaded-files-does-not-show-up-on-mnt-project-until-I-re-start-the-whole-Jupyter-Lab-VM)


- PVCF BLOCKS: Run Notebooks/WGS/DragenBlockProcessing.ipynb. Obtain from: https://biobank.ndph.ox.ac.uk/ukb/ukb/auxdata/dragen_pvcf_coordinates.zip 
It needs parsing, but in data/misc it is already parsed.

- Samples to remove file: Run Notebooks/WGS/01_QC_Samples.ipynb

#### Initialization 
##### Load packages


In [None]:
import dxpy
import pyspark

import hail as hl
from pathlib import Path
from datetime import datetime

from src.matrixtables import smart_split_multi_mt

In [None]:
# Constants
DATABASE = "matrix_tables"
REFERENCE_GENOME = "GRCh38"
PROJ_NAME = "GIPR"

Path("/tmp").resolve().mkdir(parents=True, exist_ok=True)

LOG_FILE = (
    Path("../hail_logs", f"{PROJ_NAME}_{datetime.now().strftime('%H%M')}.log")
    .resolve()
    .__str__()
)

#### Hail and spark configuration

In [None]:
# Spark init
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)

# Create database in DNAX
spark.sql(f"CREATE DATABASE IF NOT EXISTS {DATABASE} LOCATION 'dnax://'")
mt_database = dxpy.find_one_data_object(name=DATABASE)["id"]

# Hail init
hl.init(sc=sc, default_reference=REFERENCE_GENOME, log=LOG_FILE)

#### Variables

In [None]:
# RAP
VCF_VERSION = "v1"
FIELD_ID = 24310 # DRAGEN population level WGS variants, pVCF format 500k release

# Paths
BULK_DIR = Path("/mnt/project/Bulk")

# Genes
GENES = ["GIPR"]

### Quality control

#### Gene intervals and blocks 

In [None]:
# Get gene intervals
gene_interval = hl.experimental.get_gene_intervals(
    gene_symbols=GENES,
    reference_genome="GRCh38",
    gtf_file="file:///mnt/project/WGS_Javier/WGS_QC/gencode.v46.annotation.gtf",
)
gene_interval

In [None]:
# Get DRAGEN pVCF blocks
blocks = hl.import_table("file:///mnt/project/WGS_Javier/WGS_QC/dragen_pvcf_blocks.tsv", no_header=False)
blocks = blocks.annotate(Chromosome=blocks.Chromosome.replace("23", "X").replace("24", "Y"))
blocks = blocks.annotate(region=hl.str("").join([hl.str("chr"), blocks.Chromosome]))
blocks = blocks.annotate(
    interval=hl.locus_interval(
        blocks.region,
        hl.int32(blocks.Starting_Position),
        hl.int32(blocks.Ending_Position),
        reference_genome="GRCh38",
    )
).key_by("interval")

In [None]:
# Get blocks for given genes
gb = blocks.filter(hl.any(lambda inter: blocks.interval.overlaps(inter), gene_interval))
gb.show()

#### Import vcf files of specific blocks

In [None]:
VCF_DIR = Path("DRAGEN WGS/DRAGEN population level WGS variants, pVCF format 500k release")

vcf_files = [
    f"file://{BULK_DIR / VCF_DIR}/{chromosome}/ukb{FIELD_ID}_c{chromosome.replace('chr', '')}_b{block}_{VCF_VERSION}.vcf.gz"
    for block, chromosome in zip(gb.Block.collect(), gb.region.collect())
]

mt = hl.import_vcf(
    vcf_files,
    drop_samples=False,
    reference_genome="GRCh38",
    array_elements_required=False,
    force_bgz=True,
)

In [None]:
# Only genes of interest
mt = hl.filter_intervals(mt, gene_interval)
print(f"{mt.count_rows()} variants after gene filtering")

In [None]:
# First checkpoint
stage = "FIRST"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

#### Multi-allele filtering

In [None]:
# Remove variants with 6 or more alleles
mt = mt.filter_rows(mt.alleles.length() <= 6)
print(f"{mt.count_rows()} variants with not more than 6 alleles")

In [None]:
# Split multi-allele variants into single ones
mt = smart_split_multi_mt(mt)
print(f"{mt.count_rows()} variants after multi-allele splitting")

#### Quality control filtering

In [None]:
mt = mt.filter_entries(mt.FT == "PASS")

# Then, filter variants where there is at least one non-missing genotype
mt = mt.filter_rows(hl.agg.any(hl.is_defined(mt.GT)))
print(f"{mt.count_rows()} variants after entries quality filtering")
print(f"{mt.count_cols()} samples after entries quality filtering")

In [None]:
# Compute statistics about the number and fraction of filtered entries.
mt = hl.MatrixTable.compute_entry_filter_stats(mt, row_field='entry_stats_row', col_field='entry_stats_col')

In [None]:
row_fraction_threshold = 0.95

# Filter variants where at least 95% of genotypes are unfiltered
mt = mt.filter_rows(
    (1 - mt.entry_stats_row.fraction_filtered) > row_fraction_threshold
)

print(f"{mt.count_rows()} variants after fraction-based filtering")

In [None]:
col_fraction_threshold = 0.95

# Filter samples where at least 95% of variants are unfiltered
mt = mt.filter_cols(
    (1 - mt.entry_stats_col.fraction_filtered) > col_fraction_threshold
)

print(f"{mt.count_cols()} samples after fraction-based filtering")

#### Remove samples from 01_QC_Samples.ipynb

In [None]:
samples_to_remove = hl.import_table("file:///mnt/project/WGS_Javier/Data/Input_regenie/samples_to_remove.tsv", key="eid")

mt = mt.anti_join_cols(samples_to_remove)

# Filter rows (variants) where any sample information is still present
mt = mt.filter_rows(hl.agg.any(mt.GT.n_alt_alleles() > 0))

print(f"Samples remaining after removing samples from QC samples: {mt.count_cols()} ")
print(f"Variants remaining after removing samples from QC samples: {mt.count_rows()} ")

In [None]:
# Second checkpoint
stage = "SECOND"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

#### Variant Effect Predictor (VEP)

In [None]:
VEP_JSON = Path("GRCh38_VEP.json").resolve()

In [None]:
mt = hl.vep(mt, f"file:{VEP_JSON}", block_size = 100)

In [None]:
is_MANE = mt.aggregate_rows(
    hl.agg.all(hl.is_defined(mt.vep.transcript_consequences.mane_select))
)
assert is_MANE, "Selected transcript may not be MANE Select. Check manually."

mt = mt.annotate_rows(
    protCons=mt.vep.transcript_consequences.amino_acids[0].split("/")[0]
    + hl.str(mt.vep.transcript_consequences.protein_end[0])
    + mt.vep.transcript_consequences.amino_acids[0].split("/")[-1],
    varid=hl.variant_str(mt.locus, mt.alleles)
)

In [None]:
# Third checkpoint
stage = "THIRD"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

### Filtering

In [None]:
mt = hl.variant_qc(mt)
mt = mt.filter_rows(mt.variant_qc.n_non_ref > 0)
mt = mt.filter_rows(mt.variant_qc.gq_stats.mean >= 20)
mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.95)
mt = mt.filter_rows(mt.vep.most_severe_consequence != "intron_variant")
mt = mt.filter_rows(mt.vep.most_severe_consequence != "downstream_gene_variant")
mt = mt.filter_rows(mt.vep.most_severe_consequence != "upstream_gene_variant")

In [None]:
print(f"{mt.count_rows()} variants after quality filtering")

In [None]:
# Forth checkpoint
stage = "FORTH"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

### Formating

In [None]:
qt = mt.rows()
qt = qt.explode(qt.vep.transcript_consequences)

qt = qt.select(
    qt.varid,
    qt.protCons,
    qt.vep.most_severe_consequence,
    qt.vep.transcript_consequences.protein_end,
    qt.vep.transcript_consequences.protein_start,
    qt.vep.transcript_consequences.amino_acids,
    qt.vep.transcript_consequences.gene_id,
    qt.vep.transcript_consequences.transcript_id,
    **qt.variant_qc.flatten(),
)

qt = qt.annotate(AC=qt.AC[1], AF=qt.AF[1], homozygote_count=qt.homozygote_count[1])
qt = qt.key_by().drop("locus", "alleles")

qt.show(5)

In [None]:
# Group by each distinct 'most_severe_consequence' and count the number of occurrences
consequence_counts = qt.aggregate(
    hl.agg.group_by(qt.most_severe_consequence, hl.agg.count())
)

print(consequence_counts)

#### Export 

In [None]:
qt.export("/tmp/variant_qc.tsv")
!hadoop fs -getmerge /tmp/variant_qc.tsv ../variant_qc.tsv
!dx upload ../variant_qc.tsv --path /WGS_Javier/WGS_QC/Output/

In [None]:
# BGEN file
BGEN_FILE = "/tmp/GIPR"
GPs = hl.literal([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

mt = mt.annotate_entries(GP=GPs[mt.GT.n_alt_alleles()])

hl.export_bgen(
    mt=mt, varid=mt.varid, rsid=mt.varid, gp=mt.GP, output="file:" + BGEN_FILE
)

In [None]:
# ANNOTATIONS file
ANNOTATIONS_FILE = "/tmp/GIPR.annotations"

annotations = (
    mt.select_rows(
        varid=mt.varid,
        gene=mt.vep.transcript_consequences.gene_symbol[0],
        annotation=hl.if_else(
            # Check if 'protCons' is missing, if so, use "most_severe_consequence"
            hl.is_missing(mt.protCons),  
            mt.vep.most_severe_consequence,  
            mt.protCons 
        )
    )
    .rows()
    .key_by("varid")
    .drop("locus")
    .drop("alleles")
)

annotations.export("file:" + ANNOTATIONS_FILE, header=False)

In [None]:
# SETLIST file
SETLIST_FILE = "/tmp/GIPR.setlist"
position = mt.aggregate_rows(hl.agg.min(mt.locus.position))
names = mt.varid.collect()
names_str = ",".join(names)

line = f"{mt.vep.transcript_consequences.gene_symbol.collect()[0]}\t{mt.locus.contig.collect()[0]}\t{position}\t{names_str}"

with open(SETLIST_FILE, "w") as f:
    f.write(line)

In [None]:
bgen_file = BGEN_FILE + ".bgen"
sample_file = BGEN_FILE + ".sample"

!dx upload $bgen_file $sample_file $ANNOTATIONS_FILE $SETLIST_FILE --path /WGS_Javier/WGS_QC/Output/