# This notebook is for generating allele-specific annotations to be used in GATK VQSR Allele-Specific version for the following datasets:
1. Nigerian genomes (54-gene): 480 genomes
2. NeurGAP: 200 SA + 93 Other high coverage genomes
3. Gambian Genomes Variation (GGV) Project: 394 genomes
4. HGDP + 1KG: 4151 genomes
5. H3Africa? GVCFs are missing RawMQandDP or (RAW_MQ, MQ_DP): reprocess as per Laura's recommendation

## Annotations checks

Dataset |QUALapprox |VarDP |MQ_DP and RAW_MQ or RAWMQ_andDP |SB |ReadPosRankSum |MQRankSum
-----|-----|-----|-----|-----|-----|-----
Nigerian genomes |9 |9 |RAW_MQandDP |1 |1 |1
NeuroGAP SA |9 |9 |RAW_MQ |1 |1 |1
NeuroGAP Extra |9 |9 |RAW_MQ |1 |1 |1
GGV |9 |9 |RAW_MQandDP |1 |1 |1
HGDP+1kGP |1 |1 |(RAW_MQ, MQ_DP)|1 |1 |1

- From the table above, all datasets except HGDP+1kGP are missing QUALapprox and VarDP, but these can be computed from existing fields in the dataset using the compute_missing_annotations() function in this Notebook.
- NeuroGAP genomes (SA+Extra) are also missing MQ_DP. This is a message from Laura on how to "recover" MQ_DP
- MQ_DP is the depth of reads that were counted for the MQ calculation, and it is approximately equal to the INFO DP for a single sample
    - It follows that if MQ_DP is missing, it is somewhat safe to use the INFO DP from a single sample

## Below is the workflow
1. Generate AS annotations for each dataset separately
    - For the RankSum fields, we keep the frequencies (histogram) for each allele. These frequencies are used
    downstream toapproximate the median of the merged datasets.
2. Merge datasets
    - AS_QUALapprox: take the sum when merging across cohorts/batches
    - AS_VarDP: take the sum when merging across cohorts/batches
    - AS_*_RankSum: use histograms/frequencies to approximate median
3. Run VQSR on merged datasets

### RankSum histogram/frequencies
- Below in the binning strategy used

```python
min_bound = -3
max_bound = 3
hist_increment = 0.01
num_bins = int(abs(min_bound-max_bound)/hist_increment) # 600 bins in total
```

- Because the bin_edges will be the same across the datasets, it is annotated as a global

### Imports

In [1]:
import hail as hl
import logging
import numpy as np
from typing import Union, List, Dict, Optional, Set

from gnomad.utils.annotations import (
    fs_from_sb,
    get_adj_expr,
    get_lowqual_expr,
    pab_max_expr,
    sor_from_sb,
)

logging.basicConfig(
    format="%(asctime)s (%(name)s %(lineno)s): %(message)s",
    datefmt="%m/%d/%Y %I:%M:%S %p",
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

## Functions

### 1.1. For checking missing annotations and computing them from existing fields

In [2]:
def check_missing_annotations(fields: List[str]):
    """
    Check missing annotations that are required to generate AS annotations
    
    :param fields: list of fields present in the dataset

    :return: list of missing annotations
    """
    
    if 'RAW_MQandDP' in fields:
        required_fields = ['QUALapprox', 'VarDP', 'ReadPosRankSum', 'MQRankSum', 'SB', 'RAW_MQandDP']
    else:
        required_fields = ['QUALapprox', 'VarDP', 'ReadPosRankSum', 'MQRankSum', 'SB', 'MQ_DP', 'RAW_MQ']
            
    return list(set(required_fields).difference(fields))
    

def compute_missing_annotations(vds: hl.vds.variant_dataset.VariantDataset,
                               missing_annotations: List[str]) -> hl.vds.variant_dataset.VariantDataset:
    """
    Compute VarDP and/or QUALapprox if they are missing.
    
    :param mt: Input VDS with missing annotations

    :return: VariantDataset
    """
    
    if ('VarDP' in missing_annotations) and ('LAD' in vds.variant_data.entry):
        logger.info("Computing `VarDP` as sum(LAD)")
        vds.variant_data = vds.variant_data.annotate_entries(gvcf_info=
                                                             vds.variant_data.gvcf_info.annotate(VarDP=
                                                                                                 hl.int32(hl.sum(vds.variant_data.LAD))))
                    
    if ('QUALapprox' in missing_annotations) and ('LPL' in vds.variant_data.entry):
        logger.info("Computing `QUALapprox` as LPL[0]")
        vds.variant_data = vds.variant_data.annotate_entries(gvcf_info=
                                                             vds.variant_data.gvcf_info.annotate(QUALapprox=
                                                                                                 vds.variant_data.LPL[0]))
    if ('MQ_DP' in missing_annotations) and ('DP' in vds.variant_data.entry):
        logger.info("Computing `MQ_DP` as DP")
        vds.variant_data = vds.variant_data.annotate_entries(gvcf_info=
                                                             vds.variant_data.gvcf_info.annotate(MQ_DP=
                                                                                                 vds.variant_data.DP))
                    
    return vds

### 1.2 gnomAD functions (made a slight change to make sure we are computing histogram for RankSum annotations)

In [3]:
INFO_AGG_FIELDS = {
    "sum_agg_fields": ["QUALapprox"], # take the sum when merging across cohorts/batches
    "int32_sum_agg_fields": ["VarDP"], # take the sum when merging across cohorts/batches
    "median_agg_fields": ["ReadPosRankSum", "MQRankSum"], # get frequencies by bin for each cohort and compute median after merging
    "array_sum_agg_fields": ["SB", "RAW_MQandDP"],
}

**Modify \_get_info_agg_expr() such that it computes the histograms**

In [4]:
def _get_info_agg_expr(
    mt: hl.MatrixTable,
    sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["sum_agg_fields"],
    int32_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["int32_sum_agg_fields"],
    median_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["median_agg_fields"],
    array_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.ArrayNumericExpression]
    ] = INFO_AGG_FIELDS["array_sum_agg_fields"],
    prefix: str = "",
    treat_fields_as_allele_specific: bool = False,
) -> Dict[str, hl.expr.Aggregation]:
    """
    Create Aggregators for both site or AS info expression aggregations.

    .. note::

        - If `SB` is specified in array_sum_agg_fields, it will be aggregated as
          `AS_SB_TABLE`, according to GATK standard nomenclature.
        - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for
          the `MQ` calculation and then dropped according to GATK recommendation.
        - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation
          and then dropped according to GATK recommendation.
        - If the fields to be aggregated (`sum_agg_fields`, `int32_sum_agg_fields`,
          `median_agg_fields`) are passed as list of str, then they should correspond
          to entry fields in `mt` or in mt.gvcf_info`.
        - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in
          case of a name clash.

    :param mt: Input MT
    :param sum_agg_fields: Fields to aggregate using sum.
    :param int32_sum_agg_fields: Fields to aggregate using sum using int32.
    :param median_agg_fields: Fields to aggregate using (approximate) median.
    :param array_sum_agg_fields: Fields to aggregate using element-wise summing over an
        array.
    :param prefix: Optional prefix for the fields. Used for adding 'AS_' in the AS case.
    :param treat_fields_as_allele_specific: Treat info fields as allele-specific. Defaults to False.
    :return: Dictionary of expression names and their corresponding aggregation
        Expression.
    """

    def _agg_list_to_dict(
        mt: hl.MatrixTable, fields: List[str]
    ) -> Dict[str, hl.expr.NumericExpression]:
        out_fields = {}
        if "gvcf_info" in mt.entry:
            out_fields = {f: mt.gvcf_info[f] for f in fields if f in mt.gvcf_info}

        out_fields.update({f: mt[f] for f in fields if f in mt.entry})

        # Check that all fields were found.
        missing_fields = [f for f in fields if f not in out_fields]
        if missing_fields:
            raise ValueError(
                "Could not find the following field(s)in the MT entry schema (or nested"
                " under mt.gvcf_info: {}".format(",".join(missing_fields))
            )

        if treat_fields_as_allele_specific:
            # TODO: Change to use hl.vds.local_to_global when fill_value can accept
            #  missing (error in v0.2.119).
            out_fields = {
                f: hl.bind(
                    lambda x: hl.if_else(f == "AS_SB_TABLE", x, x[1:]),
                    hl.range(hl.len(mt.alleles)).map(
                        lambda i: hl.or_missing(
                            mt.LA.contains(i), out_fields[f][mt.LA.index(i)]
                        )
                    ),
                )
                for f in fields
            }

        return out_fields

    # Map str to expressions where needed.
    if isinstance(sum_agg_fields, list):
        sum_agg_fields = _agg_list_to_dict(mt, sum_agg_fields)

    if isinstance(int32_sum_agg_fields, list):
        int32_sum_agg_fields = _agg_list_to_dict(mt, int32_sum_agg_fields)

    if isinstance(median_agg_fields, list):
        median_agg_fields = _agg_list_to_dict(mt, median_agg_fields)

    if isinstance(array_sum_agg_fields, list):
        array_sum_agg_fields = _agg_list_to_dict(mt, array_sum_agg_fields)
    
    # For median_agg_fields, compute histogram instead of approximating quantiles
    aggs = [
        # (median_agg_fields, lambda x: hl.agg.approx_quantiles(x, 0.5)),
        (median_agg_fields, lambda x: hl.agg.hist(x, -3, 3, 600)),
        (sum_agg_fields, hl.agg.sum),
        (int32_sum_agg_fields, lambda x: hl.int32(hl.agg.sum(x))),
        (array_sum_agg_fields, hl.agg.array_sum),
    ]

    # Create aggregators.
    agg_expr = {}
    for agg_fields, agg_func in aggs:
        for k, expr in agg_fields.items():
            if treat_fields_as_allele_specific:
                # If annotation is of the form 'AS_RAW_*_RankSum' it has a histogram
                # representation where keys give the per-variant rank sum value to one
                # decimal place followed by a comma and the corresponding count for
                # that value, so we want to sum the rank sum value (first element).
                # Rename annotation in the form 'AS_RAW_*_RankSum' to 'AS_*_RankSum'.
                if k.startswith("AS_RAW_") and k.endswith("RankSum"):
                    agg_expr[f"{prefix}{k.replace('_RAW', '')}"] = hl.agg.array_agg(
                        lambda x: agg_func(hl.or_missing(hl.is_defined(x), x[0])), expr
                    )
                else:
                    agg_expr[f"{prefix}{k}"] = hl.agg.array_agg(
                        lambda x: agg_func(x), expr
                    )
            else:
                agg_expr[f"{prefix}{k}"] = agg_func(expr)

    if treat_fields_as_allele_specific:
        prefix = "AS_"

    # Handle annotations combinations and casting for specific annotations
    # If RAW_MQandDP is in agg_expr or if both MQ_DP and RAW_MQ are, compute MQ instead
    mq_tuple = None
    if f"{prefix}RAW_MQandDP" in agg_expr:
        logger.info(
            "Computing %sMQ as sqrt(%sRAW_MQandDP[0]/%sRAW_MQandDP[1]). "
            "Note that %sMQ will be set to 0 if %sRAW_MQandDP[1] == 0.",
            *[prefix] * 5,
        )
        mq_tuple = agg_expr.pop(f"{prefix}RAW_MQandDP")
    elif "AS_RAW_MQ" in agg_expr and treat_fields_as_allele_specific:
        logger.info(
            "Computing AS_MQ as sqrt(AS_RAW_MQ[i]/AD[i+1]). "
            "Note that AS_MQ will be set to 0 if AS_RAW_MQ == 0."
        )
        ad_expr = hl.vds.local_to_global(
            mt.LAD, mt.LA, hl.len(mt.alleles), fill_value=0, number="R"
        )
        mq_tuple = hl.zip(agg_expr.pop("AS_RAW_MQ"), hl.agg.array_sum(ad_expr[1:]))
    elif f"{prefix}RAW_MQ" in agg_expr and f"{prefix}MQ_DP" in agg_expr:
        logger.info(
            "Computing %sMQ as sqrt(%sRAW_MQ/%sMQ_DP). "
            "Note that MQ will be set to 0 if %sRAW_MQ == 0.",
            *[prefix] * 4,
        )
        mq_tuple = (agg_expr.pop(f"{prefix}RAW_MQ"), agg_expr.pop(f"{prefix}MQ_DP"))

    if mq_tuple is not None:
        if treat_fields_as_allele_specific:
            agg_expr[f"{prefix}MQ"] = mq_tuple.map(
                lambda x: hl.if_else(x[1] > 0, hl.sqrt(x[0] / x[1]), 0)
            )
        else:
            agg_expr[f"{prefix}MQ"] = hl.if_else(
                mq_tuple[1] > 0, hl.sqrt(mq_tuple[0] / mq_tuple[1]), 0
            )

    # If both VarDP and QUALapprox are present, also compute QD.
    if f"{prefix}VarDP" in agg_expr and f"{prefix}QUALapprox" in agg_expr:
        logger.info(
            "Computing %sQD as %sQUALapprox/%sVarDP. "
            "Note that %sQD will be set to 0 if %sVarDP == 0.",
            *[prefix] * 5,
        )
        var_dp = agg_expr[f"{prefix}VarDP"]
        qual_approx = agg_expr[f"{prefix}QUALapprox"]
        if treat_fields_as_allele_specific:
            agg_expr[f"{prefix}QD"] = hl.map(
                lambda x: hl.if_else(x[1] > 0, x[0] / x[1], 0),
                hl.zip(qual_approx, var_dp),
            )
        else:
            agg_expr[f"{prefix}QD"] = hl.if_else(var_dp > 0, qual_approx / var_dp, 0)

    # SB needs to be cast to int32 for FS down the line.
    if f"{prefix}SB" in agg_expr:
        agg_expr[f"{prefix}SB"] = agg_expr[f"{prefix}SB"].map(lambda x: hl.int32(x))

    # SB needs to be cast to int32 for FS down the line.
    if "AS_SB_TABLE" in agg_expr:
        agg_expr["AS_SB_TABLE"] = agg_expr["AS_SB_TABLE"].map(
            lambda x: x.map(lambda y: hl.int32(y))
        )

    return agg_expr

**We do not import get_as_info_expr() directly from gnomAD methods because it calls \_get_info_agg_expr we modified above**

In [5]:
def get_as_info_expr(
    mt: hl.MatrixTable,
    sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["sum_agg_fields"],
    int32_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["int32_sum_agg_fields"],
    median_agg_fields: Union[
        List[str], Dict[str, hl.expr.NumericExpression]
    ] = INFO_AGG_FIELDS["median_agg_fields"],
    array_sum_agg_fields: Union[
        List[str], Dict[str, hl.expr.ArrayNumericExpression]
    ] = INFO_AGG_FIELDS["array_sum_agg_fields"],
    alt_alleles_range_array_field: str = "alt_alleles_range_array",
    treat_fields_as_allele_specific: bool = False,
) -> hl.expr.StructExpression:
    """
    Return an allele-specific annotation Struct containing typical VCF INFO fields from GVCF INFO fields stored in the MT entries.

    .. note::

        - If `SB` is specified in array_sum_agg_fields, it will be aggregated as
          `AS_SB_TABLE`, according to GATK standard nomenclature.
        - If `RAW_MQandDP` is specified in array_sum_agg_fields, it will be used for
          the `MQ` calculation and then dropped according to GATK recommendation.
        - If `RAW_MQ` and `MQ_DP` are given, they will be used for the `MQ` calculation
          and then dropped according to GATK recommendation.
        - If the fields to be aggregate (`sum_agg_fields`, `int32_sum_agg_fields`,
          `median_agg_fields`) are passed as list of str, then they should correspond
          to entry fields in `mt` or in `mt.gvcf_info`.
        - Priority is given to entry fields in `mt` over those in `mt.gvcf_info` in
          case of a name clash.
        - If `treat_fields_as_allele_specific` is False, it's expected that there is a
          single value for each entry field to be aggregated. Then when performing the
          aggregation per global alternate allele, that value is included in the
          aggregation if the global allele is present in the entry's list of local
          alleles. If `treat_fields_as_allele_specific` is True, it's expected that
          each entry field to be aggregated has one value per local allele, and each
          of those is mapped to a global allele for aggregation.

    :param mt: Input Matrix Table
    :param sum_agg_fields: Fields to aggregate using sum.
    :param int32_sum_agg_fields: Fields to aggregate using sum using int32.
    :param median_agg_fields: Fields to aggregate using (approximate) median.
    :param array_sum_agg_fields: Fields to aggregate using array sum.
    :param alt_alleles_range_array_field: Annotation containing an array of the range
        of alternate alleles e.g., `hl.range(1, hl.len(mt.alleles))`
    :param treat_fields_as_allele_specific: Treat info fields as allele-specific.
        Defaults to False.
    :return: Expression containing the AS info fields
    """
    if "DP" in list(sum_agg_fields) + list(int32_sum_agg_fields):
        logger.warning(
            "`DP` was included in allele-specific aggregation, however `DP` is"
            " typically not aggregated by allele; `VarDP` is.Note that the resulting"
            " `AS_DP` field will NOT include reference genotypes."
        )

    agg_expr = _get_info_agg_expr(
        mt=mt,
        sum_agg_fields=sum_agg_fields,
        int32_sum_agg_fields=int32_sum_agg_fields,
        median_agg_fields=median_agg_fields,
        array_sum_agg_fields=array_sum_agg_fields,
        prefix="" if treat_fields_as_allele_specific else "AS_",
        treat_fields_as_allele_specific=treat_fields_as_allele_specific,
    )

    if alt_alleles_range_array_field not in mt.row or mt[
        alt_alleles_range_array_field
    ].dtype != hl.dtype("array<int32>"):
        msg = (
            f"'get_as_info_expr' expected a row field '{alt_alleles_range_array_field}'"
            " of type array<int32>"
        )
        logger.error(msg)
        raise ValueError(msg)

    if not treat_fields_as_allele_specific:
        # Modify aggregations to aggregate per allele
        agg_expr = {
            f: hl.agg.array_agg(
                lambda ai: hl.agg.filter(mt.LA.contains(ai), expr),
                mt[alt_alleles_range_array_field],
            )
            for f, expr in agg_expr.items()
        }

    # Run aggregations
    info = hl.struct(**agg_expr)

    # Add FS and SOR if SB is present.
    if "AS_SB_TABLE" in info or "AS_SB" in info:
        # Rename AS_SB to AS_SB_TABLE if present and add SB Ax2 aggregation logic.
        if "AS_SB" in agg_expr:
            if "AS_SB_TABLE" in agg_expr:
                logger.warning(
                    "Both `AS_SB` and `AS_SB_TABLE` were specified for aggregation."
                    " `AS_SB` will be used for aggregation."
                )
            as_sb_table = hl.array(
                [
                    info.AS_SB.filter(lambda x: hl.is_defined(x)).fold(
                        lambda i, j: i[:2] + j[:2], [0, 0]
                    )  # ref
                ]
            ).extend(
                info.AS_SB.map(lambda x: x[2:])  # each alt
            )
        else:
            as_sb_table = info.AS_SB_TABLE
        info = info.annotate(
            AS_SB_TABLE=as_sb_table,
            AS_FS=hl.range(1, hl.len(mt.alleles)).map(
                lambda i: fs_from_sb(as_sb_table[0].extend(as_sb_table[i]))
            ),
            AS_SOR=hl.range(1, hl.len(mt.alleles)).map(
                lambda i: sor_from_sb(as_sb_table[0].extend(as_sb_table[i]))
            ),
        )

    return info

In [6]:
def default_compute_info(
    mt: hl.MatrixTable,
    site_annotations: bool = False,
    as_annotations: bool = False,
    # Set to True by default to prevent a breaking change.
    quasi_as_annotations: bool = True,
    n_partitions: int = 5000,
    lowqual_indel_phred_het_prior: int = 40,
    ac_filter_groups: Optional[Dict[str, hl.Expression]] = None,
) -> hl.Table:
    """
    Compute a HT with the typical GATK allele-specific (AS) info fields as well as ACs and lowqual fields.

    .. note::

        - This table doesn't split multi-allelic sites.
        - At least one of `site_annotations`, `as_annotations` or `quasi_as_annotations`
          must be True.

    :param mt: Input MatrixTable. Note that this table should be filtered to nonref sites.
    :param site_annotations: Whether to generate site level info fields. Default is False.
    :param as_annotations: Whether to generate allele-specific info fields using
        allele-specific annotations in gvcf_info. Default is False.
    :param quasi_as_annotations: Whether to generate allele-specific info fields using
        non-allele-specific annotations in gvcf_info, but performing per allele
        aggregations. This method can be used in cases where genotype data doesn't
        contain allele-specific annotations to approximate allele-specific annotations.
        Default is True.
    :param n_partitions: Number of desired partitions for output Table. Default is 5000.
    :param lowqual_indel_phred_het_prior: Phred-scaled prior for a het genotype at a
        site with a low quality indel. Default is 40. We use 1/10k bases (phred=40) to
        be more consistent with the filtering used by Broad's Data Sciences Platform
        for VQSR.
    :param ac_filter_groups: Optional dictionary of sample filter expressions to compute
        additional groupings of ACs. Default is None.
    :return: Table with info fields
    :rtype: Table
    """
    if not site_annotations and not as_annotations and not quasi_as_annotations:
        raise ValueError(
            "At least one of `site_annotations`, `as_annotations`, or "
            "`quasi_as_annotations` must be True!"
        )

    # Add a temporary annotation for allele count groupings.
    ac_filter_groups = {"": True, **(ac_filter_groups or {})}
    mt = mt.annotate_cols(_ac_filter_groups=ac_filter_groups)

    # Move gvcf info entries out from nested struct.
    mt = mt.transmute_entries(**mt.gvcf_info)

    # Adding alt_alleles_range_array as a required annotation for
    # get_as_info_expr to reduce memory usage.
    mt = mt.annotate_rows(alt_alleles_range_array=hl.range(1, hl.len(mt.alleles)))

    info_expr = None
    quasi_info_expr = None

    # Compute quasi-AS info expr.
    if quasi_as_annotations:
        info_expr = get_as_info_expr(mt)

    # Compute AS info expr using gvcf_info allele specific annotations.
    if as_annotations:
        if info_expr is not None:
            quasi_info_expr = info_expr
        info_expr = get_as_info_expr(
            mt,
            **AS_INFO_AGG_FIELDS,
            treat_fields_as_allele_specific=True,
        )

    if info_expr is not None:
        # Add allele specific pab_max
        info_expr = info_expr.annotate(
            AS_pab_max=pab_max_expr(mt.LGT, mt.LAD, mt.LA, hl.len(mt.alleles))
        )

    if site_annotations:
        site_expr = get_site_info_expr(mt)
        if info_expr is None:
            info_expr = site_expr
        else:
            info_expr = info_expr.annotate(**site_expr)

    # Add 'AC' and 'AC_raw' for each allele count filter group requested.
    # First compute ACs for each non-ref allele, grouped by adj.
    grp_ac_expr = {
        f: hl.agg.array_agg(
            lambda ai: hl.agg.filter(
                mt.LA.contains(ai) & mt._ac_filter_groups[f],
                hl.agg.group_by(
                    get_adj_expr(mt.LGT, mt.GQ, mt.DP, mt.LAD),
                    hl.agg.sum(
                        mt.LGT.one_hot_alleles(mt.LA.map(lambda x: hl.str(x)))[
                            mt.LA.index(ai)
                        ]
                    ),
                ),
            ),
            mt.alt_alleles_range_array,
        )
        for f in ac_filter_groups
    }

    # Then, for each non-ref allele, compute
    # 'AC' as the adj group
    # 'AC_raw' as the sum of adj and non-adj groups
    info_expr = info_expr.annotate(
        **{
            f"AC{'_' + f if f else f}_raw": grp.map(
                lambda i: hl.int32(i.get(True, 0) + i.get(False, 0))
            )
            for f, grp in grp_ac_expr.items()
        },
        **{
            f"AC{'_' + f if f else f}": grp.map(lambda i: hl.int32(i.get(True, 0)))
            for f, grp in grp_ac_expr.items()
        },
    )

    ann_expr = {"info": info_expr}
    if quasi_info_expr is not None:
        ann_expr["quasi_info"] = quasi_info_expr

    info_ht = mt.select_rows(**ann_expr).rows()

    # Add AS lowqual flag
    info_ht = info_ht.annotate(
        AS_lowqual=get_lowqual_expr(
            info_ht.alleles,
            info_ht.info.AS_QUALapprox,
            indel_phred_het_prior=lowqual_indel_phred_het_prior,
        )
    )

    if site_annotations:
        # Add lowqual flag
        info_ht = info_ht.annotate(
            lowqual=get_lowqual_expr(
                info_ht.alleles,
                info_ht.info.QUALapprox,
                indel_phred_het_prior=lowqual_indel_phred_het_prior,
            )
        )

    return info_ht.naive_coalesce(n_partitions)


### Function for selecting only frequencies after computing AS annotations

In [7]:
def as_table_cleanup(as_table: hl.Table,
                    start: int = -3,
                    end: int = 3,
                    step: float = 0.01
                    ) -> hl.Table:
    """
    Only keep frequencies and remove other fields (bin_edges, n_smaller/bigger) computed by hl.agg.hist() to save space
    
    :param as_table: Input HT with AS annotations, including AS_*_RankSum
    :param start: Start of histogram range.
    :param end: End of histogram range.
    :param step: Difference between any two subsequent numbers

    :return: Hail Table with AS_*_RankSum fields containing only frequencies
    """
    
    as_table = as_table.annotate_globals(bin_edges = [hl.float64(i) for i in np.arange(start, end, step)])
    
    # only keep bin frequencies and drop the other elements of the struct i.e. overwrite AS_*_RankSum with just frequency
    as_table = as_table.annotate(info = as_table.info.annotate(AS_ReadPosRankSum_freq = as_table.info.AS_ReadPosRankSum['bin_freq']))
    as_table = as_table.annotate(info = as_table.info.annotate(AS_MQRankSum_freq = as_table.info.AS_MQRankSum['bin_freq']))
    
    # drop the AS_*_RankSum that has all information to save space
    as_table = as_table.annotate(info = as_table.info.drop('AS_ReadPosRankSum', 'AS_MQRankSum'))
    
    return as_table
    
    

## Datasets

### 1. Nigerian genomes

In [8]:
vds_nig = hl.vds.read_vds('gs://nigeria-54gene/combined-gvcfs/nigeria_54gene_merged_gvcfs.vds')

Initializing Hail with default parameters...
Running on Apache Spark version 3.3.0
SparkUI available at http://gnomaf-annotations-m.c.diverse-pop-seq-ref.internal:40749
Welcome to
     __  __     <>__
    / /_/ /__  __/ /
   / __  / _ `/ / /
  /_/ /_/\_,_/_/_/   version 0.2.115-10932c754edb
LOGGING: writing to /home/hail/hail-20230829-2016-0.2.115-10932c754edb.log
2023-08-29 20:16:41.030 Hail: WARN: You are reading a VDS written with an older version of Hail.
  Hail now supports much faster interval filters on VDS, but you'll need to run either
  `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the
  existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`.


In [9]:
vds_nig.variant_data.count()

(94990636, 480)

In [10]:
missing_nigeria = check_missing_annotations(list(vds_nig.variant_data.gvcf_info) + list(vds_nig.variant_data.entry))
missing_nigeria

['VarDP', 'QUALapprox']

In [11]:
# computing missing annotations
vds_nig = compute_missing_annotations(vds_nig, missing_nigeria)

# data to be passed to default_compute_info should be filtered to nonref sites
mt_nig = vds_nig.variant_data.filter_entries(vds_nig.variant_data.LGT.is_non_ref(), keep=True)

08/29/2023 08:17:09 PM (__main__ 29): Computing `VarDP` as sum(LAD)
08/29/2023 08:17:09 PM (__main__ 35): Computing `QUALapprox` as LPL[0]


In [12]:
# compute AS annotation
as_ht_nigeria = default_compute_info(mt_nig)

# only keep frequencies for the RankSum annotations
as_ht_nigeria = as_table_cleanup(as_ht_nigeria)

# as_ht_nigeria.filter(as_ht_nigeria.info.AS_ReadPosRankSum_freq[0][0] > 0).show(n=10)

08/29/2023 08:17:31 PM (__main__ 131): Computing AS_MQ as sqrt(AS_RAW_MQandDP[0]/AS_RAW_MQandDP[1]). Note that AS_MQ will be set to 0 if AS_RAW_MQandDP[1] == 0.
08/29/2023 08:17:31 PM (__main__ 166): Computing AS_QD as AS_QUALapprox/AS_VarDP. Note that AS_QD will be set to 0 if AS_VarDP == 0.


### 2.1 NeuroGAP SA

In [14]:
vds_sa = hl.vds.read_vds('gs://neurogap-highcov-genomes/SA-genomes/south_african_genomes_merged_gvcfs.vds')

2023-08-29 20:18:09.854 Hail: WARN: You are reading a VDS written with an older version of Hail.
  Hail now supports much faster interval filters on VDS, but you'll need to run either
  `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the
  existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`.


In [15]:
vds_sa.variant_data.count()

(126525458, 200)

In [16]:
missing_sa = check_missing_annotations(list(vds_sa.variant_data.gvcf_info) + list(vds_sa.variant_data.entry))
missing_sa

['MQ_DP', 'VarDP', 'QUALapprox']

In [17]:
# computing missing annotations
vds_sa = compute_missing_annotations(vds_sa, missing_sa)

# data to be passed to default_compute_info should be filtered to nonref sites
mt_sa = vds_sa.variant_data.filter_entries(vds_sa.variant_data.LGT.is_non_ref(), keep=True)

# We have RAW_MQ and MQ_DP as separate annotations but default_compute_info looks for RAW_MQandDP 
mt_sa = mt_sa.annotate_entries(gvcf_info = mt_sa.gvcf_info.annotate(RAW_MQandDP =
                                                        hl.array([mt_sa.gvcf_info.RAW_MQ, mt_sa.gvcf_info.MQ_DP])))

08/29/2023 08:19:02 PM (__main__ 29): Computing `VarDP` as sum(LAD)
08/29/2023 08:19:02 PM (__main__ 35): Computing `QUALapprox` as LPL[0]
08/29/2023 08:19:02 PM (__main__ 40): Computing `MQ_DP` as DP


In [18]:
# compute AS annotation
as_ht_sa = default_compute_info(mt_sa)

# only keep frequencies for the RankSum annotations
as_ht_sa = as_table_cleanup(as_ht_sa)

# as_ht_sa.filter(as_ht_sa.info.AS_ReadPosRankSum_freq[0][0] > 0).show(n=10)

08/29/2023 08:19:35 PM (__main__ 131): Computing AS_MQ as sqrt(AS_RAW_MQandDP[0]/AS_RAW_MQandDP[1]). Note that AS_MQ will be set to 0 if AS_RAW_MQandDP[1] == 0.
08/29/2023 08:19:35 PM (__main__ 166): Computing AS_QD as AS_QUALapprox/AS_VarDP. Note that AS_QD will be set to 0 if AS_VarDP == 0.


### 2.2 NeuroGAP Extra

In [19]:
vds_neuro_extra = hl.vds.read_vds('gs://neurogap-highcov-genomes/NeuroGAP-extra-93-genomes/neurogap_highcov_all_sites_93_genomes_merged_gvcfs.vds')


2023-08-29 20:19:47.061 Hail: WARN: You are reading a VDS written with an older version of Hail.
  Hail now supports much faster interval filters on VDS, but you'll need to run either
  `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the
  existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`.


In [20]:
vds_neuro_extra.variant_data.count()

(88618041, 93)

In [21]:
missing_neuro_extra = check_missing_annotations(list(vds_neuro_extra.variant_data.gvcf_info) + list(vds_neuro_extra.variant_data.entry))
missing_neuro_extra

['MQ_DP', 'VarDP', 'QUALapprox']

In [22]:
# computing missing annotations
vds_neuro_extra = compute_missing_annotations(vds_neuro_extra, missing_neuro_extra)

# data to be passed to default_compute_info should be filtered to nonref sites
mt_neuro_extra = vds_neuro_extra.variant_data.filter_entries(vds_neuro_extra.variant_data.LGT.is_non_ref(),
                                                             keep=True)

# We have RAW_MQ and MQ_DP as separate annotations but default_compute_info looks for RAW_MQandDP 
mt_neuro_extra = mt_neuro_extra.annotate_entries(gvcf_info =
                                                 mt_neuro_extra.gvcf_info.annotate(RAW_MQandDP =
                                                        hl.array([mt_neuro_extra.gvcf_info.RAW_MQ,
                                                                  mt_neuro_extra.gvcf_info.MQ_DP])))

08/29/2023 08:21:32 PM (__main__ 29): Computing `VarDP` as sum(LAD)
08/29/2023 08:21:32 PM (__main__ 35): Computing `QUALapprox` as LPL[0]
08/29/2023 08:21:32 PM (__main__ 40): Computing `MQ_DP` as DP


In [23]:
# compute AS annotation
as_ht_neuro_extra = default_compute_info(mt_neuro_extra)

# only keep frequencies for the RankSum annotations
as_ht_neuro_extra = as_table_cleanup(as_ht_neuro_extra)

# as_ht_neuro_extra.filter(as_ht_neuro_extra.info.AS_ReadPosRankSum_freq[0][0] > 0).show(n=10)

08/29/2023 08:21:52 PM (__main__ 131): Computing AS_MQ as sqrt(AS_RAW_MQandDP[0]/AS_RAW_MQandDP[1]). Note that AS_MQ will be set to 0 if AS_RAW_MQandDP[1] == 0.
08/29/2023 08:21:52 PM (__main__ 166): Computing AS_QD as AS_QUALapprox/AS_VarDP. Note that AS_QD will be set to 0 if AS_VarDP == 0.


### 3. GGV

In [24]:
vds_ggv = hl.vds.read_vds('gs://gnomaf/gambian-genomes/COMBINED_GVCFS/gambian_genomes_merged_gvcfs.vds')

2023-08-29 20:22:03.645 Hail: WARN: You are reading a VDS written with an older version of Hail.
  Hail now supports much faster interval filters on VDS, but you'll need to run either
  `hl.vds.truncate_reference_blocks(vds, ...)` and write a copy (see docs) or patch the
  existing VDS in place with `hl.vds.store_ref_block_max_length(vds_path)`.


In [25]:
vds_ggv.variant_data.count()

(61164017, 394)

In [26]:
missing_ggv = check_missing_annotations(list(vds_ggv.variant_data.gvcf_info) + list(vds_ggv.variant_data.entry))
missing_ggv

['VarDP', 'QUALapprox']

In [27]:
# computing missing annotations
vds_ggv = compute_missing_annotations(vds_ggv, missing_ggv)

# data to be passed to default_compute_info should be filtered to nonref sites
mt_ggv = vds_ggv.variant_data.filter_entries(vds_ggv.variant_data.LGT.is_non_ref(), keep=True)

08/29/2023 08:23:25 PM (__main__ 29): Computing `VarDP` as sum(LAD)
08/29/2023 08:23:25 PM (__main__ 35): Computing `QUALapprox` as LPL[0]


In [28]:
# compute AS annotation
as_ht_ggv = default_compute_info(mt_ggv)

# only keep frequencies for the RankSum annotations
as_ht_ggv = as_table_cleanup(as_ht_ggv)

# as_ht_ggv.filter(as_ht_ggv.info.AS_ReadPosRankSum_freq[0][0] > 0).show(n=10)

08/29/2023 08:23:40 PM (__main__ 131): Computing AS_MQ as sqrt(AS_RAW_MQandDP[0]/AS_RAW_MQandDP[1]). Note that AS_MQ will be set to 0 if AS_RAW_MQandDP[1] == 0.
08/29/2023 08:23:40 PM (__main__ 166): Computing AS_QD as AS_QUALapprox/AS_VarDP. Note that AS_QD will be set to 0 if AS_VarDP == 0.


### 4. HGDP+1kGP

In [29]:
mt_hgdp_tgp = hl.read_matrix_table('gs://gcp-public-data--gnomad/release/3.1.2/mt/genomes/gnomad.genomes.v3.1.2.hgdp_1kg_subset_sparse.mt')


In [30]:
mt_hgdp_tgp.count()

(1538715096, 4151)

In [31]:
missing_hgdp_tgp = check_missing_annotations(list(mt_hgdp_tgp.entry.gvcf_info) + list(mt_hgdp_tgp.entry))
missing_hgdp_tgp

[]

In [32]:
# data to be passed to default_compute_info should be filtered to nonref sites
mt_hgdp_tgp = mt_hgdp_tgp.filter_entries(mt_hgdp_tgp.LGT.is_non_ref(), keep=True)

# We have RAW_MQ and MQ_DP as separate annotations but default_compute_info looks for RAW_MQandDP 
mt_hgdp_tgp = mt_hgdp_tgp.annotate_entries(gvcf_info = mt_hgdp_tgp.gvcf_info.annotate(RAW_MQandDP =
                                                        hl.array([mt_hgdp_tgp.gvcf_info.RAW_MQ,
                                                                  mt_hgdp_tgp.gvcf_info.MQ_DP])))

In [33]:
# compute AS annotation
as_ht_hgdp_tgp = default_compute_info(mt_hgdp_tgp)

# only keep frequencies for the RankSum annotations
as_ht_hgdp_tgp = as_table_cleanup(as_ht_hgdp_tgp)

# as_ht_hgdp_tgp.filter(as_ht_hgdp_tgp.info.AS_ReadPosRankSum_freq[0][0] > 0).show(n=10)

08/29/2023 08:24:56 PM (__main__ 131): Computing AS_MQ as sqrt(AS_RAW_MQandDP[0]/AS_RAW_MQandDP[1]). Note that AS_MQ will be set to 0 if AS_RAW_MQandDP[1] == 0.
08/29/2023 08:24:56 PM (__main__ 166): Computing AS_QD as AS_QUALapprox/AS_VarDP. Note that AS_QD will be set to 0 if AS_VarDP == 0.


### Check if all datasets contain the same AS annotations and export

In [34]:
nigeria_as = list(as_ht_nigeria.row) + list(as_ht_nigeria.info)
neurogap_sa_as = list(as_ht_sa.row) + list(as_ht_sa.info)
neurogap_extra_as = list(as_ht_neuro_extra.row) + list(as_ht_neuro_extra.info)
ggv_as = list(as_ht_ggv.row) + list(as_ht_ggv.info)
hgdp_tgp_as = list(as_ht_hgdp_tgp.row) + list(as_ht_hgdp_tgp.info)

assert nigeria_as == neurogap_sa_as == neurogap_extra_as == ggv_as == hgdp_tgp_as

In [36]:
as_ht_nigeria.write('gs://gnomaf/AS_annotations/nigeria_54gene_merged_gvcfs.ht')
as_ht_sa.write('gs://gnomaf/AS_annotations/neurogap_south_africa_genomes_merged_gvcfs.ht')
as_ht_neuro_extra.write('gs://gnomaf/AS_annotations/neurogap_highcov_all_sites_93_genomes_merged_gvcfs.ht')
as_ht_ggv.write('gs://gnomaf/AS_annotations/gambian_genomes_merged_gvcfs.ht')
as_ht_hgdp_tgp.write('gs://gnomaf/AS_annotations/gnomad.genomes.v3.1.2.hgdp_1kg_subset.ht')

2023-08-29 20:48:20.220 Hail: INFO: wrote table with 94990636 rows in 4790 partitions to gs://gnomaf/AS_annotations/nigeria_54gene_merged_gvcfs.ht
2023-08-29 20:57:49.972 Hail: INFO: wrote table with 126525458 rows in 4800 partitions to gs://gnomaf/AS_annotations/neurogap_south_africa_genomes_merged_gvcfs.ht
2023-08-29 21:04:19.262 Hail: INFO: wrote table with 88618041 rows in 2586 partitions to gs://gnomaf/AS_annotations/neurogap_highcov_all_sites_93_genomes_merged_gvcfs.ht
2023-08-29 21:20:37.209 Hail: INFO: wrote table with 61164017 rows in 4847 partitions to gs://gnomaf/AS_annotations/gambian_genomes_merged_gvcfs.ht
2023-08-29 22:33:22.019 Hail: INFO: wrote table with 1538715096 rows in 5000 partitions to gs://gnomaf/AS_annotations/gnomad.genomes.v3.1.2.hgdp_1kg_subset.ht
