In [None]:
# @title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# Copyright 2022 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Generates Figure 2 from Cosentino et al. Nature Genetics 2023

This notebook assumes that there exists a labels TSV file stored at
`LABELS_FILEPATH`. This TSV must contain the following columns:

-   `eid`: A unique individual identifier.
-   `blow_ratio`: FEV1/FVC ratio.
-   `blow_fev1_pct_predicted`: FEV1 percent predicted.
-   `copd_spiro_gold`: Proxy GOLD 2-4 status.
-   `copd_sr_src`: COPD status sourced from self-reported data.
-   `copd_hesin_src`: COPD status sourced from HESIN data.
-   `copd_gp_src`: COPD status sourced from general practitioner notes data.

In [None]:
import string
from typing import Optional, Sequence, Tuple

import matplotlib.pyplot as plt
from matplotlib import rcParams
from matplotlib import ticker
from matplotlib import transforms
import numpy as np
import pandas as pd
import seaborn as sns
import sklearn.metrics

In [None]:
def set_matplotib_settings():
  sns.set_palette('deep')
  sns.set_style(
      'ticks',
      {
          'axes.grid': False,
          'font.family': ['Helvetica'],
          'text.usetex': True,
          'legend.frameon': False,
      },
  )
  rcParams['savefig.dpi'] = 300
  rcParams['savefig.transparent'] = True
  rcParams['font.size'] = 7


set_matplotib_settings()

In [None]:
# The relative time scale, in seconds, associated with each series step; 0.01
# denotes that volume was sampled at 10 millisecond intervals.
TIME_SCALE = 0.01

# The volume scale applied to series points; series elements are recorded in ML,
# so a scale of 0.001 converts series elements to L.
VOLUME_SCALE = 0.001

# The number of points in the final ML input representation. Blows shorter than
# `MAX_NUM_POINTS` are right-padded to `MAX_NUM_POINTS` using the last value
# while blows londer than `MAX_NUM_POINTS` are truncated to `MAX_NUM_POINTS`.
MAX_NUM_POINTS = 1000

# Min and max values for time (in seconds).
TIME_MIN = 0
TIME_MAX = 10

# Min and max values for volume (in liters).
VOLUME_MIN = 0
VOLUME_MAX = 6

# Min and max values for flow (in liters per second).
FLOW_SCALE = 2
FLOW_MIN = VOLUME_MIN * FLOW_SCALE
FLOW_MAX = VOLUME_MAX * FLOW_SCALE

# Styling parameters for the figures.
BASE_PADDING = 0.1
NUM_Y_TICKS = 4
NUM_X_TICKS = 6

# Colors.
BLUE = '#8ab4f8'
RED = '#f28b82'
YELLOW = '#fdd65d'
GREEN = '#81c995'
GRAY = '#DADCE0'
DARK_GRAY = '#80868B'
BLACK = '#202124'

# The publicly available spirometry demo curve from UKB from field 3066:
# https://biobank.ctsu.ox.ac.uk/crystal/field.cgi?id=3066
UKB_DEMO_SPIRO_3066 = np.asarray([
    0,
    3,
    10,
    25,
    54,
    101,
    169,
    258,
    363,
    478,
    589,
    689,
    785,
    879,
    970,
    1059,
    1147,
    1234,
    1320,
    1403,
    1486,
    1569,
    1650,
    1730,
    1809,
    1888,
    1965,
    2040,
    2116,
    2188,
    2261,
    2331,
    2400,
    2465,
    2532,
    2595,
    2658,
    2720,
    2780,
    2838,
    2894,
    2948,
    3001,
    3052,
    3102,
    3151,
    3197,
    3243,
    3287,
    3329,
    3371,
    3412,
    3451,
    3490,
    3527,
    3564,
    3600,
    3635,
    3670,
    3703,
    3736,
    3769,
    3800,
    3831,
    3861,
    3890,
    3918,
    3947,
    3974,
    4001,
    4028,
    4054,
    4080,
    4105,
    4130,
    4154,
    4179,
    4202,
    4226,
    4249,
    4271,
    4292,
    4312,
    4332,
    4351,
    4371,
    4390,
    4408,
    4426,
    4444,
    4461,
    4478,
    4495,
    4512,
    4528,
    4544,
    4560,
    4575,
    4590,
    4604,
    4619,
    4633,
    4647,
    4661,
    4675,
    4689,
    4703,
    4716,
    4729,
    4742,
    4755,
    4767,
    4779,
    4791,
    4802,
    4812,
    4822,
    4831,
    4840,
    4849,
    4857,
    4866,
    4874,
    4882,
    4890,
    4898,
    4906,
    4914,
    4921,
    4929,
    4936,
    4944,
    4951,
    4958,
    4966,
    4973,
    4980,
    4987,
    4994,
    5000,
    5007,
    5013,
    5020,
    5026,
    5033,
    5039,
    5045,
    5051,
    5057,
    5063,
    5069,
    5075,
    5081,
    5087,
    5092,
    5098,
    5104,
    5109,
    5114,
    5119,
    5125,
    5130,
    5134,
    5139,
    5144,
    5148,
    5153,
    5157,
    5161,
    5166,
    5170,
    5174,
    5178,
    5182,
    5186,
    5190,
    5194,
    5198,
    5202,
    5205,
    5209,
    5213,
    5216,
    5220,
    5223,
    5226,
    5230,
    5233,
    5236,
    5240,
    5243,
    5246,
    5250,
    5253,
    5256,
    5259,
    5262,
    5264,
    5267,
    5270,
    5273,
    5276,
    5279,
    5283,
    5286,
    5289,
    5292,
    5295,
    5298,
    5300,
    5303,
    5306,
    5308,
    5311,
    5314,
    5316,
    5319,
    5321,
    5323,
    5326,
    5328,
    5331,
    5333,
    5335,
    5338,
    5340,
    5343,
    5345,
    5348,
    5350,
    5352,
    5355,
    5357,
    5360,
    5362,
    5365,
    5367,
    5369,
    5372,
    5374,
    5377,
    5379,
    5381,
    5384,
    5386,
    5388,
    5390,
    5391,
    5393,
    5395,
    5397,
    5399,
    5401,
    5403,
    5404,
    5406,
    5408,
    5410,
    5412,
    5413,
    5415,
    5417,
    5419,
    5420,
    5422,
    5424,
    5426,
    5427,
    5429,
    5431,
    5432,
    5434,
    5436,
    5438,
    5439,
    5441,
    5443,
    5444,
    5446,
    5447,
    5449,
    5450,
    5452,
    5453,
    5455,
    5456,
    5457,
    5459,
    5460,
    5461,
    5462,
    5463,
    5464,
    5466,
    5467,
    5468,
    5470,
    5471,
    5473,
    5474,
    5476,
    5477,
    5478,
    5480,
    5481,
    5482,
    5484,
    5485,
    5486,
    5487,
    5489,
    5490,
    5491,
    5492,
    5493,
    5494,
    5496,
    5497,
    5498,
    5499,
    5500,
    5501,
    5502,
    5503,
    5504,
    5505,
    5506,
    5507,
    5508,
    5509,
    5510,
    5510,
    5511,
    5512,
    5513,
    5514,
    5515,
    5515,
    5516,
    5517,
    5519,
    5520,
    5521,
    5523,
    5524,
    5525,
    5527,
    5529,
    5530,
    5532,
    5533,
    5535,
    5536,
    5537,
    5539,
    5540,
    5541,
    5543,
    5544,
    5545,
    5545,
    5546,
    5547,
    5548,
    5549,
    5549,
    5550,
    5551,
    5552,
    5552,
    5553,
    5554,
    5554,
    5555,
    5556,
    5557,
    5557,
    5558,
    5559,
    5560,
    5560,
    5561,
    5562,
    5562,
    5563,
    5564,
    5564,
    5565,
    5565,
    5566,
    5567,
    5567,
    5568,
    5569,
    5570,
    5571,
    5572,
    5573,
    5574,
    5576,
    5577,
    5578,
    5579,
    5580,
    5582,
    5583,
    5584,
    5585,
    5587,
    5588,
    5589,
    5590,
    5591,
    5591,
    5592,
    5593,
    5594,
    5595,
    5596,
    5596,
    5597,
    5598,
    5598,
    5599,
    5600,
    5601,
    5601,
    5602,
    5603,
    5603,
    5604,
    5605,
    5606,
    5606,
    5607,
    5608,
    5608,
    5609,
    5609,
    5609,
    5610,
    5611,
    5611,
    5612,
    5613,
    5613,
    5614,
    5615,
    5616,
    5616,
    5617,
    5618,
    5618,
    5619,
    5620,
    5621,
    5622,
    5623,
    5624,
    5624,
    5625,
    5626,
    5626,
    5627,
    5628,
    5628,
    5629,
    5629,
    5630,
    5630,
    5631,
    5632,
    5632,
    5633,
    5633,
    5634,
    5635,
    5635,
    5636,
    5637,
    5637,
    5638,
    5639,
    5639,
    5640,
    5641,
    5642,
    5642,
    5643,
    5644,
    5645,
    5645,
    5646,
    5647,
    5647,
    5648,
    5649,
    5649,
    5650,
    5651,
    5651,
    5652,
    5652,
    5653,
    5654,
    5654,
    5655,
    5656,
    5656,
    5657,
    5658,
    5658,
    5659,
    5660,
    5660,
    5661,
    5661,
    5662,
    5663,
    5663,
    5664,
    5664,
    5665,
    5665,
    5666,
    5666,
    5667,
    5667,
    5668,
    5668,
    5669,
    5669,
    5670,
    5670,
    5670,
    5671,
    5671,
    5672,
    5672,
    5672,
    5673,
    5673,
    5673,
    5673,
    5674,
    5674,
    5674,
    5675,
    5676,
    5676,
    5677,
    5677,
    5678,
    5678,
    5679,
    5679,
    5680,
    5681,
    5681,
    5682,
    5683,
    5683,
    5684,
    5684,
    5685,
    5686,
    5686,
    5687,
    5687,
    5688,
    5688,
    5688,
    5689,
    5689,
    5690,
    5690,
    5690,
    5691,
    5691,
    5692,
    5692,
    5692,
    5693,
    5693,
    5694,
    5694,
    5694,
    5695,
    5695,
    5695,
    5696,
    5696,
    5696,
    5696,
    5696,
    5696,
    5697,
    5697,
    5698,
    5698,
    5698,
    5699,
    5699,
    5699,
    5699,
    5700,
    5700,
    5700,
    5701,
    5701,
    5702,
    5702,
    5703,
    5703,
    5704,
    5704,
    5705,
    5705,
    5706,
    5706,
    5707,
    5707,
    5708,
    5709,
    5709,
    5710,
    5710,
    5711,
    5711,
    5712,
    5712,
    5712,
    5713,
    5713,
    5713,
    5714,
    5714,
    5714,
    5715,
    5715,
    5716,
    5716,
    5716,
    5717,
    5717,
    5717,
    5718,
    5718,
    5719,
    5719,
    5720,
    5720,
    5721,
    5721,
    5721,
    5722,
    5722,
    5722,
    5723,
    5723,
    5723,
    5723,
    5724,
    5724,
    5724,
    5725,
    5725,
    5725,
    5726,
    5726,
    5726,
    5727,
    5727,
    5728,
    5728,
    5729,
    5729,
    5729,
    5730,
    5730,
    5731,
    5732,
    5732,
    5733,
    5733,
    5734,
    5735,
    5735,
    5735,
    5736,
    5736,
    5736,
    5737,
    5737,
    5737,
    5738,
    5738,
    5738,
    5739,
    5739,
    5739,
    5739,
    5740,
    5740,
    5740,
    5741,
    5741,
    5741,
    5741,
    5741,
    5741,
    5742,
    5742,
    5742,
    5742,
    5742,
    5742,
    5742,
    5742,
    5742,
    5742,
    5741,
    5741,
    5740,
    5740,
    5740,
    5740,
    5739,
    5739,
    5739,
    5739,
    5739,
    5739,
    5740,
    5740,
    5740,
    5741,
    5742,
    5742,
    5743,
    5743,
    5744,
    5745,
    5745,
    5745,
    5746,
    5746,
    5747,
    5747,
    5748,
    5748,
    5748,
    5748,
    5748,
    5748,
    5749,
    5749,
    5749,
    5749,
    5749,
    5749,
    5749,
    5750,
    5750,
    5750,
    5750,
    5750,
    5751,
    5751,
    5751,
    5752,
    5752,
    5753,
    5753,
    5754,
    5754,
    5754,
    5755,
    5755,
    5756,
    5756,
    5756,
    5757,
    5757,
    5757,
    5758,
    5758,
    5758,
    5758,
    5759,
    5759,
    5759,
    5759,
    5759,
    5759,
    5759,
    5759,
    5759,
    5760,
    5760,
    5760,
    5761,
    5761,
    5761,
    5762,
    5762,
    5763,
    5763,
    5763,
    5764,
    5764,
    5764,
    5765,
    5765,
    5766,
    5766,
    5766,
    5767,
    5767,
    5767,
    5767,
    5767,
    5768,
    5768,
    5768,
    5768,
    5769,
    5769,
    5769,
    5770,
    5770,
    5770,
    5770,
    5770,
    5771,
    5771,
    5771,
    5771,
    5771,
    5772,
    5772,
    5772,
    5773,
    5773,
    5773,
    5774,
    5774,
    5774,
    5775,
    5775,
    5775,
    5776,
    5776,
    5777,
    5777,
    5777,
    5778,
    5778,
    5778,
    5778,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5780,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5779,
    5780,
    5780,
    5780,
    5780,
    5781,
    5781,
    5781,
    5782,
    5782,
    5782,
    5783,
    5783,
    5783,
    5784,
    5784,
    5784,
    5785,
    5785,
    5785,
    5785,
    5785,
    5786,
    5786,
    5786,
    5786,
    5786,
    5786,
    5786,
    5787,
    5787,
    5787,
    5788,
    5788,
    5788,
    5789,
    5789,
    5789,
    5790,
    5790,
    5790,
    5791,
    5791,
    5792,
    5792,
    5792,
    5793,
    5793,
    5793,
    5794,
    5794,
    5795,
    5795,
    5795,
    5796,
    5796,
    5796,
    5797,
    5797,
    5798,
    5798,
    5798,
    5798,
    5798,
    5799,
    5799,
    5799,
    5799,
    5800,
    5800,
    5800,
    5801,
    5801,
    5801,
    5801,
    5802,
    5802,
    5802,
    5802,
    5803,
    5803,
    5803,
    5803,
    5803,
    5803,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5804,
    5803,
    5804,
    5804,
    5804,
    5804,
    5804,
    5805,
    5805,
    5805,
    5805,
    5806,
    5806,
    5806,
    5806,
    5806,
    5806,
    5806,
    5806,
    5806,
    5806,
    5807,
    5807,
    5807,
    5807,
    5808,
    5808,
    5809,
    5809,
    5809,
    5810,
    5810,
    5810,
    5811,
    5811,
    5812,
    5812,
    5813,
    5813,
    5813,
    5814,
    5814,
    5815,
    5815,
    5815,
    5815,
    5816,
    5816,
    5816,
    5816,
    5817,
    5817,
    5817,
    5817,
    5817,
    5817,
    5817,
    5818,
    5818,
    5818,
    5818,
    5818,
    5818,
    5818,
    5819,
    5819,
    5819,
    5819,
    5819,
    5819,
    5819,
    5819,
    5819,
    5819,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5819,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5820,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5821,
    5820,
    5820,
    5820,
    5819,
    5819,
    5818,
    5818,
    5818,
    5817,
    5817,
    5817,
    5816,
    5816,
    5816,
    5816,
    5815,
    5815,
    5815,
    5816,
    5816,
    5816,
    5817,
    5817,
    5818,
    5819,
    5819,
    5820,
    5821,
    5822,
    5823,
    5823,
    5824,
    5825,
    5826,
    5827,
    5827,
    5828,
    5828,
    5829,
    5829,
    5829,
    5830,
    5830,
    5831,
    5831,
    5831,
    5831,
    5831,
    5832,
    5831,
    5832,
    5832,
    5832,
    5832,
    5832,
    5832,
    5832,
    5833,
    5833,
    5833,
    5833,
    5833,
    5833,
    5833,
    5834,
    5834,
    5834,
    5834,
    5834,
    5835,
    5835,
    5835,
    5835,
    5835,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5836,
    5835,
    5835,
    5835,
    5835,
    5834,
    5834,
    5834,
    5834,
    5833,
    5833,
    5833,
    5833,
    5833,
    5832,
    5832,
    5832,
    5832,
    5832,
    5832,
    5832,
    5832,
    5831,
])

## Plotting functions and data utilities

In [None]:
def compute_time(max_num_points: int, time_scale: float) -> np.ndarray:
  """Returns a linear array containing `max_num_points` at `time_scale`.

  Args:
    max_num_points: The maximum number of points in a curve.
    time_scale: The scale of the timestep.

  Returns:
    A linear array containing the curve's timesteps.
  """
  return time_scale * np.linspace(
      0,
      max_num_points,
      num=max_num_points,
      endpoint=False,
      dtype=np.float32,
  )


def compute_volume(series: np.ndarray, volume_scale: float) -> np.ndarray:
  """Rescale `series` to a liter-based volume curve."""
  return (series * VOLUME_SCALE).astype(np.float32)


def compute_flow(volume: np.ndarray, time_scale: float) -> np.ndarray:
  """Computes flow for the given `volume` array and `time_scale`.

  Note: `time_scale` is expected to be input volume unit per second.

  Flow is the simple first derivative of volume. Note: This should be run before
  right padding to avoid large negative flow values if zero-padded.
  """
  return np.concatenate(([0.0], np.diff(volume) / time_scale))


def derive_base_curves(
    series: np.ndarray,
    max_num_points: int = MAX_NUM_POINTS,
    volume_scale: float = VOLUME_SCALE,
    time_scale: float = TIME_SCALE,
) -> dict[str, np.ndarray]:
  """Converts the provided spirometry series to time, volume, and flow curves.

  Note: Since the demo base curve is larger than `max_num_points`, we omit
  the padding operation in this notebook demonstration for simplicity.
  """
  assert len(series) >= max_num_points
  series = series[:max_num_points]
  record = {
      'time': compute_time(max_num_points, time_scale),
      'volume': compute_volume(series, volume_scale),
  }
  record['flow'] = compute_flow(record['volume'], time_scale)
  return record


def _plot_spirogram(
    y: np.ndarray,
    y_min: float,
    y_max: float,
    y_label: str,
    y_units: Optional[str],
    num_y_ticks: int,
    x: np.ndarray,
    x_min: float,
    x_max: float,
    x_label: str,
    x_units: Optional[str],
    num_x_ticks: int,
    padding: float,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots a y-x spirogram on the given axis.

  Args:
    y: Values to plot on the y axis.
    y_min: The minimum y value to consider; used for setting y limits.
    y_max: The maximum y value to consider; used for setting y limits.
    y_label: The label applied to the y axis.
    y_units: Describes y axis units; added to the axis label if specified.
    num_y_ticks: The maximum number of ticks on the y axis.
    x: Values to plot on the x axis.
    x_min: The minimum x value to consider; used for setting x limits.
    x_max: The maximum x value to consider; used for setting x limits.
    x_label: The label applied to the x axis.
    x_units: Describes x axis units; added to the axis label if specified.
    num_x_ticks: The maximum number of ticks on the x axis.
    padding: Denotes the amount of padding as a fraction of axis size.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the spirogram was plotted.
  """
  if ax is None:
    ax = plt.axes()

  # Set labels.
  ax.set_xlabel(f'{x_label} ({x_units})' if x_units else x_label)
  ax.set_ylabel(f'{y_label} ({y_units})' if y_units else y_label)

  # General plot adjustments.
  ax.xaxis.set_major_locator(ticker.MaxNLocator(num_x_ticks))
  ax.yaxis.set_major_locator(ticker.MaxNLocator(num_y_ticks))

  # Scale axes and ticks so that the grid is aligned for all units.
  x_padding = (x_max - x_min) * padding
  y_padding = (y_max - y_min) * padding
  ax.set_xlim((x_min - x_padding, x_max + x_padding))
  ax.set_ylim((y_min - y_padding, y_max + y_padding))

  sns.lineplot(
      x=x,
      y=y,
      label=f'{y_label} vs {x_label}',
      color=GREEN,
      legend=False,
      ax=ax,
  )

  return ax


def plot_vt_spirogram(
    volume: np.ndarray,
    time: np.ndarray,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots a volume-time spirogram on the given axis.

  Args:
    volume: A series of values denoting spirogram volume.
    time: A series of values denoting the time at which volume was sampled.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the spirogram was plotted.
  """
  return _plot_spirogram(
      y=volume,
      y_min=VOLUME_MIN,
      y_max=VOLUME_MAX,
      y_label='Volume',
      y_units='L',
      num_y_ticks=NUM_Y_TICKS,
      x=time,
      x_min=TIME_MIN,
      x_max=TIME_MAX,
      x_label='Time',
      x_units='s',
      num_x_ticks=NUM_X_TICKS,
      padding=BASE_PADDING,
      ax=ax,
  )


def plot_ft_spirogram(
    flow: np.ndarray,
    time: np.ndarray,
    ax: Optional[plt.Axes] = None,
):
  """Plots a flow-time spirogram on the given axis.

  Args:
    flow: A series of values denoting spirogram flow.
    time: A series of values denoting the time at which flow was sampled.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the spirogram was plotted.
  """
  return _plot_spirogram(
      y=flow,
      y_min=FLOW_MIN,
      y_max=FLOW_MAX,
      y_label='Flow',
      y_units='L/s',
      num_y_ticks=NUM_Y_TICKS,
      x=time,
      x_min=TIME_MIN,
      x_max=TIME_MAX,
      x_label='Time',
      x_units='s',
      num_x_ticks=NUM_X_TICKS,
      padding=BASE_PADDING,
      ax=ax,
  )


def plot_fv_spirogram(
    flow: np.ndarray,
    volume: np.ndarray,
    ax: Optional[plt.Axes] = None,
):
  """Plots a flow-volume spirogram on the given axis.

  Args:
    flow: A series of values denoting spirogram flow.
    volume: A series of values denoting spirogram volume.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the spirogram was plotted.
  """
  return _plot_spirogram(
      y=flow,
      y_min=FLOW_MIN,
      y_max=FLOW_MAX,
      y_label='Flow',
      y_units='L/s',
      num_y_ticks=NUM_Y_TICKS,
      x=volume,
      x_min=VOLUME_MIN,
      x_max=VOLUME_MAX,
      x_label='Volume',
      x_units='L',
      num_x_ticks=NUM_X_TICKS,
      padding=BASE_PADDING,
      ax=ax,
  )


def build_source_df(labels_df: pd.DataFrame) -> pd.DataFrame:
  """Returns a dataframe of copd by data source.

  A binary "healthy" column is appended to the dataframe for individuals that
  do not have a copd diagnosis across sources. Only contains individuals with
  valid blow data.

  Args:
    labels_df: A labels dataframe from which copd source and spirometry metrics
      are fetched.

  Returns:
    A dataframe containing binary copd source labels for self-report, GP, HESIN,
    or healthy as well as spirometry metrics for individuals with valid blows.
  """
  source_df = labels_df.copy()
  source_df = source_df.dropna(subset=['blow_ratio'])
  source_df = source_df.loc[
      :,
      [
          'eid',
          'copd_gp_src',
          'copd_hesin_src',
          'copd_sr_src',
          'blow_ratio',
          'blow_fev1_pct_predicted',
      ],
  ]
  source_df['healthy'] = (
      ~(
          source_df[[
              'copd_gp_src',
              'copd_hesin_src',
              'copd_sr_src',
          ]]
          == 1
      ).any(axis=1)
  ).astype(float)
  return source_df


def plot_source_ecdf(
    source_df: pd.DataFrame,
    spiro_col: str,
    spiro_label: str,
    vertical_line: float,
    plot_legend: bool = False,
    padding: int = BASE_PADDING,
    color_order: Sequence[str] = (BLUE, RED, YELLOW, GREEN),
    source_order: Sequence[Tuple[str, str]] = (
        ('copd_gp_src', 'GP COPD'),
        ('copd_hesin_src', 'HESIN COPD'),
        ('copd_sr_src', 'Self-report COPD'),
        ('healthy', 'Healthy'),
    ),
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots an ECDF of the spirometry label grouped by COPD source.

  Args:
    source_df: A dataframe containing COPD source and spirometry measures.
    spiro_col: The spirometry measure column.
    spiro_label: The spirometry column's plot label.
    vertical_line: An optional x coordinate at which to plot a vertical line.
    plot_legend: Whether to include the legend to the right of the ECDF.
    padding: Denotes the amount of padding as a fraction of axis size.
    color_order: A sequence of colors corresponding to `status_order`.
    source_order: The order in which to plot disease statuses and their labels.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the ECDF was plotted.

  Raises:
    ValueError: If the dataframe does not contain the desired columns or if
      there is a mismatch between color order and status order.
  """
  if not set([spiro_col, *[c[0] for c in source_order]]).issubset(
      set(source_df.columns)
  ):
    raise ValueError('`source_df` missing desired columns.')
  if len(color_order) != len(source_order):
    raise ValueError('`color_order` and `source_order` are mismatched.')

  if ax is None:
    ax = plt.axes()
  for (column, label), color in zip(source_order, color_order):
    sns.ecdfplot(
        source_df[source_df[column] == 1],
        x=spiro_col,
        label=label,
        color=color,
        ax=ax,
    )
  ax.set_xlabel(spiro_label)
  if plot_legend:
    legend = ax.legend(
        bbox_to_anchor=(1.2, 0.5),
        loc='center left',
        borderaxespad=0,
        frameon=False,
    )
    legend._legend_box.align = 'left'
  ax.axvline(
      vertical_line,
      0,
      1.0,
      color=DARK_GRAY,
      linestyle='--',
      dashes=(5, 1),
  )
  # Scale axes and ticks so that the grid is aligned for all units.
  x_min, x_max = ax.get_xlim()
  x_padding = (x_max - x_min) * padding
  ax.set_xlim((x_min - x_padding, x_max + x_padding))
  y_min, y_max = ax.get_ylim()
  y_padding = (y_max - y_min) * padding
  ax.set_ylim((y_min - y_padding, y_max + y_padding))
  return ax


def plot_label_contigency(
    labels_df: pd.DataFrame,
    col_a: str,
    label_a: str,
    col_b: str,
    label_b: str,
    plot_cbar: bool = False,
    ax: Optional[plt.Axes] = None,
) -> plt.Axes:
  """Plots a contigency table for individuals where both labels are present.

  Args:
    labels_df: A dataframe containing the target labels.
    col_a: The column containing values for the vertical axis.
    label_a: `col_a`'s label value.
    col_b: The column containing values for the horizontal axis.
    label_b: `col_b`'s label value.
    plot_cbar: Whether to include the cbar to the right of the contigency table.
    ax: An optional axis on which the spirogram is plotted; if not specified, a
      new axis is created.

  Returns:
    The axis on which the contigency table was plotted.
  """
  if ax is None:
    ax = plt.axes()

  # Get agreement values.
  contingency_df = labels_df[['eid', col_a, col_b]].dropna().copy()
  label_a_values = contingency_df[col_a].values
  label_b_values = contingency_df[col_b].values
  cf_matrix = sklearn.metrics.confusion_matrix(
      label_a_values,
      label_b_values,
  )
  group_counts = [f'{value:g}' for value in cf_matrix.flatten()]
  group_percentages = [
      f'{value:0.2%}' for value in cf_matrix.flatten() / np.sum(cf_matrix)
  ]
  labels = [f'{v2}\n({v3})' for v2, v3 in zip(group_counts, group_percentages)]
  labels = np.asarray(labels).reshape(2, 2)
  sns.heatmap(
      cf_matrix,
      square=True,
      cmap=sns.light_palette(BLACK, as_cmap=True),
      fmt='',
      annot=labels,
      xticklabels=['Control', 'Case'],
      yticklabels=['Control', 'Case'],
      cbar=plot_cbar,
      ax=ax,
  )
  ax.set_yticklabels(labels=ax.get_yticklabels(), va='center')
  ax.set_ylabel(label_a)
  ax.set_xlabel(label_b)
  return ax

## Load data

We first load the blow corresponding to the demo value in UKB field 3066 and our
labels dataframe.

In [None]:
spirometry_base_curves = derive_base_curves(UKB_DEMO_SPIRO_3066)

We then load the labels TSV as a dataframe. The labels TSV must contain the set
of required columns.

In [None]:
LABELS_FILEPATH = '/path/to/labels.tsv'
REQUIRED_COLUMNS = {
    # A unique individual identifier.
    'eid',
    # FEV1/FVC ratio.
    'blow_ratio',
    # FEV1 percent predicted.
    'blow_fev1_pct_predicted',
    # Proxy GOLD 2-4 status.
    'copd_spiro_gold',
    # COPD status sourced from self-reported data.
    'copd_sr_src',
    # COPD status sourced from HESIN data.
    'copd_hesin_src',
    # COPD status sourced from general practitioner notes data.
    'copd_gp_src',
}


with open(LABELS_FILEPATH, mode='r') as f:
  labels_df = pd.read_csv(f, sep='\t', index_col=None)

assert REQUIRED_COLUMNS.issubset(set(labels_df.columns))
source_df = build_source_df(labels_df)

## Plot Figure 2

In [None]:
fig, axes = plt.subplots(3, 3, figsize=(9, 9), dpi=300)

# Plot spirograms.
plot_vt_spirogram(
    time=spirometry_base_curves['time'],
    volume=spirometry_base_curves['volume'],
    ax=axes[0][0],
)
plot_ft_spirogram(
    time=spirometry_base_curves['time'],
    flow=spirometry_base_curves['flow'],
    ax=axes[0][1],
)
plot_fv_spirogram(
    flow=spirometry_base_curves['flow'],
    volume=spirometry_base_curves['volume'],
    ax=axes[0][2],
)

# Plot ECDFs.
axes[1][2].remove()
plot_source_ecdf(
    source_df,
    'blow_ratio',
    '$\mathregular{FEV_1}$/FVC Ratio',
    0.7,
    ax=axes[1][0],
)
plot_source_ecdf(
    source_df,
    'blow_fev1_pct_predicted',
    '$\mathregular{FEV_1}$ Percent Predicted',
    80,
    ax=axes[1][1],
)

# Plot contigency tables.
plot_label_contigency(
    labels_df,
    col_a='copd_spiro_gold',
    label_a='GOLD Labels',
    col_b='copd_sr_src',
    label_b='Self-report COPD',
    ax=axes[2][0],
)
plot_label_contigency(
    labels_df,
    col_a='copd_spiro_gold',
    label_a='GOLD Labels',
    col_b='copd_hesin_src',
    label_b='HESIN COPD',
    ax=axes[2][1],
)
plot_label_contigency(
    labels_df,
    col_a='copd_spiro_gold',
    label_a='GOLD Labels',
    col_b='copd_gp_src',
    label_b='GP COPD',
    ax=axes[2][2],
)

# Note: We add labels before `tight_layout` so that spacing is preserved.
labeled_axes = [*fig.get_axes()]
for i, ax in enumerate(labeled_axes):
  ax_label = string.ascii_lowercase[i]
  trans = transforms.ScaledTranslation(-20 / 72, 7 / 72, fig.dpi_scale_trans)
  ax.text(
      0.0,
      1.0,
      ax_label,
      transform=ax.transAxes + trans,
      fontsize='8',
      va='bottom',
      fontfamily='Helvetica',
      weight='bold',
  )
  ax.spines[['right', 'top']].set_visible(False)


plt.tight_layout()

# Note: The legend needs to be set after `tight_layout()` is called to prevent
# the squishing of plot widths.
legend = axes[1][1].legend(
    bbox_to_anchor=(1.2, 0.5),
    loc='center left',
    borderaxespad=0,
    frameon=False,
)

In [None]:
fig.savefig('figure_2.pdf', dpi=300)
%download_file figure_2.pdf