In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generates Figure 4 from Cosentino et al. Nature Genetics 2023

This notebook builds the GWAS comparison subfigures for Figure 4 of the ML-based
COPD manuscript (Cosentino et al. Nature Genetics 2023).

This notebook also generates the following Extended Data and Supplementary GWAS
comparison Figures:

-   Extended Data Figure 5: Statistical power comparison of ML-based COPD with
    Hobbs et al. Nature Genetics 2017 COPD GWAS.
-   Extended Data Figure 6: Statistical power comparison of ML-based COPD
    without MRB COPD cases with Hobbs et al. Nature Genetics 2017 COPD GWAS.
-   Extended Data Figure 7: Statistical power comparison of binarized ML-based
    COPD with Sakornsakolpat et al. Nature Genetics 2019.
-   Extended Data Figure 8: Statistical power comparison of binarized ML-based
    COPD matching GOLD prevalence with Sakornsakolpat et al. Nature
    Genetics 2019.
-   Supplementary Figure 12: Statistical power comparison of ML-based COPD with
    GBMI COPD excluding UKB.
-   Supplementary Figure 14: Statistical power comparison of binarized ML-based
    COPD with medical-record-based COPD labels.
-   Supplementary Figure 15: Statistical power comparison of proxy-GOLD with
    Sakornsakolpat et al Nature Genetics 2019.
-   Supplementary Figure 17: Statistical power comparison of proxy-GOLD label
    using BOLT-LMM vs Regenie.

**Important: Generating all figures requires populating the
`ASSOC_RESULTS_BASE_DIR` directory with expected data files.** Each GWAS
requires a hits, loci, and filtered GWAS results file
(`{}.association_results.annotated_hits.tsv`, `{}.association_results.loci.tsv`,
and `{}.association_results.filtered.tsv`, respectively). The notebook expects
following filename prefixes for each GWAS:

-   `ml_based_copd`: Our ML-based COPD liability score.
-   `ml_based_copd_no_mrb_cases`: Our ML-based COPD liability score with
    medical-record-based COPD cases removed.
-   `hobbs_natgen_2017`: Hobbs et al. Nature Genetics 2017.
-   `ml_based_copd_binarized_gold_prev`: ML-based COPD binarized to match proxy
    GOLD prevalence.
-   `sakornsakolpat_natgen_2019`: Sakornsakolpat et al. Nature Genetics 2019.
-   `ml_based_copd_binarized`: ML-based COPD binarized to a 50-50 case-control
    split.
-   `mrb_labels_copd`: Medical-record-based COPD.
-   `spiro_gold_copd`: Proxy GOLD COPD.
-   `gbmi_excluding_ukb_copd`: Global Biobank Meta-analysis Initiative.
-   `spiro_gold_copd_regenie`: Proxy GOLD COPD run with Regenie.
-   `spiro_gold_copd_bolt`: Proxy GOLD COPD run with BOLT-LMM.

We provide these files in our
[GitHub repository](https://github.com/Google-Health/genomics-research/tree/main/ml-based-copd)
for all GWAS *except* those from external data sources: `hobbs_natgen_2017`,
`sakornsakolpat_natgen_2019`, and `gbmi_excluding_ukb_copd`. See the
corresponding manuscripts for details on how to access each data source. Once
downloaded, simply convert the summary statistics to match the schema outlined
in the provided files.

In [None]:
import collections
import copy
import csv
import dataclasses
import decimal
import io
import tempfile
import os
import typing
from typing import Any, Dict, Generator, List, Sequence, Union, NamedTuple, Optional, Tuple

from absl import logging
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib
from matplotlib import rcParams
import matplotlib.pyplot as plt

In [None]:
def set_matplotib_settings():
  rcParams['text.usetex'] = 'False'
  rcParams['font.family'] = 'Helvetica'
  rcParams['savefig.dpi'] = 300
  rcParams['savefig.transparent'] = True
  rcParams['font.size'] = 7

  # Note: font types are needed to edit in Adobe Illustrator.
  rcParams['pdf.fonttype'] = 42
  rcParams['ps.fonttype'] = 42


set_matplotib_settings()

In [None]:
# Directory containing all GWAS annotated hits, loci, and filtered result files.
ASSOC_RESULTS_BASE_DIR = '/path/to/association_results'

# Expected suffixes for each type of GWAS results files.
ASSOC_RESULTS_HITS_SUFFIX = '.association_results.annotated_hits.tsv'
ASSOC_RESULTS_LOCI_SUFFIX = '.association_results.loci.tsv'
ASSOC_RESULTS_FILTERED_SUFFIX = '.association_results.filtered.tsv'


# Denotes the chromosome offsets needed for a Manhattan plot. Chromosome sizes
# can be obtained from UCSC. For example, for the Genome Reference Consortium
# Human Reference 37, it can be downloaded from:
# https://hgdownload.cse.ucsc.edu/goldenPath/hg19/bigZips/hg19.chrom.sizes
CHROM_OFFSETS = collections.OrderedDict([
    ('1', 0),
    ('2', 249250621),
    ('3', 492449994),
    ('4', 690472424),
    ('5', 881626700),
    ('6', 1062541960),
    ('7', 1233657027),
    ('8', 1392795690),
    ('9', 1539159712),
    ('10', 1680373143),
    ('11', 1815907890),
    ('12', 1950914406),
    ('13', 2084766301),
    ('14', 2199936179),
    ('15', 2307285719),
    ('16', 2409817111),
    ('17', 2500171864),
    ('18', 2581367074),
    ('19', 2659444322),
    ('20', 2718573305),
    ('21', 2781598825),
    ('22', 2829728720),
    ('$', 2881033286),
])

## Utilities for parsing annotated hits, loci, and filtered GWAS results.

In [None]:
ArbitraryPrecisionReal = Union[float, decimal.Decimal]


def pvalue_str_to_numeric(pvalue_str: str) -> ArbitraryPrecisionReal:
  """Returns p-value as a float if possible, otherwise as Decimal."""
  float_value = float(pvalue_str)
  if float_value != 0:
    return float_value

  # Analytically computed p-values from GWAS software can write out p-values
  # that are too small to be represented by a float value, e.g. '1e-500'. In
  # these cases we represent the result using an arbitrary-precision Decimal
  # instead. But if the value written is truly 0, then we just use the float.
  decimal_value = decimal.Decimal(pvalue_str)
  if decimal_value == 0:
    return 0.0
  return decimal_value


def _validate_assoc_result(
    vid: str,
    chrom: str,
    bp: int,
    ref: str,
    alt: str,
    rsid: str,
    affx: str,
    eff: str,
    alt_freq: float,
    num_indv: int,
    src: str,
    info_score: float,
    p_hwe_pop: ArbitraryPrecisionReal,
    p: ArbitraryPrecisionReal,
    beta: float,
    se: float,
    ctrl_cnt: Optional[int],
    case_cnt: Optional[int],
) -> None:
  """Raises ValueError if inputs are not a valid AssocResult."""
  # 20 is derived from l.g.medgen.util.variant_tools._MAX_VARIANT_ID_LENGTH
  # We avoid its dependency since it brings in Nucleus as well.
  if len(vid) > 20:
    raise ValueError(f'Unexpected variant id: {vid}')
  if bp <= 0:
    raise ValueError(f'Invalid non-positive bp: {bp}')
  if ref == alt:
    raise ValueError(f'Ref and alt must be distinct: {ref} vs {alt}')
  if eff not in [ref, alt]:
    raise ValueError(
        f'Effect allele must be one of ref or alt: {eff}, [{ref}, {alt}]'
    )
  if not 0 <= alt_freq <= 1:
    raise ValueError(f'Alt allele frequency must be in [0, 1]: {alt_freq}')
  if num_indv < 1:
    raise ValueError(f'Number of individuals must be positive: {num_indv}')
  if src not in ['Genotyped', 'Imputed']:
    raise ValueError(f'Unexpected source: {src}')
  if not 0 <= info_score <= 1:
    raise ValueError(f'INFO score must be in [0, 1]: {info_score}')
  if src == 'Genotyped' and info_score != 1:
    raise ValueError(
        f'INFO score for genotyped variants must be 1: {info_score}'
    )
  if not 0 <= p_hwe_pop <= 1:
    raise ValueError(f'HWE p-value must be in [0, 1]: {p_hwe_pop}')
  if not 0 <= p <= 1:
    raise ValueError(f'Association p-value must be in [0, 1]: {p}')
  if se < 0:
    raise ValueError(f'Standard error must be positive: {se}')
  if (ctrl_cnt is None) != (case_cnt is None):
    raise ValueError(
        'ctrl_cnt and case_cnt must either both be None or neither'
        f' be None: {ctrl_cnt} vs {case_cnt}'
    )
  if ctrl_cnt is not None and ctrl_cnt < 0:
    raise ValueError(f'Non-None ctrl_cnt must be non-negative: {ctrl_cnt}')
  if case_cnt is not None and case_cnt < 0:
    raise ValueError(f'Non-None case_cnt must be non-negative: {case_cnt}')


class AssocResult(
    typing.NamedTuple(
        'AssocResult',
        [
            ('vid', str),
            ('chrom', str),
            ('bp', int),
            ('ref', str),
            ('alt', str),
            ('rsid', str),
            ('affx', str),
            ('eff', str),
            ('alt_freq', float),
            ('num_indv', int),
            ('src', str),
            ('info_score', float),
            ('p_hwe_pop', ArbitraryPrecisionReal),
            ('p', ArbitraryPrecisionReal),
            ('beta', float),
            ('se', float),
            ('ctrl_cnt', Optional[int]),
            ('case_cnt', Optional[int]),
        ],
    )
):
  """Container class for GWAS pipeline association results.

  This is suitable for representing individual lines from both
  "association_results.raw.tsv" and "association_results.filtered.tsv" output
  files generated by the GWAS pipeline.

  Attributes:
    vid: The variant ID.
    chrom: The string representation of the chromosome.
    bp: The 1-based base pair position of the variant.
    ref: The reference allele of the variant. In [ACGT]+.
    alt: The alternate allele of the variant. In [ACGT]+.
    rsid: The rsid of the variant. If it does not exist, uses
      normalization.EMPTY_FIELD_PLACEHOLDER instead.
    affx: The Affymetrix ID of the variant. If it does not exist, uses
      normalization.EMPTY_FIELD_PLACEHOLDER.
    eff: The effect allele used for estimating `beta`. Must be one of the `ref`
      or `alt` alleles.
    alt_freq: The frequency of the `alt` allele.
    num_indv: The number of individuals used to compute the association result.
    src: The source of the variant (genotyped directly or imputed).
    info_score: The INFO score of the variant. For genotyped variants, is 1.0.
    p_hwe_pop: The population Hardy-Weinberg p-value of the variant.
    p: The p-value of association.
    beta: The effect size of the association.
    se: The standard error of the effect size estimate.
    ctrl_cnt: The number of controls present in the association result. This
      field is only applicable for PLINK association runs of case/control
      phenotypes; when it is not applicable this field is set to None.
    case_cnt: The number of cases present in the association result. This field
      is only applicable for PLINK association runs of case/control phenotypes;
      when it is not applicable this field is set to None.
  """

  __slots__ = ()

  def __new__(
      cls,
      vid: str,
      chrom: str,
      bp: int,
      ref: str,
      alt: str,
      rsid: str,
      affx: str,
      eff: str,
      alt_freq: float,
      num_indv: int,
      src: str,
      info_score: float,
      p_hwe_pop: ArbitraryPrecisionReal,
      p: ArbitraryPrecisionReal,
      beta: float,
      se: float,
      ctrl_cnt: Optional[int],
      case_cnt: Optional[int],
  ) -> 'AssocResult':
    _validate_assoc_result(
        vid=vid,
        chrom=chrom,
        bp=bp,
        ref=ref,
        alt=alt,
        rsid=rsid,
        affx=affx,
        eff=eff,
        alt_freq=alt_freq,
        num_indv=num_indv,
        src=src,
        info_score=info_score,
        p_hwe_pop=p_hwe_pop,
        p=p,
        beta=beta,
        se=se,
        ctrl_cnt=ctrl_cnt,
        case_cnt=case_cnt,
    )
    return super().__new__(
        cls,
        vid=vid,
        chrom=chrom,
        bp=bp,
        ref=ref,
        alt=alt,
        rsid=rsid,
        affx=affx,
        eff=eff,
        alt_freq=alt_freq,
        num_indv=num_indv,
        src=src,
        info_score=info_score,
        p_hwe_pop=p_hwe_pop,
        p=p,
        beta=beta,
        se=se,
        ctrl_cnt=ctrl_cnt,
        case_cnt=case_cnt,
    )

  @classmethod
  def from_line(cls, line) -> 'AssocResult':
    """Returns an AssocResult from its string representation."""
    tokens = line.strip().split('\t')
    if len(tokens) not in [17, 19]:
      raise ValueError(
          f'Expected 17 or 19 tokens in AssocResult, got {len(tokens)}: {line}'
      )
    if len(tokens) == 19:
      ctrl_cnt, case_cnt = [int(x) for x in tokens[17:]]
    else:
      ctrl_cnt, case_cnt = None, None

    return cls(
        vid=tokens[0],
        chrom=tokens[1],
        bp=int(tokens[2]),
        ref=tokens[3],
        alt=tokens[4],
        rsid=tokens[5],
        affx=tokens[6],
        eff=tokens[7],
        alt_freq=float(tokens[8]),
        num_indv=int(tokens[9]),
        src=tokens[10],
        info_score=float(tokens[11]),
        # Note: We skip tokens[12] as that is the deprecated p_hwe_coh.
        p_hwe_pop=pvalue_str_to_numeric(tokens[13]),
        p=pvalue_str_to_numeric(tokens[14]),
        beta=float(tokens[15]),
        se=float(tokens[16]),
        ctrl_cnt=ctrl_cnt,
        case_cnt=case_cnt,
    )

  def __str__(self) -> str:
    """Returns a string representation of self.

    Note: Round tripping `__str__` and `from_line` are not guaranteed to be the
    same due to float numerical issues.
    """
    tokens = [
        self.vid,
        self.chrom,
        str(self.bp),
        self.ref,
        self.alt,
        self.rsid,
        self.affx,
        self.eff,
        str(self.alt_freq),
        str(self.num_indv),
        self.src,
        str(self.info_score),
        # Supporting the old P_HWE_COH which is unused.
        '.',
        # Transform the p-values to .lower() so that scientific notation uses
        # lowercase 'e' which is what our pipelines write out.
        str(self.p_hwe_pop).lower(),
        str(self.p).lower(),
        str(self.beta),
        str(self.se),
    ]
    if self.ctrl_cnt is not None:
      # Add case and control counts if they are applicable.
      tokens += [str(self.ctrl_cnt), str(self.case_cnt)]
    return '\t'.join(tokens)

  def columns(self) -> List[str]:
    """Returns the header column names corresponding to this."""
    cols = [
        'VID',
        'CHR',
        'BP',
        'REF',
        'ALT',
        'RS',
        'AFFX',
        'EFF',
        'AAF',
        'NUM_INDV',
        'SRC',
        'INFO_SCORE',
        'P_HWE_COH',
        'P_HWE_POP',
        'P',
        'BETA',
        'SE',
    ]
    if self.ctrl_cnt is not None:
      cols += ['CTRL_CNT', 'CASE_CNT']
    return cols

  @property
  def maf(self) -> float:
    """Returns the minor allele frequency."""
    return min(self.alt_freq, 1 - self.alt_freq)

  @property
  def name(self) -> str:
    """Returns a human-readable name for the variant."""
    if self.rsid and self.rsid != '.':
      return self.rsid
    return f'{self.chrom}:{self.bp}_{self.ref}_{self.alt}'


def _validate_annotated_locus(
    assoc_result: AssocResult,
    cluster_left_bp: int,
    cluster_right_bp: int,
    num_variants: int,
    num_hits: int,
    cytoband: str,
    gene_context: str,
    closest_genes: str,
    replication_variants: 'collections.OrderedDict[str, str]',
) -> None:
  """Raises ValueError if inputs are not a valid AnnotatedLocus."""
  if cluster_left_bp > assoc_result.bp:
    raise ValueError(
        'Cluster left bp must not be after locus index bp: '
        f'{cluster_left_bp} vs {assoc_result.bp}'
    )
  if cluster_right_bp < assoc_result.bp:
    raise ValueError(
        'Cluster right bp must not be before locus index bp: '
        f'{cluster_right_bp} vs {assoc_result.bp}'
    )
  if num_hits < 1:
    raise ValueError(f'Num hits in locus must be at least 1: {num_hits}')
  if num_variants < num_hits:
    raise ValueError(
        'Num variants in locus must be >= num hits: '
        f'{num_variants} vs {num_hits}'
    )


def _load_replication_variants(
    column_names: List[str], values: List[str]
) -> collections.OrderedDict:
  if len(column_names) != len(values):
    raise ValueError(
        f'Non-matching number of columns and values: {column_names}, {values}'
    )
  return collections.OrderedDict(zip(column_names, values))


class AnnotatedLocus(
    typing.NamedTuple(
        'AnnotatedLocus',
        [
            ('assoc_result', AssocResult),
            ('cluster_left_bp', int),
            ('cluster_right_bp', int),
            ('num_variants', int),
            ('num_hits', int),
            ('cytoband', str),
            ('gene_context', str),
            ('closest_genes', str),
            ('replication_variants', 'collections.OrderedDict[str, str]'),
        ],
    )
):
  """Container class for GWAS pipeline annotated hit results.

  Attributes:
    assoc_result: A constituent AssocResult for the index variant. Its
      attributes are available to simplify access.
    cluster_left_bp: The position of the leftmost significant variant within the
      locus cluster.
    cluster_right_bp: The position of the rightmost significant variant within
      the locus cluster.
    num_variants: The number of significant variants contained in the cluster(s)
      represented by this AnnotatedLocus.
    num_hits: The number of independent GWAS signals (clusters) present in this
      AnnotatedLocus.
    cytoband: The cytoband location of the index variant.
    gene_context: The context of the index variant relative to nearby genes.
    closest_genes: The closest gene(s) (if located within one or more gene
      bodies), otherwise the closest leftmost and rightmost genes within 1 MB.
    replication_variants: A collections.OrderedDict with keys being the
      replication keyword column name in the
      "association_results.annotated_hits.tsv" file and the value being a
      colon-delimited string containing the IDs of the variants in the GWAS
      catalog with this keyword that are replicated by this variant.
  """

  __slots__ = ()

  def __new__(
      cls,
      assoc_result: AssocResult,
      cluster_left_bp: int,
      cluster_right_bp: int,
      num_variants: int,
      num_hits: int,
      cytoband: str,
      gene_context: str,
      closest_genes: str,
      replication_variants: 'collections.OrderedDict[str, str]',
  ) -> 'AnnotatedLocus':
    _validate_annotated_locus(
        assoc_result=assoc_result,
        cluster_left_bp=cluster_left_bp,
        cluster_right_bp=cluster_right_bp,
        num_variants=num_variants,
        num_hits=num_hits,
        cytoband=cytoband,
        gene_context=gene_context,
        closest_genes=closest_genes,
        replication_variants=replication_variants,
    )
    return super().__new__(
        cls,
        assoc_result=assoc_result,
        cluster_left_bp=cluster_left_bp,
        cluster_right_bp=cluster_right_bp,
        num_variants=num_variants,
        num_hits=num_hits,
        cytoband=cytoband,
        gene_context=gene_context,
        closest_genes=closest_genes,
        replication_variants=replication_variants,
    )

  @classmethod
  def from_line(cls, line, header_tokens: List[str]) -> 'AnnotatedLocus':
    """Returns an instance of the class from its string representation.

    Args:
      line: The line to be converted into an AnnotatedLocus.
      header_tokens: The ordered list of header fields for the record.

    Returns:
      An AnnotatedLocus representing the line.
    """
    tokens = line.strip().split('\t')
    if len(tokens) != len(header_tokens):
      raise ValueError(
          'Number of tokens does not match header. '
          f'Header ({len(header_tokens)}): {header_tokens}. '
          f'Tokens ({len(tokens)}): {tokens}.'
      )
    if header_tokens[-1] != 'LOCUSZOOM_ID':
      raise ValueError(
          'Last header field expected to be deprecated '
          f'"LOCUSZOOM_ID", found: {header_tokens[-1]}.'
      )

    # There are either 17 or 19 tokens in the AssocResult (19 iff it is a binary
    # phenotype run using PLINK).
    num_assoc_tokens = 19 if 'CTRL_CNT' in header_tokens else 17
    assoc_tokens = tokens[:num_assoc_tokens]
    assoc_result = AssocResult.from_line('\t'.join(assoc_tokens))

    # This relies on the convention that the replication names start immediately
    # after the 'CLOSEST_GENES' column.
    annot_start_ix = header_tokens.index('CLOSEST_GENES') + 1
    # Slice off the final field since it is the deprecated LOCUSZOOM_ID field.
    annot_names = header_tokens[annot_start_ix:-1]
    if not all(name.endswith('_HITS') for name in annot_names):
      raise ValueError(
          'Expected all annotation names to end in "_HITS", found:'
          f' {annot_names}.'
      )
    replication_variant_strs = tokens[annot_start_ix:-1]
    replication_variants = _load_replication_variants(
        annot_names, replication_variant_strs
    )

    field_lookup = dict(zip(header_tokens, tokens))
    # Note: We drop the 'CLUSTER_ID' field that is only present in
    # distance-based clustering as it was not being referenced elsewhere anyway.
    return cls(
        assoc_result=assoc_result,
        cluster_left_bp=int(field_lookup['CLUSTER_LEFT']),
        cluster_right_bp=int(field_lookup['CLUSTER_RIGHT']),
        # This field is optionally generated depending on whether multiple
        # independent hits were merged in the writing of the output. If not,
        # this is presumed to be a single independent hit.
        num_hits=int(field_lookup.get('NUM_CLUSTERS', 1)),
        num_variants=int(field_lookup['CLUSTER_SIZE']),
        cytoband=field_lookup['CYTOBAND'],
        gene_context=field_lookup['GENE_CONTEXT'],
        closest_genes=field_lookup['CLOSEST_GENES'],
        replication_variants=replication_variants,
    )

  def __str__(self) -> str:
    """Returns a string representation of self."""
    return '\t'.join([
        str(self.assoc_result),
        str(self.num_hits),
        str(self.cluster_left_bp),
        str(self.cluster_right_bp),
        str(self.num_variants),
        self.cytoband,
        self.gene_context,
        self.closest_genes,
        '\t'.join(self.replication_variants.values()),
        '.',  # Deprecated LocusZoom ID.
    ])

  def columns(self) -> List[str]:
    """Returns the list of columns of a TSV file representing this."""
    return (
        self.assoc_result.columns()
        + [
            'NUM_CLUSTERS',
            'CLUSTER_LEFT',
            'CLUSTER_RIGHT',
            'CLUSTER_SIZE',
            'CYTOBAND',
            'GENE_CONTEXT',
            'CLOSEST_GENES',
        ]
        + list(self.replication_variants.keys())
        + ['LOCUSZOOM_ID']
    )

  def merge(
      self, other: 'AnnotatedLocus', strict: bool = True
  ) -> 'AnnotatedLocus':
    """Returns an AnnotatedLocus consisting of the merger of the two.

    The "index" position retained is based on the p-value of association -- the
    more significant variant is kept. Assumptions of the merge are that the
    independent "hits" associated with each of the inputs are disjoint, so that
    total variants and total hits can just be summed together. This assumption
    is true for both our LD-based and distance-based clustering implementations.

    Args:
      other: The other AnnotatedLocus to merge with.
      strict: If True, raises an exception when replication keys don't match.
        Otherwise we merge them.

    Returns:
      An AnnotatedLocus that is the merger of the two.

    Raises:
      ValueError: The hits cannot be safely merged.
    """
    if self.chrom != other.chrom:
      raise ValueError(
          f'Cannot merge hits on separate chromosomes: {self} vs {other}'
      )

    self_replication_keys = list(self.replication_variants)
    other_replication_keys = list(other.replication_variants)
    chrom_pos = (
        f'{self.assoc_result.chrom}:'
        f'{self.cluster_left_bp}-{self.cluster_right_bp}'
    )

    if self_replication_keys == other_replication_keys:
      merged_replication_keys = self_replication_keys
    elif strict:
      raise ValueError(
          f'Cannot merge hits with different replications at {chrom_pos}: '
          f'{self_replication_keys} vs {other_replication_keys}.'
      )
    else:
      # Using print to minimize dependencies of the docker scripts.
      print(
          f'Merging distinct replication keys at {chrom_pos}: '
          f'{self_replication_keys} and {other_replication_keys}'
      )
      # Concatenate & deduplicate while preserving the order.
      merged_replication_keys = list(
          dict.fromkeys(self_replication_keys + other_replication_keys)
      )

    replication_variants = collections.OrderedDict()
    for key in merged_replication_keys:
      replication_variants[key] = _merge_replication_variants(
          self.replication_variants.get(key),
          other.replication_variants.get(key),
      )

    index_hit = self if self.p <= other.p else other
    return AnnotatedLocus(
        # Note: It's fine to use the underlying AssocResult object since it's
        # immutable.
        assoc_result=index_hit.assoc_result,
        cluster_left_bp=min(self.cluster_left_bp, other.cluster_left_bp),
        cluster_right_bp=max(self.cluster_right_bp, other.cluster_right_bp),
        num_variants=self.num_variants + other.num_variants,
        num_hits=self.num_hits + other.num_hits,
        cytoband=index_hit.cytoband,
        gene_context=index_hit.gene_context,
        closest_genes=index_hit.closest_genes,
        replication_variants=replication_variants,
    )

  def overlaps(self, other: 'AnnotatedLocus') -> bool:
    """Returns True if and only if `self` overlaps `other`."""
    return self.chrom == other.chrom and min(
        self.cluster_right_bp, other.cluster_right_bp
    ) >= max(self.cluster_left_bp, other.cluster_left_bp)

  # Convenience methods for accessing the contained AssocResult.

  @property
  def vid(self) -> str:
    return self.assoc_result.vid

  @property
  def chrom(self) -> str:
    return self.assoc_result.chrom

  @property
  def bp(self) -> int:
    return self.assoc_result.bp

  @property
  def ref(self) -> str:
    return self.assoc_result.ref

  @property
  def alt(self) -> str:
    return self.assoc_result.alt

  @property
  def rsid(self) -> str:
    return self.assoc_result.rsid

  @property
  def affx(self) -> str:
    return self.assoc_result.affx

  @property
  def eff(self) -> str:
    return self.assoc_result.eff

  @property
  def alt_freq(self) -> float:
    return self.assoc_result.alt_freq

  @property
  def num_indv(self) -> int:
    return self.assoc_result.num_indv

  @property
  def src(self) -> str:
    return self.assoc_result.src

  @property
  def info_score(self) -> float:
    return self.assoc_result.info_score

  @property
  def p_hwe_pop(self) -> ArbitraryPrecisionReal:
    return self.assoc_result.p_hwe_pop

  @property
  def p(self) -> ArbitraryPrecisionReal:
    return self.assoc_result.p

  @property
  def beta(self) -> float:
    return self.assoc_result.beta

  @property
  def se(self) -> float:
    return self.assoc_result.se

  @property
  def maf(self) -> float:
    return self.assoc_result.maf

  @property
  def name(self) -> str:
    return self.assoc_result.name


class AnnotatedLoci:
  """Class representing a file of AnnotatedLocus records.

  This can be either the independent hits identified by LD-based clumping or
  further reduced to distance-based loci.
  """

  def __init__(self, chrom_locus_map: Dict[str, List[AnnotatedLocus]]):
    """Constructor.

    This is not usually instantiated directly, with preference for `from_file`.

    Arguments:
      chrom_locus_map: collections.OrderedDict mapping from chromosome name to a
        list of AnnotatedLocus records on that chromosome sorted by position.
    """
    self._chrom_locus_map = chrom_locus_map

  @classmethod
  def from_file(
      cls, path_or_buffer: Union[str, os.PathLike, io.TextIOBase]
  ) -> 'AnnotatedLoci':
    """Returns an AnnotatedLoci object from the file."""
    if isinstance(path_or_buffer, (str, os.PathLike)):
      handle = open(path_or_buffer, 'rt')
    else:
      handle = path_or_buffer

    chrom_locus_map = collections.OrderedDict()
    with handle:
      try:
        header = next(handle)
      except StopIteration:
        # Completely empty file, this should be treated the same as a header
        # that is just blank (what the GWAS pipeline emits).
        header = ''

      if not header.strip():
        # There are no hits at all. Just return an empty object.
        return cls(chrom_locus_map=chrom_locus_map)

      header_tokens = header.strip().split('\t')
      for line in handle:
        locus = AnnotatedLocus.from_line(line, header_tokens)
        if locus.chrom not in chrom_locus_map:
          chrom_locus_map[locus.chrom] = []
        chrom_locus_map[locus.chrom].append(locus)

    # Order the results by position.
    for chrom in chrom_locus_map.keys():
      chrom_locus_map[chrom] = sorted(
          chrom_locus_map[chrom], key=lambda locus: locus.bp
      )
    return cls(chrom_locus_map=chrom_locus_map)

  def chroms(self) -> List[str]:
    return list(self._chrom_locus_map.keys())

  def loci_counts_per_chrom(self) -> Dict[str, int]:
    """Returns the number of loci in each chromosome."""
    return {c: len(self._chrom_locus_map[c]) for c in self._chrom_locus_map}

  def get_loci_in_chrom(self, chrom: str) -> List[AnnotatedLocus]:
    """Returns all loci in the given chromosome."""
    if chrom in self._chrom_locus_map:
      return copy.deepcopy(self._chrom_locus_map[chrom])
    else:
      return []

  def to_file(
      self, path_or_buffer: Union[str, os.PathLike, io.TextIOBase]
  ) -> None:
    """Write to a file."""
    if isinstance(path_or_buffer, (str, os.PathLike)):
      handle = open(path_or_buffer, 'wt')
    else:
      handle = path_or_buffer

    if not self._chrom_locus_map:
      # Since we have no loci, we don't know the proper set of columns. Just
      # write an empty file instead (this is the current behavior of the GWAS
      # pipeline anyway). Since above we have opened the handle, we can just
      # return now to achieve this.
      return

    # By construction, all keys present in the map have non-empty values.
    locus = next(iter(self._chrom_locus_map.values()))[0]
    header_cols = locus.columns()

    handle.write('\t'.join(header_cols) + '\n')
    for chrom_loci in self._chrom_locus_map.values():
      for locus in chrom_loci:
        assert locus.columns() == header_cols
        handle.write(str(locus) + '\n')

  def merge_by_distance(
      self, max_cluster_separation: int, strict: bool = True
  ) -> 'AnnotatedLoci':
    """Returns an AnnotatedLoci object with variants merged by distance.

    Args:
      max_cluster_separation: The maximum distance in bp between clusters to
        allow to be merged together.
      strict: If True, raises an exception when replication keys don't match.
        Otherwise we merge them.

    Returns:
      An AnnotatedLoci object with the hits merged by distance.
    """
    merged_locus_map = collections.OrderedDict()
    for chrom, unmerged_loci in self._chrom_locus_map.items():
      merged_locus_map[chrom] = _merge_chrom_loci_by_distance(
          unmerged_loci, max_cluster_separation, strict=strict
      )
    return AnnotatedLoci(chrom_locus_map=merged_locus_map)

  def all_loci(self) -> Generator[AnnotatedLocus, None, None]:
    """Yields all AnnotatedLocus records in order."""
    for chrom_loci in self._chrom_locus_map.values():
      yield from chrom_loci

  def items(self) -> Generator[Tuple[str, List[AnnotatedLocus]], None, None]:
    """Yields (chrom, List[AnnotatedLocus]) elements."""
    yield from self._chrom_locus_map.items()

  def __len__(self) -> int:
    """Returns the total number of loci."""
    return sum(len(chrom_loci) for chrom_loci in self._chrom_locus_map.values())

  def union(self, others: Sequence['AnnotatedLoci']) -> 'AnnotatedLoci':
    """Returns all loci in this and other AnnotatedLoci."""
    if not others:
      raise ValueError('`others` cannot be empty.')
    merged_locus_map = collections.OrderedDict()
    new_chroms_set = set(self.chroms())
    for other in others:
      new_chroms_set |= set(other.chroms())
    new_chroms = sorted(new_chroms_set)

    for chrom in new_chroms:
      new_loci = self.get_loci_in_chrom(chrom)
      for other in others:
        new_loci += other.get_loci_in_chrom(chrom)
      merged_locus_map[chrom] = new_loci
    return AnnotatedLoci(chrom_locus_map=merged_locus_map)

  def replication_results(
      self, other: 'AnnotatedLoci'
  ) -> Tuple['AnnotatedLoci', 'AnnotatedLoci']:
    """Returns the loci that are shared with `other` and distinct from `other`.

    This function is not symmetric, since a single locus in one `AnnotatedLoci`
    object may replicate zero, one, or multiple loci in another.

    Args:
      other: The other AnnotatedLoci object used to determine replication
        status.

    Returns:
      A tuple of (replicated, unique) loci from `self`. The two sets are
      mutually exclusive and every locus in `self` is in one of the two outputs.
    """
    replicated = collections.OrderedDict()
    unique = collections.OrderedDict()
    for chrom, chrom_loci in self.items():
      # Loci on the chromosome of interest that are replicated in `other`.
      chrom_replicated = []
      # Loci on the chromosome of interest that are distinct/unique from the
      # loci in `other`.
      chrom_unique = []
      # pylint: disable=protected-access
      other_chrom_loci = other._chrom_locus_map.get(chrom, [])
      # pylint: enable=protected-access
      for locus in chrom_loci:
        if any(locus.overlaps(other_locus) for other_locus in other_chrom_loci):
          chrom_replicated.append(locus)
        else:
          chrom_unique.append(locus)

      if chrom_replicated:
        replicated[chrom] = copy.deepcopy(chrom_replicated)
      if chrom_unique:
        unique[chrom] = copy.deepcopy(chrom_unique)

    return AnnotatedLoci(replicated), AnnotatedLoci(unique)


def _merge_chrom_loci_by_distance(
    loci: List[AnnotatedLocus], max_cluster_separation: int, strict: bool
) -> List[AnnotatedLocus]:
  """Returns a list of AnnotatedLocus objects merged by distance.

  Args:
    loci: The loci to merge.
    max_cluster_separation: The maximum distance in bp between clusters to allow
      to be merged.
    strict: If True, raises an exception when replication keys don't match.
      Otherwise we merge them.

  Returns:
    The merged loci, sorted by index variant position.
  """
  if not loci:
    return []

  retval = []
  # Ensure inputs are sorted by cluster left for merging. This is not
  # necessarily the same as being sorted by index position, since clusters may
  # be defined using LD.
  left_sorted_loci = sorted(
      loci, key=lambda locus: (locus.cluster_left_bp, locus.cluster_right_bp)
  )
  curr = left_sorted_loci[0]
  for locus in left_sorted_loci[1:]:
    sep = locus.cluster_left_bp - curr.cluster_right_bp
    if sep <= max_cluster_separation:
      curr = curr.merge(locus, strict=strict)
    else:
      retval.append(curr)
      curr = locus
  # Fencepost.
  retval.append(curr)

  return sorted(retval, key=lambda locus: locus.bp)


class GwasResults:
  """Class that contains results from a GWAS pipeline run.

  It has both an AnnotatedLoci object and all underlying association results
  that pass quality control. These underlying association results are guaranteed
  to be a superset of the variants present in the `AnnotatedLoci` object.
  """

  REQUIRED_ASSOC_COLS = frozenset(['VID', 'CHR', 'BP', 'P', 'BETA', 'SE'])

  def __init__(
      self, loci: AnnotatedLoci, chrom_assoc_df: Dict[str, pd.DataFrame]
  ):
    """Constructor.

    This is typically not called directly; prefer `GwasResults.from_files` or
    `GwasResults.from_gwas_pipeline_dir` instead.

    Args:
      loci: `AnnotatedLoci` object. This is the set of "GWAS hits" for the GWAS;
        see its class description for full details.
      chrom_assoc_df: Mapping from chromosome name to DataFrame of all variants
        that pass quality control for the GWAS (typically a chromosome-sharded
        version of the "association_results.filtered.tsv" file).

    Raises:
      ValueError: DataFrames in `chrom_assoc_df` do not have identical columns.
      ValueError: DataFrames in `chrom_assoc_df` do not contain all
        `REQUIRED_ASSOC_COLS` columns.
      ValueError: `loci` contains some variants not present in `chrom_assoc_df`.
    """
    chrom_cols = list(
        set([frozenset(df.columns) for df in chrom_assoc_df.values()])
    )
    if len(chrom_cols) != 1:
      raise ValueError(
          f'All assoc dataframes must have the same columns: {chrom_cols}'
      )
    if GwasResults.REQUIRED_ASSOC_COLS - chrom_cols[0]:
      raise ValueError(
          f'assoc_df missing some required columns: {sorted(chrom_cols[0])}'
      )
    loci_vids = set(locus.vid for locus in loci.all_loci())
    assoc_vids = set()
    for df in chrom_assoc_df.values():
      assoc_vids.update(df['VID'])

    if loci_vids - assoc_vids:
      raise ValueError(
          f'{len(loci_vids - assoc_vids)} VIDs from loci not in assoc_df.'
      )

    self._loci = loci
    self._chrom_assoc_df = chrom_assoc_df

  @classmethod
  def from_files(
      cls,
      annotated_loci_path_or_buffer: Union[str, os.PathLike, io.TextIOBase],
      filtered_results_path_or_buffer: Union[str, os.PathLike, io.TextIOBase],
      columns: Optional[Sequence[str]] = None,
  ) -> 'GwasResults':
    """Returns a GwasResults object from the input filepaths.

    Args:
      annotated_loci_path_or_buffer: Path to the annotated hits (typically named
        "association_results.annotated_hits.tsv" or
        "association_results.annotated_loci.tsv").
      filtered_results_path_or_buffer: Path to the filtered results (typically
        named "association_results.filtered.tsv").
      columns: List of columns of the filtered results to retain in the data. We
        always keep required columns `REQUIRED_ASSOC_COLS`; this field can be
        used to specify additional columns if they should be retained.

    Returns:
      GwasResults object representing the GWAS.
    """
    loci = AnnotatedLoci.from_file(annotated_loci_path_or_buffer)

    if isinstance(filtered_results_path_or_buffer, (str, os.PathLike)):
      handle = open(filtered_results_path_or_buffer, mode='rt')
    else:
      handle = filtered_results_path_or_buffer

    if columns:
      cols = GwasResults.REQUIRED_ASSOC_COLS | set(columns)
    else:
      cols = GwasResults.REQUIRED_ASSOC_COLS

    with handle:
      assoc_df = pd.read_csv(handle, delimiter='\t', usecols=cols)
      assoc_df.CHR = assoc_df.CHR.astype(str)
      chrom_assoc_df = {
          chrom: assoc_df.loc[assoc_df.CHR == chrom].copy(deep=True)
          for chrom in assoc_df.CHR.unique()
      }
    return cls(loci=loci, chrom_assoc_df=chrom_assoc_df)

  @classmethod
  def from_gwas_pipeline_dir(
      cls, gwas_dir: str, columns: Optional[Sequence[str]] = None
  ) -> 'GwasResults':
    """Convenience method for loading from a GWAS pipeline output directory.

    Args:
      gwas_dir: The directory in which GWAS pipeline outputs are written.
      columns: List of columns of the filtered results to retain in the data. We
        always keep required columns `REQUIRED_ASSOC_COLS`; this field can be
        used to specify additional columns if they should be retained.

    Returns:
      GwasResults object representing the GWAS.
    """
    annotated_loci_file = os.path.join(
        gwas_dir, 'analysis', 'association_results.annotated_hits.tsv'
    )
    filtered_results_file = os.path.join(
        gwas_dir, 'analysis', 'association_results.filtered.tsv'
    )
    return cls.from_files(
        annotated_loci_path_or_buffer=annotated_loci_file,
        filtered_results_path_or_buffer=filtered_results_file,
        columns=columns,
    )

  def merge_by_distance(self, max_cluster_separation: int) -> 'GwasResults':
    """Returns a GwasResults object with variants merged by distance.

    Args:
      max_cluster_separation: The maximum distance in bp between clusters to
        allow to be merged together.

    Returns:
      A GwasResults object with the hits merged by distance.
    """
    return GwasResults(
        loci=self._loci.merge_by_distance(
            max_cluster_separation=max_cluster_separation
        ),
        chrom_assoc_df=copy.deepcopy(self._chrom_assoc_df),
    )


def join_two_gwases(
    res1: GwasResults,
    res2: GwasResults,
    how: str,
    columns: Optional[Sequence[str]] = None,
    suffixes: Tuple[str, str] = ('_x', '_y'),
) -> pd.DataFrame:
  """Returns a DataFrame containing the results of joining two `GwasResult`s.

  The goal of joining is to compare information about equivalent variants across
  two `GwasResult`s. Typically this is performed to understand whether one GWAS
  has better power than another (by comparing p-values) or to determine
  consistency between estimated effect sizes (by comparing effect sizes). This
  join can be performed at the raw association result level (i.e. by joining all
  pairs of variants that have satisfied QC to produce valid results), but this
  can be visually misleading when one or few clumps of variants in strong LD
  dominate the results. Consequently, we also support joining of just a subset
  of variants (the independent loci identified by one or both GWAS). This is
  conceptually equivalent to joining results from all variants and then
  filtering those results to just the subset of interest (but is implemented by
  filtering and then joining for computational efficiency).

  There are four supported methods for joining results:
    * 'left_loci': All loci from the left (i.e. res1) result are included. This
        is a typically useful incantation, and reflects direct comparison of
        variants highlighted in a particular study.
    * 'either_loci': Loci present in either of the results are included. This
        is a fuller comparison of two GWASs than using 'left_loci', but can
        cause the same locus to be represented multiple times (if in the same
        region, one variant is designated the "locus representative" in one
        GWAS but a different variant is the locus representative for the other).
    * 'all_variants': Joins all variants that passed QC metrics in both GWASs.
    * 'best_loci': Currently not implemented. Its goal is to have the benefit of
        the 'either_loci' method but avoid double-counting.

  Filtering by 'left_loci' or 'either_loci' means just keeping the specific
  variants that correspond to the AnnotatedLocus objects within the desired
  GWAS(s) for the join.

  Note that the total number of variants in the result may be slightly smaller
  than requested, as variants must have passed QC checks in both GWASs to be
  included.

  Args:
    res1: One `GwasResults` object to compare (the "left" result).
    res2: The other `GwasResults` object to compare.
    how: How the join should be performed. Must be in [left_loci, either_loci,
      all_variants, best_loci] (see above for details).
    columns: Columns to keep in the output DataFrame.
    suffixes: Suffixes to append to the left and right shared columns in the
      output DataFrame (see pandas.merge documentation for details).

  Returns:
    A pd.DataFrame containing genome-wide joined variants.
  """
  if columns is None:
    columns = ['P', 'BETA', 'SE']
  else:
    columns = list(columns)  # Make a copy as we may need to modify it below.

  # Validate input data.
  if 'VID' in columns and how == 'best_loci':
    raise ValueError(
        '"VID" cannot be in columns when using '
        f'"best_loci" join method: {columns}.'
    )
  if not set(columns) - {'VID'}:
    raise ValueError(f'At least one non-VID column must be present: {columns}.')
  if how not in ['left_loci', 'either_loci', 'best_loci', 'all_variants']:
    raise ValueError(f'Unsupported GWAS join method: "{how}".')
  if not any(suffixes):
    raise ValueError(f'At least one suffix must be non-empty: {suffixes}')
  if not len(suffixes) == len(set(suffixes)) == 2:
    raise ValueError(f'Suffixes must be unique: {suffixes}')

  if how == 'best_loci':
    # This must be implemented independently, as it is a "fuzzy" join where the
    # left and right variant values returned do not necessarily come from the
    # same variants.
    raise NotImplementedError('Join using "best_loci" is not yet implemented.')

  # For all other join types, we select the same variants in each GwasResult.
  # So, we can just find the set of VIDs per chromosome that need to be kept,
  # and share the logic of actually performing the join that includes those
  # VIDs.
  # The following code blocks are logically distinct; the first identifies the
  # particular VIDs to retain, and the next does the actual joining based on
  # those VIDs. If these functionalities show independent utility we can factor
  # them into separate functions.
  chrom_vids = collections.defaultdict(set)
  # pylint: disable=protected-access
  if how == 'left_loci':
    for chrom, loci in res1._loci.items():
      chrom_vids[chrom].update({locus.vid for locus in loci})
  elif how == 'either_loci':
    for gwas_results in [res1, res2]:
      for chrom, loci in gwas_results._loci.items():
        chrom_vids[chrom].update({locus.vid for locus in loci})
  elif how == 'all_variants':
    for chrom in set(res1._chrom_assoc_df) & set(res2._chrom_assoc_df):
      res1_vids = set(res1._chrom_assoc_df[chrom]['VID'])
      res2_vids = set(res2._chrom_assoc_df[chrom]['VID'])
      chrom_vids[chrom] = res1_vids & res2_vids
  else:
    assert False, 'Programming error -- should not reach this clause.'

  # We do the actual join using VID as the key. If it is not requested as an
  # output field, we need to add it and then remove at the end.
  vid_requested = 'VID' in columns
  if not vid_requested:
    columns.append('VID')

  chrom_results = []
  for chrom, vids in chrom_vids.items():
    if vids and chrom in res1._chrom_assoc_df and chrom in res2._chrom_assoc_df:
      # Trim both dataframes to only the VIDs and columns to retain.
      left_full_df = res1._chrom_assoc_df[chrom]
      left_df = left_full_df.loc[left_full_df['VID'].isin(vids), columns]
      right_full_df = res2._chrom_assoc_df[chrom]
      right_df = right_full_df.loc[right_full_df['VID'].isin(vids), columns]
      joined_chrom_df = pd.merge(
          left_df,
          right_df,
          how='inner',
          on='VID',
          suffixes=suffixes,
          validate='one_to_one',
      )
      if not vid_requested:
        joined_chrom_df.drop('VID', axis='columns', inplace=True)
      chrom_results.append(joined_chrom_df)
  return pd.concat(chrom_results, ignore_index=True)

## Utilities for loading annotated hits, loci, and filtered GWAS results.

In [None]:
@dataclasses.dataclass
class Gwas:
  """Represents an individual GWAS.

  Attributes:
    gwas_id: A unique GWAS identifier.
    gwas_label: A label used in figures when plotting GWAS results.
    base_dir: A base directory containing annotated hits, loci, and results.
    hits_suffix: The suffix of the hits file.
    loci_suffix: The suffix of the loci file.
    results_suffix: The suffix of the filtered results file.
    hits: An `AnnotatedLoci` containing all hits.
    loci: An `AnnotatedLoci` containing all loci.
    gwas_results: An `GwasResults` containing filtered results.
  """

  gwas_id: str
  gwas_label: str
  base_dir: str = ASSOC_RESULTS_BASE_DIR
  hits_suffix: str = ASSOC_RESULTS_HITS_SUFFIX
  loci_suffix: str = ASSOC_RESULTS_LOCI_SUFFIX
  results_suffix: str = ASSOC_RESULTS_FILTERED_SUFFIX
  hits: AnnotatedLoci = dataclasses.field(init=False)
  loci: AnnotatedLoci = dataclasses.field(init=False)
  gwas_results: GwasResults = dataclasses.field(init=False)

  @property
  def hits_path(self) -> str:
    """A path to the GWAS's annotated hits."""
    return os.path.join(self.base_dir, f'{self.gwas_id}{self.hits_suffix}')

  @property
  def loci_path(self) -> str:
    """A path to the GWAS's annotated loci."""
    return os.path.join(self.base_dir, f'{self.gwas_id}{self.loci_suffix}')

  @property
  def results_path(self) -> str:
    """A path to the GWAS's filtered results."""
    return os.path.join(self.base_dir, f'{self.gwas_id}{self.results_suffix}')

  def __post_init__(self):
    if not os.path.exists(self.hits_path):
      raise FileNotFoundError(self.hits_path)
    if not os.path.exists(self.loci_path):
      raise FileNotFoundError(self.loci_path)
    self.hits = AnnotatedLoci.from_file(self.hits_path)
    self.loci = AnnotatedLoci.from_file(self.loci_path)
    self.gwas_results = GwasResults.from_files(
        self.hits_path,
        self.results_path,
        columns=['VID', 'P', 'SE', 'BETA', 'EFF'],
    )

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.gwas_id})'

## Utilities for generating plots and figures from GWAS results.

In [None]:
def gwas_comparison_pvalue_scatter_inset(
    df,
    x_col: str,
    y_col: str,
    xlabel: str = '',
    ylabel: str = '',
    p_thresh: float = 5e-8,
    legend_x_desc: str = '',
    legend_y_desc: str = '',
    inset_axes: Optional[Tuple[float, float]] = None,
    ax: matplotlib.axes.Axes = None,
) -> matplotlib.axes.Axes:
  """Plots a scatterplot of -log_10(P) from the two joined GWASes.

  Args:
    df: pd.DataFrame containing p-value columns for two GWASes.
    x_col: The name of `df`s p-value column to plot on the x-axis.
    y_col: The name of `df`s p-value column to plot on the y-axis.
    xlabel: The label to give the x-axis.
    ylabel: The label to give the y-axis.
    p_thresh: The p-value threshold to annotate on the plot.
    ax: matplotlib.Axes on which to plot. If unspecified, uses the current axes.

  Returns:
    The axes on which the plot is added.

  Raises:
    ValueError: The columns to plot are not present in the dataframe.
  """
  if any(col not in df.columns for col in [x_col, y_col]):
    raise ValueError(f'{x_col} and {y_col} must be in {df.columns}')

  if ax is None:
    ax = plt.gca()

  x = -np.log10(df[x_col])
  y = -np.log10(df[y_col])
  thresh = -np.log10(p_thresh)

  only_y_sig_mask = (x < thresh) & (y >= thresh)
  only_x_sig_mask = (x >= thresh) & (y < thresh)
  both_sig_mask = (x >= thresh) & (y >= thresh)
  neither_sig_mask = (x < thresh) & (y < thresh)

  only_y_sig_color = '#2ca02c'
  only_x_sig_color = '#ff7f0e'
  both_sig_color = '#1f77b4'
  neither_sig_color = '#84878a'

  def scatter_onto_axis(mpl_axis):
    handles = []
    labels = []
    for mask, color, label in [
        (both_sig_mask, both_sig_color, 'Both significant'),
        (
            only_y_sig_mask,
            only_y_sig_color,
            legend_y_desc or 'Only y-axis significant',
        ),
        (
            only_x_sig_mask,
            only_x_sig_color,
            legend_x_desc or 'Only x-axis significant',
        ),
        (neither_sig_mask, neither_sig_color, 'Neither significant'),
    ]:
      if mask.sum():
        masked_x = x[mask]
        masked_y = y[mask]
        handles.append(mpl_axis.scatter(masked_x, masked_y, alpha=0.3, c=color))
        labels.append(f'{label} (N={mask.sum()})')
    return handles, labels

  handles, labels = scatter_onto_axis(ax)

  # Make figure square.
  axlim = np.ceil(max(x.max(), y.max()))
  ax.set_xlim([0, axlim])
  ax.set_ylim([0, axlim])

  ax.legend(handles=handles, labels=labels)

  # Plot bounding box at significant p-value threshold and y=x line.
  ax.hlines(y=thresh, xmin=0, xmax=thresh, linestyles='dashed', colors='r')
  ax.vlines(x=thresh, ymin=0, ymax=thresh, linestyles='dashed', colors='r')
  ax.plot([0, axlim], [0, axlim], 'r--', linewidth=1)

  ax.set_xlabel(xlabel, fontsize=20)
  ax.set_ylabel(ylabel, fontsize=20)
  for xtick in ax.get_xticklabels():
    xtick.update({'fontsize': 16})
  for ytick in ax.get_yticklabels():
    ytick.update({'fontsize': 16})

  if inset_axes:
    min_ax, max_ax = inset_axes
    axins = ax.inset_axes([0.55, 0.07, 0.3, 0.3])
    scatter_onto_axis(axins)
    if min_ax < thresh < max_ax:
      axins.hlines(
          y=thresh, xmin=min_ax, xmax=thresh, linestyles='dashed', colors='r'
      )
      axins.vlines(
          x=thresh, ymin=min_ax, ymax=thresh, linestyles='dashed', colors='r'
      )
    axins.plot([min_ax, max_ax], [min_ax, max_ax], 'r--', linewidth=1)

    axins.set_xlim(*inset_axes)
    axins.set_ylim(*inset_axes)
    ax.indicate_inset_zoom(axins, edgecolor='black')

  return ax


def gwas_comparison_effect_size_scatter_inset(
    df,
    x_col: str,
    y_col: str,
    xlabel: str = '',
    ylabel: str = '',
    p_thresh: float = 5e-8,
    inset_axes: Optional[Tuple[float, float]] = None,
    ax: matplotlib.axes.Axes = None,
) -> matplotlib.axes.Axes:
  if any(col not in df.columns for col in [x_col, y_col]):
    raise ValueError(f'{x_col} and {y_col} must be in {df.columns}')

  if ax is None:
    ax = plt.gca()

  ax.set_xlabel(xlabel, fontsize=20)
  ax.set_ylabel(ylabel, fontsize=20)
  for xtick in ax.get_xticklabels():
    xtick.update({'fontsize': 16})
  for ytick in ax.get_yticklabels():
    ytick.update({'fontsize': 16})

  corr = df[[x_col, y_col]].corr()
  r2 = corr.iloc[0, 1] ** 2
  r2 = '{:.2f}'.format(r2)
  sns.regplot(
      x=x_col,
      y=y_col,
      data=df,
      line_kws={'color': 'red'},
      scatter_kws={'color': 'black'},
  )
  ax.scatter(df[x_col], df[y_col], alpha=0.3, c='black')
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  plt.text(np.min(df[x_col]), np.max(df[y_col]), f'$R^2$={r2}', fontsize=20)
  return ax


def plot_scatter(
    y_gwas: GwasResults,
    x_gwas: GwasResults,
    pheno: str,
    x_axis_method: str,
    y_axis_method: str,
    ax,
    pvalue: bool = True,
) -> None:
  """Plots a p-value or effect size comparison between two GWAS."""

  def update_effect_size(row):
    if row['EFF_x'] == row['EFF_y']:
      return row['BETA_y']
    else:
      return -row['BETA_y']

  merged_df = join_two_gwases(
      x_gwas,
      y_gwas,
      columns=['VID', 'P', 'SE', 'BETA', 'EFF'],
      how='either_loci',
  )
  merged_df['BETA_y'] = merged_df.apply(update_effect_size, axis=1)
  y_gwas_beta_larger_x_gwas = merged_df[
      np.abs(merged_df['BETA_y']) > np.abs(merged_df['BETA_x'])
  ]
  if pvalue:
    ax = gwas_comparison_pvalue_scatter_inset(
        merged_df,
        'P_x',
        'P_y',
        xlabel=f'{pheno} -log(p_value) {x_axis_method}',
        ylabel=f'{pheno} -log(p_value) {y_axis_method}',
        legend_x_desc=f'Only {x_axis_method} significant',
        legend_y_desc=f'Only {y_axis_method} significant',
        inset_axes=(5, 10),
        ax=ax,
    )
  else:
    ax = gwas_comparison_effect_size_scatter_inset(
        merged_df,
        'BETA_x',
        'BETA_y',
        xlabel=f'{pheno} Effect size {x_axis_method}',
        ylabel=f'{pheno} Effect size {y_axis_method}',
        inset_axes=(5, 10),
        ax=ax,
    )


def plot_manhattan(
    ax,
    main_gwas: Gwas,
    overlap_gwas: Gwas,
    chrom_offsets: collections.OrderedDict[str, int] = CHROM_OFFSETS,
    p_cutoff: float = 0.001,
    sig_cutoff: float = 5e-8,
    max_y: float = 170,
    min_y: float = -5,
    min_x: float = -1.5e8,
    max_x: float = 3.025e9,
    seed: int = 23,
):
  """Plots the Manhattan plot for GWAS."""
  prng = np.random.RandomState(seed)

  # Compute replication across the two GWAS.
  replicated_loci, unique_loci = main_gwas.loci.replication_results(
      overlap_gwas.loci
  )
  replicated_vids = {l.vid for l in replicated_loci.all_loci()}
  unique_vids = {l.vid for l in unique_loci.all_loci()}

  # Plot SNPS above the P value cutoff, coloring hits if novel or shared.
  chrom_colors = ['#c2a5cf', '#a6dba0']
  xs, ys, cs = [], [], []
  xhs, yhs, chs = [], [], []
  for chrom_df in main_gwas.gwas_results._chrom_assoc_df.values():
    chrom_df = chrom_df[chrom_df['P'] <= p_cutoff]
    chrom_df = chrom_df.assign(LOGP=-np.log10(chrom_df['P']))
    chrom_df = chrom_df[['VID', 'CHR', 'BP', 'LOGP']]
    for record in chrom_df.to_dict('records'):
      x = chrom_offsets[str(record['CHR'])] + record['BP']
      y = record['LOGP']
      if record['VID'] in unique_vids:
        xhs.append(x)
        yhs.append(y)
        chs.append('#d73027')
      elif record['VID'] in replicated_vids:
        xhs.append(x)
        yhs.append(y)
        chs.append('#4575b4')
      else:
        xs.append(x)
        ys.append(y)
        cs.append(chrom_colors[int(record['CHR']) % len(chrom_colors)])
  ax.scatter(xs, ys, c=cs, s=1, rasterized=True)
  ax.scatter(xhs, yhs, c=chs, s=10)

  # Plot the significance cutoff line.
  ax.axhline(y=-np.log10(sig_cutoff), linestyle='--', color='r', linewidth=0.5)

  for hit in sorted(
      main_gwas.hits.all_loci(),
      key=lambda x: (int(x.assoc_result.chrom), int(x.assoc_result.bp)),
  ):
    x_hit = chrom_offsets[hit.assoc_result.chrom] + hit.assoc_result.bp
    y_hit = -np.log10(hit.assoc_result.p)
    is_common = hit.assoc_result.vid in replicated_vids
    if y_hit < 20 and not is_common:
      continue
    c_text = '#08306b' if is_common else 'k'
    if y_hit < 50:
      ax.annotate(
          hit.closest_genes,
          xy=(x_hit, y_hit),
          xytext=(x_hit, y_hit + 20),
          style='oblique',
          rotation=90,
          color=c_text,
          ha='center',
          va='center',
          arrowprops={
              'facecolor': 'black',
              'width': 0,
              'headwidth': 0,
              'shrink': 0.0,
              'lw': 0.5,
          },
          bbox={'fc': 'none', 'ec': 'k', 'pad': 2, 'lw': 0},
      )
    else:
      ax.annotate(
          hit.closest_genes,
          xy=(x_hit, y_hit),
          xytext=(x_hit, y_hit + 2),
          ha='center',
          color=c_text,
          style='oblique',
      )

  # Plot offsets, ticks, axis labels, etc.
  offsets = np.asarray(list(chrom_offsets.values()))
  xticks = (offsets[1:] + offsets[:-1]) / 2
  ax.set_xticks(xticks)
  ax.set_xticklabels([index for index in range(1, 23)], fontsize=16)
  ax.set_yticks([0, 30, 60, 90, 120])
  ax.set_yticklabels([0, 30, 60, 90, 120], fontsize=18)
  ax.set_ylim([min_y, max_y])
  ax.set_xlim([min_x, max_x])
  ax.set_xlabel('Chromosomes', fontsize=20)
  ax.set_ylabel(r'$-\log_{10}(P)$', fontsize=20)


def _build_gwas_comparison_figure(
    gwas_x: Gwas,
    gwas_y: Gwas,
    pheno: str,
    p_value_ax: plt.Axes,
    effect_size_ax: plt.Axes,
) -> None:
  plot_scatter(
      x_gwas=gwas_x.gwas_results,
      x_axis_method=gwas_x.gwas_label,
      y_gwas=gwas_y.gwas_results,
      y_axis_method=gwas_y.gwas_label,
      ax=p_value_ax,
      pheno=pheno,
  )
  p_value_ax.text(
      0,
      1.15,
      'a',
      transform=p_value_ax.transAxes,
      fontsize=30,
      va='top',
      ha='right',
  )
  plot_scatter(
      x_gwas=gwas_x.gwas_results,
      x_axis_method=gwas_x.gwas_label,
      y_gwas=gwas_y.gwas_results,
      y_axis_method=gwas_y.gwas_label,
      ax=effect_size_ax,
      pheno=pheno,
      pvalue=False,
  )
  effect_size_ax.text(
      0,
      1.15,
      'b',
      transform=effect_size_ax.transAxes,
      fontsize=30,
      va='top',
      ha='right',
  )

  for ax in [p_value_ax, effect_size_ax]:
    ax.spines[['right', 'top']].set_visible(False)
    ax.grid(False)


def build_gwas_comparison_figure(
    gwas_x: Gwas,
    gwas_y: Gwas,
    pheno: str = 'COPD',
) -> plt.Figure:
  """Constructs a GWAS p-value and effect size comparison between two GWAS."""
  fig = plt.figure(figsize=(16, 8), constrained_layout=True)
  spec = fig.add_gridspec(1, 2)

  # Plot the scatter comparison of P values.
  ax00 = fig.add_subplot(spec[0, 0])
  ax01 = fig.add_subplot(spec[0, 1])
  _build_gwas_comparison_figure(gwas_x, gwas_y, pheno, ax00, ax01)
  return fig


def build_figure_4(
    main_gwas: Gwas,
    overlap_gwas: Gwas,
    pheno: str = 'COPD',
) -> plt.Figure:
  """Constructs Figure 4 from the manuscript."""
  fig = plt.figure(figsize=(16, 16), constrained_layout=True)
  spec = fig.add_gridspec(2, 2)

  # Plot the manhattan plot.
  ax0 = fig.add_subplot(spec[0, :])
  plot_manhattan(ax0, main_gwas, overlap_gwas)
  ax0.text(
      0,
      1.15,
      'a',
      transform=ax0.transAxes,
      fontsize=30,
      va='top',
      ha='right',
  )
  ax0.spines[['right', 'top']].set_visible(False)
  ax0.grid(False)

  # Plot the variant-in-hits comparison.
  ax10 = fig.add_subplot(spec[1, 0])
  ax11 = fig.add_subplot(spec[1, 1])
  _build_gwas_comparison_figure(overlap_gwas, main_gwas, pheno, ax10, ax11)

  return fig

In [None]:
# The set of GWAS IDs used in the following analyses.
GWAS_ML_BASED_COPD = 'ml_based_copd'
GWAS_ML_BASED_COPD_NO_MRB_CASES = 'ml_based_copd_no_mrb_cases'
GWAS_HOBBS_NATGEN_2017 = 'hobbs_natgen_2017'
GWAS_ML_BASED_COPD_BINARIZED_GOLD_PREV = 'ml_based_copd_binarized_gold_prev'
GWAS_SAKORNSAKOLPAT_NATGEN_2019 = 'sakornsakolpat_natgen_2019'
GWAS_ML_BASED_COPD_BINARIZED = 'ml_based_copd_binarized'
GWAS_MRB_LABELS_COPD = 'mrb_labels_copd'
GWAS_SPIRO_GOLD_COPD = 'spiro_gold_copd'
GWAS_GBMI_EXCLUDING_UKB_COPD = 'gbmi_excluding_ukb_copd'
GWAS_SPIRO_GOLD_COPD_REGENIE = 'spiro_gold_copd_regenie'
GWAS_SPIRO_GOLD_COPD_BOLT = 'spiro_gold_copd_bolt'
GWAS_IDS = (
    GWAS_ML_BASED_COPD,
    GWAS_ML_BASED_COPD_NO_MRB_CASES,
    GWAS_HOBBS_NATGEN_2017,
    GWAS_ML_BASED_COPD_BINARIZED_GOLD_PREV,
    GWAS_SAKORNSAKOLPAT_NATGEN_2019,
    GWAS_ML_BASED_COPD_BINARIZED,
    GWAS_MRB_LABELS_COPD,
    GWAS_SPIRO_GOLD_COPD,
    GWAS_GBMI_EXCLUDING_UKB_COPD,
    GWAS_SPIRO_GOLD_COPD_REGENIE,
    GWAS_SPIRO_GOLD_COPD_BOLT,
)

# A mapping of GWAS IDs to GWAS axis label.
GWAS_ID_TO_GWAS_LABEL: dict[str, str] = {
    GWAS_ML_BASED_COPD: 'ML-based',
    GWAS_ML_BASED_COPD_NO_MRB_CASES: 'ML-based without MRB COPD',
    GWAS_HOBBS_NATGEN_2017: 'Hobbs et al. 2017 NG',
    GWAS_ML_BASED_COPD_BINARIZED_GOLD_PREV: (
        'ML-based Binarized GOLD Prevalence'
    ),
    GWAS_SAKORNSAKOLPAT_NATGEN_2019: 'Sakornsakolpat et al. 2019 NG',
    GWAS_ML_BASED_COPD_BINARIZED: 'ML-based Binarized',
    GWAS_MRB_LABELS_COPD: 'MRB',
    GWAS_SPIRO_GOLD_COPD: 'Proxy UKB GOLD',
    GWAS_GBMI_EXCLUDING_UKB_COPD: 'GBMI Excluding UKB',
    GWAS_SPIRO_GOLD_COPD_REGENIE: 'Proxy UKB GOLD (Regenie)',
    GWAS_SPIRO_GOLD_COPD_BOLT: 'Proxy UKB GOLD (BOLT-LMM)',
}

# A mapping of GWAS IDs to GWAS results.
g_gwas_id_to_gwas: dict[str, Gwas] = {
    gwas_id: Gwas(gwas_id=gwas_id, gwas_label=GWAS_ID_TO_GWAS_LABEL[gwas_id])
    for gwas_id in GWAS_IDS
}

## Main Figures

Figure 4: ML-based COPD discovers 67 novel association loci.

In [None]:
g_fig_4 = build_figure_4(
    main_gwas=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD],
    overlap_gwas=g_gwas_id_to_gwas[GWAS_SAKORNSAKOLPAT_NATGEN_2019],
)

g_fig_4.savefig(
    'figure_4.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file figure_4.pdf

## Extended Data Figures

Extended Data Figure 5: Statistical power comparison of ML-based COPD with Hobbs
et al. Nature Genetics 2017 COPD GWAS.

In [None]:
g_ext_data_fig_5 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_HOBBS_NATGEN_2017],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD],
)

g_ext_data_fig_5.savefig(
    'extended_data_figure_5.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file extended_data_figure_5.pdf

Extended Data Figure 6: Statistical power comparison of ML-based COPD without
MRB COPD cases with Hobbs et al. Nature Genetics 2017 COPD GWAS.

In [None]:
g_ext_data_fig_6 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_HOBBS_NATGEN_2017],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD_NO_MRB_CASES],
)

g_ext_data_fig_6.savefig(
    'extended_data_figure_6.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file extended_data_figure_6.pdf

Extended Data Figure 7: Statistical power comparison of binarized ML-based COPD
with Sakornsakolpat et al. Nature Genetics 2019.

In [None]:
g_ext_data_fig_7 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_SAKORNSAKOLPAT_NATGEN_2019],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD_BINARIZED],
)

g_ext_data_fig_7.savefig(
    'extended_data_figure_7.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file extended_data_figure_7.pdf

Extended Data Figure 8: Statistical power comparison of binarized ML-based COPD
matching GOLD prevalence with Sakornsakolpat et al. Nature Genetics 2019.

In [None]:
g_ext_data_fig_8 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_SAKORNSAKOLPAT_NATGEN_2019],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD_BINARIZED_GOLD_PREV],
)

g_ext_data_fig_8.savefig(
    'extended_data_figure_8.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file extended_data_figure_8.pdf

## Supplementary Figures

Supplementary Figure 12: Statistical power comparison of ML-based COPD with GBMI
COPD excluding UKB.

In [None]:
g_suppl_fig_12 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_GBMI_EXCLUDING_UKB_COPD],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD],
)

g_suppl_fig_12.savefig(
    'supplementary_figure_12.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file supplementary_figure_12.pdf

Supplementary Figure 14: Statistical power comparison of binarized ML-based COPD
with medical-record-based COPD labels.

In [None]:
g_suppl_fig_14 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_MRB_LABELS_COPD],
    gwas_y=g_gwas_id_to_gwas[GWAS_ML_BASED_COPD_BINARIZED],
)

g_suppl_fig_14.savefig(
    'supplementary_figure_14.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',

)
%download_file supplementary_figure_14.pdf

Supplementary Figure 15: Statistical power comparison of proxy-GOLD with
Sakornsakolpat et al Nature Genetics 2019.

In [None]:
g_suppl_fig_15 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_SAKORNSAKOLPAT_NATGEN_2019],
    gwas_y=g_gwas_id_to_gwas[GWAS_SPIRO_GOLD_COPD],
)

g_suppl_fig_15.savefig(
    'supplementary_figure_15.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file supplementary_figure_15.pdf

Supplementary Figure 17: Statistical power comparison of proxy-GOLD label using
BOLT-LMM vs Regenie.

In [None]:
g_suppl_fig_17 = build_gwas_comparison_figure(
    gwas_x=g_gwas_id_to_gwas[GWAS_SPIRO_GOLD_COPD_BOLT],
    gwas_y=g_gwas_id_to_gwas[GWAS_SPIRO_GOLD_COPD_REGENIE],
)

g_suppl_fig_17.savefig(
    'supplementary_figure_17.pdf',
    dpi=300,
    format='pdf',
    bbox_inches='tight',
)
%download_file supplementary_figure_17.pdf