# Preparation

This section includes imports and functions.

In [1]:
import dataclasses
from typing import Dict, List, Optional, Sequence, Union

import abc
from typing import Callable

import numpy as np
import pandas as pd
import scipy.stats
import sklearn
import sklearn.metrics
from sklearn import metrics

In [2]:
# A function that computes a numeric outcome from label and prediction arrays.
BootstrappableFn = Callable[[np.ndarray, np.ndarray], float]

# Constants denoting the expected case and control values for binary encodings.
BINARY_LABEL_CONTROL = 0
BINARY_LABEL_CASE = 1

class Metric(abc.ABC):
  """Represents a callable wrapper class for a named metric function.

  Attributes:
    name: The metric's name.
  """

  def __init__(self, name: str, fn: BootstrappableFn) -> None:
    """Initializes the metric.

    Args:
      name: The metric's name.
      fn: A function that computes an outcome from label and prediction arrays.
        The function's signature should accept a `y_true` label array and a
        `y_pred` model prediction array. This function is invoked when the
        `Metric` instance is called.
    """
    self._name: str = name
    self._fn: BootstrappableFn = fn

  @property
  def name(self) -> str:
    """The `Metric`'s name."""
    return self._name

  @abc.abstractmethod
  def _validate(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Note: Each prediction subarray `y_pred[i, ...]` at index `i` should
    correspond to the `y_true[i]` label.

    Args:
      y_true: The ground truth label targets.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
    """
    if y_true.shape[0] != y_pred.shape[0]:
      raise ValueError('`y_true` and `y_pred` first dimension mismatch: '
                       f'{y_true.shape[0]} != {y_pred.shape[0]}')

  def __call__(self, y_true: np.ndarray, y_pred: np.ndarray) -> float:
    """Invokes the `Metric`'s function.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Returns:
      The result of the `Metric.fn(y_true, y_pred)`.
    """
    self._validate(y_true, y_pred)
    return self._fn(y_true, y_pred)

  def __str__(self) -> str:
    return self.name


class ContinuousMetric(Metric):
  """Represents a callable wrapper class for a named continuous label function.

  Attributes:
    name: The metric's name.
  """

  # Note: This is a useful delegation since _validate is an @abc.abstractmethod.
  def _validate(  # pylint: disable=useless-super-delegation
      self,
      y_true: np.ndarray,
      y_pred: np.ndarray,
  ) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
    """
    super()._validate(y_true, y_pred)


class BinaryMetric(Metric):
  """Represents a callable wrapper class for a named binary label function.

  This class asserts that the provided `y_true` labels are binary targets in
  `{0, 1}` and that `y_true` contains at least one element in each class, i.e.,
  not all samples are from the same class.

  Attributes:
    name: The metric's name.
  """

  def _validate(self, y_true: np.ndarray, y_pred: np.ndarray) -> None:
    """Validates the `y_true` labels and `y_pred` predictions.

    Args:
      y_true: The ground truth label values.
      y_pred: The target predictions.

    Raises:
      ValueError: If the first dimension of `y_true` and `y_pred` do not match.
      ValueError: If `y_true` labels are nonbinary, i.e., not all values are in
        `{BINARY_LABEL_CONTROL, BINARY_LABEL_CASE}` or if `y_true` does not
        contain at least one element from each class.
    """
    super()._validate(y_true, y_pred)
    if not is_valid_binary_label(y_true):
      raise ValueError('`y_true` labels must be in `{BINARY_LABEL_CONTROL, '
                       'BINARY_LABEL_CASE}` and have at least one element from '
                       f'each class; found: {y_true}')


def is_binary(metric: Metric) -> bool:
  """Whether `metric` is a metric computed with binary `y_true` labels."""
  return isinstance(metric, BinaryMetric)


def is_valid_binary_label(array: np.ndarray) -> bool:
  """Whether `array` is a "valid" binary label array for bootstrapping.

  We define a valid binary label array as an array that contains only binary
  values, i.e., `{BINARY_LABEL_CONTROL, BINARY_LABEL_CASE}`, and contains at
  least one value from each class.

  Args:
    array: A numpy array.

  Returns:
    Whether `array` is a "valid" binary label array.
  """
  is_case_mask = array == BINARY_LABEL_CASE
  is_control_mask = array == BINARY_LABEL_CONTROL
  return (np.any(is_case_mask) and np.any(is_control_mask) and
          np.all(np.logical_or(is_case_mask, is_control_mask)))


def pearsonr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
  """Returns the Pearson R correlation coefficient."""
  # Note: We ignore the returned p value.
  r, _ = scipy.stats.pearsonr(y_true, y_pred)
  return r


def pearsonr_squared(y_true: np.ndarray, y_pred: np.ndarray) -> float:
  """Returns the square of the Pearson correlation coefficient."""
  return pearsonr(y_true, y_pred)**2


def spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
  """Returns the Spearman R correlation coefficient."""
  # Note: We ignore the returned p value.
  r, _ = scipy.stats.spearmanr(y_true, y_pred)
  return r


def count(y_true: np.ndarray, y_pred: np.ndarray) -> float:
  """Returns the number of samples in `y_true`."""
  if y_true.shape[0] != y_pred.shape[0]:
    raise ValueError('`y_true` and `y_pred` first dimension mismatch: '
                     f'{y_true.shape[0]} != {y_pred.shape[0]}')
  return len(y_true)


def frequency_between(y_true: np.ndarray, y_pred: np.ndarray,
                      percentile_lower: int, percentile_upper: int) -> float:
  """Computes the positive class frequency within a percentile interval.

  Args:
    y_true: Ground truth (correct) target values.
    y_pred: Estimated targets as returned by a classifier.
    percentile_lower: The lower bound (inclusive) of percentile. 0 to include
      all samples.
    percentile_upper: The upper bound (inclusive for 100, exclusive for all
      other values) of percentile. 100 to include all samples.

  Returns:
    A [0.0, 1.0] float corresponding to the positive class frequency within
    the percentile interval.

  Raises:
    ValueError: Invalid percentile range.
  """
  if not 0 <= percentile_lower < 100:
    raise ValueError('`percentile_lower` must be in range `[0, 100)`: '
                     f'{percentile_lower}')
  if not 0 < percentile_upper <= 100:
    raise ValueError('`percentile_upper` must be in range `(0, 100]`: '
                     f'{percentile_upper}')

  pred_lower_percentile, pred_upper_percentile = np.percentile(
      a=y_pred, q=[percentile_lower, percentile_upper])
  lower_mask = (y_pred >= pred_lower_percentile)
  if percentile_upper == 100:
    mask = lower_mask
  else:
    upper_mask = (y_pred < pred_upper_percentile)
    mask = lower_mask & upper_mask
  assert len(mask) == len(y_true)
  return np.mean(y_true[mask])


def frequency(y_true: np.ndarray,
              y_pred: np.ndarray,
              top_percentile: int = 100) -> float:
  """Computes the positive class frequency within the top prediction percentile.

  We select the subset of `y_true` labels corresponding to `y_pred`'s
  `top_percentile`-th prediction percetile and return the positive class
  frequency within this subset. `top_percentile=100` indicates the frequency for
  all samples.

  Args:
    y_true: Ground truth (correct) target values.
    y_pred: Estimated targets as returned by a classifier.
    top_percentile: Determines the set of examples considered in the frequency
      calculation. The top percentile represents the top percentile by
      prediction risk. 100 indicates using all samples.

  Returns:
    A [0.0, 1.0] float corresponding to the positive class frequency in the top
    percentile.

  Raises:
    ValueError: `top_percentile` is not in range `(0, 100]`.
  """
  if not 0 < top_percentile <= 100:
    raise ValueError('`top_percentile` must be in range `(0, 100]`: '
                     f'{top_percentile}')

  return frequency_between(
      y_true,
      y_pred,
      percentile_lower=100 - top_percentile,
      percentile_upper=100)


def frequency_fn(top_percentile: int) -> BootstrappableFn:
  """Returns a function that computes `frequency` at `top_percentile`."""

  def _frequency(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return frequency(y_true, y_pred, top_percentile)

  return _frequency


def frequency_between_fn(percentile_lower: int,
                         percentile_upper: int) -> BootstrappableFn:
  """Returns a function that computes `frequency` in a percentile interval."""

  def _freq_between(y_true: np.ndarray, y_pred: np.ndarray) -> float:
    return frequency_between(
        y_true,
        y_pred,
        percentile_lower=percentile_lower,
        percentile_upper=percentile_upper)

  return _freq_between

In [3]:
# Represents a numpy array of indices for a single bootstrap sample.
IndexSample = np.ndarray


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class NamedArray:
  """Represents a named numpy array.

  Attributes:
    name: The array name.
    values: A numpy array.
  """

  name: str
  values: np.ndarray

  def __post_init__(self):
    if not self.name:
      raise ValueError('`name` must be specified.')

  def __len__(self) -> int:
    return len(self.values)

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.name})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Label(NamedArray):
  """Represents a named numpy array of ground truth label targets.

  Attributes:
    name: The label name.
    values: A numpy array containing ground truth label targets.
  """


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Prediction(NamedArray):
  """Represents a named numpy array of target predictions.

  Attributes:
    model_name: The name of the model that generated the predictions.
    name: The name of the predictions (e.g., the prediction column).
    values: A numpy array containing model predictions.
  """

  model_name: str

  def __post_init__(self):
    super().__post_init__()
    if not self.model_name:
      raise ValueError('`model_name` must be specified.')

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.model_name}.{self.name})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class SampleMean:
  """Represents an estimate of the population mean for a given sample.

  Attributes:
    mean: The mean of a given sample.
    stddev: The standard deviation of the sample mean.
    num_samples: The number of samples used to calculate `mean` and `stddev`.

  Raises:
    ValueError: If `num_samples` is not >= `1`.
    ValueError: If `stddev` is not `0` when `num_samples` is `1`.
  """

  mean: float
  stddev: float
  num_samples: int

  def __post_init__(self):
    # Ensure we have a valid number of samples.
    if self.num_samples < 1:
      raise ValueError(f'`num_samples` must be >= `1`: {self.num_samples}')

    # Ensure the standard deviation is 0 given a single sample.
    if self.num_samples == 1 and self.stddev != 0.0:
      raise ValueError(
          f'`stddev` must be `0` if `num_samples` is `1`: {self.stddev:0.4f}'
      )

  def __str__(self) -> str:
    return f'{self.mean:0.4f} (SD={self.stddev:0.4f}, n={self.num_samples})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class ConfidenceInterval(SampleMean):
  """Represents a confidence interval (CI) for a sample mean.

  Attributes:
    mean: The mean of a given sample.
    stddev: The standard deviation of the sample mean.
    num_samples: The number of samples used to calculate `mean` and `stddev`.
    level: The confidence level at which the CI is calculated (e.g., 95).
    ci_lower: The lower limit of the `level` confidence interval.
    ci_upper: The upper limit of the `level` confidence interval.

  Raises:
    ValueError: If `num_samples` is not >= `1`.
    ValueError: If `stddev` is not `0` when `num_samples` is `1`.
    ValueError: If `level` is not in range (0, 100].
    ValueError: If `ci_lower` or `ci_upper` does not match not `mean` when
      `num_samples` is `1`.
  """

  level: float
  ci_lower: float
  ci_upper: float

  def __post_init__(self):
    super().__post_init__()
    # Ensure we have a valid confidence level.
    if not 0 < self.level <= 100:
      raise ValueError(f'`level` must be in range (0, 100]: {self.level:0.2f}')

    # Ensure confidence intervals match the sample mean given a single sample.
    if self.num_samples == 1:
      if (self.ci_lower != self.mean) or (self.ci_upper != self.mean):
        raise ValueError(
            '`ci_lower` and `ci_upper` must match `mean` if `num_samples` is '
            f'1: mean={self.mean:0.4f}, ci_lower={self.ci_lower:0.4f}, '
            f'ci_upper={self.ci_upper:0.4f}'
        )

  def __str__(self) -> str:
    return (
        f'{self.mean:0.4f} (SD={self.stddev:0.4f}, n={self.num_samples}, '
        f'{self.level:0>6.2f}% CI=[{self.ci_lower:0.4f}, '
        f'{self.ci_upper:0.4f}])'
    )


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Result:
  """Represents a bootstrapped metric result for an individual model.

  Attributes:
    model_name: The model's name.
    prediction_name: The model's prediction name (e.g., the model head's name or
      the label name used in training).
    metric_name: The metric's name.
    ci: A confidence interval describing the distribution of metric samples.
  """

  model_name: str
  prediction_name: str
  metric_name: str
  ci: ConfidenceInterval

  def __post_init__(self):
    # Ensure model, prediction, and metric names are specified.
    if not self.model_name:
      raise ValueError('`model_name` must be specified.')
    if not self.prediction_name:
      raise ValueError('`prediction_name` must be specified.')
    if not self.metric_name:
      raise ValueError('`metric_name` must be specified.')

  def __str__(self) -> str:
    return (
        f'{self.model_name}.{self.prediction_name}: '
        f'{self.metric_name}: {self.ci}'
    )


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class PairedResult:
  """Represents a paired bootstrapped metric result for two models.

  Attributes:
    model_name_a: The first model's name.
    prediction_name_a: The first model's prediction name (e.g., the model head's
      name or the label name used in training).
    model_name_b: The second model's name.
    prediction_name_b: The second model's prediction name (e.g., the model
      head's name or the label name used in training).
    metric_name: The metric's name.
    ci: A confidence interval describing the distribution of differences between
      the first and second models' metric samples.
  """

  model_name_a: str
  prediction_name_a: str
  model_name_b: str
  prediction_name_b: str
  metric_name: str
  ci: ConfidenceInterval

  def __post_init__(self):
    # Ensure model, prediction, and metric names are specified.
    if not self.model_name_a:
      raise ValueError('`model_name_a` must be specified.')
    if not self.prediction_name_a:
      raise ValueError('`prediction_name_a` must be specified.')
    if not self.model_name_b:
      raise ValueError('`model_name_b` must be specified.')
    if not self.prediction_name_b:
      raise ValueError('`prediction_name_b` must be specified.')
    if not self.metric_name:
      raise ValueError('`metric_name` must be specified.')

  def __str__(self) -> str:
    return (
        f'({self.model_name_a}.{self.prediction_name_a} - '
        f'{self.model_name_b}.{self.prediction_name_b}): '
        f'{self.metric_name}: {self.ci}'
    )


def _reverse_paired_result(paired_result: PairedResult) -> PairedResult:
  """Returns the "(b - a)" inverse of an "(a - b)" `PairedResult`."""
  reversed_ci = ConfidenceInterval(
      mean=(paired_result.ci.mean * -1),
      stddev=paired_result.ci.stddev,
      num_samples=paired_result.ci.num_samples,
      level=paired_result.ci.level,
      ci_upper=(paired_result.ci.ci_lower * -1),
      ci_lower=(paired_result.ci.ci_upper * -1),
  )
  reversed_paired_result = PairedResult(
      model_name_a=paired_result.model_name_b,
      prediction_name_a=paired_result.prediction_name_b,
      model_name_b=paired_result.model_name_a,
      prediction_name_b=paired_result.prediction_name_a,
      metric_name=paired_result.metric_name,
      ci=reversed_ci,
  )
  return reversed_paired_result


def _compute_confidence_interval(
    samples: np.ndarray,
    ci_level: float,
) -> ConfidenceInterval:
  """Computes the mean, standard deviation, and confidence interval for samples.

  Args:
    samples: A boostrapped array of observed sample values.
    ci_level: The confidence level/width of the desired confidence interval.

  Returns:
    A `Result` containing the mean, standard deviation, and the `ci_level`%
    confidence interval for the observed sample values.
  """
  sample_mean = np.mean(samples, axis=0)
  sample_std = np.std(samples, axis=0)

  lower_percentile = (100 - ci_level) / 2
  upper_percentile = 100 - lower_percentile
  percentiles = [lower_percentile, upper_percentile]
  ci_lower, ci_upper = np.percentile(a=samples, q=percentiles, axis=0)

  ci = ConfidenceInterval(
      mean=sample_mean,
      stddev=sample_std,
      num_samples=len(samples),
      level=ci_level,
      ci_lower=ci_lower,
      ci_upper=ci_upper,
  )

  return ci


def _generate_sample_indices(
    label: Label,
    is_binary: bool,
    num_bootstrap: int,
    seed: int,
) -> List[IndexSample]:
  """Returns a list of `num_bootstrap` randomly sampled bootstrap indices.

  Args:
    label: The ground truth label targets.
    is_binary: Whether to generate valid binary samples; i.e., each index sample
      contains at least one index corresponding to a label from each class.
    num_bootstrap: The number of bootstrap indices to generate.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A list of `num_bootstrap` bootstrap sample indices.
  """
  rng = np.random.default_rng(seed)
  num_observations = len(label)
  sample_indices = []
  while len(sample_indices) < num_bootstrap:
    index = rng.integers(0, high=num_observations, size=num_observations)
    sample_true = label.values[index]
    # If computing a binary metric, skip indices that result in invalid labels.
    if is_binary and not is_valid_binary_label(sample_true):
      continue
    sample_indices.append(index)
  return sample_indices


def _compute_metric_samples(
    metric: Metric,
    label: Label,
    predictions: Sequence[Prediction],
    sample_indices: Sequence[np.ndarray],
) -> Dict[str, np.ndarray]:
  """Generates `num_bootstrap` metric samples for each `Prediction`.

  Note: This method assumes that label and prediction values are orded so that
  the value at index `i` in a given `Prediction` corresponds to the label value
  at index `i` in `label`. Both the `Label` and `Prediction` arrays are indexed
  using the given `sample_indices`.

  Args:
    metric: An instance of a bootstrappable `Metric`; used to compute samples.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    sample_indices: An array of bootstrap sample indices. If empty, returns the
      single value computed on the entire dataset for each prediction.

  Returns:
    A mapping of model names to the corresponding metric samples array.
  """
  if not sample_indices:
    metric_samples = {}
    for prediction in predictions:
      value = metric(label.values, prediction.values)
      metric_samples[prediction.model_name] = np.asarray([value])
    return metric_samples

  metric_samples = {prediction.model_name: [] for prediction in predictions}
  for index in sample_indices:
    sample_true = label.values[index]
    for prediction in predictions:
      sample_value = metric(sample_true, prediction.values[index])
      metric_samples[prediction.model_name].append(sample_value)

  metric_samples = {
      name: np.asarray(samples) for name, samples in metric_samples.items()
  }

  return metric_samples


def _compute_all_metric_samples(
    metrics: Sequence[Metric],
    contains_binary_metric: bool,
    label: Label,
    predictions: Sequence[Prediction],
    num_bootstrap: int,
    seed: int,
) -> Dict[str, Dict[str, np.ndarray]]:
  """Generates `num_bootstrap` samples for each `Prediction` and `Metric`.

  Args:
    metrics: A sequence of a bootstrappable `Metric` instances.
    contains_binary_metric: Whether the set of metrics contains a binary metric.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    num_bootstrap: The number of bootstrap iterations.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A mapping of metric names to model-sample dictionaries.
  """
  sample_indices = _generate_sample_indices(
      label,
      contains_binary_metric,
      num_bootstrap,
      seed,
  )
  metric_samples = []
  for metric in metrics:
    metric_samples.append(
        _compute_metric_samples(
            metric=metric,
            label=label,
            predictions=predictions,
            sample_indices=sample_indices,
        )
    )

  return {
      metric.name: metric_sample
      for metric, metric_sample in zip(metrics, metric_samples)
  }


def _process_metric_samples(
    metric: Metric,
    predictions: Sequence[Prediction],
    model_names_to_metric_samples: Dict[str, np.ndarray],
    ci_level: float,
) -> List[Result]:
  """Compute `ConfidenceInterval`s for metric samples across predictions."""
  results = []
  for prediction in predictions:
    metric_samples = model_names_to_metric_samples[prediction.model_name]
    ci = _compute_confidence_interval(metric_samples, ci_level)
    result = Result(prediction.model_name, prediction.name, metric.name, ci)
    results.append(result)
  return results


def _process_metric_samples_paired(
    metric: Metric,
    predictions: Sequence[Prediction],
    model_names_to_metric_samples: Dict[str, np.ndarray],
    ci_level: float,
) -> List[PairedResult]:
  """Compute `ConfidenceInterval`s for paired samples across predictions."""
  results = []
  for i, prediction_a in enumerate(predictions[:-1]):
    for prediction_b in predictions[i + 1 :]:
      # Compute the result of `prediction_a - prediction_b`.
      metric_samples_a = model_names_to_metric_samples[prediction_a.model_name]
      metric_samples_b = model_names_to_metric_samples[prediction_b.model_name]
      metric_samples_diff = metric_samples_a - metric_samples_b
      ci = _compute_confidence_interval(metric_samples_diff, ci_level)
      result = PairedResult(
          prediction_a.model_name,
          prediction_a.name,
          prediction_b.model_name,
          prediction_b.name,
          metric.name,
          ci,
      )
      results.append(result)
      # Derive and include the result of `prediction_b - prediction_a`.
      results.append(_reverse_paired_result(result))
  return results


def _bootstrap(
    metrics: Sequence[Metric],
    contains_binary_metric: bool,
    label: Label,
    predictions: Sequence[Prediction],
    num_bootstrap: int,
    ci_level: float,
    seed: int,
) -> Dict[str, List[Result]]:
  """Performs bootstrapping for all models using the given metrics.

  Args:
    metrics: A sequence of a bootstrappable `Metric` instances.
    contains_binary_metric: Whether the set of metrics contains a binary metric.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    num_bootstrap: The number of bootstrap iterations.
    ci_level: The confidence level/width of the desired confidence interval.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A dictionary mapping metric names to a list of `Result`s containing the mean
    metric values of each model over `num_bootstrap` bootstrapping iterations.
  """
  metric_to_model_to_samples = _compute_all_metric_samples(
      metrics,
      contains_binary_metric,
      label,
      predictions,
      num_bootstrap,
      seed,
  )
  metric_samples = []
  for metric in metrics:
    metric_samples.append(
        _process_metric_samples(
            metric=metric,
            predictions=predictions,
            model_names_to_metric_samples=metric_to_model_to_samples[
                metric.name
            ],
            ci_level=ci_level,
        )
    )

  return {
      metric.name: metric_sample
      for metric, metric_sample in zip(metrics, metric_samples)
  }


def _paired_bootstrap(
    metrics: Sequence[Metric],
    contains_binary_metric: bool,
    label: Label,
    predictions: Sequence[Prediction],
    num_bootstrap: int,
    ci_level: float,
    seed: int,
) -> Dict[str, List[PairedResult]]:
  """Performs paired bootstrapping for all models using the given metrics.

  Args:
    metrics: A sequence of a bootstrappable `Metric` instances.
    contains_binary_metric: Whether the set of metrics contains a binary metric.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    num_bootstrap: The number of bootstrap iterations.
    ci_level: The confidence level/width of the desired confidence interval.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A dictionary mapping metric names to `PairedResult`s containing the mean
    metric difference between models over `num_bootstrap` bootstrapping
    iterations.
  """
  metric_to_model_to_samples = _compute_all_metric_samples(
      metrics,
      contains_binary_metric,
      label,
      predictions,
      num_bootstrap,
      seed,
  )
  metric_samples = []
  for metric in metrics:
    metric_samples.append(
        _process_metric_samples_paired(
            metric=metric,
            predictions=predictions,
            model_names_to_metric_samples=metric_to_model_to_samples[
                metric.name
            ],
            ci_level=ci_level,
        )
    )

  return {
      metric.name: metric_sample
      for metric, metric_sample in zip(metrics, metric_samples)
  }


def _default_binary_metrics() -> List[BinaryMetric]:
  """Returns `PerformanceMetrics`'s default metrics for binary target."""
  metrics = [
      BinaryMetric('num', count),
      BinaryMetric('auc', sklearn.metrics.roc_auc_score),
      BinaryMetric('auprc', sklearn.metrics.average_precision_score),
  ]
  for percentile in [100, 10, 5, 1]:
    metrics.append(
        BinaryMetric(
            f'freq@{percentile:>03}%',
            frequency_fn(percentile),
        )
    )
  return metrics


def _default_continuous_metrics() -> List[ContinuousMetric]:
  """Returns `PerformanceMetrics`'s default metrics for continuous target."""
  metrics = [
      ContinuousMetric('num', count),
      ContinuousMetric('pearson', pearsonr),
      ContinuousMetric('pearsonr_squared', pearsonr_squared),
      ContinuousMetric('spearman', spearmanr),
      ContinuousMetric('mse', sklearn.metrics.mean_squared_error),
      ContinuousMetric('mae', sklearn.metrics.mean_absolute_error),
  ]
  return metrics


def _default_metrics(binary_targets: bool) -> List[Metric]:
  """Returns `PerformanceMetrics`'s default set of metrics for the target type.

  Args:
    binary_targets: Whether the target labels are binary. If false, the returned
      metrics assume continuous labels.

  Returns:
    The default set of binary or continuous `bootstrap_metrics.Metric`s.
  """
  if binary_targets:
    return _default_binary_metrics()
  return _default_continuous_metrics()


class PerformanceMetrics:
  """A named collection of invocable, bootstrapable `Metric`s.

  Initializes a class that applies the given `Metric` functions to new ground
  truth labels and predictions. `Metric`s can be evaluated with and without
  bootstrapping.

  The default metrics are number of samples, auc, auprc, and frequency
  calculations for the top 100/10/5/1 top percentiles, if `default_metrics` is
  'binary'. If `default_metrics` is 'continuous', the default metrics are
  Pearson and Spearman correlations, the square of the Pearson correlation, mean
  squared error (MSE) and mean absolute error (MAE).

  TODO(b/199452239): Refactor `PerformanceMetrics` so that the default metric
  set is not parameterized with a string.

  Raises:
    ValueError: if an item in `metrics` is not of type `Metric`.
  """

  def __init__(
      self,
      name: str,
      default_metrics: Optional[str] = None,
      metrics: Optional[List[Metric]] = None,
  ) -> None:

    if metrics is None:
      if default_metrics is None:
        raise ValueError('`default_metrics` is None and no metric is provided.')
      elif default_metrics == 'binary':
        metrics = _default_metrics(binary_targets=True)
      elif default_metrics == 'continuous':
        metrics = _default_metrics(binary_targets=False)
      else:
        raise ValueError(
            'unknown `default_metrics`: {}'.format(default_metrics)
        )

    for metric in metrics:
      if not isinstance(metric, Metric):
        raise ValueError('Invalid metric value: must be of class `Metric`.')

    if len(metrics) != len({metric.name for metric in metrics}):
      raise ValueError(f'Metric names must be unique: {metrics}')

    self.name = name
    self.metrics = metrics
    self.contains_binary = any(is_binary(m) for m in metrics)

  def compute(
      self,
      y_true: np.ndarray,
      y_pred: np.ndarray,
      mask: Optional[np.ndarray] = None,
      n_bootstrap: int = 0,
      conf_interval: float = 95,
      seed: int = 42,
  ) -> Dict[str, Result]:
    """Evaluates all metrics using the given labels and predictions.

    Args:
      y_true: Ground truth (correct) target values.
      y_pred: Estimated targets as returned by a classifier.
      mask: A boolean mask; applied to `y_true` and `y_pred`.
      n_bootstrap: An integer denoting the number of bootstrap iterations for
        each evaluation metric.
      conf_interval: A float denoting the width of confidence interval.
      seed: An int denoting the seed for the PRNG.

    Returns:
      A dictionary of bootstrapped metrics keyed on metric name with
      `Result` values.

    Raises:
      ValueError: If the dimensions of `y_true`, `y_pred`, or `mask` do not
      match, or labels are not in {0 , 1}.
    """
    if len(y_true) != len(y_pred):
      raise ValueError('Label and prediction dimensions do not match.')

    if mask is not None and len(mask) != len(y_pred):
      raise ValueError('Label and prediction dimensions do not match mask.')

    if mask is not None:
      y_true = y_true[mask]
      y_pred = y_pred[mask]

    # TODO(b/197539434): Pipe through non-empty names after public api refactor.
    label_name = 'label'
    label = Label(label_name, y_true)
    predictions = [Prediction(label_name, y_pred, 'model')]

    metric_results = _bootstrap(
        self.metrics,
        contains_binary_metric=self.contains_binary,
        label=label,
        predictions=predictions,
        num_bootstrap=n_bootstrap,
        ci_level=conf_interval,
        seed=seed,
    )

    # TODO(b/197539434): Remove temporary asserts after public api refactor.
    final_results = {}
    for metric_name, results in metric_results.items():
      assert len(results) == 1
      final_results[metric_name] = results[0]

    return final_results

  def compute_paired(
      self,
      y_true: np.ndarray,
      y_pred_a: np.ndarray,
      y_pred_b: np.ndarray,
      mask: Optional[np.ndarray] = None,
      n_bootstrap: int = 0,
      conf_interval: float = 95,
      seed: int = 42,
  ) -> Dict[str, PairedResult]:
    """Computes a paired bootstrap value for each metric.

    Args:
      y_true: Ground truth (correct) target values.
      y_pred_a: Target predictions from model A; compared to `y_pred_b`.
      y_pred_b: Target predictions from model B; compared to `y_pred_a`.
      mask: A boolean mask; applied to `y_true`, `y_pred_a`, and `y_pred_b`.
      n_bootstrap: An integer denoting the number of bootstrap iterations for
        each evaluation metric.
      conf_interval: A float denoting the width of confidence interval.
      seed: An int denoting the seed for the PRNG.

    Returns:
      A dictionary of paired bootstrapped metrics keyed on metric name with
      `PairedResult` values.

    Raises:
      ValueError: If the dimensions of `y_true`, `y_pred_a`, `y_pred_b` or
      `mask` do not match, or labels are not in {0 , 1}.
    """
    if (len(y_true) != len(y_pred_a)) or (len(y_true) != len(y_pred_b)):
      raise ValueError('Label and prediction dimensions do not match.')

    if mask is not None and len(mask) != len(y_pred_a):
      raise ValueError('Label and prediction dimensions do not match mask.')

    if mask is not None:
      y_true = y_true[mask]
      y_pred_a = y_pred_a[mask]
      y_pred_b = y_pred_b[mask]

    # TODO(b/197539434): Pipe through non-empty names after public api refactor.
    label_name = 'label'
    label = Label(label_name, y_true)
    first_model_name = 'model_a'
    predictions = [
        Prediction(label_name, y_pred_a, first_model_name),
        Prediction(label_name, y_pred_b, 'model_b'),
    ]

    metric_results = _paired_bootstrap(
        self.metrics,
        contains_binary_metric=self.contains_binary,
        label=label,
        predictions=predictions,
        num_bootstrap=n_bootstrap,
        ci_level=conf_interval,
        seed=seed,
    )

    # TODO(b/197539434): Remove temporary asserts after public api refactor.
    final_results = {}
    for metric_name, results in metric_results.items():
      assert len(results) == 2
      assert results[0].model_name_a == first_model_name
      final_results[metric_name] = results[0]

    return final_results

  def _print_results(
      self,
      title: str,
      results: Dict[str, Union[Result, PairedResult]],
  ) -> None:
    """Prints each result object under the current name and given title."""
    print(f'{self.name}: {title}')
    for _, result in sorted(results.items()):
      print(f'\t{result}')

  def compute_and_print(
      self,
      y_true: np.ndarray,
      y_pred: np.ndarray,
      mask: Optional[np.ndarray] = None,
      n_bootstrap: int = 0,
      conf_interval: float = 95,
      seed: int = 42,
      title: str = '',
  ) -> None:
    """Evaluates and pretty-prints metrics using given labels and predictions.

    Args:
      y_true: Ground truth (correct) target values.
      y_pred: Estimated targets as returned by a classifier.
      mask: A boolean mask; applied to `y_true` and `y_pred`.
      n_bootstrap: An integer denoting the number of bootstrap iterations for
        each evaluation metric.
      conf_interval: A float denoting the width of confidence interval.
      seed: An int denoting the seed for the PRNG.
      title: A title appended to the printed evaluation metrics.

    Raises:
      ValueError: If any of `y_true`, `y_pred`, or `mask` are not of type
          numpy.array of if their dimensions do not match.
    """
    results = self.compute(
        y_true,
        y_pred,
        mask=mask,
        n_bootstrap=n_bootstrap,
        conf_interval=conf_interval,
        seed=seed,
    )
    self._print_results(title, results)

  def compute_paired_and_print(
      self,
      y_true: np.ndarray,
      y_pred_a: np.ndarray,
      y_pred_b: np.ndarray,
      mask: Optional[np.ndarray] = None,
      n_bootstrap: int = 0,
      conf_interval: float = 95,
      seed: int = 42,
      title: str = '',
      **kwargs,
  ) -> None:
    """Evaluates and pretty-prints paired metrics.

    Args:
      y_true: Ground truth (correct) target values.
      y_pred_a: Target predictions from model A; compared to `y_pred_b`.
      y_pred_b: Target predictions from model B; compared to `y_pred_a`.
      mask: A boolean mask; applied to `y_true`, `y_pred_a`, and `y_pred_b`.
      n_bootstrap: An integer denoting the number of bootstrap iterations for
        each evaluation metric.
      conf_interval: A float denoting the width of confidence interval.
      seed: An int denoting the seed for the PRNG.
      title: A title appended to the printed evaluation metrics.
      **kwargs: Additional keyword arguments passed to each Metric's `func`.
    """
    results = self.compute_paired(
        y_true,
        y_pred_a,
        y_pred_b,
        mask=mask,
        n_bootstrap=n_bootstrap,
        conf_interval=conf_interval,
        seed=seed,
        **kwargs,
    )
    self._print_results(title, results)

In [4]:
N_BOOTSTRAP = 300
BOOTSTRAP_METRICS_LIST = [
    BinaryMetric('roc_auc', metrics.roc_auc_score),
    BinaryMetric('pr_auc', metrics.average_precision_score),
    ContinuousMetric('pearsonr', pearsonr),
    BinaryMetric('top10prev', frequency_fn(10)),
]

def get_prs_eval_info(y_true, y_pred, name, as_dataframe=False):
  performance_metrics = PerformanceMetrics(
      'Metrics', metrics=BOOTSTRAP_METRICS_LIST)
  performance_metrics_values = performance_metrics.compute(
      y_true=y_true,
      y_pred=y_pred,
      n_bootstrap=N_BOOTSTRAP,
  )
  # print(performance_metrics_values, flush=True)
  roc_auc_ci = performance_metrics_values['roc_auc'].ci
  pr_auc_ci = performance_metrics_values['pr_auc'].ci
  pearsonr_ci = performance_metrics_values['pearsonr'].ci
  top10prev_ci = performance_metrics_values['top10prev'].ci
  info = {
      'method': name,
      'pearsonr': pearsonr_ci.mean,
      'pearsonr_std': pearsonr_ci.stddev,
      'pearsonr_lower': pearsonr_ci.ci_lower,
      'pearsonr_upper': pearsonr_ci.ci_upper,
      'roc_auc': roc_auc_ci.mean,
      'roc_auc_std': roc_auc_ci.stddev,
      'roc_auc_lower': roc_auc_ci.ci_lower,
      'roc_auc_upper': roc_auc_ci.ci_upper,
      'pr_auc': pr_auc_ci.mean,
      'pr_auc_std': pr_auc_ci.stddev,
      'pr_auc_lower': pr_auc_ci.ci_lower,
      'pr_auc_upper': pr_auc_ci.ci_upper,
      'top10prev': top10prev_ci.mean,
      'top10prev_std': top10prev_ci.stddev,
      'top10prev_lower': top10prev_ci.ci_lower,
      'top10prev_upper': top10prev_ci.ci_upper,
  }
  if as_dataframe:
    return pd.DataFrame(info, index=[0])
  else:
    return info


def get_prs_paired_eval_info(y_true,
                             y_pred1,
                             y_pred2,
                             name1,
                             name2,
                             as_dataframe=False):
  performance_metrics = PerformanceMetrics(
      'Metrics', metrics=BOOTSTRAP_METRICS_LIST)
  performance_metrics_values_paired = performance_metrics.compute_paired(
      y_true=y_true,
      y_pred_a=y_pred1,
      y_pred_b=y_pred2,
      n_bootstrap=N_BOOTSTRAP,
  )
  # print(performance_metrics_values_paired, flush=True)
  roc_auc_ci = performance_metrics_values_paired['roc_auc'].ci
  pr_auc_ci = performance_metrics_values_paired['pr_auc'].ci
  pearsonr_ci = performance_metrics_values_paired['pearsonr'].ci
  top10prev_ci = performance_metrics_values_paired['top10prev'].ci
  info = {
      'method_a': name1,
      'method_b': name2,
      'pearsonr': pearsonr_ci.mean,
      'pearsonr_std': pearsonr_ci.stddev,
      'pearsonr_lower': pearsonr_ci.ci_lower,
      'pearsonr_upper': pearsonr_ci.ci_upper,
      'roc_auc': roc_auc_ci.mean,
      'roc_auc_std': roc_auc_ci.stddev,
      'roc_auc_lower': roc_auc_ci.ci_lower,
      'roc_auc_upper': roc_auc_ci.ci_upper,
      'pr_auc': pr_auc_ci.mean,
      'pr_auc_std': pr_auc_ci.stddev,
      'pr_auc_lower': pr_auc_ci.ci_lower,
      'pr_auc_upper': pr_auc_ci.ci_upper,
      'top10prev': top10prev_ci.mean,
      'top10prev_std': top10prev_ci.stddev,
      'top10prev_lower': top10prev_ci.ci_lower,
      'top10prev_upper': top10prev_ci.ci_upper,
  }
  if as_dataframe:
    return pd.DataFrame(info, index=[0])
  else:
    return info

# Simulated data generation

In this code example, we generate some simulated data (N=1,000) to demonstrate how to use the above code snippet to compute various metrics in the PRS evaluation part of the paper.

In [5]:
np.random.seed(42)
individual_prs1 = np.random.normal(size=(1000,))
individual_prs2 = 0.8 * individual_prs1 + 0.2 * np.random.normal(size=(1000,))
individual_phenotype = 0.3 * individual_prs1 + 0.7 * np.random.normal(
    size=(1000,)
)
individual_phenotype = (individual_phenotype >= 0).astype(int)

data_df = pd.DataFrame({
    'prs1': individual_prs1,
    'prs2': individual_prs2,
    'phenotype': individual_phenotype,
})

In [6]:
data_df.head()

Unnamed: 0,prs1,prs2,phenotype
0,0.496714,0.677242,0
1,-0.138264,0.074315,0
2,0.647689,0.530077,0
3,1.52303,1.089037,1
4,-0.234153,-0.047678,0


# PRS evaluation with bootstrapping

The following code generates all evaluation metrics, namely Pearson R, AUC-ROC, AUC-PR, top 10% prevalence, and their 95% confidence intervals using bootstrapping. Note that, from the way we generated the simulated data, we expect the Pearson R of ~0.3 for `prs1` and we expect `prs1` to have higher correlation with the phenotype than `prs2`.

In [7]:
get_prs_eval_info(
    y_true=data_df['phenotype'],
    y_pred=data_df['prs1'],
    name='prs1',
    as_dataframe=True
)

Unnamed: 0,method,pearsonr,pearsonr_std,pearsonr_lower,pearsonr_upper,roc_auc,roc_auc_std,roc_auc_lower,roc_auc_upper,pr_auc,pr_auc_std,pr_auc_lower,pr_auc_upper,top10prev,top10prev_std,top10prev_lower,top10prev_upper
0,prs1,0.333455,0.027456,0.277529,0.387433,0.69263,0.016445,0.65976,0.725288,0.675271,0.022152,0.632141,0.715912,0.770216,0.043321,0.688044,0.85078


In [8]:
get_prs_eval_info(
    y_true=data_df['phenotype'],
    y_pred=data_df['prs2'],
    name='prs2',
    as_dataframe=True
)

Unnamed: 0,method,pearsonr,pearsonr_std,pearsonr_lower,pearsonr_upper,roc_auc,roc_auc_std,roc_auc_lower,roc_auc_upper,pr_auc,pr_auc_std,pr_auc_lower,pr_auc_upper,top10prev,top10prev_std,top10prev_lower,top10prev_upper
0,prs2,0.319189,0.027899,0.260433,0.373947,0.6837,0.016604,0.649911,0.717019,0.664467,0.022454,0.620486,0.706022,0.764624,0.042396,0.671552,0.84


# PRS comparison with paired bootstrapping

The following code snippet compares the performance of `prs1` and `prs2` using paired bootstrapping. Note that the difference is statistically significant with 95% paired bootstrapping confidence interval, if the lower and upper end of the confidence interval are both positive (implying `prs1` is significantly better than `prs2`) or both negative (implying `prs2` is significantly better than `prs1`).

In [9]:
get_prs_paired_eval_info(
    y_true=data_df['phenotype'],
    y_pred1=data_df['prs1'],
    y_pred2=data_df['prs2'],
    name1='prs1',
    name2='prs2',
    as_dataframe=True)

Unnamed: 0,method_a,method_b,pearsonr,pearsonr_std,pearsonr_lower,pearsonr_upper,roc_auc,roc_auc_std,roc_auc_lower,roc_auc_upper,pr_auc,pr_auc_std,pr_auc_lower,pr_auc_upper,top10prev,top10prev_std,top10prev_lower,top10prev_upper
0,prs1,prs2,0.014266,0.007112,0.000436,0.027211,0.008931,0.004466,0.000157,0.017171,0.010803,0.005761,-0.00061,0.02107,0.005593,0.026971,-0.042589,0.062382
