In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generates Figure 3 from Cosentino et al. Nature Genetics 2023

This notebook builds the ML model performance comparison and survival analysis
Kaplan-Meier curve subfigures for Figure 3 of the ML-based COPD manuscript
(Cosentino et al. Nature Genetics 2023).

It compares the following models:

-   FEV1/FVC ratio-based risk
-   FEV1 percent predicted-based risk
-   ML-based COPD risk

It evaluates COPD risk predictions against the following labels:

-   `copd_all_srcs_subset`: A subset of `copd` labels where all EIDs have valid
    values from self-report, HESIN, and GP sources.
-   `copd_hesin_primary_after`: Binary label indicating a COPD-related
    hospitalization event after spirometery data collection.
-   `copd_death_primary`: Binary label indicating a COPD-related death.

This notebook assumes that there exists a TSV file containing risk predictions
and labels stored at `DATA_FILEPATH`. This TSV must contain the following
columns:

-   `eid`: A unique individual identifier.
-   `copd_all_srcs_subset`: A subset of `copd` labels where all EIDs have valid
    values from self-report, HESIN, and GP sources.
-   `copd_hesin_primary_after`: Binary label indicating a COPD-related
    hospitalization event after spirometery data collection.
-   `copd_death_primary`: Binary label indicating a COPD-related death (i.e., 0
    if no death or death due to non-COPD causes).
-   `blow_ratio_risk`: A risk score based on FEV1/FVC ratio (i.e., `1 -
    blow_ratio`).
-   `blow_fev1_pct_pred_norm_risk`: A risk score based on FEV1 percent predicted
    (i.e., `1 - normalized(blow_fev1_pct_pred)`).
-   `ml_based_copd`: The ML-based COPD liability score.

**Important: We assume that `DATA_FILEPATH` contains *only* individuals from the
validation holdout set.**

This notebook assumes that there exists a TSV file at `KM_ML_DATA_FILEPATH`
containing raw Kaplan-Meier curve datapoints. This file can be generated by
running the `survival_analysis.R` Rlang script.

In [None]:
import collections
import concurrent.futures
import dataclasses
import enum
import pathlib
import string
from typing import AbstractSet, Callable, Dict, List, Mapping, NamedTuple, Optional, Set, Sequence, Tuple, Type, Union

import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib import transforms
import mpl_toolkits.axes_grid1.inset_locator as inset_locator
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics

In [None]:
def set_matplotib_settings():
  sns.set_palette('deep')
  sns.set_style(
      'ticks',
      {
          'axes.grid': True,
          'font.family': ['Helvetica'],
          'text.usetex': True,
          'legend.frameon': False,
      },
  )
  rcParams['savefig.dpi'] = 300
  rcParams['savefig.transparent'] = False
  rcParams['font.size'] = 7


set_matplotib_settings()

## Bootstrapping and plotting utilities

In [None]:
# Constants denoting the expected case and control values for binary encodings.
BINARY_LABEL_CONTROL = 0
BINARY_LABEL_CASE = 1

# Represents a numpy array of indices for a single bootstrap sample.
IndexSample = np.ndarray


class CurveType(enum.Enum):
  """Denotes the type of a performance curve."""

  ROC = 'roc'
  PR = 'precision_recall'


class CurveFnResult(NamedTuple):
  """Represents a single performance curve sample.

  Attributes:
    curve_type: The curve's type.
    value_array_x: The curve's x coordinates.
    value_array_y: The curve's y coordinates.
    threshold_array: The thresholds corresponding to each x-y coordinate.
  """

  curve_type: CurveType
  value_array_x: np.ndarray
  value_array_y: np.ndarray
  threshold_array: np.ndarray


class CurveBootstrapResult(NamedTuple):
  """Represents a bootstrapped curve result.

  Attributes:
    curve_type: The curve's type.
    mean: The mean of the curve's y-coordinate values.
    stddev: The standard deviation of the curve's y coordinate values.
    num_samples: The number of bootstrap curve samples.
    ci_level: The confidence level at which the CI is calculated (e.g., 95).
    ci_lower: The lower limit of the `level` confidence interval.
    ci_upper: The upper limit of the `level` confidence interval.
    interp: The interpolated x-coordinates.
  """

  curve_type: CurveType
  mean: np.ndarray
  stddev: np.ndarray
  num_samples: int
  ci_level: float
  ci_upper: np.ndarray
  ci_lower: np.ndarray
  interp: np.ndarray


# Denotes a bootstrappable function used to compute performance curves. Returns
# a tuple containing two coordinate arrays and one threshold array.
CurveFn = Callable[[np.ndarray, np.ndarray], CurveFnResult]


class RocCurve(NamedTuple):
  """Container for a ROC curve."""

  tpr: np.ndarray
  fpr: np.ndarray
  err_lower: Optional[np.ndarray] = None
  err_upper: Optional[np.ndarray] = None


class PrecisionRecallCurve(NamedTuple):
  """Container for a precision-recall curve."""

  precision: np.ndarray
  recall: np.ndarray
  err_lower: Optional[np.ndarray] = None
  err_upper: Optional[np.ndarray] = None


class KmCurve(NamedTuple):
  """Container for a Kaplan-meier curve."""

  time: np.ndarray
  prob: np.ndarray
  std_err: np.ndarray
  err_lower: np.ndarray
  err_upper: np.ndarray


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class NamedArray:
  """Represents a named numpy array.

  Attributes:
    name: The array name.
    values: A numpy array.
  """

  name: str
  values: np.ndarray

  def __post_init__(self):
    if not self.name:
      raise ValueError('`name` must be specified.')

  def __len__(self) -> int:
    return len(self.values)

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.name})'


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Label(NamedArray):
  """Represents a named numpy array of ground truth label targets.

  Attributes:
    name: The label name.
    values: A numpy array containing ground truth label targets.
  """


@dataclasses.dataclass(eq=False, order=False, frozen=True)
class Prediction(NamedArray):
  """Represents a named numpy array of target predictions.

  Attributes:
    model_name: The name of the model that generated the predictions.
    name: The name of the predictions (e.g., the prediction column).
    values: A numpy array containing model predictions.
  """

  model_name: str

  def __post_init__(self):
    super().__post_init__()
    if not self.model_name:
      raise ValueError('`model_name` must be specified.')

  def __str__(self) -> str:
    return f'{self.__class__.__name__}({self.model_name}.{self.name})'


def bootstrap_result_to_roc(
    result: CurveBootstrapResult,
) -> RocCurve:
  """Converts a bootstrapped curve result to a `RocCurve`."""
  roc_curve = RocCurve(
      tpr=result.mean,
      fpr=result.interp,
      err_lower=result.ci_lower,
      err_upper=result.ci_upper,
  )
  return roc_curve


def bootstrap_result_to_pr(
    result: CurveBootstrapResult,
) -> PrecisionRecallCurve:
  """Converts a bootstrapped curve result to a `PrecisionRecallCurve`."""
  pr_curve = PrecisionRecallCurve(
      precision=result.mean,
      recall=result.interp,
      err_lower=result.ci_lower,
      err_upper=result.ci_upper,
  )
  return pr_curve


def df_to_km_curves(km_data_filepath: pathlib.Path) -> Dict[str, KmCurve]:
  """Converts a Kaplan-meier TSV to a mapping of identifiers to `KmCurve`s."""
  with open(str(km_data_filepath), mode='r') as f:
    km_df = pd.read_csv(f, sep='\t')
  km_df_cols = set(km_df.columns)
  expected_cols = {'group', 'time', 'prob', 'se', 'lower', 'upper'}
  col_diff = km_df_cols.symmetric_difference(expected_cols)
  if col_diff:
    raise ValueError(f'Unexpected KM data columns: {col_diff}')

  km_curves = {}
  groups = km_df['group'].unique()
  for group in groups:
    group_df = km_df[km_df['group'] == group].copy()
    km_curves[group] = KmCurve(
        time=group_df.time,
        prob=group_df.prob,
        std_err=group_df.se,
        err_lower=group_df.lower,
        err_upper=group_df.upper,
    )

  return km_curves


def l2_distance(
    point_a: Tuple[float, float],
    point_b: Tuple[float, float],
) -> float:
  """Computes the L2 distance between two points."""
  return np.linalg.norm(np.array(point_a) - np.array(point_b))


def plot_roc_curves(
    curves: Mapping[str, RocCurve],
    label_overrides: Optional[Mapping[str, str]] = None,
    color_overrides: Optional[Mapping[str, str]] = None,
    line_width: float = 2,
    xlabel: str = 'False Positive Rate',
    ylabel: str = 'True Positive Rate',
    title: str = '',
    plot_x_eq_y: bool = True,
    plot_marker: bool = True,
    plot_legend: bool = True,
    inset_xlim: Optional[Tuple[int, int]] = None,
    inset_ylim: Optional[Tuple[int, int]] = None,
    inset_zoom: float = 2,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots ROC curves on the given axis.

  Args:
    curves: A mapping of curve identifiers to `RocCurve` values.
    label_overrides: An optional mapping of curve identifiers to label overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    color_overrides: An optional mapping of curve identifiers to color overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    line_width: A float denoting the line width of curves.
    xlabel: The label for the x axis.
    ylabel: The label for the y axis.
    title: The subplot's title.
    plot_x_eq_y: Whether to plot an x=y line.
    plot_marker: Whether to plot a marker on the point closest to `[0, 1]` in
      each curve.
    plot_legend: Whether to plot a legend on the axis.
    inset_xlim: An optional tuple denoting the inset's x limits; if not
      specified, not insets are added to the axis.
    inset_ylim: An optional tuple denoting the inset's y limits; if not
      specified, not insets are added to the axis.
    inset_zoom: The inset's zoom level.
    ax: An optional axis on which to plot the curves. If not specified, a new
      axis is created.

  Returns:
    The axis on which curves were plotted.
  """
  if label_overrides is None:
    label_overrides = {}
  if color_overrides is None:
    color_overrides = {}

  if ax is None:
    ax = plt.axes()

  plot_inset = inset_zoom > 0 and inset_xlim and inset_ylim
  if plot_inset:
    ax_inset = inset_locator.zoomed_inset_axes(
        ax,
        inset_zoom,
        loc='lower right',
    )
    ax_inset.set_xlim(*inset_xlim)
    ax_inset.set_ylim(*inset_ylim)
    inset_locator.mark_inset(
        ax,
        ax_inset,
        loc1=1,
        loc2=3,
        fc='none',
        ec='0.5',
        linestyle='--',
    )
  else:
    ax_inset = None

  for curve_name, curve in curves.items():
    curve_label = label_overrides.get(curve_name, curve_name)
    curve_color = color_overrides.get(curve_name, None)
    curve_plot = ax.plot(
        curve.fpr,
        curve.tpr,
        color=curve_color,
        lw=line_width,
        label=curve_label,
    )

    # If no curve color was specified, use the color set by matplotlib.
    assert len(curve_plot) == 1
    curve_color = curve_color if curve_color else curve_plot[0].get_color()

    # If specified, replot on the inset.
    if ax_inset:
      ax_inset.plot(
          curve.fpr,
          curve.tpr,
          color=curve_color,
          lw=line_width,
          label=curve_label,
      )

    if curve.err_lower is not None and curve.err_upper is not None:
      ax.fill_between(
          curve.fpr,
          curve.err_lower,
          curve.err_upper,
          color=curve_color,
          alpha=0.2,
      )
      if ax_inset:
        ax_inset.fill_between(
            curve.fpr,
            curve.err_lower,
            curve.err_upper,
            color=curve_color,
            alpha=0.2,
        )

    # If specified, plot a marker on the point closest to `[0, 1]`.
    if plot_marker:
      points = list(zip(curve.fpr, curve.tpr))
      distances = [l2_distance((0, 1), p) for p in points]
      closest_point = points[np.argmin(distances)]
      ax.plot(
          closest_point[0],
          closest_point[1],
          marker='o',
          color=curve_color,
          markersize=8,
      )
      if ax_inset:
        ax_inset.plot(
            closest_point[0],
            closest_point[1],
            marker='o',
            color=curve_color,
            markersize=8,
        )

  if plot_x_eq_y:
    ax.plot([-0.05, 1.05], [-0.05, 1.05], color='gray', lw=1, linestyle='--')

  ax.set_title(title)
  ax.set_xlabel(xlabel)
  ax.set_ylabel(ylabel)
  ax.set_xlim([-0.05, 1.05])
  ax.set_ylim([-0.05, 1.05])
  if plot_legend:
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
  if ax_inset:
    ax_inset.set_xticklabels([])
    ax_inset.set_yticklabels([])

  return ax


def plot_pr_curves(
    curves: Mapping[str, PrecisionRecallCurve],
    label_overrides: Optional[Mapping[str, str]] = None,
    color_overrides: Optional[Mapping[str, str]] = None,
    line_width: float = 2,
    xlabel: str = 'Recall',
    ylabel: str = 'Precision',
    title: str = '',
    plot_marker: bool = True,
    plot_legend: bool = True,
    inset_xlim: Optional[Tuple[int, int]] = None,
    inset_ylim: Optional[Tuple[int, int]] = None,
    inset_zoom: float = 2,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots precision-recall curves on the given axis.

  Args:
    curves: A mapping of curve identifiers to `PrecisionRecallCurve` values.
    label_overrides: An optional mapping of curve identifiers to label overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    color_overrides: An optional mapping of curve identifiers to color overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    line_width: A float denoting the line width of curves.
    xlabel: The label for the x axis.
    ylabel: The label for the y axis.
    title: The subplot's title.
    plot_marker: Whether to plot a marker on the point closest to `[0, 1]` in
      each curve.
    plot_legend: Whether to plot a legend on the axis.
    inset_xlim: An optional tuple denoting the inset's x limits; if not
      specified, not insets are added to the axis.
    inset_ylim: An optional tuple denoting the inset's y limits; if not
      specified, not insets are added to the axis.
    inset_zoom: The inset's zoom level.
    ax: An optional axis on which to plot the curves. If not specified, a new
      axis is created.

  Returns:
    The axis on which curves were plotted.
  """
  if label_overrides is None:
    label_overrides = {}
  if color_overrides is None:
    color_overrides = {}

  if ax is None:
    ax = plt.axes()

  plot_inset = inset_zoom > 0 and inset_xlim and inset_ylim
  if plot_inset:
    ax_inset = inset_locator.zoomed_inset_axes(
        ax,
        inset_zoom,
        loc='upper right',
    )
    ax_inset.set_xlim(*inset_xlim)
    ax_inset.set_ylim(*inset_ylim)
    inset_locator.mark_inset(
        ax,
        ax_inset,
        loc1=2,
        loc2=4,
        fc='none',
        ec='0.5',
        linestyle='--',
    )
  else:
    ax_inset = None

  for curve_name, curve in curves.items():
    curve_label = label_overrides.get(curve_name, curve_name)
    curve_color = color_overrides.get(curve_name, None)
    curve_plot = ax.plot(
        curve.recall,
        curve.precision,
        color=curve_color,
        lw=line_width,
        label=curve_label,
    )

    # If no curve color was specified, use the color set by matplotlib.
    assert len(curve_plot) == 1
    curve_color = curve_color if curve_color else curve_plot[0].get_color()

    # If specified, replot on the inset.
    if ax_inset:
      ax_inset.plot(
          curve.recall,
          curve.precision,
          color=curve_color,
          lw=line_width,
          label=curve_label,
      )

    if curve.err_lower is not None and curve.err_upper is not None:
      ax.fill_between(
          curve.recall,
          curve.err_lower,
          curve.err_upper,
          color=curve_color,
          alpha=0.2,
      )
      if ax_inset:
        ax_inset.fill_between(
            curve.recall,
            curve.err_lower,
            curve.err_upper,
            color=curve_color,
            alpha=0.2,
        )

    # If specified, plot a marker on the point closest to `[0, 1]`.
    if plot_marker:
      points = list(zip(curve.recall, curve.precision))
      distances = [l2_distance((1, 1), p) for p in points]
      closest_point = points[np.argmin(distances)]
      ax.plot(
          closest_point[0],
          closest_point[1],
          marker='o',
          color=curve_color,
          markersize=8,
      )
      if ax_inset:
        ax_inset.plot(
            closest_point[0],
            closest_point[1],
            marker='o',
            color=curve_color,
            markersize=8,
        )

  ax.set_title(title)
  ax.set_xlabel(xlabel)
  ax.set_ylabel(ylabel)
  ax.set_xlim([-0.05, 1.05])
  ax.set_ylim([-0.05, 1.05])
  if plot_legend:
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
  if ax_inset:
    ax_inset.set_xticklabels([])
    ax_inset.set_yticklabels([])

  return ax


def plot_km_curves(
    curves: Mapping[str, KmCurve],
    label_overrides: Optional[Mapping[str, str]] = None,
    color_overrides: Optional[Mapping[str, str]] = None,
    line_width: float = 2,
    xlabel: str = 'Time (days)',
    ylabel: str = 'Survival probability',
    title: str = '',
    plot_legend: bool = True,
    legend_title: Optional[str] = None,
    ax: Optional[plt.Axes] = None,
):
  """Plots Kaplan-meier curves on the given axis.

  Args:
    curves: A mapping of curve identifiers to `KmCurve` values.
    label_overrides: An optional mapping of curve identifiers to label overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    color_overrides: An optional mapping of curve identifiers to color overrides
      used when plotting the given curve. The curve identifier must match a
      curve identifier in `curves`.
    line_width: A float denoting the line width of curves.
    xlabel: The label for the x axis.
    ylabel: The label for the y axis.
    title: The subplot's title.
    plot_legend: Whether to plot a legend on the axis.
    legend_title: An optional title for the legend.
    ax: An optional axis on which to plot the curves. If not specified, a new
      axis is created.

  Returns:
    The axis on which curves were plotted.
  """
  if label_overrides is None:
    label_overrides = {}
  if color_overrides is None:
    color_overrides = {}

  if ax is None:
    ax = plt.axes()

  for curve_name, curve in curves.items():
    curve_label = label_overrides.get(curve_name, curve_name)
    curve_color = color_overrides.get(curve_name, None)
    curve_plot = ax.plot(
        curve.time,
        curve.prob,
        color=curve_color,
        lw=line_width,
        label=curve_label,
    )

    # If no curve color was specified, use the color set by matplotlib.
    assert len(curve_plot) == 1
    curve_color = curve_color if curve_color else curve_plot[0].get_color()

    if curve.err_lower is not None and curve.err_upper is not None:
      ax.fill_between(
          curve.time,
          curve.err_lower,
          curve.err_upper,
          color=curve_color,
          alpha=0.2,
      )

  ax.set_title(title)
  ax.set_xlabel(xlabel)
  ax.set_ylabel(ylabel)

  if plot_legend:
    if legend_title is None:
      # If no legend title is given, just plot base labels.
      ax.legend(
          bbox_to_anchor=(0, 1.02, 1, 1),
          mode='expand',
          loc='lower center',
          borderaxespad=0,
          frameon=False,
          ncol=len(curves),
      )
    else:
      # matplotlib doesn't give us an easy way to include a legend title inline
      # with markers, so we create an empty placeholder.
      title_handle = plt.plot([], marker='', ls='')[0]
      handles, labels = ax.get_legend_handles_labels()
      ax.legend(
          bbox_to_anchor=(-0.03, 1.03, 1, 1),
          mode='expand',
          loc='lower center',
          borderaxespad=0,
          frameon=False,
          labels=[legend_title] + labels,
          handles=[title_handle] + handles,
          ncol=len(curves) + 1,
      )

  return ax


def bs_curves_to_fig_curves(
    label_to_type_to_id_to_bs: Mapping[
        str,
        Mapping[
            CurveType,
            Mapping[
                str,
                CurveBootstrapResult,
            ],
        ],
    ],
) -> Mapping[
    str,
    Mapping[
        Union[Type[RocCurve], Type[PrecisionRecallCurve]],
        Mapping[
            str,
            Union[
                RocCurve,
                PrecisionRecallCurve,
            ],
        ],
    ],
]:
  """Converts a mapping of bootstrap curves to the figure curve equivalent."""
  label_to_type_to_id_to_curve = collections.defaultdict(
      lambda: collections.defaultdict(dict)
  )
  for label_col, curve_type_to_bs_result in label_to_type_to_id_to_bs.items():
    for curve_type, model_id_to_bs_result in curve_type_to_bs_result.items():
      for model_id, bs_result in model_id_to_bs_result.items():
        if curve_type == CurveType.ROC:
          roc_curve = bootstrap_result_to_roc(bs_result)
          label_to_type_to_id_to_curve[label_col][RocCurve][
              model_id
          ] = roc_curve
        elif curve_type == CurveType.PR:
          pr_curve = bootstrap_result_to_pr(bs_result)
          label_to_type_to_id_to_curve[label_col][PrecisionRecallCurve][
              model_id
          ] = pr_curve
        else:
          raise NotImplementedError(curve_type)
  return label_to_type_to_id_to_curve


def is_valid_binary_label(array: np.ndarray) -> bool:
  """Whether `array` is a "valid" binary label array for bootstrapping.

  We define a valid binary label array as an array that contains only binary
  values, i.e., `{BINARY_LABEL_CONTROL, BINARY_LABEL_CASE}`, and contains at
  least one value from each class.

  Args:
    array: A numpy array.

  Returns:
    Whether `array` is a "valid" binary label array.
  """
  is_case_mask = array == BINARY_LABEL_CASE
  is_control_mask = array == BINARY_LABEL_CONTROL
  return (
      np.any(is_case_mask)
      and np.any(is_control_mask)
      and np.all(np.logical_or(is_case_mask, is_control_mask))
  )


def _roc_curve(
    y_true: np.ndarray,
    y_pred: np.ndarray,
) -> CurveFnResult:
  """Wrapper function for computing ROC curves."""
  fpr, tpr, thresholds = sklearn.metrics.roc_curve(y_true, y_pred)
  return CurveFnResult(CurveType.ROC, fpr, tpr, thresholds)


def _precision_recall_curve(
    y_true: np.ndarray,
    y_pred: np.ndarray,
) -> CurveFnResult:
  """Wrapper function for computing precision-recall curves."""
  precision, recall, thresholds = sklearn.metrics.precision_recall_curve(
      y_true,
      y_pred,
  )
  return CurveFnResult(CurveType.PR, precision, recall, thresholds)


def _bootstrap_curve(
    curve_type: CurveType,
    curve_fn: CurveFn,
    label: Label,
    predictions: Sequence[Prediction],
    sample_indices: Sequence[np.ndarray],
    ci_level: float,
    n_interp: int,
) -> Dict[str, CurveBootstrapResult]:
  """Generates bootstrapped `Prediction` curves from `num_bootstrap` samples.

  Args:
    curve_type: The type of curve generated by `curve_fn`.
    curve_fn: A bootstrappable function for generating `curve_type` curves.
    label: The ground truth label targets.
    predictions: A list of target predictions from a set of models.
    sample_indices: An array of bootstrap sample indices. If empty, returns the
      single value computed on the entire dataset for each prediction.
    ci_level: The confidence level at which the CI is calculated (e.g., 95).
    n_interp: The number of interpolation points for x-coordinates in [0, 1].

  Returns:
    A bootstrapped curve computed from labels and predictions.
  """
  if not sample_indices:
    curve_samples = {}
    for prediction in predictions:
      curve_result = curve_fn(label.values, prediction.values)
      curve_samples[prediction.model_name] = [curve_result]
  else:
    curve_samples = {prediction.model_name: [] for prediction in predictions}
    for index in sample_indices:
      sample_true = label.values[index]
      for prediction in predictions:
        curve_result = curve_fn(sample_true, prediction.values[index])
        curve_samples[prediction.model_name].append(curve_result)

  interp_points = np.linspace(0, 1, n_interp)
  bootstrapped_curves = {}
  for model_name, samples in curve_samples.items():
    interp_values = []
    for sample in samples:
      sample_interp = np.interp(
          interp_points,
          sample.value_array_x,
          sample.value_array_y,
      )
      interp_values.append(sample_interp)
    mean_interp = np.mean(interp_values, axis=0)
    stddev_interp = np.std(interp_values, axis=0)
    lower_percentile = (100 - ci_level) / 2
    upper_percentile = 100 - lower_percentile
    percentiles = [lower_percentile, upper_percentile]
    ci_lower, ci_upper = np.percentile(a=interp_values, q=percentiles, axis=0)
    ci_lower = np.maximum(ci_lower, 0)
    ci_upper = np.minimum(ci_upper, 1)
    bootstrapped_curves[model_name] = CurveBootstrapResult(
        curve_type,
        mean_interp,
        stddev_interp,
        len(sample_indices),
        ci_level,
        ci_upper,
        ci_lower,
        interp_points,
    )

  return bootstrapped_curves


def _generate_sample_indices(
    label: Label,
    is_binary: bool,
    num_bootstrap: int,
    seed: int,
) -> List[IndexSample]:
  """Returns a list of `num_bootstrap` randomly sampled bootstrap indices.

  Args:
    label: The ground truth label targets.
    is_binary: Whether to generate valid binary samples; i.e., each index sample
      contains at least one index corresponding to a label from each class.
    num_bootstrap: The number of bootstrap indices to generate.
    seed: The random seed; set prior to generating bootstrap indices.

  Returns:
    A list of `num_bootstrap` bootstrap sample indices.
  """
  rng = np.random.default_rng(seed)
  num_observations = len(label)
  sample_indices = []
  while len(sample_indices) < num_bootstrap:
    index = rng.integers(0, high=num_observations, size=num_observations)
    sample_true = label.values[index]
    # If computing a binary metric, skip indices that result in invalid labels.
    if is_binary and not is_valid_binary_label(sample_true):
      continue
    sample_indices.append(index)
  return sample_indices


def _validate_and_mask(
    label: Label,
    predictions: Sequence[Prediction],
    mask: Optional[np.ndarray] = None,
) -> Tuple[Label, List[Prediction]]:
  """Validates bootstrap argument shape and applies the mask if needed."""
  for prediction in predictions:
    if len(label) != len(prediction):
      raise ValueError('Label and prediction dimensions do not match.')
  if mask is not None and len(mask) != len(label):
    raise ValueError('Label and prediction dimensions do not match mask.')
  if mask is not None:
    label = Label(label.name, label.values[mask])
    predictions = [
        Prediction(
            name=label.name, values=p.values[mask], model_name=p.model_name
        )
        for p in predictions
    ]
  return label, predictions


def bootstrap_curves(
    label: Label,
    predictions: Sequence[Prediction],
    mask: Optional[np.ndarray] = None,
    n_bootstrap: int = 0,
    conf_interval: float = 95,
    n_interp: int = 1000,
    seed: int = 42,
) -> Dict[CurveType, Dict[str, CurveBootstrapResult]]:
  """Returns confidence intervals for ROC and precision-recall curves."""
  label, predictions = _validate_and_mask(label, predictions, mask)
  sample_indices = _generate_sample_indices(  # pylint: disable=protected-access
      label,
      is_binary=True,
      num_bootstrap=n_bootstrap,
      seed=seed,
  )
  curve_samples_kwargs = []
  curve_fns = {CurveType.ROC: _roc_curve, CurveType.PR: _precision_recall_curve}
  for curve_type, curve_fn in curve_fns.items():
    curve_samples_kwargs.append({
        'curve_type': curve_type,
        'curve_fn': curve_fn,
        'label': label,
        'predictions': predictions,
        'sample_indices': sample_indices,
        'ci_level': conf_interval,
        'n_interp': n_interp,
    })

  with concurrent.futures.ThreadPoolExecutor(len(curve_fns)) as executor:
    curve_samples = list(
        executor.map(
            lambda arg_map: _bootstrap_curve(**arg_map), curve_samples_kwargs
        )
    )
  return {
      curve_type: curve_sample
      for curve_type, curve_sample in zip(curve_fns, curve_samples)
  }


def build_ordered_maps(
    labels_df: pd.DataFrame,
    label_cols: Sequence[str],
    model_id_to_preds: Mapping[str, pd.DataFrame],
    model_id_to_pred_col: Mapping[str, str],
    target_eids: Optional[AbstractSet[int]],
) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray]]:
  """Returns maps of labels and model IDs to consistently ordered numpy arrays.

  When computing performance metrics in bootstrapping or thresholding, the label
  and predictions numpy arrays must be consistently ordered so that each index
  corresponds to the same sample. This function joins and returns the target
  labels and model predictions so that the ordering is consistent across arrays.

  Args:
    labels_df: The labels dataframe.
    label_cols: A list of target label columns from `labels_df`.
    model_id_to_preds: A mapping of model IDs to model predictions.
    model_id_to_pred_col: A mapping of model IDs to the target prediction column
      from the model's corresponding value in `model_id_to_preds`.
    target_eids: A set of EIDs to which labels and predictions are restricted.

  Returns:
    A tuple of dictionaries mapping labels and model IDs to consistently ordered
    numpy arrays.
  """
  tmp_label_cols = {label_col: f'{label_col}_label' for label_col in label_cols}
  tmp_pred_cols = {
      model_id: f'{model_id}:{pred_col}'
      for model_id, pred_col in model_id_to_pred_col.items()
  }
  if not set(tmp_label_cols.values()).isdisjoint(set(tmp_pred_cols.values())):
    raise ValueError(f'Column conflict: {tmp_label_cols} v.s. {tmp_pred_cols}')

  # Build and merge eid-column dataframes for each label and model prediction.
  merged_df = labels_df[['eid'] + list(label_cols)]
  merged_df = merged_df.rename(columns=tmp_label_cols)
  for model_id, model_preds in model_id_to_preds.items():
    pred_col = model_id_to_pred_col[model_id]
    pred_df = model_preds[['eid', pred_col]]
    pred_df = pred_df.rename(columns={pred_col: tmp_pred_cols[model_id]})
    merged_df = merged_df.merge(pred_df, on='eid', how='inner')

  # Filter to target eids if specified and ensure we have no NaN values.
  if target_eids:
    merged_df = merged_df[merged_df.eid.isin(target_eids)]
    assert len(merged_df) == len(target_eids)
  for label_col in tmp_label_cols.values():
    print(
        f'Dropping {sum(~merged_df[label_col].notna())} NaNs from "{label_col}"'
    )
    merged_df = merged_df.dropna(subset=[label_col])
  assert (merged_df[list(tmp_label_cols.values())].notna()).all(axis=None)
  assert (merged_df[list(tmp_pred_cols.values())].notna()).all(axis=None)

  label_to_np = {}
  for label_col, tmp_label_col in tmp_label_cols.items():
    label_to_np[label_col] = merged_df[tmp_label_col].to_numpy()
  model_id_to_np = {}
  for model_id, tmp_pred_col in tmp_pred_cols.items():
    model_id_to_np[model_id] = merged_df[tmp_pred_col].to_numpy()
  return label_to_np, model_id_to_np


def _build_bootstrap_inputs(
    labels_df: pd.DataFrame,
    label_col: str,
    model_id_to_preds: Mapping[str, pd.DataFrame],
    model_id_to_pred_col: Mapping[str, str],
    target_eids: Optional[AbstractSet[int]],
) -> Tuple[Label, List[Prediction]]:
  """Returns a bootstrap label and predictions with matching EID order."""
  label_to_np, model_id_to_np = build_ordered_maps(
      labels_df,
      [label_col],
      model_id_to_preds,
      model_id_to_pred_col,
      target_eids,
  )
  label = Label(label_col, label_to_np[label_col])
  preds = []
  for model_id, pred_values in model_id_to_np.items():
    pred_col = model_id_to_pred_col[model_id]
    preds.append(Prediction(pred_col, pred_values, model_id))
  return label, preds


def _validate_bootstrap_model_pred_args(
    labels_df: pd.DataFrame,
    label_col: str,
    model_id_to_preds: Mapping[str, pd.DataFrame],
    model_id_to_pred_col: Mapping[str, str],
    model_id_to_threshold: Optional[Mapping[str, float]],
    target_eids: Optional[AbstractSet[int]],
) -> None:
  """Ensures `bootstrap_model_preds` args match expected structured."""
  # The label dataframe must contain the target column and EIDs.
  labels_columns = set(labels_df.columns)
  if label_col not in labels_columns:
    raise ValueError(f'Unexpected label column: {label_col}')
  if target_eids and not set(target_eids).issubset(set(labels_df.eid)):
    raise ValueError('Labels dataframe missing target EIDs.')

  # All model maps must contain the same set of model IDs.
  model_ids = set(model_id_to_preds)
  if model_ids != set(model_id_to_pred_col):
    raise ValueError(
        f'Mismatched model IDs in map: {model_id_to_preds} '
        f'v.s. {set(model_id_to_pred_col)}'
    )
  if model_id_to_threshold and model_ids != set(model_id_to_threshold):
    raise ValueError(
        f'Mismatched model IDs in map: {model_id_to_preds} '
        f'v.s. {set(model_id_to_threshold)}'
    )

  # All model dataframes must contain the prediction column and target EIDs.
  for model_id, model_preds in model_id_to_preds.items():
    model_eids = set(model_preds.eid)
    pred_col = model_id_to_pred_col[model_id]
    model_cols = model_preds.columns
    if pred_col not in model_cols:
      raise ValueError(
          f'Unexpected "{model_id}" prediction column: '
          f'"{pred_col}" not in {model_cols}'
      )
    if target_eids and not set(target_eids).issubset(model_eids):
      raise ValueError(f'"{model_id}" dataframe missing target EIDs.')


def bootstrap_curves_model_preds(
    labels_df: pd.DataFrame,
    label_col: str,
    model_id_to_preds: Mapping[str, pd.DataFrame],
    model_id_to_pred_col: Mapping[str, str],
    target_eids: Optional[AbstractSet[int]] = None,
    n_bootstrap: int = 100,
    n_interp: int = 1000,
    seed: int = 42,
) -> Dict[CurveType, Dict[str, CurveBootstrapResult]]:
  """Returns a bootstrap curve results dictionary for the label and predictions.

  Args:
    labels_df: The labels dataframe.
    label_col: The target label column from `labels_df`.
    model_id_to_preds: A mapping of model IDs to model predictions.
    model_id_to_pred_col: A mapping of model IDs to the target prediction column
      from the model's corresponding value in `model_id_to_preds`.
    target_eids: A set of EIDs to which labels and predictions are restricted.
    n_bootstrap: The number of bootstrapping iterations.
    n_interp: The number of points to interpolate along curve samples.
    seed: The random seed.

  Returns:
    A dictionary mapping curve types to the corresponding model-curve bootstrap
    results mapping for the label and predictions.
  """
  _validate_bootstrap_model_pred_args(
      labels_df,
      label_col,
      model_id_to_preds,
      model_id_to_pred_col,
      None,
      target_eids,
  )
  label, preds = _build_bootstrap_inputs(
      labels_df,
      label_col,
      model_id_to_preds,
      model_id_to_pred_col,
      target_eids,
  )
  return bootstrap_curves(
      label,
      preds,
      n_bootstrap=n_bootstrap,
      n_interp=n_interp,
      seed=seed,
  )

## Load labels and model predictions

In [None]:
DATA_FILEPATH = '/path/to/data.tsv'
REQUIRED_COLUMNS = {
    'eid',
    'copd_all_srcs_subset',
    'copd_hesin_primary_after',
    'copd_death_primary',
    'blow_ratio_risk',
    'blow_fev1_pct_pred_norm_risk',
    'ml_based_copd',
}

with open(DATA_FILEPATH, mode='r') as f:
  g_data_df = pd.read_csv(f, sep='\t', index_col=None)

assert REQUIRED_COLUMNS.issubset(set(g_data_df.columns))

In [None]:
# Coerce the input dataframe into the format expected by utilities.
LABELS = [
    'copd_all_srcs_subset',
    'copd_hesin_primary_after',
    'copd_death_primary',
]
MODEL_IDS = [
    'blow_ratio_risk',
    'blow_fev1_pct_pred_norm_risk',
    'ml_based_copd',
]
g_label_df = g_data_df[['eid', *LABELS]].copy()
g_model_id_to_preds = {m: g_data_df[['eid', m]].copy() for m in MODEL_IDS}
g_model_id_to_pred_col = {m: m for m in MODEL_IDS}
g_validation_eids = set(g_data_df.eid)

### Load and preprocess Kaplan-meier curves from survival analysis

In [None]:
KM_ML_DATA_FILEPATH = '/path/to/ml_based_copd_km_data.tsv'
km_curves = df_to_km_curves(pathlib.Path(KM_ML_DATA_FILEPATH))

## Build Figure 2

### Bootstrap performance curves

In [None]:
# Bootstrap curve results for each model over all labels.
g_label_to_type_to_result = {
    label_col: bootstrap_curves_model_preds(
        labels_df=g_label_df,
        label_col=label_col,
        model_id_to_preds=g_model_id_to_preds,
        model_id_to_pred_col=g_model_id_to_pred_col,
        target_eids=g_validation_eids,
        n_bootstrap=100,
    )
    for label_col in LABELS
}

# Convert bootstrap result representations to ROC plotting representation.
g_label_to_type_to_curves = bs_curves_to_fig_curves(g_label_to_type_to_result)

### Plot subfigures

In [None]:
g_legend_overrides = {
    'blow_ratio_risk': 'FEV1/FVC Ratio',
    'blow_fev1_pct_pred_norm_risk': 'FEV1 Percent Predicted',
    'ml_based_copd': 'Flow-volume ResNet18',
}

fig, axes = plt.subplots(3, 3, figsize=(9, 9), dpi=300)

# Plot COPD risk curves.
plot_roc_curves(
    g_label_to_type_to_curves[LABELS[0]][RocCurve],
    g_legend_overrides,
    inset_xlim=(0.1, 0.4),
    inset_ylim=(0.6, 0.85),
    ax=axes[0][0],
    plot_legend=False,
    title='COPD status ROC curves',
)
plot_pr_curves(
    g_label_to_type_to_curves[LABELS[0]][PrecisionRecallCurve],
    g_legend_overrides,
    plot_marker=False,
    ax=axes[1][0],
    plot_legend=False,
    title='COPD status PR curves',
)

# Plot COPD hospitalization curves.
plot_roc_curves(
    g_label_to_type_to_curves[LABELS[1]][RocCurve],
    g_legend_overrides,
    inset_xlim=(0.05, 0.35),
    inset_ylim=(0.7, 0.95),
    ax=axes[0][1],
    plot_legend=False,
    title='COPD hospitalization ROC curves',
)
plot_pr_curves(
    g_label_to_type_to_curves[LABELS[1]][PrecisionRecallCurve],
    g_legend_overrides,
    plot_marker=False,
    ax=axes[1][1],
    plot_legend=False,
    title='COPD hospitalization PR curves',
)

# Plot COPD death curves.
plot_roc_curves(
    g_label_to_type_to_curves[LABELS[2]][RocCurve],
    g_legend_overrides,
    inset_xlim=(0.00, 0.30),
    inset_ylim=(0.70, 0.95),
    ax=axes[0][2],
    plot_legend=False,
    title='COPD death ROC curves',
)
plot_pr_curves(
    g_label_to_type_to_curves[LABELS[2]][PrecisionRecallCurve],
    g_legend_overrides,
    plot_marker=False,
    ax=axes[1][2],
    plot_legend=False,
    title='COPD death PR curves',
)

# Plot KM curves along the entirety of the bottom row.
grid_spec = axes[-1][0].get_gridspec()
for ax in axes[-1, 0:]:
  ax.remove()
km_ax = fig.add_subplot(grid_spec[-1, 0:])
plot_km_curves(km_curves, legend_title='Risk group', ax=km_ax)

# Note: We add labels before `tight_layout` so that spacing is preserved.
# Note: The definition of "labeled_axes" is a bit of hack so that we only
# label "major" subplot axes (ie., insets should not be labeled.).
labeled_axes = [ax for ax in fig.get_axes() if ax.get_ylabel()]
for i, ax in enumerate(labeled_axes):
  ax_label = string.ascii_lowercase[i]
  # label physical distance to the left and up scaled by dpi:
  trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
  ax.text(
      0.0,
      1.0,
      ax_label,
      transform=ax.transAxes + trans,
      fontsize='8',
      va='bottom',
      fontfamily='Helvetica',
      weight='bold',
  )

  # Per natgen formatting, remove the spines from named axes (i.e., keep the
  # spine on insets).
  ax.spines[['right', 'top']].set_visible(False)

# Per natgen formatting, remove the grid from all axes.
for ax in fig.get_axes():
  ax.grid(False)

plt.tight_layout()

axes[0, 1].legend(
    loc='lower center',
    bbox_to_anchor=(-0.75, 1.1, 2.5, 1.02),
    borderaxespad=0,
    frameon=False,
    mode='expand',
    ncol=3,
)

In [None]:
fig.savefig('figure_3.pdf', dpi=300, format='pdf', bbox_inches='tight')
%download_file figure_3.pdf