## 1. Import packages

In [None]:
# Import packages
import hail as hl
from bokeh.io import output_notebook,show
import gnomad.utils.vep

## 2. Import data

In [None]:
# Import gnomaAD v.3.1.2
ht = hl.read_table('gs://gcp-public-data--gnomad/release/3.1.2/ht/genomes/gnomad.genomes.v3.1.2.sites.ht')
ht = ht.head(100000) # Subset the data

# Import mutation rates from gnomAD paper
ht_mu = hl.import_table('data/supplementary_dataset_10_mutation_rates.tsv.gz',
                delimiter='\t', impute=True, force_bgz=True)

# Import context table from gnomad (https://broadinstitute.github.io/gnomad_methods/api_reference/utils/vep.html?highlight=context#gnomad.utils.vep.get_vep_context)
context_table = gnomad.utils.vep.get_vep_context("GRCh38").ht()
context_table_parsed = context_table.select(context_table.context)
context_table_parsed = context_table_parsed.transmute(context = context_table_parsed.context[2:5])

### Show the data structure

In [None]:
ht.show(3)

In [None]:
# Table with methylation level and mutational rate in the trinucleotide context
ht_mu.show(3)

In [None]:
# This table contains already precalculated nucleotides -3/+3 from mutation site 
context_table_parsed.show(3)

## 3. Add context field to main data

In [None]:
# Before joining the tri-nucleotide context of mutation
ht.count()

In [None]:
# Join only matching rows from context to ht table.
ht = ht.key_by('locus', 'alleles').join(context_table_parsed.key_by('locus', 'alleles'), how = 'left')

In [None]:
# After
ht.count()

## 4. Add mutation rates for added contexts

In [None]:
# Split alleles field to ref and alt allele
ht = ht.annotate(ref=ht.alleles[0], alt=ht.alleles[1])

# Add mutation rates according to the context, but also ref and alt allele for this context
ht = ht.key_by("context", "ref", "alt").join(ht_mu.key_by("context", "ref", "alt"), how = 'left')

In [None]:
# After adding context and mutation rates to the main table 
# (can be more than original number of rows as context may occure more than once depending on the locus)
ht.count()

In [None]:
# Add merged context with ref and alt variants for later table with mutation rates
ht = ht.annotate(context_ref_alt = ht.context + '_' + ht.ref + '_' + ht.alt)

# Change key for grouping to merged context
ht = ht.key_by('context_ref_alt')

In [None]:
# Show that contexts may be the same, but locus is completely different
ht.show(20)

#### *From this point the main key to group tables is by `context_ref_alt`*

## 5. Train linear model on synonymous variants for mutational class correction

In [None]:
# Filter only synonymous variants
ht_syn = ht.filter(ht.vep.most_severe_consequence == "synonymous_variant")

# Calculate number of variants in each tri-nucleotide context in synonymous variants
ht_syn_N_variants = (ht_syn.group_by(ht_syn.context_ref_alt).aggregate(N_variants = hl.agg.count()))

# Calculate number of singletons for each tri-nucleotide context in synonymous variants
ht_syn_singletons = ht_syn.filter(ht_syn.info.singleton == 1)
ht_syn_N_singletons = (ht_syn_singletons.group_by(ht_syn_singletons.context_ref_alt).aggregate(N_singletons = hl.agg.count()))

# Merge the N variants and N singletons tables
ht_syn_ps  = ht_syn_N_variants.join(ht_syn_N_singletons, how = 'outer') # outer as all will match and we want all
ht_syn_ps = ht_syn_ps.annotate(ps = ht_syn_ps.N_singletons/ht_syn_ps.N_variants)

### Show input table for regression

In [None]:
# How many rows before adding mutation rates
ht_syn_ps.count()

In [None]:
# Add mutation rate back to the table
#ht_syn_ps = ht_syn_ps.join(ht.select(ht.mu_snp), how = 'left')
# With the code below it doesn't keep duplicates, which is good as main table has contexts annotated to more
# than one variant, causing the duplication as the mu_snp is sometimes more than once appearing in the table)
ht_syn_ps = ht_syn_ps.annotate(**ht.select(ht.mu_snp)[ht_syn_ps.context_ref_alt])

#ht1.annotate(**ht2[ht1.x1])

ht_syn_ps.show(3)

In [None]:
# How many rows after adding mutation rates
ht_syn_ps.count()

### Perform regression

In [None]:
# Perform regression
ht_syn_lm = ht_syn_ps.aggregate(hl.agg.linreg(ht_syn_ps.ps, [1, ht_syn_ps.mu_snp], weight=ht_syn_ps.N_variants).beta)

# Show intercept and beta
ht_syn_lm

## 6. Predict expected number of variants for each context

### For testing purposes focus on `upstream_gene_variant`

### Function for regression eventually will be made starting here and put in `/utils/utils.py` script

In [None]:
# Filter specific variant functionall class for calculating MAPS in
ht_reg_table = ht.filter(ht.vep.most_severe_consequence == "upstream_gene_variant")

# Count number of variants and singletons
ht_reg_table_N_variants = (specific.group_by(specific.context_ref_alt).aggregate(N_variants = hl.agg.count()))
ht_reg_table_N_singletons = (specific.group_by(specific.context_ref_alt).aggregate(N_singletons = hl.agg.sum(specific.info.singleton)))

# Merge the tables to obtain proportions (ps)
ht_reg_table_ps = ht_reg_table_N_variants.join(ht_reg_table_N_singletons, how = "outer") # outer as we want all anyway
ht_reg_table_ps = ht_reg_table_ps.annotate(ps = ht_reg_table_ps.N_singletons/ht_reg_table_ps.N_variants)

# Add mutation rate back to the table (mu_snp matching key from variants table)
#ht_reg_table_ps = ht_reg_table_ps.join(ht.select(ht.mu_snp), how = 'left')
# With the code below it doesn't keep duplicates, which is good as main table has contexts annotated to more
# than one variant, causing the duplication as the mu_snp is sometimes more than once appearing in the table)
ht_reg_table_ps = ht_reg_table_ps.annotate(**ht.select(ht.mu_snp)[ht_reg_table_ps.context_ref_alt])

# Get expected number of singletons by applying the model factors
ht_reg_table_ps_lm = ht_reg_table_ps.annotate(expected_singletons=(ht_reg_table_ps.mu_snp * ht_syn_lm[1] + ht_syn_lm[0]) * ht_reg_table_ps.N_variants)

# Now aggregate into the main consequence from contexts
ht_reg_table_ps_lm_cons = ht_reg_table_ps_lm.annotate(consequence = "upstream_gene_variant")

# To aggregate just sum for the context
ht_reg_table_ps_lm_cons_agg = (ht_reg_table_ps_lm_cons.group_by("consequence")
              .aggregate(N_singletons=hl.agg.sum(ht_reg_table_ps_lm_cons.N_singletons),
                         expected_singletons=hl.agg.sum(ht_reg_table_ps_lm_cons.expected_singletons),
                         N_variants=hl.agg.sum(ht_reg_table_ps_lm_cons.N_variants)))

# Calculate MAPS and aggregated proportions 
ht_reg_table_ps_lm_cons_agg_MAPS = ht_reg_table_ps_lm_cons_agg.annotate(ps_agg=ht_reg_table_ps_lm_cons_agg.N_singletons / ht_reg_table_ps_lm_cons_agg.N_variants,
    maps=(ht_reg_table_ps_lm_cons_agg.N_singletons - ht_reg_table_ps_lm_cons_agg.expected_singletons) / ht_reg_table_ps_lm_cons_agg.N_variants)

# Add MAPS standard error of the mean (sem)
ht_reg_table_ps_lm_cons_agg_MAPS = ht_reg_table_ps_lm_cons_agg.annotate(maps_sem=(ht_reg_table_ps_lm_cons_agg.ps_agg * (1 - ht_reg_table_ps_lm_cons_agg.ps_agg) / ht_reg_table_ps_lm_cons_agg.N_variants) ** 0.5)

In [None]:
# Show final result
ht_reg_table_ps_lm_cons_agg_MAPS.show(20)

In [None]:
## function definition for the future
#def Regress(name):
#    """
#    Regress ...
#    """
    