# SCRIPT TO GENERATE COVARIATES

## This script should be only run once

In order to run, there has to be several files in the project folder:
- GENCODE GTF: Obtain from: https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/gencode.v46.annotation.gtf.gz (Check for newer versions)

In order to obtain this file, run this:

In [None]:
import requests
GENCODE_GTF = "https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_46/gencode.v46.annotation.gtf.gz"

if not Path("/mnt/project/WGS_Javier/WGS_QC/gencode.v46.annotation.gtf.gz").exists():
    response = requests.get(GENCODE_GTF)
    if response.status_code == 200:
        with open(Path("/tmp/gencode.v46.annotation.gtf.gz"), "wb") as file:
            file.write(response.content)

!dx upload /tmp/gencode.v46.annotation.gtf.gz --path /WGS_Javier/WGS_QC/

This file should be unziped
In order to do so, use Swiss-army-knife with the following command: unzip -d gencode.v46.annotation.gtf.gz 

Once completed, a new Jupyter Notebook should be initialized so we can access this file


- PVCF BLOCKS: https://biobank.ndph.ox.ac.uk/ukb/ukb/auxdata/dragen_pvcf_coordinates.zip Obtain from:
It needs parsing, but in data/misc it is already parsed


#### Initialization 
##### Load packages

In [None]:
import hail as hl
from pathlib import Path
from datetime import datetime
import pyspark
import dxpy
import subprocess
import pandas as pd
from src.matrixtables import import_mt, smart_split_multi_mt

In [None]:
# Constants
DATABASE = "matrix_tables"
REFERENCE_GENOME = "GRCh38"
PROJ_NAME = "GIPR_test"

# RAP
VCF_VERSION = "v1"
FIELD_ID = 24310

# Paths
BULK_DIR = Path("/mnt/project/Bulk")
VCF_DIR = Path("DRAGEN WGS/DRAGEN population level WGS variants, pVCF format 500k release")
INTERVAL_FILE = Path("Exome sequences/Exome OQFE CRAM files/helper_files/xgen_plus_spikein.GRCh38.bed")
MISC_DIR = Path("/mnt/project/WGS_QC/")

# Genes
GENES = ["GIPR"]

In [None]:
Path("/tmp").resolve().mkdir(parents=True, exist_ok=True)

LOG_FILE = (
    Path("../hail_logs", f"{PROJ_NAME}_{datetime.now().strftime('%H%M')}.log")
    .resolve()
    .__str__()
)

# Spark init
sc = pyspark.SparkContext()
spark = pyspark.sql.SparkSession(sc)

# Create database in DNAX
spark.sql(f"CREATE DATABASE IF NOT EXISTS {DATABASE} LOCATION 'dnax://'")
mt_database = dxpy.find_one_data_object(name=DATABASE)["id"]

# Hail init
hl.init(sc=sc, default_reference=REFERENCE_GENOME, log=LOG_FILE)

In [None]:
# Get gene intervals
gene_interval = hl.experimental.get_gene_intervals(
    gene_symbols=GENES,
    reference_genome="GRCh38",
    gtf_file="file:///mnt/project/WGS_Javier/WGS_QC/gencode.v46.annotation.gtf",
)
gene_interval

In [None]:
# Get DRAGEN pVCF blocks
blocks = hl.import_table("file:///mnt/project/WGS_Javier/WGS_QC/dragen_pvcf_blocks.tsv", no_header=False)
blocks = blocks.annotate(Chromosome=blocks.Chromosome.replace("23", "X").replace("24", "Y"))
blocks = blocks.annotate(region=hl.str("").join([hl.str("chr"), blocks.Chromosome]))
blocks = blocks.annotate(
    interval=hl.locus_interval(
        blocks.region,
        hl.int32(blocks.Starting_Position),
        hl.int32(blocks.Ending_Position),
        reference_genome="GRCh38",
    )
).key_by("interval")

In [None]:
# Filter for the given genes
gb = blocks.filter(hl.any(lambda inter: blocks.interval.overlaps(inter), gene_interval))

gb.show()

In [None]:
# GET VCF FILES
vcf_files = [
    f"file://{BULK_DIR / VCF_DIR}/{chromosome}/ukb{FIELD_ID}_c{chromosome.replace('chr', '')}_b{block}_{VCF_VERSION}.vcf.gz"
    for block, chromosome in zip(gb.f2.collect(), gb.region.collect())
]

mt = hl.import_vcf(
    vcf_files,
    drop_samples=False,
    reference_genome="GRCh38",
    array_elements_required=False,
    force_bgz=True,
)



In [None]:
# Only genes of interest
mt = hl.filter_intervals(mt, gene_interval)

In [None]:
# Only exome capture region
interval_table = hl.import_bed(
    f"file://{BULK_DIR / INTERVAL_FILE}",
    reference_genome="GRCh38",
)

mt = mt.filter_rows(hl.is_defined(interval_table[mt.locus]))

In [None]:
print(f"{mt.count_rows()} variants after interval filtering")

In [None]:
# First checkpoint
stage = "FIRST"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

In [None]:
# Multi allele filtering
mt = mt.filter_rows(mt.alleles.length() <= 6)
mt = smart_split_multi_mt(mt)

print(f"{mt.count_rows()} variants with not more than 6 alleles after splitting")

In [None]:
# Variant effect predictor 
VEP_JSON = Path("GRCh38_VEP.json").resolve()

mt = hl.vep(mt, f"file:{VEP_JSON}")

mt = mt.annotate_rows(**mt.vep)
mt = mt.annotate_rows(**mt.transcript_consequences[0])
mt = mt.annotate_rows(
    protCons=mt.amino_acids.split("/")[0]
    + hl.str(mt.protein_end)
    + mt.amino_acids.split("/")[-1],
    varid=hl.variant_str(mt.locus, mt.alleles),
)

mt = mt.drop("vep", "transcript_consequences", "vep_proc_id")

In [None]:
# Second checkpoint
stage = "SECOND"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

In [None]:
# Remove samples
SAMPLES_TO_REMOVE_FILE = Path('samples_to_remove.tsv').resolve()
!hadoop fs -put {SAMPLES_TO_REMOVE_FILE} /tmp

samples_to_remove = hl.import_table(f"/tmp/{SAMPLES_TO_REMOVE_FILE.name}", key="eid")
samples_to_remove.count()

mt = mt.anti_join_cols(samples_to_remove)

print(f"Samples remaining after hard filtering samples: {mt.count_cols()} ")

##### Filtering

In [None]:
# This is not working

#mt = mt.annotate_entries(AB=(mt.LAD[1] / hl.sum(mt.LAD)))

#filter_condition_ab = (
#    (mt.GT.is_hom_ref() & (mt.AB <= 0.1))
#    | (mt.GT.is_het() & (mt.AB >= 0.25) & (mt.AB <= 0.75))
#    | (mt.GT.is_hom_var() & (mt.AB >= 0.9))
#)

#mt = mt.filter_entries(filter_condition_ab)

In [None]:
mt = hl.variant_qc(mt)
mt = mt.filter_rows(mt.variant_qc.gq_stats.mean >= 20)
#mt = mt.filter_rows(mt.variant_qc.dp_stats.mean >= 12)
mt = mt.filter_rows(mt.variant_qc.call_rate >= 0.95)
mt = mt.filter_rows(mt.variant_qc.n_non_ref > 0)

In [None]:
# Third checkpoint
stage = "THIRD"
checkpoint_file = f"/tmp/{PROJ_NAME}.{stage}.cp.mt"

mt = mt.checkpoint(checkpoint_file, overwrite=True)

In [None]:
qt = mt.rows()

qt = qt.select(
    qt.varid,
    qt.protCons,
    qt.most_severe_consequence,
    qt.protein_end,
    qt.protein_start,
    qt.amino_acids,
    qt.gene_id,
    qt.transcript_id,
    **qt.variant_qc.flatten(),
)
qt = qt.annotate(AC=qt.AC[1], AF=qt.AF[1], homozygote_count=qt.homozygote_count[1])
qt = qt.key_by().drop("locus", "alleles")

qt.show(5)

##### Export 

In [None]:
qt.export("/tmp/variant_qc.tsv")
!hadoop fs -getmerge /tmp/variant_qc.tsv ../variant_qc.tsv

In [None]:
# BGEN file
BGEN_FILE = "/tmp/GIPR_test"
GPs = hl.literal([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])

mt = mt.annotate_entries(GP=GPs[mt.GT.n_alt_alleles()])

hl.export_bgen(
    mt=mt, varid=mt.varid, rsid=mt.varid, gp=mt.GP, output="file:" + BGEN_FILE
)

In [None]:
# ANNOTATIONS file
ANNOTATIONS_FILE = "/tmp/GIPR_test.annotations"

annotations = (
    mt.select_rows(
        varid=mt.varid,
        gene=mt.gene_id,
        annotation=mt.protCons,
    )
    .rows()
    .key_by("varid")
    .drop("locus")
    .drop("alleles")
)

annotations.export("file:" + ANNOTATIONS_FILE, header=False)

In [None]:
# SETLIST file
SETLIST_FILE = "/tmp/GIPR_test.setlist"
position = mt.aggregate_rows(hl.agg.min(mt.locus.position))
names = mt.varid.collect()
names_str = ",".join(names)

line = f"{mt.gene_id.collect()[0]}\t{mt.locus.contig.collect()[0]}\t{position}\t{names_str}"

with open(SETLIST_FILE, "w") as f:
    f.write(line)

In [None]:
mt.count()

In [None]:
bgen_file = BGEN_FILE + ".bgen"
sample_file = BGEN_FILE + ".sample"

!dx upload bgen_file sample_file ANNOTATIONS_FILE SETLIST_FILE --path WGS_Javier/WGS_QC/Output/